diff --git a/Dockerfile b/Dockerfile index 5bc76281..a18077f0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,6 +26,7 @@ ENV PORT="80" ENV MODEL_REVISION="" ENV MODEL_CACHE_DIR="/models" ENV MODEL_LOAD_IN_8BIT="false" +ENV MODEL_LOAD_IN_4BIT="false" ENV MODEL_LOCAL_FILES_ONLY="false" ENV MODEL_TRUST_REMOTE_CODE="false" ENV MODEL_HALF_PRECISION="false" diff --git a/basaran/__init__.py b/basaran/__init__.py index 28281577..28f20bd3 100644 --- a/basaran/__init__.py +++ b/basaran/__init__.py @@ -21,6 +21,7 @@ def is_true(value): MODEL_REVISION = os.getenv("MODEL_REVISION", "") MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models") MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", "")) +MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", "")) MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", "")) MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", "")) MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", "")) diff --git a/basaran/__main__.py b/basaran/__main__.py index 570d48a3..bcd123b2 100644 --- a/basaran/__main__.py +++ b/basaran/__main__.py @@ -20,6 +20,7 @@ from . import MODEL_REVISION from . import MODEL_CACHE_DIR from . import MODEL_LOAD_IN_8BIT +from . import MODEL_LOAD_IN_4BIT from . import MODEL_LOCAL_FILES_ONLY from . import MODEL_TRUST_REMOTE_CODE from . import MODEL_HALF_PRECISION @@ -42,6 +43,7 @@ revision=MODEL_REVISION, cache_dir=MODEL_CACHE_DIR, load_in_8bit=MODEL_LOAD_IN_8BIT, + load_in_4bit=MODEL_LOAD_IN_4BIT, local_files_only=MODEL_LOCAL_FILES_ONLY, trust_remote_code=MODEL_TRUST_REMOTE_CODE, half_precision=MODEL_HALF_PRECISION, diff --git a/basaran/model.py b/basaran/model.py index b681cf30..f5d35406 100644 --- a/basaran/model.py +++ b/basaran/model.py @@ -302,6 +302,7 @@ def load_model( revision=None, cache_dir=None, load_in_8bit=False, + load_in_4bit=False, local_files_only=False, trust_remote_code=False, half_precision=False, @@ -322,9 +323,10 @@ def load_model( kwargs = kwargs.copy() kwargs["device_map"] = "auto" kwargs["load_in_8bit"] = load_in_8bit + kwargs["load_in_4bit"] = load_in_4bit # Cast all parameters to float16 if quantization is enabled. - if half_precision or load_in_8bit: + if half_precision or load_in_8bit or load_in_4bit: kwargs["torch_dtype"] = torch.float16 # Support both decoder-only and encoder-decoder models.