Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion operatorspy/tests/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
)

data = torch.rand((voc), dtype=x_dtype).to(torch_device)
if(topp > 0 and topk > 0):
if(topp > 0 and topk > 1):
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
else:
ans = random_sample_0(data)
Expand Down Expand Up @@ -169,6 +169,7 @@ def test_bang(lib, test_cases):
(512, 0.92, 0, 3, 0.5),
(4096, 0.95, 0.9, 0, 1.0),
(16384, 0.85, 0, 0, 2.0),
(16384, 0.85, 0, 1, 2.0),
]

args = get_args()
Expand Down
2 changes: 1 addition & 1 deletion src/ops/random_sample/bang/random_sample_bang.mlu
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void random_sampleUnion(cnrtQueue_t queue, void *workspace, void const *source,
k_type = CNRT_FUNC_TYPE_UNION1;

int taskNum = k_dim.x * k_dim.y * k_dim.z;
if(topp > 0 && topk > 0){
if(topp > 0 && topk > 1){
const int maxNum = SRC_MAX_SIZE/sizeof(T);
char *origin = reinterpret_cast<char *>(workspace);
char *indTmp = origin + taskNum * topk * sizeof(uint64_t);
Expand Down
2 changes: 1 addition & 1 deletion src/ops/random_sample/cpu/random_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ infiniopStatus_t cpuRandomSample(RandomSampleCpuDescriptor_t desc,
float temperature,
void *stream) {
if (dtype_eq(desc->dtype, F16)) {
if (topp > 0 && topk > 0) {
if (topp > 0 && topk > 1) {
random_sample_cpu_f16(desc,
workspace,
result,
Expand Down
2 changes: 1 addition & 1 deletion src/ops/random_sample/cuda/random_sample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
key_in, key_out,
voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上
//排序结束,然后开始做softmax变换
if (topp > 0 && topk > 0) {
if (topp > 0 && topk > 1) {
int BLOCK_DIM = 1024;
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
softmax<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>(val_out, topk,
Expand Down