diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index e3e13a82..f10b8f8d 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -84,7 +84,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ ) data = torch.rand((voc), dtype=x_dtype).to(torch_device) - if(topp > 0 and topk > 0): + if(topp > 0 and topk > 1): ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu") else: ans = random_sample_0(data) @@ -169,6 +169,7 @@ def test_bang(lib, test_cases): (512, 0.92, 0, 3, 0.5), (4096, 0.95, 0.9, 0, 1.0), (16384, 0.85, 0, 0, 2.0), + (16384, 0.85, 0, 1, 2.0), ] args = get_args() diff --git a/src/ops/random_sample/bang/random_sample_bang.mlu b/src/ops/random_sample/bang/random_sample_bang.mlu index 7024658b..5b6a0751 100644 --- a/src/ops/random_sample/bang/random_sample_bang.mlu +++ b/src/ops/random_sample/bang/random_sample_bang.mlu @@ -464,7 +464,7 @@ void random_sampleUnion(cnrtQueue_t queue, void *workspace, void const *source, k_type = CNRT_FUNC_TYPE_UNION1; int taskNum = k_dim.x * k_dim.y * k_dim.z; - if(topp > 0 && topk > 0){ + if(topp > 0 && topk > 1){ const int maxNum = SRC_MAX_SIZE/sizeof(T); char *origin = reinterpret_cast(workspace); char *indTmp = origin + taskNum * topk * sizeof(uint64_t); diff --git a/src/ops/random_sample/cpu/random_sample.cc b/src/ops/random_sample/cpu/random_sample.cc index 63b27508..3706e1ea 100644 --- a/src/ops/random_sample/cpu/random_sample.cc +++ b/src/ops/random_sample/cpu/random_sample.cc @@ -163,7 +163,7 @@ infiniopStatus_t cpuRandomSample(RandomSampleCpuDescriptor_t desc, float temperature, void *stream) { if (dtype_eq(desc->dtype, F16)) { - if (topp > 0 && topk > 0) { + if (topp > 0 && topk > 1) { random_sample_cpu_f16(desc, workspace, result, diff --git a/src/ops/random_sample/cuda/random_sample.cu b/src/ops/random_sample/cuda/random_sample.cu index d29bec27..40761e89 100644 --- a/src/ops/random_sample/cuda/random_sample.cu +++ b/src/ops/random_sample/cuda/random_sample.cu @@ -133,7 +133,7 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace key_in, key_out, voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上 //排序结束,然后开始做softmax变换 - if (topp > 0 && topk > 0) { + if (topp > 0 && topk > 1) { int BLOCK_DIM = 1024; int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM; softmax<<>>(val_out, topk,