diff --git a/classy_vision/configs/hmdb51/r3d34.json b/classy_vision/configs/hmdb51/r3d34.json index 9bd204aea8..b30b358c59 100644 --- a/classy_vision/configs/hmdb51/r3d34.json +++ b/classy_vision/configs/hmdb51/r3d34.json @@ -102,13 +102,13 @@ "schedulers": [ { "name": "linear", - "start_lr": 0.005, - "end_lr": 0.04 + "start_value": 0.005, + "end_value": 0.04 }, { "name": "cosine", - "start_lr": 0.04, - "end_lr": 0.00004 + "start_value": 0.04, + "end_value": 0.00004 } ], "update_interval": "epoch", diff --git a/classy_vision/configs/kinetics400/postactivated_i3d50.json b/classy_vision/configs/kinetics400/postactivated_i3d50.json index 54b497d2ab..7abb93d9f6 100644 --- a/classy_vision/configs/kinetics400/postactivated_i3d50.json +++ b/classy_vision/configs/kinetics400/postactivated_i3d50.json @@ -102,8 +102,8 @@ "param_schedulers": { "lr": { "name": "cosine", - "start_lr": 0.1, - "end_lr": 0.0001 + "start_value": 0.1, + "end_value": 0.0001 } }, "weight_decay": 0.0001, diff --git a/classy_vision/configs/kinetics400/preactivated_i3d50.json b/classy_vision/configs/kinetics400/preactivated_i3d50.json index 777cbff2a3..d98022163d 100644 --- a/classy_vision/configs/kinetics400/preactivated_i3d50.json +++ b/classy_vision/configs/kinetics400/preactivated_i3d50.json @@ -102,8 +102,8 @@ "param_schedulers": { "lr": { "name": "cosine", - "start_lr": 0.1, - "end_lr": 0.0001 + "start_value": 0.1, + "end_value": 0.0001 } }, "weight_decay": 0.0001, diff --git a/classy_vision/configs/ucf101/r3d34.json b/classy_vision/configs/ucf101/r3d34.json index 6f6b59e0c5..68533376e2 100644 --- a/classy_vision/configs/ucf101/r3d34.json +++ b/classy_vision/configs/ucf101/r3d34.json @@ -102,13 +102,13 @@ "schedulers": [ { "name": "linear", - "start_lr": 0.005, - "end_lr": 0.04 + "start_value": 0.005, + "end_value": 0.04 }, { "name": "cosine", - "start_lr": 0.04, - "end_lr": 0.00004 + "start_value": 0.04, + "end_value": 0.00004 } ], "lengths": [0.13, 0.87], diff --git a/classy_vision/optim/param_scheduler/__init__.py b/classy_vision/optim/param_scheduler/__init__.py index c9f242fd3a..3150d7d6e8 100644 --- a/classy_vision/optim/param_scheduler/__init__.py +++ b/classy_vision/optim/param_scheduler/__init__.py @@ -12,6 +12,7 @@ from .classy_vision_param_scheduler import ( # noqa F401 ClassyParamScheduler, UpdateInterval, + update_interval_from_config, ) diff --git a/classy_vision/optim/param_scheduler/classy_vision_param_scheduler.py b/classy_vision/optim/param_scheduler/classy_vision_param_scheduler.py index e64070708d..92cbc44d1b 100644 --- a/classy_vision/optim/param_scheduler/classy_vision_param_scheduler.py +++ b/classy_vision/optim/param_scheduler/classy_vision_param_scheduler.py @@ -21,6 +21,22 @@ class UpdateInterval(Enum): STEP = "step" +def update_interval_from_config( + config: Dict[str, Any], default: UpdateInterval +) -> UpdateInterval: + """Fetches the update interval from a config + + Args: + config: The config for the parameter scheduler + default: The value to use if the config doesn't specify an update interval + """ + if "update_interval" not in config: + return default + if config.get("update_interval") not in ["step", "epoch"]: + raise ValueError("Choices for update interval are 'step' or 'epoch'") + return UpdateInterval[config["update_interval"].upper()] + + class ClassyParamScheduler(object): """ Base class for Classy parameter schedulers. @@ -33,7 +49,7 @@ class ClassyParamScheduler(object): # To be used for comparisons with where WHERE_EPSILON = 1e-6 - def __init__(self, update_interval: UpdateInterval = UpdateInterval.EPOCH): + def __init__(self, update_interval: UpdateInterval): """ Constructor for ClassyParamScheduler diff --git a/classy_vision/optim/param_scheduler/composite_scheduler.py b/classy_vision/optim/param_scheduler/composite_scheduler.py index b168b4fc8b..490db13b84 100644 --- a/classy_vision/optim/param_scheduler/composite_scheduler.py +++ b/classy_vision/optim/param_scheduler/composite_scheduler.py @@ -12,6 +12,7 @@ UpdateInterval, build_param_scheduler, register_param_scheduler, + update_interval_from_config, ) @@ -41,7 +42,7 @@ class CompositeParamScheduler(ClassyParamScheduler): update_interval = "step" schedulers = [ {"name": "constant", "value": 0.42}, - {"name": "cosine_decay", "start_lr": 0.42, "end_lr": 0.0001} + {"name": "cosine_decay", "start_value": 0.42, "end_value": 0.0001} ] interval_scaling = ['rescaled', 'rescaled'], lengths = [0.3, 0.7] @@ -49,17 +50,17 @@ class CompositeParamScheduler(ClassyParamScheduler): The parameter value will be 0.42 for the first [0%, 30%) of steps, and then will cosine decay from 0.42 to 0.0001 for [30%, 100%) of training. + The schedule is updated after every train step by default. """ def __init__( self, schedulers: Sequence[ClassyParamScheduler], lengths: Sequence[float], - update_interval: UpdateInterval, interval_scaling: Sequence[IntervalScaling], + update_interval: UpdateInterval = UpdateInterval.STEP, ): - super().__init__() - self.update_interval = update_interval + super().__init__(update_interval=update_interval) self._lengths = lengths self._schedulers = schedulers self._interval_scaling = interval_scaling @@ -89,13 +90,6 @@ def from_config(cls, config: Dict[str, Any]) -> "CompositeParamScheduler": ), "The sum of all values in lengths must be 1" if sum(config["lengths"]) != 1.0: config["lengths"][-1] = 1.0 - sum(config["lengths"][:-1]) - update_interval = UpdateInterval.STEP - if "update_interval" in config: - assert config["update_interval"] in { - "step", - "epoch", - }, "Choices for update interval are 'step' or 'epoch'" - update_interval = UpdateInterval[config["update_interval"].upper()] interval_scaling = [] if "interval_scaling" in config: assert len(config["schedulers"]) == len( @@ -119,7 +113,7 @@ def from_config(cls, config: Dict[str, Any]) -> "CompositeParamScheduler": build_param_scheduler(scheduler) for scheduler in config["schedulers"] ], lengths=config["lengths"], - update_interval=update_interval, + update_interval=update_interval_from_config(config, UpdateInterval.STEP), interval_scaling=interval_scaling, ) diff --git a/classy_vision/optim/param_scheduler/constant_scheduler.py b/classy_vision/optim/param_scheduler/constant_scheduler.py index 6f973bc81b..879aba48d5 100644 --- a/classy_vision/optim/param_scheduler/constant_scheduler.py +++ b/classy_vision/optim/param_scheduler/constant_scheduler.py @@ -6,7 +6,7 @@ from typing import Any, Dict -from . import ClassyParamScheduler, register_param_scheduler +from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler @register_param_scheduler("constant") @@ -16,7 +16,7 @@ class ConstantParamScheduler(ClassyParamScheduler): """ def __init__(self, value: float): - super().__init__() + super().__init__(update_interval=UpdateInterval.EPOCH) self._value = value @classmethod diff --git a/classy_vision/optim/param_scheduler/cosine_scheduler.py b/classy_vision/optim/param_scheduler/cosine_scheduler.py index 25cff0cb9e..a1254ab0ca 100644 --- a/classy_vision/optim/param_scheduler/cosine_scheduler.py +++ b/classy_vision/optim/param_scheduler/cosine_scheduler.py @@ -4,11 +4,15 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import math from typing import Any, Dict -from . import ClassyParamScheduler, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) @register_param_scheduler("cosine") @@ -18,19 +22,25 @@ class CosineParamScheduler(ClassyParamScheduler): //arxiv.org/pdf/1608.03983.pdf>`_. Can be used for either cosine decay or cosine warmup schedules based on start and end values. + The schedule is updated after every train step by default. Example: .. code-block:: python - start_lr: 0.1 - end_lr: 0.0001 + start_value: 0.1 + end_value: 0.0001 """ - def __init__(self, start_lr: float, end_lr: float): - super().__init__() - self._start_lr = start_lr - self._end_lr = end_lr + def __init__( + self, + start_value: float, + end_value: float, + update_interval: UpdateInterval = UpdateInterval.STEP, + ): + super().__init__(update_interval=update_interval) + self._start_value = start_value + self._end_value = end_value @classmethod def from_config(cls, config: Dict[str, Any]) -> "CosineParamScheduler": @@ -44,12 +54,16 @@ def from_config(cls, config: Dict[str, Any]) -> "CosineParamScheduler": A CosineParamScheduler instance. """ assert ( - "start_lr" in config and "end_lr" in config - ), "Cosine scheduler requires a start_lr and a end_lr" + "start_value" in config and "end_value" in config + ), "Cosine scheduler requires a start_value and a end_value" - return cls(start_lr=config["start_lr"], end_lr=config["end_lr"]) + return cls( + start_value=config["start_value"], + end_value=config["end_value"], + update_interval=update_interval_from_config(config, UpdateInterval.STEP), + ) def __call__(self, where: float): - return self._end_lr + 0.5 * (self._start_lr - self._end_lr) * ( + return self._end_value + 0.5 * (self._start_value - self._end_value) * ( 1 + math.cos(math.pi * where) ) diff --git a/classy_vision/optim/param_scheduler/linear_scheduler.py b/classy_vision/optim/param_scheduler/linear_scheduler.py index 3fc6ab1bfc..3114db6596 100644 --- a/classy_vision/optim/param_scheduler/linear_scheduler.py +++ b/classy_vision/optim/param_scheduler/linear_scheduler.py @@ -6,28 +6,39 @@ from typing import Any, Dict -from . import ClassyParamScheduler, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) @register_param_scheduler("linear") class LinearParamScheduler(ClassyParamScheduler): """ - Linearly interpolates parameter between ``start_lr`` and ``end_lr``. + Linearly interpolates parameter between ``start_value`` and ``end_value``. Can be used for either warmup or decay based on start and end values. + The schedule is updated after every train step by default. Example: .. code-block:: python - start_lr: 0.0001 - end_lr: 0.01 + start_value: 0.0001 + end_value: 0.01 Corresponds to a linear increasing schedule with values in [0.0001, 0.01) """ - def __init__(self, start_lr: float, end_lr: float): - super().__init__() - self._start_lr = start_lr - self._end_lr = end_lr + def __init__( + self, + start_value: float, + end_value: float, + update_interval: UpdateInterval = UpdateInterval.STEP, + ): + super().__init__(update_interval=update_interval) + self._start_value = start_value + self._end_value = end_value @classmethod def from_config(cls, config: Dict[str, Any]) -> "LinearParamScheduler": @@ -41,10 +52,15 @@ def from_config(cls, config: Dict[str, Any]) -> "LinearParamScheduler": A LinearParamScheduler instance. """ assert ( - "start_lr" in config and "end_lr" in config + "start_value" in config and "end_value" in config ), "Linear scheduler requires a start and a end" - return cls(start_lr=config["start_lr"], end_lr=config["end_lr"]) + + return cls( + start_value=config["start_value"], + end_value=config["end_value"], + update_interval=update_interval_from_config(config, UpdateInterval.STEP), + ) def __call__(self, where: float): # interpolate between start and end values - return self._end_lr * where + self._start_lr * (1 - where) + return self._end_value * where + self._start_value * (1 - where) diff --git a/classy_vision/optim/param_scheduler/multi_step_scheduler.py b/classy_vision/optim/param_scheduler/multi_step_scheduler.py index a1c22de84c..3ac1498d0f 100644 --- a/classy_vision/optim/param_scheduler/multi_step_scheduler.py +++ b/classy_vision/optim/param_scheduler/multi_step_scheduler.py @@ -10,7 +10,12 @@ from classy_vision.generic.util import is_pos_int -from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) @register_param_scheduler("multistep") @@ -18,6 +23,7 @@ class MultiStepParamScheduler(ClassyParamScheduler): """ Takes a predefined schedule for a param value, and a list of epochs which stand for the upper boundary (excluded) of each range. + The schedule is updated after every train epoch by default. Example: @@ -37,10 +43,10 @@ def __init__( self, values, num_epochs: int, - update_interval: UpdateInterval, milestones: Optional[List[int]] = None, + update_interval: UpdateInterval = UpdateInterval.EPOCH, ): - super().__init__(update_interval) + super().__init__(update_interval=update_interval) self._param_schedule = values self._num_epochs = num_epochs self._milestones = milestones @@ -96,11 +102,12 @@ def from_config(cls, config: Dict[str, Any]) -> "MultiStepParamScheduler": "Non-Equi Step scheduler requires a list of %d epochs" % (len(config["values"]) - 1) ) + return cls( values=config["values"], num_epochs=config["num_epochs"], - update_interval=UpdateInterval(config.get("update_interval", "epoch")), milestones=milestones, + update_interval=update_interval_from_config(config, UpdateInterval.EPOCH), ) def __call__(self, where: float): diff --git a/classy_vision/optim/param_scheduler/polynomial_decay_scheduler.py b/classy_vision/optim/param_scheduler/polynomial_decay_scheduler.py index db91502be9..d37f540f2b 100644 --- a/classy_vision/optim/param_scheduler/polynomial_decay_scheduler.py +++ b/classy_vision/optim/param_scheduler/polynomial_decay_scheduler.py @@ -6,7 +6,12 @@ from typing import Any, Dict -from . import ClassyParamScheduler, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) @register_param_scheduler("polynomial") @@ -14,22 +19,28 @@ class PolynomialDecayParamScheduler(ClassyParamScheduler): """ Decays the param value after every epoch according to a polynomial function with a fixed power. + The schedule is updated after every train step by default. Example: .. code-block:: python - base_lr: 0.1 + base_value: 0.1 power: 0.9 Then the param value will be 0.1 for epoch 0, 0.099 for epoch 1, and so on. """ - def __init__(self, base_lr, power): - super().__init__() + def __init__( + self, + base_value: float, + power: float, + update_interval: UpdateInterval = UpdateInterval.STEP, + ): + super().__init__(update_interval=update_interval) - self._base_lr = base_lr + self._base_value = base_value self._power = power @classmethod @@ -44,9 +55,13 @@ def from_config(cls, config: Dict[str, Any]) -> "PolynomialDecayParamScheduler": A PolynomialDecayParamScheduler instance. """ assert ( - "base_lr" in config and "power" in config + "base_value" in config and "power" in config ), "Polynomial decay scheduler requires a base lr and a power of decay" - return cls(base_lr=config["base_lr"], power=config["power"]) + return cls( + base_value=config["base_value"], + power=config["power"], + update_interval=update_interval_from_config(config, UpdateInterval.STEP), + ) def __call__(self, where: float): - return self._base_lr * (1 - where) ** self._power + return self._base_value * (1 - where) ** self._power diff --git a/classy_vision/optim/param_scheduler/step_scheduler.py b/classy_vision/optim/param_scheduler/step_scheduler.py index 5dbb46fb45..80529694a4 100644 --- a/classy_vision/optim/param_scheduler/step_scheduler.py +++ b/classy_vision/optim/param_scheduler/step_scheduler.py @@ -6,7 +6,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Union -from . import ClassyParamScheduler, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) @register_param_scheduler("step") @@ -15,6 +20,7 @@ class StepParamScheduler(ClassyParamScheduler): Takes a fixed schedule for a param value. If the length of the fixed schedule is less than the number of epochs, then the epochs are divided evenly among the param schedule. + The schedule is updated after every train epoch by default. Example: @@ -27,8 +33,13 @@ class StepParamScheduler(ClassyParamScheduler): epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. """ - def __init__(self, num_epochs: Union[int, float], values: List[float]): - super().__init__() + def __init__( + self, + num_epochs: Union[int, float], + values: List[float], + update_interval: UpdateInterval = UpdateInterval.EPOCH, + ): + super().__init__(update_interval=update_interval) self._param_schedule = values @@ -50,7 +61,11 @@ def from_config(cls, config: Dict[str, Any]) -> "StepParamScheduler": ), "Step scheduler requires a list of at least one param value" assert config["num_epochs"] > 0, "Num epochs must be greater than 0" - return cls(num_epochs=config["num_epochs"], values=config["values"]) + return cls( + num_epochs=config["num_epochs"], + values=config["values"], + update_interval=update_interval_from_config(config, UpdateInterval.EPOCH), + ) def __call__(self, where: float): ind = int((where + self.WHERE_EPSILON) * len(self._param_schedule)) diff --git a/classy_vision/optim/param_scheduler/step_with_fixed_gamma_scheduler.py b/classy_vision/optim/param_scheduler/step_with_fixed_gamma_scheduler.py index 7db70f4b35..d8f0053585 100644 --- a/classy_vision/optim/param_scheduler/step_with_fixed_gamma_scheduler.py +++ b/classy_vision/optim/param_scheduler/step_with_fixed_gamma_scheduler.py @@ -6,7 +6,12 @@ from typing import Any, Dict -from . import ClassyParamScheduler, UpdateInterval, register_param_scheduler +from . import ( + ClassyParamScheduler, + UpdateInterval, + register_param_scheduler, + update_interval_from_config, +) from .step_scheduler import StepParamScheduler @@ -15,12 +20,13 @@ class StepWithFixedGammaParamScheduler(ClassyParamScheduler): """ Decays the param value by gamma at equal number of steps so as to have the specified total number of decays. + The schedule is updated after every train step by default. Example: .. code-block:: python - base_lr: 0.1 + base_value: 0.1 gamma: 0.1 num_decays: 3 num_epochs: 120 @@ -29,6 +35,31 @@ class StepWithFixedGammaParamScheduler(ClassyParamScheduler): epochs 30-59, 0.001 for epoch 60-89, 0.0001 for epochs 90-119. """ + def __init__( + self, + base_value: float, + num_decays: int, + gamma: float, + num_epochs: int, + update_interval: UpdateInterval = UpdateInterval.STEP, + ): + super().__init__(update_interval=update_interval) + + self.base_value = base_value + self.num_decays = num_decays + self.gamma = gamma + self.num_epochs = num_epochs + values = [base_value] + for _ in range(num_decays): + values.append(values[-1] * gamma) + + self._step_param_scheduler = StepParamScheduler( + num_epochs=num_epochs, values=values + ) + + # make this a STEP scheduler + self.update_interval = UpdateInterval.STEP + @classmethod def from_config(cls, config: Dict[str, Any]) -> "StepWithFixedGammaParamScheduler": """Instantiates a StepWithFixedGammaParamScheduler from a configuration. @@ -40,9 +71,9 @@ def from_config(cls, config: Dict[str, Any]) -> "StepWithFixedGammaParamSchedule Returns: A StepWithFixedGammaParamScheduler instance. """ - for key in ["base_lr", "gamma", "num_decays", "num_epochs"]: + for key in ["base_value", "gamma", "num_decays", "num_epochs"]: assert key in config, f"Step with fixed decay scheduler requires: {key}" - for key in ["base_lr", "gamma"]: + for key in ["base_value", "gamma"]: assert ( isinstance(config[key], (int, float)) and config[key] > 0 ), f"{key} must be a positive number" @@ -52,29 +83,12 @@ def from_config(cls, config: Dict[str, Any]) -> "StepWithFixedGammaParamSchedule ), f"{key} must be a positive integer" return cls( - base_lr=config["base_lr"], + base_value=config["base_value"], num_decays=config["num_decays"], gamma=config["gamma"], num_epochs=config["num_epochs"], + update_interval=update_interval_from_config(config, UpdateInterval.STEP), ) - def __init__(self, base_lr, num_decays, gamma, num_epochs): - super().__init__() - - self.base_lr = base_lr - self.num_decays = num_decays - self.gamma = gamma - self.num_epochs = num_epochs - values = [base_lr] - for _ in range(num_decays): - values.append(values[-1] * gamma) - - self._step_param_scheduler = StepParamScheduler( - num_epochs=num_epochs, values=values - ) - - # make this a STEP scheduler - self.update_interval = UpdateInterval.STEP - def __call__(self, where: float) -> float: return self._step_param_scheduler(where) diff --git a/test/api_test.py b/test/api_test.py index 2021c37f84..f6a4e4a27c 100644 --- a/test/api_test.py +++ b/test/api_test.py @@ -83,7 +83,7 @@ def test_one(self): optimizer = SGD(momentum=0.9, weight_decay=1e-4, nesterov=True) optimizer.set_param_schedulers( - {"lr": LinearParamScheduler(start_lr=0.01, end_lr=0.009)} + {"lr": LinearParamScheduler(start_value=0.01, end_value=0.009)} ) task = ( diff --git a/test/generic/optim_test_util.py b/test/generic/optim_test_util.py index a821420503..60ca92a6cd 100644 --- a/test/generic/optim_test_util.py +++ b/test/generic/optim_test_util.py @@ -212,7 +212,7 @@ def _test_lr_schedule(optimizer, num_epochs, epochs, targets): config["lr"] = { "name": "composite", "schedulers": [ - {"name": "linear", "start_lr": init_lr, "end_lr": 0.1}, + {"name": "linear", "start_value": init_lr, "end_value": 0.1}, {"name": "step", "values": [0.1, 0.01, 0.001]}, ], "update_interval": "epoch", diff --git a/test/optim_param_scheduler_composite_test.py b/test/optim_param_scheduler_composite_test.py index 9f968c5b6a..cbfd41ebb6 100644 --- a/test/optim_param_scheduler_composite_test.py +++ b/test/optim_param_scheduler_composite_test.py @@ -45,7 +45,7 @@ def _get_valid_mixed_config(self): "name": "composite", "schedulers": [ {"name": "step", "values": [0.1, 0.2, 0.3, 0.4, 0.5], "num_epochs": 10}, - {"name": "cosine", "start_lr": 0.42, "end_lr": 0.0001}, + {"name": "cosine", "start_value": 0.42, "end_value": 0.0001}, ], "lengths": [0.5, 0.5], } @@ -54,8 +54,8 @@ def _get_valid_linear_config(self): return { "name": "composite", "schedulers": [ - {"name": "linear", "start_lr": 0.0, "end_lr": 0.5}, - {"name": "linear", "start_lr": 0.5, "end_lr": 1.0}, + {"name": "linear", "start_value": 0.0, "end_value": 0.5}, + {"name": "linear", "start_value": 0.5, "end_value": 1.0}, ], "lengths": [0.5, 0.5], "interval_scaling": ["rescaled", "rescaled"], diff --git a/test/optim_param_scheduler_cosine_test.py b/test/optim_param_scheduler_cosine_test.py index a93028174a..3aa0cb81e4 100644 --- a/test/optim_param_scheduler_cosine_test.py +++ b/test/optim_param_scheduler_cosine_test.py @@ -15,7 +15,7 @@ class TestCosineScheduler(unittest.TestCase): _num_epochs = 10 def _get_valid_decay_config(self): - return {"name": "cosine", "start_lr": 0.1, "end_lr": 0} + return {"name": "cosine", "start_value": 0.1, "end_value": 0} def _get_valid_decay_config_intermediate_values(self): return [0.0976, 0.0905, 0.0794, 0.0655, 0.05, 0.0345, 0.0206, 0.0095, 0.0024] @@ -26,13 +26,13 @@ def test_invalid_config(self): bad_config = copy.deepcopy(config) # Invalid Base lr - del bad_config["start_lr"] + del bad_config["start_value"] with self.assertRaises(AssertionError): CosineParamScheduler.from_config(bad_config) - # Invalid end_lr - bad_config["start_lr"] = config["start_lr"] - del bad_config["end_lr"] + # Invalid end_value + bad_config["start_value"] = config["start_value"] + del bad_config["end_value"] with self.assertRaises(AssertionError): CosineParamScheduler.from_config(bad_config) @@ -45,7 +45,7 @@ def test_scheduler_as_decay(self): for epoch_num in range(self._num_epochs) ] expected_schedule = [ - config["start_lr"] + config["start_value"] ] + self._get_valid_decay_config_intermediate_values() self.assertEqual(schedule, expected_schedule) @@ -53,9 +53,9 @@ def test_scheduler_as_decay(self): def test_scheduler_as_warmup(self): config = self._get_valid_decay_config() # Swap start and end lr to change to warmup - tmp = config["start_lr"] - config["start_lr"] = config["end_lr"] - config["end_lr"] = tmp + tmp = config["start_value"] + config["start_value"] = config["end_value"] + config["end_value"] = tmp scheduler = CosineParamScheduler.from_config(config) schedule = [ @@ -63,7 +63,7 @@ def test_scheduler_as_warmup(self): for epoch_num in range(self._num_epochs) ] # Schedule should be decay reversed - expected_schedule = [config["start_lr"]] + list( + expected_schedule = [config["start_value"]] + list( reversed(self._get_valid_decay_config_intermediate_values()) ) @@ -75,9 +75,9 @@ def test_scheduler_warmup_decay_match(self): warmup_config = copy.deepcopy(decay_config) # Swap start and end lr to change to warmup - tmp = warmup_config["start_lr"] - warmup_config["start_lr"] = warmup_config["end_lr"] - warmup_config["end_lr"] = tmp + tmp = warmup_config["start_value"] + warmup_config["start_value"] = warmup_config["end_value"] + warmup_config["end_value"] = tmp warmup_scheduler = CosineParamScheduler.from_config(warmup_config) decay_schedule = [ diff --git a/test/optim_param_scheduler_linear_test.py b/test/optim_param_scheduler_linear_test.py index 9c38a9a667..dd844b0127 100644 --- a/test/optim_param_scheduler_linear_test.py +++ b/test/optim_param_scheduler_linear_test.py @@ -18,20 +18,20 @@ def _get_valid_intermediate(self): return [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09] def _get_valid_config(self): - return {"name": "linear", "start_lr": 0.0, "end_lr": 0.1} + return {"name": "linear", "start_value": 0.0, "end_value": 0.1} def test_invalid_config(self): config = self._get_valid_config() bad_config = copy.deepcopy(config) # No start lr - del bad_config["start_lr"] + del bad_config["start_value"] with self.assertRaises(AssertionError): LinearParamScheduler.from_config(bad_config) # No end lr - bad_config["start_lr"] = config["start_lr"] - del bad_config["end_lr"] + bad_config["start_value"] = config["start_value"] + del bad_config["end_value"] with self.assertRaises(AssertionError): LinearParamScheduler.from_config(bad_config) @@ -44,19 +44,19 @@ def test_scheduler(self): round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] - expected_schedule = [config["start_lr"]] + self._get_valid_intermediate() + expected_schedule = [config["start_value"]] + self._get_valid_intermediate() self.assertEqual(schedule, expected_schedule) # Check as decay - tmp = config["start_lr"] - config["start_lr"] = config["end_lr"] - config["end_lr"] = tmp + tmp = config["start_value"] + config["start_value"] = config["end_value"] + config["end_value"] = tmp scheduler = LinearParamScheduler.from_config(config) schedule = [ round(scheduler(epoch_num / self._num_epochs), 4) for epoch_num in range(self._num_epochs) ] - expected_schedule = [config["start_lr"]] + list( + expected_schedule = [config["start_value"]] + list( reversed(self._get_valid_intermediate()) ) self.assertEqual(schedule, expected_schedule) diff --git a/test/optim_param_scheduler_polynomial_test.py b/test/optim_param_scheduler_polynomial_test.py index 70e8971334..63c86309bd 100644 --- a/test/optim_param_scheduler_polynomial_test.py +++ b/test/optim_param_scheduler_polynomial_test.py @@ -20,7 +20,7 @@ def _get_valid_config(self): return { "name": "polynomial", "num_epochs": self._num_epochs, - "base_lr": 0.1, + "base_value": 0.1, "power": 1, } @@ -30,7 +30,7 @@ def test_invalid_config(self): # Invalid Base lr bad_config = copy.deepcopy(config) - del bad_config["base_lr"] + del bad_config["base_value"] with self.assertRaises(AssertionError): PolynomialDecayParamScheduler.from_config(bad_config) diff --git a/test/optim_param_scheduler_step_with_fixed_gamma_test.py b/test/optim_param_scheduler_step_with_fixed_gamma_test.py index 41eb44681e..b6e611f030 100644 --- a/test/optim_param_scheduler_step_with_fixed_gamma_test.py +++ b/test/optim_param_scheduler_step_with_fixed_gamma_test.py @@ -19,7 +19,7 @@ class TestStepWithFixedGammaScheduler(unittest.TestCase): def _get_valid_config(self): return { "name": "step_with_fixed_gamma", - "base_lr": 1, + "base_value": 1, "gamma": 0.1, "num_decays": 3, "num_epochs": self._num_epochs, @@ -39,9 +39,9 @@ def test_invalid_config(self): with self.assertRaises(AssertionError): StepWithFixedGammaParamScheduler.from_config(bad_config) - # Invalid base_lr + # Invalid base_value bad_config = copy.deepcopy(config) - bad_config["base_lr"] = -0.01 + bad_config["base_value"] = -0.01 with self.assertRaises(AssertionError): StepWithFixedGammaParamScheduler.from_config(bad_config) diff --git a/test/optim_param_scheduler_test.py b/test/optim_param_scheduler_test.py index 08aca43ff9..5e218ce98d 100644 --- a/test/optim_param_scheduler_test.py +++ b/test/optim_param_scheduler_test.py @@ -16,6 +16,7 @@ ClassyParamScheduler, UpdateInterval, register_param_scheduler, + update_interval_from_config, ) from classy_vision.tasks import ClassificationTask, ClassyTask from classy_vision.trainer import LocalTrainer @@ -250,3 +251,21 @@ def scheduler_mock(where): # the weight decay scheduler uses an epoch update interval self.assertEqual(weight_decay_list, [0 / 6, 0 / 6, 4 / 6, 4 / 6, 8 / 6, 8 / 6]) self.assertEqual(momentum_list, [0.9, 0.9, 0.9, 0.9, 0.9, 0.9]) + + def test_update_interval_from_config(self): + # test a config which specifies an update interval + config = {"update_interval": "epoch"} + self.assertEqual( + update_interval_from_config(config, UpdateInterval.STEP), + UpdateInterval.EPOCH, + ) + # test a config which doesn't specify an update interval + config = {} + self.assertEqual( + update_interval_from_config(config, UpdateInterval.STEP), + UpdateInterval.STEP, + ) + # test a config with an invalid update interval + config = {"update_interval": "invalid"} + with self.assertRaises(Exception): + update_interval_from_config(config, UpdateInterval.EPOCH)