From 7fdcffcad4bbdfb0fdf8966b7e0a806a4e45241f Mon Sep 17 00:00:00 2001 From: Vinicius Reis Date: Wed, 8 Apr 2020 08:01:40 -0700 Subject: [PATCH] Move use_gpu from ClassyTrainer to ClassificationTask (#468) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/468 This is the first in a series of diffs to eliminate the ClassyTrainer abstraction. The only reason Trainer existed was to support elastic training, but PET v0.2 does not require changing out training loop. The plan is to move all attributes from ClassyTrainer into ClassificationTask. Start by moving use_gpu to the task. Reviewed By: mannatsingh Differential Revision: D20801017 fbshipit-source-id: d56bb330737a98a7ea7545d33da0ad0f21a0b6a1 --- .circleci/config.yml | 2 +- classy_train.py | 7 +-- classy_vision/generic/opts.py | 13 ------ classy_vision/tasks/classification_task.py | 43 +++++++++++-------- classy_vision/tasks/classy_task.py | 23 +++------- classy_vision/tasks/fine_tuning_task.py | 10 +---- classy_vision/trainer/classy_trainer.py | 12 +----- classy_vision/trainer/distributed_trainer.py | 30 +++---------- classy_vision/trainer/local_trainer.py | 28 ++---------- test/generic_util_test.py | 7 ++- test/hooks_checkpoint_hook_test.py | 6 +-- .../tasks_classification_task_amp_test.py | 3 +- test/tasks_classification_task_test.py | 36 ++++++++-------- test/trainer_distributed_trainer_test.py | 4 -- 14 files changed, 70 insertions(+), 154 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b011910643..9aa8e5e91f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -106,7 +106,7 @@ jobs: pip install . classy-project my-project pushd my-project - ./classy_train.py --device cpu --config configs/template_config.json + ./classy_train.py --config configs/template_config.json popd rm -rf my-project diff --git a/classy_train.py b/classy_train.py index 852627be2c..9a148870c7 100755 --- a/classy_train.py +++ b/classy_train.py @@ -93,18 +93,13 @@ def main(args, config): # Configure hooks to do tensorboard logging, checkpoints and so on task.set_hooks(configure_hooks(args, config)) - use_gpu = None - if args.device is not None: - use_gpu = args.device == "gpu" - assert torch.cuda.is_available() or not use_gpu, "CUDA is unavailable" - # LocalTrainer is used for a single node. DistributedTrainer will setup # training to use PyTorch's DistributedDataParallel. trainer_class = {"none": LocalTrainer, "ddp": DistributedTrainer}[ args.distributed_backend ] - trainer = trainer_class(use_gpu=use_gpu, num_dataloader_workers=args.num_workers) + trainer = trainer_class(num_dataloader_workers=args.num_workers) logging.info( f"Starting training on rank {get_rank()} worker. " diff --git a/classy_vision/generic/opts.py b/classy_vision/generic/opts.py index 5a0cd3083d..0520766f45 100644 --- a/classy_vision/generic/opts.py +++ b/classy_vision/generic/opts.py @@ -18,12 +18,6 @@ def add_generic_args(parser): parser.add_argument( "--config_file", type=str, help="path to config file for model", required=True ) - parser.add_argument( - "--device", - default=None, - type=str, - help="device to use: either 'cpu' or 'gpu'. If unspecified, will use GPU when available and CPU otherwise.", - ) parser.add_argument( "--num_workers", default=4, @@ -145,13 +139,6 @@ def check_generic_args(args): # check types and values: assert is_pos_int(args.num_workers), "incorrect number of workers" assert is_pos_int(args.visdom_port), "incorrect visdom port" - assert ( - args.device is None or args.device == "cpu" or args.device == "gpu" - ), "unknown device" - - # check that CUDA is available: - if args.device == "gpu": - assert torch.cuda.is_available(), "CUDA required to train on GPUs" # create checkpoint folder if it does not exist: if args.checkpoint_folder != "" and not os.path.exists(args.checkpoint_folder): diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index 8c8c089aa7..bd546f0257 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -142,6 +142,16 @@ def __init__(self): self.perf_log = [] self.last_batch = None self.batch_norm_sync_mode = BatchNormSyncMode.DISABLED + self.use_gpu = torch.cuda.is_available() + + def set_use_gpu(self, use_gpu: bool): + self.use_gpu = use_gpu + + assert ( + not self.use_gpu or torch.cuda.is_available() + ), "CUDA required to train on GPUs" + + return self def set_checkpoint(self, checkpoint): """Sets checkpoint on task. @@ -359,6 +369,10 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask": .set_hooks(hooks) ) + use_gpu = config.get("use_gpu") + if use_gpu is not None: + task.set_use_gpu(use_gpu) + for phase_type in phase_types: task.set_dataset(datasets[phase_type], phase_type) @@ -508,24 +522,19 @@ def build_dataloaders( for phase_type in self.datasets.keys() } - def prepare( - self, - num_dataloader_workers=0, - pin_memory=False, - use_gpu=False, - dataloader_mp_context=None, - ): + def prepare(self, num_dataloader_workers=0, dataloader_mp_context=None): """Prepares task for training, populates all derived attributes Args: num_dataloader_workers: Number of dataloading processes. If 0, dataloading is done on main process - pin_memory: if true pin memory on GPU - use_gpu: if true, load model, optimizer, loss, etc on GPU dataloader_mp_context: Determines how processes are spawned. Value must be one of None, "spawn", "fork", "forkserver". If None, then context is inherited from parent process """ + + pin_memory = self.use_gpu and torch.cuda.device_count() > 1 + self.phases = self._build_phases() self.dataloaders = self.build_dataloaders( num_workers=num_dataloader_workers, @@ -539,7 +548,7 @@ def prepare( self.base_model = apex.parallel.convert_syncbn_model(self.base_model) # move the model and loss to the right device - if use_gpu: + if self.use_gpu: self.base_model, self.loss = copy_model_to_gpu(self.base_model, self.loss) else: self.loss.cpu() @@ -686,7 +695,7 @@ def set_classy_state(self, state): # Set up pytorch module in train vs eval mode, update optimizer. self._set_model_train_mode() - def eval_step(self, use_gpu): + def eval_step(self): self.last_batch = None # Process next sample @@ -699,7 +708,7 @@ def eval_step(self, use_gpu): # Copy sample to GPU target = sample["target"] - if use_gpu: + if self.use_gpu: for key, value in sample.items(): sample[key] = recursive_copy_to_gpu(value, non_blocking=True) @@ -726,12 +735,8 @@ def check_inf_nan(self, loss): if loss == float("inf") or loss == float("-inf") or loss != loss: raise FloatingPointError(f"Loss is infinity or NaN: {loss}") - def train_step(self, use_gpu): - """Train step to be executed in train loop - - Args: - use_gpu: if true, execute training on GPU - """ + def train_step(self): + """Train step to be executed in train loop.""" self.last_batch = None @@ -745,7 +750,7 @@ def train_step(self, use_gpu): # Copy sample to GPU target = sample["target"] - if use_gpu: + if self.use_gpu: for key, value in sample.items(): sample[key] = recursive_copy_to_gpu(value, non_blocking=True) diff --git a/classy_vision/tasks/classy_task.py b/classy_vision/tasks/classy_task.py index d0d2fe462a..23517198e3 100644 --- a/classy_vision/tasks/classy_task.py +++ b/classy_vision/tasks/classy_task.py @@ -86,11 +86,7 @@ def set_classy_state(self, state): @abstractmethod def prepare( - self, - num_dataloader_workers=0, - pin_memory=False, - use_gpu=False, - dataloader_mp_context=None, + self, num_dataloader_workers=0, dataloader_mp_context=None ) -> None: """ Prepares the task for training. @@ -102,19 +98,15 @@ def prepare( num_dataloader_workers: Number of workers to create for the dataloaders pin_memory: Whether the dataloaders should copy the Tensors into CUDA pinned memory (default False) - use_gpu: True if training on GPUs, False otherwise """ pass @abstractmethod - def train_step(self, use_gpu) -> None: + def train_step(self) -> None: """ Run a train step. This corresponds to training over one batch of data from the dataloaders. - - Args: - use_gpu: True if training on GPUs, False otherwise """ pass @@ -155,24 +147,21 @@ def on_end(self): pass @abstractmethod - def eval_step(self, use_gpu) -> None: + def eval_step(self) -> None: """ Run an evaluation step. This corresponds to evaluating the model over one batch of data. - - Args: - use_gpu: True if training on GPUs, False otherwise """ pass - def step(self, use_gpu) -> None: + def step(self) -> None: from classy_vision.hooks import ClassyHookFunctions if self.train: - self.train_step(use_gpu) + self.train_step() else: - self.eval_step(use_gpu) + self.eval_step() for hook in self.hooks: hook.on_step(self) diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index 93a795fcba..5da2b382ce 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -67,18 +67,12 @@ def _set_model_train_mode(self): self.base_model.train(phase["train"]) def prepare( - self, - num_dataloader_workers: int = 0, - pin_memory: bool = False, - use_gpu: bool = False, - dataloader_mp_context=None, + self, num_dataloader_workers: int = 0, dataloader_mp_context=None ) -> None: assert ( self.pretrained_checkpoint is not None ), "Need a pretrained checkpoint for fine tuning" - super().prepare( - num_dataloader_workers, pin_memory, use_gpu, dataloader_mp_context - ) + super().prepare(num_dataloader_workers, dataloader_mp_context) if self.checkpoint is None: # no checkpoint exists, load the model's state from the pretrained # checkpoint diff --git a/classy_vision/trainer/classy_trainer.py b/classy_vision/trainer/classy_trainer.py index 043ba1b361..bca12b1464 100644 --- a/classy_vision/trainer/classy_trainer.py +++ b/classy_vision/trainer/classy_trainer.py @@ -27,25 +27,18 @@ class ClassyTrainer: def __init__( self, - use_gpu: Optional[bool] = None, num_dataloader_workers: int = 0, dataloader_mp_context: Optional[str] = None, ): """Constructor for ClassyTrainer. Args: - use_gpu: If true, then use GPUs for training. - If None, then check if we have GPUs available, if we do - then use GPU for training. num_dataloader_workers: Number of CPU processes doing dataloading per GPU. If 0, then dataloading is done on main thread. dataloader_mp_context: Determines how to launch new processes for dataloading. Must be one of "fork", "forkserver", "spawn". If None, process launching is inherited from parent. """ - if use_gpu is None: - use_gpu = torch.cuda.is_available() - self.use_gpu = use_gpu self.num_dataloader_workers = num_dataloader_workers self.dataloader_mp_context = dataloader_mp_context @@ -57,11 +50,8 @@ def train(self, task: ClassyTask): everything that is needed for training """ - pin_memory = self.use_gpu and torch.cuda.device_count() > 1 task.prepare( num_dataloader_workers=self.num_dataloader_workers, - pin_memory=pin_memory, - use_gpu=self.use_gpu, dataloader_mp_context=self.dataloader_mp_context, ) assert isinstance(task, ClassyTask) @@ -75,7 +65,7 @@ def train(self, task: ClassyTask): task.on_phase_start() while True: try: - task.step(self.use_gpu) + task.step() except StopIteration: break task.on_phase_end() diff --git a/classy_vision/trainer/distributed_trainer.py b/classy_vision/trainer/distributed_trainer.py index e4d02098e3..339fa3171f 100644 --- a/classy_vision/trainer/distributed_trainer.py +++ b/classy_vision/trainer/distributed_trainer.py @@ -56,39 +56,19 @@ class DistributedTrainer(ClassyTrainer): """Distributed trainer for using multiple training processes """ - def __init__( - self, - use_gpu: Optional[bool] = None, - num_dataloader_workers: int = 0, - dataloader_mp_context: Optional[str] = None, - ): - """Constructor for DistributedTrainer. - - Args: - use_gpu: If true, then use GPU 0 for training. - If None, then check if we have GPUs available, if we do - then use GPU for training. - num_dataloader_workers: Number of CPU processes doing dataloading - per GPU. If 0, then dataloading is done on main thread. - dataloader_mp_context: Determines how to launch - new processes for dataloading. Must be one of "fork", "forkserver", - "spawn". If None, process launching is inherited from parent. - """ - super().__init__( - use_gpu=use_gpu, - num_dataloader_workers=num_dataloader_workers, - dataloader_mp_context=dataloader_mp_context, - ) + def train(self, task): _init_env_vars() - _init_distributed(self.use_gpu) + _init_distributed(task.use_gpu) logging.info( f"Done setting up distributed process_group with rank {get_rank()}" + f", world_size {get_world_size()}" ) local_rank = int(os.environ["LOCAL_RANK"]) - if self.use_gpu: + if task.use_gpu: logging.info("Using GPU, CUDA device index: {}".format(local_rank)) set_cuda_device_index(local_rank) else: logging.info("Using CPU") set_cpu_device() + + super().train(task) diff --git a/classy_vision/trainer/local_trainer.py b/classy_vision/trainer/local_trainer.py index 4435381ffb..05be158943 100644 --- a/classy_vision/trainer/local_trainer.py +++ b/classy_vision/trainer/local_trainer.py @@ -16,32 +16,12 @@ class LocalTrainer(ClassyTrainer): """Trainer to be used if you want want use only a single training process. """ - def __init__( - self, - use_gpu: Optional[bool] = None, - num_dataloader_workers: int = 0, - dataloader_mp_context: Optional[str] = None, - ): - """Constructor for LocalTrainer. - - Args: - use_gpu: If true, then use GPU 0 for training. - If None, then check if we have GPUs available, if we do - then use GPU for training. - num_dataloader_workers: Number of CPU processes doing dataloading - per GPU. If 0, then dataloading is done on main thread. - dataloader_mp_context: Determines how to launch - new processes for dataloading. Must be one of "fork", "forkserver", - "spawn". If None, process launching is inherited from parent. - """ - super().__init__( - use_gpu=use_gpu, - num_dataloader_workers=num_dataloader_workers, - dataloader_mp_context=dataloader_mp_context, - ) - if self.use_gpu: + def train(self, task): + if task.use_gpu: logging.info("Using GPU, CUDA device index: {}".format(0)) set_cuda_device_index(0) else: logging.info("Using CPU") set_cpu_device() + + super().train(task) diff --git a/test/generic_util_test.py b/test/generic_util_test.py index dca4d76e75..a43d6b41ea 100644 --- a/test/generic_util_test.py +++ b/test/generic_util_test.py @@ -437,7 +437,7 @@ def test_update_classy_state(self): task = build_task(config) task_2 = build_task(config) task_2.prepare() - trainer = LocalTrainer(use_gpu=False) + trainer = LocalTrainer() trainer.train(task) update_classy_state(task_2, task.get_classy_state(deep_copy=True)) self._compare_states(task.get_classy_state(), task_2.get_classy_state()) @@ -449,13 +449,12 @@ def test_update_classy_model(self): """ config = get_fast_test_task_config() task = build_task(config) - use_gpu = torch.cuda.is_available() - trainer = LocalTrainer(use_gpu=use_gpu) + trainer = LocalTrainer() trainer.train(task) for reset_heads in [False, True]: task_2 = build_task(config) # prepare task_2 for the right device - task_2.prepare(use_gpu=use_gpu) + task_2.prepare() update_classy_model( task_2.model, task.model.get_classy_state(deep_copy=True), reset_heads ) diff --git a/test/hooks_checkpoint_hook_test.py b/test/hooks_checkpoint_hook_test.py index 828bd998ef..7996f2a221 100644 --- a/test/hooks_checkpoint_hook_test.py +++ b/test/hooks_checkpoint_hook_test.py @@ -155,7 +155,7 @@ def test_checkpointing(self): cuda_available = torch.cuda.is_available() task = build_task(config) - task.prepare(use_gpu=cuda_available) + task.prepare() # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) @@ -175,8 +175,8 @@ def test_checkpointing(self): # set the checkpoint task.set_checkpoint(checkpoint) - task.prepare(use_gpu=use_gpu) + task.set_use_gpu(use_gpu) # we should be able to run the trainer using the checkpoint - trainer = LocalTrainer(use_gpu=use_gpu) + trainer = LocalTrainer() trainer.train(task) diff --git a/test/manual/tasks_classification_task_amp_test.py b/test/manual/tasks_classification_task_amp_test.py index 80edc14b8b..ea245e8a6e 100644 --- a/test/manual/tasks_classification_task_amp_test.py +++ b/test/manual/tasks_classification_task_amp_test.py @@ -32,5 +32,6 @@ def test_training(self): config = get_fast_test_task_config() config["amp_args"] = {"opt_level": "O2"} task = build_task(config) - trainer = LocalTrainer(use_gpu=True) + task.set_use_gpu(True) + trainer = LocalTrainer() trainer.train(task) diff --git a/test/tasks_classification_task_test.py b/test/tasks_classification_task_test.py index 7580189e5f..8a90b366d9 100644 --- a/test/tasks_classification_task_test.py +++ b/test/tasks_classification_task_test.py @@ -65,10 +65,10 @@ def test_get_state(self): dataset = build_dataset(config["dataset"][phase_type]) task.set_dataset(dataset, phase_type) - task.prepare(num_dataloader_workers=1, pin_memory=False) + task.prepare(num_dataloader_workers=1) task = build_task(config) - task.prepare(num_dataloader_workers=1, pin_memory=False) + task.prepare(num_dataloader_workers=1) def test_checkpointing(self): """ @@ -79,10 +79,10 @@ def test_checkpointing(self): task = build_task(config).set_hooks([LossLrMeterLoggingHook()]) task_2 = build_task(config).set_hooks([LossLrMeterLoggingHook()]) - use_gpu = torch.cuda.is_available() + task.set_use_gpu(torch.cuda.is_available()) # prepare the tasks for the right device - task.prepare(use_gpu=use_gpu) + task.prepare() # test in both train and test mode for _ in range(2): @@ -90,7 +90,7 @@ def test_checkpointing(self): # set task's state as task_2's checkpoint task_2.set_checkpoint(get_checkpoint_dict(task, {}, deep_copy=True)) - task_2.prepare(use_gpu=use_gpu) + task_2.prepare() # task 2 should have the same state self._compare_states(task.get_classy_state(), task_2.get_classy_state()) @@ -102,8 +102,8 @@ def test_checkpointing(self): # test that the train step runs the same way on both states # and the loss remains the same - task.train_step(use_gpu) - task_2.train_step(use_gpu) + task.train_step() + task_2.train_step() self._compare_states(task.get_classy_state(), task_2.get_classy_state()) def test_final_train_checkpoint(self): @@ -115,9 +115,9 @@ def test_final_train_checkpoint(self): ) task_2 = build_task(config) - use_gpu = torch.cuda.is_available() + task.set_use_gpu(torch.cuda.is_available()) - trainer = LocalTrainer(use_gpu=use_gpu) + trainer = LocalTrainer() trainer.train(task) # load the final train checkpoint @@ -130,7 +130,7 @@ def test_final_train_checkpoint(self): # set task_2's state as task's final train checkpoint task_2.set_checkpoint(checkpoint) - task_2.prepare(use_gpu=use_gpu) + task_2.prepare() # we should be able to train the task trainer.train(task_2) @@ -148,20 +148,18 @@ def test_test_only_checkpointing(self): train_task = build_task(train_config).set_hooks([LossLrMeterLoggingHook()]) test_only_task = build_task(test_config).set_hooks([LossLrMeterLoggingHook()]) - use_gpu = torch.cuda.is_available() - # prepare the tasks for the right device - train_task.prepare(use_gpu=use_gpu) + train_task.prepare() # test in both train and test mode - trainer = LocalTrainer(use_gpu=use_gpu) + trainer = LocalTrainer() trainer.train(train_task) # set task's state as task_2's checkpoint test_only_task.set_checkpoint( get_checkpoint_dict(train_task, {}, deep_copy=True) ) - test_only_task.prepare(use_gpu=use_gpu) + test_only_task.prepare() test_state = test_only_task.get_classy_state() # We expect the phase idx to be different for a test only task @@ -177,7 +175,7 @@ def test_test_only_checkpointing(self): self.assertEqual(test_state["train_phase_idx"], -1) # Verify task will run - trainer = LocalTrainer(use_gpu=use_gpu) + trainer = LocalTrainer() trainer.train(test_only_task) @unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run") @@ -187,11 +185,13 @@ def test_checkpointing_different_device(self): task_2 = build_task(config) for use_gpu in [True, False]: - task.prepare(use_gpu=use_gpu) + task.set_use_gpu(use_gpu) + task.prepare() # set task's state as task_2's checkpoint task_2.set_checkpoint(get_checkpoint_dict(task, {}, deep_copy=True)) # we should be able to run the trainer using state from a different device - trainer = LocalTrainer(use_gpu=not use_gpu) + trainer = LocalTrainer() + task_2.set_use_gpu(not use_gpu) trainer.train(task_2) diff --git a/test/trainer_distributed_trainer_test.py b/test/trainer_distributed_trainer_test.py index 675a3c32e2..fb8727dfc2 100644 --- a/test/trainer_distributed_trainer_test.py +++ b/test/trainer_distributed_trainer_test.py @@ -44,7 +44,6 @@ def test_training(self): """Checks we can train a small MLP model.""" num_processes = 2 - device = "gpu" if torch.cuda.is_available() else "cpu" for config_key, expected_success in [ ("invalid_config", False), @@ -57,7 +56,6 @@ def test_training(self): --master_port=29500 \ --use_env \ {self.path}/../classy_train.py \ - --device={device} \ --config={self.config_files[config_key]} \ --num_workers=4 \ --log_freq=100 \ @@ -72,7 +70,6 @@ def test_sync_batch_norm(self): """Test that sync batch norm training doesn't hang.""" num_processes = 2 - device = "gpu" cmd = f"""{sys.executable} -m torch.distributed.launch \ --nnodes=1 \ @@ -81,7 +78,6 @@ def test_sync_batch_norm(self): --master_port=29500 \ --use_env \ {self.path}/../classy_train.py \ - --device={device} \ --config={self.config_files["sync_bn_config"]} \ --num_workers=4 \ --log_freq=100 \