-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Description
What would you like to happen?
The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:
- Pytorch: Do a forward pass by calling the
__call__method ->output = torch_model(input) - Sklearn: call the model's
predictmethod ->output = sklearn_model.predict(input)
However in some cases we want to provide a custom method for RunInference to call.
Two examples:
-
A number of pretrained models loaded with the Huggingface transformers library recommend using the
generate()method. From the Huggingface docs on the T5 mode:At inference time, it is recommended to use generate(). This method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder and auto-regressively generates the decoder output.
tokenizer = T5Tokenizer.from_pretrained("t5-small") model = T5ForConditionalGeneration.from_pretrained("t5-small") input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids outputs = model.generate(input_ids) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) Das Haus ist wunderbar. -
Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text
image_embedding, text_embedding = clip_model(image, text)but instead only compute the image embeddingsimage_embedding = clip_model.encode_image(image).
Solution: Allowing the user to specify the inference_fn when creating a ModelHandler would enable this usage.
Issue Priority
Priority: 2
Issue Component
Component: sdk-py-core