Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:

- name: Install Python dependencies
run: |
pip install numpy
pip install torch

- name: Install xmake
Expand Down
36 changes: 17 additions & 19 deletions operatorspy/tests/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):
else:
end = topk



sum_s = 0
for i in range(end):
sum_s += dataNp[i]
Expand All @@ -78,12 +76,14 @@ def random_sample(data, random_val, topp, topk, voc, temperature, torch_device):

def random_sample_0(data):
return torch.argmax(data)

def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_dtype=torch.float16):
print(
f"Testing RandomSample on {torch_device} with voc:{voc} dtype:{x_dtype}"
)

data = torch.rand((voc), dtype=x_dtype).to(torch_device)
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
if(topp > 0 and topk > 1):
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
else:
Expand Down Expand Up @@ -130,12 +130,9 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
if torch_device == "npu":
torch.npu.synchronize()

assert indices[0].type(ans.dtype) == ans or abs(data[indices[0]] - data[ans]) == 0.0, "compute error"



assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]]
check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor))

print("Test passed!")

def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
Expand Down Expand Up @@ -176,15 +173,16 @@ def test_ascend(lib, test_cases):
if __name__ == "__main__":
test_cases = [
# voc, random_val, topp, topk, temperature
(512, 0.92, 0.8, 3, 0.5),
(4096, 0.95, 0.9, 5, 1.0),
(16384, 0.85, 0.85, 10, 2.0),
(512, 0.92, 0, 3, 0.5),
(4096, 0.95, 0.9, 1, 1.0),
(16384, 0.85, 0, 1, 2.0),
(16384, 0.85, 0, 1, 2.0),
(32000, 0.8, 0.8, 50, 1.0),
(32000, 0.8, 1.0, 25, 1.0),
(512, 0.8, 0.8, 3, 0.5),
(4096, 0.05, 0.9, 5, 1.0),
(16384, 0.15, 0.85, 10, 2.0),
(512, 0.08, 0, 3, 0.5),
(4096, 0.5, 0.9, 1, 1.0),
(16384, 0.15, 0, 1, 2.0),
(16384, 0.15, 0, 1, 2.0),
(32000, 0.08, 0.8, 50, 1.0),
(32000, 0.08, 1.0, 25, 1.0),
# (119696, 0.01, 1.0, 100, 1.0),
]

args = get_args()
Expand Down Expand Up @@ -228,4 +226,4 @@ def test_ascend(lib, test_cases):
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
print("Test passed!")
print("\033[92mTest passed!\033[0m")
Loading