From df0637a59fe1a50d73d0a38346c2b8aabaa68382 Mon Sep 17 00:00:00 2001 From: Nick Stenning Date: Thu, 11 May 2023 17:05:04 +0200 Subject: [PATCH] Don't persist or send cookies to Replicate API Replicate's API does not use cookies and even if we return cookies the client should not save and replay them. Co-authored-by: Mattt Zmuda Signed-off-by: Nick Stenning Signed-off-by: Mattt Zmuda --- replicate/client.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/replicate/client.py b/replicate/client.py index 4196d6aa..95c6e0a3 100644 --- a/replicate/client.py +++ b/replicate/client.py @@ -5,6 +5,7 @@ import requests from requests.adapters import HTTPAdapter, Retry +from requests.cookies import RequestsCookieJar from replicate.__about__ import __version__ from replicate.exceptions import ModelError, ReplicateError @@ -25,7 +26,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: self.poll_interval = float(os.environ.get("REPLICATE_POLL_INTERVAL", "0.5")) # TODO: make thread safe - self.read_session = requests.Session() + self.read_session = _create_session() read_retries = Retry( total=5, backoff_factor=2, @@ -50,7 +51,7 @@ def __init__(self, api_token: Optional[str] = None) -> None: self.read_session.mount("http://", HTTPAdapter(max_retries=read_retries)) self.read_session.mount("https://", HTTPAdapter(max_retries=read_retries)) - self.write_session = requests.Session() + self.write_session = _create_session() write_retries = Retry( total=5, backoff_factor=2, @@ -138,3 +139,21 @@ def run(self, model_version: str, **kwargs) -> Union[Any, Iterator[Any]]: if prediction.status == "failed": raise ModelError(prediction.error) return prediction.output + + +class _NonpersistentCookieJar(RequestsCookieJar): + """ + A cookie jar that doesn't persist cookies between requests. + """ + + def set(self, name, value, **kwargs) -> None: + return + + def set_cookie(self, cookie, *args, **kwargs) -> None: + return + + +def _create_session() -> requests.Session: + s = requests.Session() + s.cookies = _NonpersistentCookieJar() + return s