diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index ec12b045d3f0..e84e171811eb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -391,9 +391,22 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // + Array probs, // Optional decision) { - return sch->SampleCategorical(candidates, probs, decision); + Array probs_float = probs.Map([](const ObjectRef& prob) { + const auto* prob_float = prob.as(); + if (prob_float != nullptr) { + return GetRef(prob_float); + } + const auto* prob_int = prob.as(); + if (prob_int != nullptr) { + return FloatImm(DataType::Float(32), static_cast(prob_int->value)); + } + LOG(FATAL) + << "SampleCategorical does not accept probability with type other than float or int."; + throw; + }); + return sch->SampleCategorical(candidates, probs_float, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 916db184e09b..a87fd4ed5b56 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -316,6 +316,30 @@ def test_apply_json_to_schedule_1(): tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) +def test_apply_json_to_schedule_sample_categorical(): + var = tir.Var("v", "int32") + trace1 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[[tvm.tir.IntImm("int32", 3)], [tvm.tir.FloatImm("float32", 1.0)]], + outputs=[var], + ) + ], + decisions={}, + ) + json = trace1.as_json() + assert json == [[["SampleCategorical", [], [[3], [1]], ["v0"]]], []] + + sch = tir.Schedule(elementwise, debug_mask="all") + # As long as the application does not fail, it is fine. + Trace.apply_json_to_schedule(json, sch) + python_str = sch.trace.as_python() + assert len(python_str) == 1 + assert python_str[0] == "v0 = sch.sample_categorical(candidates=[3], probs=[1], decision=0)" + + def _test_apply_annotation_trace_from_json(annotation: str): """Test applying an annotation works without crashing. @@ -367,5 +391,4 @@ def test_apply_annotation_from_json(): if __name__ == "__main__": - test_trace_simplified_2() - # tvm.testing.main() + tvm.testing.main()