Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 87 additions & 18 deletions olive/passes/onnx/insert_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@


class InsertBeamSearch(Pass):
"""Insert Beam Search Op."""
"""Insert Beam Search Op. Only used for whisper models.

Uses WhisperBeamSearch contrib op if ORT version >= 1.17.1, else uses BeamSearch contrib op.
"""

_accepts_composite_model = True

Expand Down Expand Up @@ -65,6 +68,13 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon
" 1.16.0"
),
),
"use_temperature": PassConfigParam(
type_=bool,
default_value=False,
description=(
"Use temperature as an extra graph input to the beam search op. Only supported in ORT >= 1.17.1"
),
),
"fp16": PassConfigParam(
type_=bool,
default_value=False,
Expand All @@ -87,6 +97,8 @@ def chain_model(

# version check
version_1_16 = version.parse(OrtVersion) >= version.parse("1.16.0")
version_1_17_1 = version.parse(OrtVersion) >= version.parse("1.17.1")
# NOTE: will ignore cross qk related options for now

# Chain two models (model_A and model_B) by inserting beam search op in between.
model_A.graph.name = f"{model_A_name} subgraph"
Expand All @@ -107,8 +119,18 @@ def chain_model(
if version_1_16:
beam_inputs.extend(["decoder_input_ids" if options["use_forced_decoder_ids"] else ""])
beam_inputs.extend(["logits_processor" if options["use_logits_processor"] else ""])
if version_1_17_1:
beam_inputs.extend(["", ""])
beam_inputs.extend(
[("temperature_fp16" if options["fp16"] else "temperature") if options["use_temperature"] else ""]
)
# remove empty string from the end of beam_inputs
# otherwise, the model gives error when the last input is empty
while beam_inputs[-1] == "":
beam_inputs.pop()

input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None
# Cast input features to fp16 if required
graph_nodes = []
if options["fp16"]:
input_features_cast_node = helper.make_node(
"Cast",
Expand All @@ -131,23 +153,64 @@ def chain_model(
name="CastRepetitionPenaltyToFp16",
to=TensorProto.FLOAT16,
)
graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])

if version_1_17_1 and options["use_temperature"]:
temperature_cast_node = helper.make_node(
"Cast",
inputs=["temperature"],
outputs=["temperature_fp16"],
name="CastTemperatureToFp16",
to=TensorProto.FLOAT16,
)
graph_nodes.append(temperature_cast_node)

beam_outputs = ["sequences"]

node = helper.make_node("BeamSearch", inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_node")
node.domain = "com.microsoft"
node.attribute.extend(
[
helper.make_attribute("eos_token_id", model_config["eos_token_id"]),
helper.make_attribute("pad_token_id", model_config["pad_token_id"]),
helper.make_attribute("decoder_start_token_id", model_config["decoder_start_token_id"]),
helper.make_attribute("no_repeat_ngram_size", options["no_repeat_ngram_size"]),
helper.make_attribute("early_stopping", True),
helper.make_attribute("model_type", 2),
]
# beam search op attributes
beam_search_attrs = [
helper.make_attribute("eos_token_id", model_config["eos_token_id"]),
helper.make_attribute("pad_token_id", model_config["pad_token_id"]),
helper.make_attribute("decoder_start_token_id", model_config["decoder_start_token_id"]),
helper.make_attribute("no_repeat_ngram_size", options["no_repeat_ngram_size"]),
helper.make_attribute("early_stopping", True),
helper.make_attribute("model_type", 2),
]
if version_1_17_1:
from transformers import AutoTokenizer

# get tokenizer
# can get the base name of the model from the config
tokenizer = AutoTokenizer.from_pretrained(model_config["_name_or_path"])

beam_search_attrs.extend(
[
helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
helper.make_attribute(
"transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]
),
helper.make_attribute(
"start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]
),
helper.make_attribute(
"no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]
),
helper.make_attribute(
"beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]
),
]
)

node = helper.make_node(
"WhisperBeamSearch" if version_1_17_1 else "BeamSearch",
inputs=beam_inputs,
outputs=beam_outputs,
name="BeamSearch_node",
domain="com.microsoft",
)
node.attribute.extend(beam_search_attrs)

# beam graph inputs
# Graph inputs
input_features = helper.make_tensor_value_info(
"input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
)
Expand Down Expand Up @@ -195,12 +258,17 @@ def chain_model(
logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
graph_inputs.append(logits_processor)

# graph outputs
if version_1_17_1 and options["use_temperature"]:
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
graph_inputs.append(temperature)

# Graph outputs
sequences = helper.make_tensor_value_info(
"sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
)
graph_outputs = [sequences]

# Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
if options["use_gpu"] and version_1_16:
from onnxruntime.transformers.convert_generation import (
update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha as update_decoder_with_ort,
Expand All @@ -220,14 +288,15 @@ def chain_model(
helper.make_attribute("encoder", model_A.graph),
]
)

opset_import = [
helper.make_opsetid(domain="com.microsoft", version=1),
helper.make_opsetid(domain="", version=17),
]
graph_nodes = (
[input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] if options["fp16"] else [node]
)

graph_nodes.append(node)

# Make graph with BeamSearch/WhisperBeamSearch op
beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers)
assert model_A.ir_version == model_B.ir_version
logger.debug("Using IR version %s for chained model", model_A.ir_version)
Expand Down