Skip to content

[Feature Request]: Allow specification of a custom model inference method for a RunInference ModelHandler #22572

@agvdndor

Description

@agvdndor

What would you like to happen?

The current implementation of RunInference provides model handlers for PyTorch and Sklearn models. These handlers assume that the method to call for inference is fixed:

  • Pytorch: Do a forward pass by calling the __call__ method -> output = torch_model(input)
  • Sklearn: call the model's predict method -> output = sklearn_model.predict(input)

However in some cases we want to provide a custom method for RunInference to call.
Two examples:

  1. A number of pretrained models loaded with the Huggingface transformers library recommend using the generate() method. From the Huggingface docs on the T5 mode:

    At inference time, it is recommended to use generate(). This method takes care of encoding the input and feeding the encoded hidden states via cross-attention layers to the decoder and auto-regressively generates the decoder output.

    
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = T5ForConditionalGeneration.from_pretrained("t5-small")
    
    input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
    outputs = model.generate(input_ids)
    print(tokenizer.decode(outputs[0], skip_special_tokens=True))
    Das Haus ist wunderbar.
    
  2. Using OpenAI's CLIP model which is implemented as a torch model we might not want to execute the normal forward pass to encode both images and text image_embedding, text_embedding = clip_model(image, text) but instead only compute the image embeddings image_embedding = clip_model.encode_image(image).

Solution: Allowing the user to specify the inference_fn when creating a ModelHandler would enable this usage.

Issue Priority

Priority: 2

Issue Component

Component: sdk-py-core

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions