diff --git a/olive/passes/onnx/insert_beam_search.py b/olive/passes/onnx/insert_beam_search.py index b05093c6e2..b6733b6cde 100644 --- a/olive/passes/onnx/insert_beam_search.py +++ b/olive/passes/onnx/insert_beam_search.py @@ -21,7 +21,10 @@ class InsertBeamSearch(Pass): - """Insert Beam Search Op.""" + """Insert Beam Search Op. Only used for whisper models. + + Uses WhisperBeamSearch contrib op if ORT version >= 1.17.1, else uses BeamSearch contrib op. + """ _accepts_composite_model = True @@ -65,6 +68,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon " 1.16.0" ), ), + "use_temperature": PassConfigParam( + type_=bool, + default_value=False, + description=( + "Use temperature as an extra graph input to the beam search op. Only supported in ORT >= 1.17.1" + ), + ), "fp16": PassConfigParam( type_=bool, default_value=False, @@ -87,6 +97,8 @@ def chain_model( # version check version_1_16 = version.parse(OrtVersion) >= version.parse("1.16.0") + version_1_17_1 = version.parse(OrtVersion) >= version.parse("1.17.1") + # NOTE: will ignore cross qk related options for now # Chain two models (model_A and model_B) by inserting beam search op in between. model_A.graph.name = f"{model_A_name} subgraph" @@ -107,8 +119,18 @@ def chain_model( if version_1_16: beam_inputs.extend(["decoder_input_ids" if options["use_forced_decoder_ids"] else ""]) beam_inputs.extend(["logits_processor" if options["use_logits_processor"] else ""]) + if version_1_17_1: + beam_inputs.extend(["", ""]) + beam_inputs.extend( + [("temperature_fp16" if options["fp16"] else "temperature") if options["use_temperature"] else ""] + ) + # remove empty string from the end of beam_inputs + # otherwise, the model gives error when the last input is empty + while beam_inputs[-1] == "": + beam_inputs.pop() - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None + # Cast input features to fp16 if required + graph_nodes = [] if options["fp16"]: input_features_cast_node = helper.make_node( "Cast", @@ -131,23 +153,64 @@ def chain_model( name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if version_1_17_1 and options["use_temperature"]: + temperature_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="CastTemperatureToFp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temperature_cast_node) beam_outputs = ["sequences"] - node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_node") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", model_config["eos_token_id"]), - helper.make_attribute("pad_token_id", model_config["pad_token_id"]), - helper.make_attribute("decoder_start_token_id", model_config["decoder_start_token_id"]), - helper.make_attribute("no_repeat_ngram_size", options["no_repeat_ngram_size"]), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + # beam search op attributes + beam_search_attrs = [ + helper.make_attribute("eos_token_id", model_config["eos_token_id"]), + helper.make_attribute("pad_token_id", model_config["pad_token_id"]), + helper.make_attribute("decoder_start_token_id", model_config["decoder_start_token_id"]), + helper.make_attribute("no_repeat_ngram_size", options["no_repeat_ngram_size"]), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + ] + if version_1_17_1: + from transformers import AutoTokenizer + + # get tokenizer + # can get the base name of the model from the config + tokenizer = AutoTokenizer.from_pretrained(model_config["_name_or_path"]) + + beam_search_attrs.extend( + [ + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute( + "transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0] + ), + helper.make_attribute( + "start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0] + ), + helper.make_attribute( + "no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0] + ), + helper.make_attribute( + "beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0] + ), + ] + ) + + node = helper.make_node( + "WhisperBeamSearch" if version_1_17_1 else "BeamSearch", + inputs=beam_inputs, + outputs=beam_outputs, + name="BeamSearch_node", + domain="com.microsoft", ) + node.attribute.extend(beam_search_attrs) - # beam graph inputs + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -195,12 +258,17 @@ def chain_model( logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) graph_inputs.append(logits_processor) - # graph outputs + if version_1_17_1 and options["use_temperature"]: + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) + graph_inputs.append(temperature) + + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) graph_outputs = [sequences] + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if options["use_gpu"] and version_1_16: from onnxruntime.transformers.convert_generation import ( update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha as update_decoder_with_ort, @@ -220,14 +288,15 @@ def chain_model( helper.make_attribute("encoder", model_A.graph), ] ) + opset_import = [ helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17), ] - graph_nodes = ( - [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] if options["fp16"] else [node] - ) + graph_nodes.append(node) + + # Make graph with BeamSearch/WhisperBeamSearch op beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) assert model_A.ir_version == model_B.ir_version logger.debug("Using IR version %s for chained model", model_A.ir_version)