diff --git a/replicate/identifier.py b/replicate/identifier.py index 4cff010a..7ca194ac 100644 --- a/replicate/identifier.py +++ b/replicate/identifier.py @@ -1,28 +1,9 @@ import re -from typing import NamedTuple +from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union - -class ModelIdentifier(NamedTuple): - """ - A reference to a model in the format owner/name:version. - """ - - owner: str - name: str - - @classmethod - def parse(cls, ref: str) -> "ModelIdentifier": - """ - Split a reference in the format owner/name:version into its components. - """ - - match = re.match(r"^(?P[^/]+)/(?P[^:]+)$", ref) - if not match: - raise ValueError( - f"Invalid reference to model version: {ref}. Expected format: owner/name" - ) - - return cls(match.group("owner"), match.group("name")) +if TYPE_CHECKING: + from replicate.model import Model + from replicate.version import Version class ModelVersionIdentifier(NamedTuple): @@ -32,7 +13,7 @@ class ModelVersionIdentifier(NamedTuple): owner: str name: str - version: str + version: Optional[str] = None @classmethod def parse(cls, ref: str) -> "ModelVersionIdentifier": @@ -40,10 +21,30 @@ def parse(cls, ref: str) -> "ModelVersionIdentifier": Split a reference in the format owner/name:version into its components. """ - match = re.match(r"^(?P[^/]+)/(?P[^:]+):(?P.+)$", ref) + match = re.match(r"^(?P[^/]+)/(?P[^/:]+)(:(?P.+))?$", ref) if not match: raise ValueError( f"Invalid reference to model version: {ref}. Expected format: owner/name:version" ) return cls(match.group("owner"), match.group("name"), match.group("version")) + + +def _resolve( + ref: Union["Model", "Version", "ModelVersionIdentifier", str] +) -> Tuple[Optional["Version"], Optional[str], Optional[str], Optional[str]]: + from replicate.model import Model # pylint: disable=import-outside-toplevel + from replicate.version import Version # pylint: disable=import-outside-toplevel + + version = None + owner, name, version_id = None, None, None + if isinstance(ref, Model): + owner, name = ref.owner, ref.name + elif isinstance(ref, Version): + version = ref + version_id = ref.id + elif isinstance(ref, ModelVersionIdentifier): + owner, name, version_id = ref + elif isinstance(ref, str): + owner, name, version_id = ModelVersionIdentifier.parse(ref) + return version, owner, name, version_id diff --git a/replicate/model.py b/replicate/model.py index e557db6e..ab8e1406 100644 --- a/replicate/model.py +++ b/replicate/model.py @@ -3,7 +3,7 @@ from typing_extensions import NotRequired, TypedDict, Unpack, deprecated from replicate.exceptions import ReplicateException -from replicate.identifier import ModelIdentifier +from replicate.identifier import ModelVersionIdentifier from replicate.pagination import Page from replicate.prediction import ( Prediction, @@ -296,7 +296,7 @@ class ModelsPredictions(Namespace): def create( self, - model: Optional[Union[str, Tuple[str, str], "Model"]], + model: Union[str, Tuple[str, str], "Model"], input: Dict[str, Any], **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: @@ -317,7 +317,7 @@ def create( async def async_create( self, - model: Optional[Union[str, Tuple[str, str], "Model"]], + model: Union[str, Tuple[str, str], "Model"], input: Dict[str, Any], **params: Unpack["Predictions.CreatePredictionParams"], ) -> Prediction: @@ -391,7 +391,11 @@ def _create_prediction_url_from_model( elif isinstance(model, tuple): owner, name = model[0], model[1] elif isinstance(model, str): - owner, name = ModelIdentifier.parse(model) + owner, name, version_id = ModelVersionIdentifier.parse(model) + if version_id is not None: + raise ValueError( + f"Invalid reference to model version: {model}. Expected model or reference in the format owner/name" + ) if owner is None or name is None: raise ValueError( diff --git a/replicate/run.py b/replicate/run.py index b755fe7d..6bbab588 100644 --- a/replicate/run.py +++ b/replicate/run.py @@ -1,21 +1,23 @@ -import asyncio from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union from typing_extensions import Unpack +from replicate import identifier from replicate.exceptions import ModelError -from replicate.identifier import ModelVersionIdentifier +from replicate.model import Model +from replicate.prediction import Prediction from replicate.schema import make_schema_backwards_compatible -from replicate.version import Versions +from replicate.version import Version, Versions if TYPE_CHECKING: from replicate.client import Client + from replicate.identifier import ModelVersionIdentifier from replicate.prediction import Predictions def run( client: "Client", - ref: str, + ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 @@ -23,25 +25,26 @@ def run( Run a model and wait for its output. """ - owner, name, version_id = ModelVersionIdentifier.parse(ref) + version, owner, name, version_id = identifier._resolve(ref) - prediction = client.predictions.create( - version=version_id, input=input or {}, **params - ) + if version_id is not None: + prediction = client.predictions.create( + version=version_id, input=input or {}, **params + ) + elif owner and name: + prediction = client.models.predictions.create( + model=(owner, name), input=input or {}, **params + ) + else: + raise ValueError( + f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" + ) - if owner and name: + if not version and (owner and name and version_id): version = Versions(client, model=(owner, name)).get(version_id) - # Return an iterator of the output - schema = make_schema_backwards_compatible( - version.openapi_schema, version.cog_version - ) - output = schema["components"]["schemas"]["Output"] - if ( - output.get("type") == "array" - and output.get("x-cog-array-type") == "iterator" - ): - return prediction.output_iterator() + if version and (iterator := _make_output_iterator(version, prediction)): + return iterator prediction.wait() @@ -53,7 +56,7 @@ def run( async def async_run( client: "Client", - ref: str, + ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Union[Any, Iterator[Any]]: # noqa: ANN401 @@ -61,29 +64,28 @@ async def async_run( Run a model and wait for its output asynchronously. """ - owner, name, version_id = ModelVersionIdentifier.parse(ref) + version, owner, name, version_id = identifier._resolve(ref) - prediction = await client.predictions.async_create( - version=version_id, input=input or {}, **params - ) + if version or version_id: + prediction = await client.predictions.async_create( + version=(version or version_id), input=input or {}, **params + ) + elif owner and name: + prediction = await client.models.predictions.async_create( + model=(owner, name), input=input or {}, **params + ) + else: + raise ValueError( + f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" + ) - if owner and name: - version = await Versions(client, model=(owner, name)).async_get(version_id) + if not version and (owner and name and version_id): + version = Versions(client, model=(owner, name)).get(version_id) - # Return an iterator of the output - schema = make_schema_backwards_compatible( - version.openapi_schema, version.cog_version - ) - output = schema["components"]["schemas"]["Output"] - if ( - output.get("type") == "array" - and output.get("x-cog-array-type") == "iterator" - ): - return prediction.output_iterator() + if version and (iterator := _make_output_iterator(version, prediction)): + return iterator - while prediction.status not in ["succeeded", "failed", "canceled"]: - await asyncio.sleep(client.poll_interval) - prediction = await client.predictions.async_get(prediction.id) + prediction.wait() if prediction.status == "failed": raise ModelError(prediction.error) @@ -91,4 +93,17 @@ async def async_run( return prediction.output +def _make_output_iterator( + version: Version, prediction: Prediction +) -> Optional[Iterator[Any]]: + schema = make_schema_backwards_compatible( + version.openapi_schema, version.cog_version + ) + output = schema["components"]["schemas"]["Output"] + if output.get("type") == "array" and output.get("x-cog-array-type") == "iterator": + return prediction.output_iterator() + + return None + + __all__: List = [] diff --git a/replicate/stream.py b/replicate/stream.py index be1827d2..22cea974 100644 --- a/replicate/stream.py +++ b/replicate/stream.py @@ -7,12 +7,13 @@ Iterator, List, Optional, + Union, ) from typing_extensions import Unpack +from replicate import identifier from replicate.exceptions import ReplicateError -from replicate.identifier import ModelVersionIdentifier try: from pydantic import v1 as pydantic # type: ignore @@ -24,7 +25,10 @@ import httpx from replicate.client import Client + from replicate.identifier import ModelVersionIdentifier + from replicate.model import Model from replicate.prediction import Predictions + from replicate.version import Version class ServerSentEvent(pydantic.BaseModel): # type: ignore @@ -157,7 +161,7 @@ async def __aiter__(self) -> AsyncIterator[ServerSentEvent]: def stream( client: "Client", - ref: str, + ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> Iterator[ServerSentEvent]: @@ -168,10 +172,20 @@ def stream( params = params or {} params["stream"] = True - _, _, version_id = ModelVersionIdentifier.parse(ref) - prediction = client.predictions.create( - version=version_id, input=input or {}, **params - ) + version, owner, name, version_id = identifier._resolve(ref) + + if version or version_id: + prediction = client.predictions.create( + version=(version or version_id), input=input or {}, **params + ) + elif owner and name: + prediction = client.models.predictions.create( + model=(owner, name), input=input or {}, **params + ) + else: + raise ValueError( + f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" + ) url = prediction.urls and prediction.urls.get("stream", None) if not url or not isinstance(url, str): @@ -187,7 +201,7 @@ def stream( async def async_stream( client: "Client", - ref: str, + ref: Union["Model", "Version", "ModelVersionIdentifier", str], input: Optional[Dict[str, Any]] = None, **params: Unpack["Predictions.CreatePredictionParams"], ) -> AsyncIterator[ServerSentEvent]: @@ -198,10 +212,20 @@ async def async_stream( params = params or {} params["stream"] = True - _, _, version_id = ModelVersionIdentifier.parse(ref) - prediction = await client.predictions.async_create( - version=version_id, input=input or {}, **params - ) + version, owner, name, version_id = identifier._resolve(ref) + + if version or version_id: + prediction = await client.predictions.async_create( + version=(version or version_id), input=input or {}, **params + ) + elif owner and name: + prediction = await client.models.predictions.async_create( + model=(owner, name), input=input or {}, **params + ) + else: + raise ValueError( + f"Invalid argument: {ref}. Expected model, version, or reference in the format owner/name or owner/name:version" + ) url = prediction.urls and prediction.urls.get("stream", None) if not url or not isinstance(url, str): @@ -214,3 +238,6 @@ async def async_stream( async with client._async_client.stream("GET", url, headers=headers) as response: async for event in EventSource(response): yield event + + +__all__ = ["ServerSentEvent"] diff --git a/replicate/training.py b/replicate/training.py index 619159a3..8cb74c64 100644 --- a/replicate/training.py +++ b/replicate/training.py @@ -14,7 +14,7 @@ from typing_extensions import NotRequired, Unpack from replicate.files import upload_file -from replicate.identifier import ModelIdentifier, ModelVersionIdentifier +from replicate.identifier import ModelVersionIdentifier from replicate.json import encode_json from replicate.model import Model from replicate.pagination import Page @@ -373,14 +373,14 @@ def _create_training_url_from_shorthand(ref: str) -> str: def _create_training_url_from_model_and_version( model: Union[str, Tuple[str, str], "Model"], - version: Union[str, Version], + version: Union[str, "Version"], ) -> str: if isinstance(model, Model): owner, name = model.owner, model.name elif isinstance(model, tuple): owner, name = model[0], model[1] elif isinstance(model, str): - owner, name = ModelIdentifier.parse(model) + owner, name, _ = ModelVersionIdentifier.parse(model) else: raise ValueError( "model must be a Model, a tuple of (owner, name), or a string in the format 'owner/name'" diff --git a/tests/test_identifier.py b/tests/test_identifier.py new file mode 100644 index 00000000..7e907fd9 --- /dev/null +++ b/tests/test_identifier.py @@ -0,0 +1,57 @@ +import pytest + +from replicate.identifier import ModelVersionIdentifier + + +@pytest.mark.parametrize( + "id, expected", + [ + ( + "meta/llama-2-70b-chat", + { + "owner": "meta", + "name": "llama-2-70b-chat", + "version": None, + "error": False, + }, + ), + ( + "mistralai/mistral-7b-instruct-v1.4", + { + "owner": "mistralai", + "name": "mistral-7b-instruct-v1.4", + "version": None, + "error": False, + }, + ), + ( + "nateraw/video-llava:a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", + { + "owner": "nateraw", + "name": "video-llava", + "version": "a494250c04691c458f57f2f8ef5785f25bc851e0c91fd349995081d4362322dd", + "error": False, + }, + ), + ( + "", + {"error": True}, + ), + ( + "invalid", + {"error": True}, + ), + ( + "invalid/id/format", + {"error": True}, + ), + ], +) +def test_parse_model_id(id, expected): + try: + result = ModelVersionIdentifier.parse(id) + assert result.owner == expected["owner"] + assert result.name == expected["name"] + assert result.version == expected["version"] + except ValueError: + assert expected["error"]