Skip to content

Support **kwargs for PyTorch models. #21453

@damccorm

Description

@damccorm

Some models in Pytorch instantiating from torch.nn.Module, has extra parameters in the forward function call. These extra parameters can be passed as Dict or as positional arguments. 

Example of PyTorch models supported by Hugging Face -> https://huggingface.co/bert-base-uncased

Some torch models on Hugging face

Eg: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel


inputs = {
     input_ids: Tensor1,
     attention_mask: Tensor2,
     token_type_ids: Tensor3,
}

model = BertModel.from_pretrained("bert-base-uncased") # which is a  
# subclass of torch.nn.Module

outputs
= model(**inputs) # model forward method should be expecting the keys in the inputs as the positional
arguments.

 

Transformers integrated in Pytorch is supported by Hugging Face as well. 

 

Imported from Jira BEAM-14337. Original Jira may contain additional context.
Reported by: Anand Inguva.
Subtask of issue #21435

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions