From 260ef562605fbb1fd49397d4967cdda24d09a480 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 26 Feb 2023 16:16:23 -0500 Subject: [PATCH] [Fix][TIR] SampleCategorical apply-to-schedule This PR is another way to fix the issue described in #14118. Since we do not have a standard for json file on the format of float numbers (for example, we cannot require a json file producer to print the "integer" float numbers with at least one decimal), and the json parser is not responsible for determining if an integer in a json file should be parsed to a float or an int, the most convenient way of fixing the SampleCategorical issue will be allowing both FloatImms and IntImms as input, and converting all IntImms to FloatImms accordingly. This PR fixes the issue in this way. --- src/tir/schedule/primitive/sampling.cc | 17 ++++++++++-- .../unittest/test_tir_schedule_trace.py | 27 +++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index ec12b045d3f0..e84e171811eb 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -391,9 +391,22 @@ struct SampleCategoricalTraits : public UnpackedInstTraits candidates, // - Array probs, // + Array probs, // Optional decision) { - return sch->SampleCategorical(candidates, probs, decision); + Array probs_float = probs.Map([](const ObjectRef& prob) { + const auto* prob_float = prob.as(); + if (prob_float != nullptr) { + return GetRef(prob_float); + } + const auto* prob_int = prob.as(); + if (prob_int != nullptr) { + return FloatImm(DataType::Float(32), static_cast(prob_int->value)); + } + LOG(FATAL) + << "SampleCategorical does not accept probability with type other than float or int."; + throw; + }); + return sch->SampleCategorical(candidates, probs_float, decision); } static String UnpackedAsPython(Array outputs, // diff --git a/tests/python/unittest/test_tir_schedule_trace.py b/tests/python/unittest/test_tir_schedule_trace.py index 916db184e09b..a87fd4ed5b56 100644 --- a/tests/python/unittest/test_tir_schedule_trace.py +++ b/tests/python/unittest/test_tir_schedule_trace.py @@ -316,6 +316,30 @@ def test_apply_json_to_schedule_1(): tvm.ir.assert_structural_equal(elementwise_inlined, sch.mod["main"]) +def test_apply_json_to_schedule_sample_categorical(): + var = tir.Var("v", "int32") + trace1 = Trace( + insts=[ + Instruction( + kind=InstructionKind.get("SampleCategorical"), + inputs=[], + attrs=[[tvm.tir.IntImm("int32", 3)], [tvm.tir.FloatImm("float32", 1.0)]], + outputs=[var], + ) + ], + decisions={}, + ) + json = trace1.as_json() + assert json == [[["SampleCategorical", [], [[3], [1]], ["v0"]]], []] + + sch = tir.Schedule(elementwise, debug_mask="all") + # As long as the application does not fail, it is fine. + Trace.apply_json_to_schedule(json, sch) + python_str = sch.trace.as_python() + assert len(python_str) == 1 + assert python_str[0] == "v0 = sch.sample_categorical(candidates=[3], probs=[1], decision=0)" + + def _test_apply_annotation_trace_from_json(annotation: str): """Test applying an annotation works without crashing. @@ -367,5 +391,4 @@ def test_apply_annotation_from_json(): if __name__ == "__main__": - test_trace_simplified_2() - # tvm.testing.main() + tvm.testing.main()