From f43df8473bfad06050a61205b819855f63fb11be Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Wed, 18 Dec 2024 11:33:16 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20random=20sample=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E7=A1=AE=E5=AE=9A=E7=9A=84=E5=88=86=E5=B8=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/main.yaml | 1 + operatorspy/tests/random_sample.py | 36 ++++++++++++++---------------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 65731dd1..84108c51 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -23,6 +23,7 @@ jobs: - name: Install Python dependencies run: | + pip install numpy pip install torch - name: Install xmake diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 795c2c1a..ea680c57 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -63,8 +63,6 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): else: end = topk - - sum_s = 0 for i in range(end): sum_s += dataNp[i] @@ -78,12 +76,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device): def random_sample_0(data): return torch.argmax(data) + def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16): print( f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}" ) - - data = torch.rand((voc), dtype=x_dtype).to(torch_device) + data = torch.arange(voc).float() * 0.0001 + _perm = torch.randperm(voc) + data = data[_perm].to(x_dtype).to(torch_device) if(topp > 0 and topk > 1): ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") else: @@ -130,12 +130,9 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ if torch_device == "npu": torch.npu.synchronize() - assert indices[0].type(ans.dtype) == ans or abs(data[indices[0]] - data[ans]) == 0.0, "compute error" - - - + assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) - + print("Test passed!") def test_cpu(lib, test_cases): device = DeviceEnum.DEVICE_CPU @@ -176,15 +173,16 @@ def test_ascend(lib, test_cases): if __name__ == "__main__": test_cases = [ # voc, random_val, topp, topk, temperature - (512, 0.92, 0.8, 3, 0.5), - (4096, 0.95, 0.9, 5, 1.0), - (16384, 0.85, 0.85, 10, 2.0), - (512, 0.92, 0, 3, 0.5), - (4096, 0.95, 0.9, 1, 1.0), - (16384, 0.85, 0, 1, 2.0), - (16384, 0.85, 0, 1, 2.0), - (32000, 0.8, 0.8, 50, 1.0), - (32000, 0.8, 1.0, 25, 1.0), + (512, 0.8, 0.8, 3, 0.5), + (4096, 0.05, 0.9, 5, 1.0), + (16384, 0.15, 0.85, 10, 2.0), + (512, 0.08, 0, 3, 0.5), + (4096, 0.5, 0.9, 1, 1.0), + (16384, 0.15, 0, 1, 2.0), + (16384, 0.15, 0, 1, 2.0), + (32000, 0.08, 0.8, 50, 1.0), + (32000, 0.08, 1.0, 25, 1.0), + # (119696, 0.01, 1.0, 100, 1.0), ] args = get_args() @@ -228,4 +226,4 @@ def test_ascend(lib, test_cases): test_ascend(lib, test_cases) if not (args.cpu or args.cuda or args.bang or args.ascend): test_cpu(lib, test_cases) - print("Test passed!") + print("\033[92mTest passed!\033[0m")