From e9cc3ded28e1736e439d8f23a74638c82a7c696e Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 28 May 2025 19:38:07 +0530 Subject: [PATCH 1/3] doc fix signature for 8-bit optim --- bitsandbytes/optim/adam.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 1a8800843..335434a32 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -76,8 +76,6 @@ def __init__( betas=(0.9, 0.999), eps=1e-8, weight_decay=0, - amsgrad=False, - optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, @@ -98,10 +96,6 @@ def __init__( The epsilon value prevents division by zero in the optimizer. weight_decay (`float`, defaults to 0.0): The weight decay value for the optimizer. - amsgrad (`bool`, defaults to `False`): - Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. - optim_bits (`int`, defaults to 32): - The number of bits of the optimizer state. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): From 9edeb0fff65615fd3e4a768caf13a31994f8a48b Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 28 May 2025 23:29:24 +0530 Subject: [PATCH 2/3] required changes --- bitsandbytes/optim/adam.py | 34 +++++++++++++++++++++++++++++++--- bitsandbytes/optim/adamw.py | 27 +++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 335434a32..5ae459e0c 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import warnings from bitsandbytes.optim.optimizer import Optimizer2State @@ -76,6 +76,8 @@ def __init__( betas=(0.9, 0.999), eps=1e-8, weight_decay=0, + amsgrad=False, + optim_bits=32, args=None, min_8bit_size=4096, percentile_clipping=100, @@ -96,6 +98,12 @@ def __init__( The epsilon value prevents division by zero in the optimizer. weight_decay (`float`, defaults to 0.0): The weight decay value for the optimizer. + amsgrad (`bool`, defaults to `False`): + Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + Note: This parameter is not supported in Adam8bit and must be False. + optim_bits (`int`, defaults to 32): + The number of bits of the optimizer state. + Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): @@ -107,6 +115,15 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ + # Validate unsupported parameters + if amsgrad: + raise ValueError("Adam8bit does not support amsgrad=True") + + if optim_bits != 32: + # We allow the default value of 32 to maintain compatibility with the function signature, + # but any other value is invalid since Adam8bit always uses 8-bit optimization + raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)") + super().__init__( "adam", params, @@ -114,7 +131,7 @@ def __init__( betas, eps, weight_decay, - 8, + 8, # Hardcoded to 8 bits args, min_8bit_size, percentile_clipping, @@ -277,8 +294,10 @@ def __init__( The weight decay value for the optimizer. amsgrad (`bool`, defaults to `False`): Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + Note: This parameter is not supported in PagedAdam8bit and must be False. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. + Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): @@ -290,6 +309,15 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ + # Validate unsupported parameters + if amsgrad: + raise ValueError("PagedAdam8bit does not support amsgrad=True") + + if optim_bits != 32: + # We allow the default value of 32 to maintain compatibility with the function signature, + # but any other value is invalid since PagedAdam8bit always uses 8-bit optimization + raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)") + super().__init__( "adam", params, @@ -297,7 +325,7 @@ def __init__( betas, eps, weight_decay, - 8, + 8, # Hardcoded to 8 bits args, min_8bit_size, percentile_clipping, diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index 4bf3f6436..afe3de872 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from bitsandbytes.optim.optimizer import Optimizer2State +import warnings class AdamW(Optimizer2State): @@ -98,8 +99,10 @@ def __init__( The weight decay value for the optimizer. amsgrad (`bool`, defaults to `False`): Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + Note: This parameter is not supported in AdamW8bit and must be False. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. + Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): @@ -111,6 +114,15 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ + # Validate unsupported parameters + if amsgrad: + raise ValueError("AdamW8bit does not support amsgrad=True") + + if optim_bits != 32: + # We allow the default value of 32 to maintain compatibility with the function signature, + # but any other value is invalid since AdamW8bit always uses 8-bit optimization + raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)") + super().__init__( "adam", params, @@ -118,7 +130,7 @@ def __init__( betas, eps, weight_decay, - 8, + 8, # Hardcoded to 8 bits args, min_8bit_size, percentile_clipping, @@ -279,8 +291,10 @@ def __init__( The weight decay value for the optimizer. amsgrad (`bool`, defaults to `False`): Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead. + Note: This parameter is not supported in PagedAdamW8bit and must be False. optim_bits (`int`, defaults to 32): The number of bits of the optimizer state. + Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization. args (`object`, defaults to `None`): An object with additional arguments. min_8bit_size (`int`, defaults to 4096): @@ -292,6 +306,15 @@ def __init__( is_paged (`bool`, defaults to `False`): Whether the optimizer is a paged optimizer or not. """ + # Validate unsupported parameters + if amsgrad: + raise ValueError("PagedAdamW8bit does not support amsgrad=True") + + if optim_bits != 32: + # We allow the default value of 32 to maintain compatibility with the function signature, + # but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization + raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)") + super().__init__( "adam", params, @@ -299,7 +322,7 @@ def __init__( betas, eps, weight_decay, - 8, + 8, # Hardcoded to 8 bits args, min_8bit_size, percentile_clipping, From b4de8c57ef8a43c5a12d7099774436fa7a3825ae Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 28 May 2025 23:30:15 +0530 Subject: [PATCH 3/3] precommit --- bitsandbytes/optim/adam.py | 1 - bitsandbytes/optim/adamw.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/optim/adam.py b/bitsandbytes/optim/adam.py index 5ae459e0c..22a217c3b 100644 --- a/bitsandbytes/optim/adam.py +++ b/bitsandbytes/optim/adam.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import warnings from bitsandbytes.optim.optimizer import Optimizer2State diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py index afe3de872..a32394bd5 100644 --- a/bitsandbytes/optim/adamw.py +++ b/bitsandbytes/optim/adamw.py @@ -2,8 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. + from bitsandbytes.optim.optimizer import Optimizer2State -import warnings class AdamW(Optimizer2State):