From 5b7fcdc71f2c254eb09ea1bedaa931d0997331cd Mon Sep 17 00:00:00 2001 From: Jeff Rasley Date: Tue, 19 Oct 2021 19:55:20 +0000 Subject: [PATCH] change device to local_rank and simplify fire --- Intro101/train_bert.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/Intro101/train_bert.py b/Intro101/train_bert.py index 349a6ad4d..67c07c60b 100644 --- a/Intro101/train_bert.py +++ b/Intro101/train_bert.py @@ -576,7 +576,7 @@ def train( num_iterations: int = 10000, checkpoint_every: int = 1000, log_every: int = 10, - device: int = -1, + local_rank: int = -1, ) -> pathlib.Path: """Trains a [Bert style](https://arxiv.org/pdf/1810.04805.pdf) (transformer encoder only) model for MLM Task @@ -625,7 +625,7 @@ def train( Defaults to 1000. log_every (int, optional): Print logs after these many steps. Defaults to 10. - device (int, optional): + local_rank (int, optional): Which GPU to run on (-1 for CPU). Defaults to -1. Returns: @@ -633,8 +633,8 @@ def train( """ device = ( - torch.device("cuda", device) - if (device > -1) and torch.cuda.is_available() + torch.device("cuda", local_rank) + if (local_rank > -1) and torch.cuda.is_available() else torch.device("cpu") ) ################################ @@ -787,4 +787,4 @@ def train( if __name__ == "__main__": - fire.Fire({"train": train, "data": create_data_iterator, "model": create_model}) + fire.Fire(train)