Skip to content
This repository was archived by the owner on Jan 24, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ENV PORT="80"
ENV MODEL_REVISION=""
ENV MODEL_CACHE_DIR="/models"
ENV MODEL_LOAD_IN_8BIT="false"
ENV MODEL_LOAD_IN_4BIT="false"
ENV MODEL_LOCAL_FILES_ONLY="false"
ENV MODEL_TRUST_REMOTE_CODE="false"
ENV MODEL_HALF_PRECISION="false"
Expand Down
1 change: 1 addition & 0 deletions basaran/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def is_true(value):
MODEL_REVISION = os.getenv("MODEL_REVISION", "")
MODEL_CACHE_DIR = os.getenv("MODEL_CACHE_DIR", "models")
MODEL_LOAD_IN_8BIT = is_true(os.getenv("MODEL_LOAD_IN_8BIT", ""))
MODEL_LOAD_IN_4BIT = is_true(os.getenv("MODEL_LOAD_IN_4BIT", ""))
MODEL_LOCAL_FILES_ONLY = is_true(os.getenv("MODEL_LOCAL_FILES_ONLY", ""))
MODEL_TRUST_REMOTE_CODE = is_true(os.getenv("MODEL_TRUST_REMOTE_CODE", ""))
MODEL_HALF_PRECISION = is_true(os.getenv("MODEL_HALF_PRECISION", ""))
Expand Down
2 changes: 2 additions & 0 deletions basaran/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from . import MODEL_REVISION
from . import MODEL_CACHE_DIR
from . import MODEL_LOAD_IN_8BIT
from . import MODEL_LOAD_IN_4BIT
from . import MODEL_LOCAL_FILES_ONLY
from . import MODEL_TRUST_REMOTE_CODE
from . import MODEL_HALF_PRECISION
Expand All @@ -42,6 +43,7 @@
revision=MODEL_REVISION,
cache_dir=MODEL_CACHE_DIR,
load_in_8bit=MODEL_LOAD_IN_8BIT,
load_in_4bit=MODEL_LOAD_IN_4BIT,
local_files_only=MODEL_LOCAL_FILES_ONLY,
trust_remote_code=MODEL_TRUST_REMOTE_CODE,
half_precision=MODEL_HALF_PRECISION,
Expand Down
4 changes: 3 additions & 1 deletion basaran/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ def load_model(
revision=None,
cache_dir=None,
load_in_8bit=False,
load_in_4bit=False,
local_files_only=False,
trust_remote_code=False,
half_precision=False,
Expand All @@ -322,9 +323,10 @@ def load_model(
kwargs = kwargs.copy()
kwargs["device_map"] = "auto"
kwargs["load_in_8bit"] = load_in_8bit
kwargs["load_in_4bit"] = load_in_4bit

# Cast all parameters to float16 if quantization is enabled.
if half_precision or load_in_8bit:
if half_precision or load_in_8bit or load_in_4bit:
kwargs["torch_dtype"] = torch.float16

# Support both decoder-only and encoder-decoder models.
Expand Down