diff --git a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py index 91450e4d8d9a..a75e1c3c9e1a 100644 --- a/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py +++ b/sdks/python/apache_beam/ml/transforms/embeddings/vertex_ai.py @@ -115,7 +115,7 @@ def __init__( def get_request( self, text_batch: Sequence[TextEmbeddingInput], - model: MultiModalEmbeddingModel, + model: TextEmbeddingModel, throttle_delay_secs: int): while self.throttler.throttle_request(time.time() * _MSEC_TO_SEC): LOGGER.info(