diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index cffac25947..04d3d7b08d 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -7,11 +7,11 @@ import logging import time import os -from typing import Dict, TYPE_CHECKING, List, Optional, Any +from typing import Dict, List, Optional, Any import numpy as np from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have -from deepmd.env import tf +from deepmd.env import reset_default_tf_session_config from deepmd.infer.data_modifier import DipoleChargeModifier from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions from deepmd.train.trainer import DPTrainer @@ -21,118 +21,11 @@ from deepmd.utils.sess import run_sess from deepmd.utils.neighbor_stat import NeighborStat -if TYPE_CHECKING: - from deepmd.run_options import TFServerV1 - __all__ = ["train"] log = logging.getLogger(__name__) -def create_done_queue( - cluster_spec: tf.train.ClusterSpec, task_index: int -) -> tf.FIFOQueue: - """Create FIFO queue for distributed tasks. - - Parameters - ---------- - cluster_spec : tf.train.ClusterSpec - tf cluster specification object - task_index : int - identifying index of a task - - Returns - ------- - tf.FIFOQueue - tf distributed FIFI queue - """ - with tf.device(f"/job:ps/task:{task_index:d}"): - queue = tf.FIFOQueue( - cluster_spec.num_tasks("worker"), - tf.int32, - shared_name=f"done_queue{task_index}", - ) - return queue - - -def wait_done_queue( - cluster_spec: tf.train.ClusterSpec, - server: "TFServerV1", - queue: tf.FIFOQueue, - task_index: int, -): - """Wait until all enqued operation in tf distributed queue are finished. - - Parameters - ---------- - cluster_spec : tf.train.ClusterSpec - tf cluster specification object - server : TFServerV1 - tf server specification object - queue : tf.FIFOQueue - tf distributed queue - task_index : int - identifying index of a task - """ - with tf.Session(server.target) as sess: - for i in range(cluster_spec.num_tasks("worker")): - run_sess(sess, queue.dequeue()) - log.debug(f"ps:{task_index:d} received done from worker:{i:d}") - log.debug(f"ps:{task_index:f} quitting") - - -def connect_done_queue( - cluster_spec: tf.train.ClusterSpec, task_index: int -) -> List[tf.Operation]: - """Create tf FIFO queue filling operations. - - Parameters - ---------- - cluster_spec : tf.train.ClusterSpec - tf cluster specification object - task_index : int - identifying index of a task - - Returns - ------- - List[tf.Operation] - list of tf operations that will populate the queue - """ - done_ops = [] - for i in range(cluster_spec.num_tasks("ps")): - with tf.device(f"/job:ps/task:{i:d}"): - queue = tf.FIFOQueue( - cluster_spec.num_tasks("worker"), tf.int32, shared_name=f"done_queue{i}" - ) - done_ops.append(queue.enqueue(task_index)) - return done_ops - - -def fill_done_queue( - cluster_spec: tf.train.ClusterSpec, - server: "TFServerV1", - done_ops: List[tf.Operation], - task_index: int, -): - """Run specified operations that will fill the tf distributed FIFO queue. - - Parameters - ---------- - cluster_spec : tf.train.ClusterSpec - tf cluster specification object - server : TFServerV1 - tf server specification object - done_ops : List[tf.Operation] - a list of tf operations that will fill the queue - task_index : int - identifying index of a task - """ - with tf.Session(server.target) as sess: - for i in range(cluster_spec.num_tasks("ps")): - run_sess(sess, done_ops[i]) - log.debug(f"worker:{task_index:d} sending done to ps:{i:d}") - - def train( *, INPUT: str, @@ -186,34 +79,14 @@ def train( restart=restart, log_path=log_path, log_level=log_level, - mpi_log=mpi_log, - try_distrib=jdata.get("with_distrib", False), + mpi_log=mpi_log ) for message in WELCOME + CITATION + BUILD: log.info(message) run_opt.print_resource_summary() - - if run_opt.is_distrib: - # distributed training - if run_opt.my_job_name == "ps": - queue = create_done_queue(run_opt.cluster_spec, run_opt.my_task_index) - wait_done_queue( - run_opt.cluster_spec, run_opt.server, queue, run_opt.my_task_index - ) - # server.join() - elif run_opt.my_job_name == "worker": - done_ops = connect_done_queue(run_opt.cluster_spec, run_opt.my_task_index) - _do_work(jdata, run_opt) - fill_done_queue( - run_opt.cluster_spec, run_opt.server, done_ops, run_opt.my_task_index - ) - else: - raise RuntimeError("unknown job name") - else: - # serial training - _do_work(jdata, run_opt) + _do_work(jdata, run_opt) def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): @@ -234,6 +107,10 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions): # make necessary checks assert "training" in jdata + # avoid conflict of visible gpus among multipe tf sessions in one process + if run_opt.is_distrib and len(run_opt.gpus or []) > 1: + reset_default_tf_session_config(cpu_only=True) + # init the model model = DPTrainer(jdata, run_opt=run_opt) rcut = model.model.get_rcut() diff --git a/deepmd/env.py b/deepmd/env.py index 4e03aa4f0b..925976a09b 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -117,14 +117,31 @@ def get_tf_session_config() -> Any: """ set_tf_default_nthreads() intra, inter = get_tf_default_nthreads() - return tf.ConfigProto( + config = tf.ConfigProto( intra_op_parallelism_threads=intra, inter_op_parallelism_threads=inter ) + return config default_tf_session_config = get_tf_session_config() +def reset_default_tf_session_config(cpu_only: bool): + """Limit tensorflow session to CPU or not. + + Parameters + ---------- + cpu_only : bool + If enabled, no GPU device is visible to the TensorFlow Session. + """ + global default_tf_session_config + if cpu_only: + default_tf_session_config.device_count['GPU'] = 0 + else: + if 'GPU' in default_tf_session_config.device_count: + del default_tf_session_config.device_count['GPU'] + + def get_module(module_name: str) -> "ModuleType": """Load force module. diff --git a/deepmd/loggers/loggers.py b/deepmd/loggers/loggers.py index f787ff1e1a..3bb9e9fa4c 100644 --- a/deepmd/loggers/loggers.py +++ b/deepmd/loggers/loggers.py @@ -137,8 +137,7 @@ def setStream(self, stream): def set_log_handles( level: int, log_path: Optional["Path"] = None, - mpi_log: Optional[str] = None, - MPI: Optional["MPI"] = None, + mpi_log: Optional[str] = None ): """Set desired level for package loggers and add file handlers. @@ -154,16 +153,13 @@ def set_log_handles( only from rank==0. `collect` will write messages from all ranks to one file opened under rank==0 and to console. `workers` will open one log file for each worker designated by its rank, console behaviour is the same as for `collect`. - If this argument is specified than also `MPI` object must be passed in. - by default None - MPI : Optional[MPI, optional] - `MPI` communicator object, must be specified if `mpi_log` is specified, + If this argument is specified, package 'mpi4py' must be already installed. by default None Raises ------ RuntimeError - if only one of the arguments `mpi_log`, `MPI` is specified + If the argument `mpi_log` is specified, package `mpi4py` is not installed. References ---------- @@ -204,8 +200,12 @@ def set_log_handles( root_log.removeHandler(hdlr) # check if arguments are present - if (mpi_log and not MPI) or (not mpi_log and MPI): - raise RuntimeError("You cannot specify only one of 'mpi_log', 'MPI' arguments") + MPI = None + if mpi_log: + try: + from mpi4py import MPI + except ImportError as e: + raise RuntimeError("You cannot specify 'mpi_log' when mpi4py not installed") from e # * add console handler ************************************************************ ch = logging.StreamHandler() diff --git a/deepmd/train/run_options.py b/deepmd/train/run_options.py index 25029c4308..1a1145817a 100644 --- a/deepmd/train/run_options.py +++ b/deepmd/train/run_options.py @@ -11,18 +11,7 @@ from deepmd.loggers import set_log_handles if TYPE_CHECKING: - from mpi4py import MPI - - try: - from typing import Protocol # python >=3.8 - except ImportError: - from typing_extensions import Protocol # type: ignore - - class TFServerV1(Protocol): - """Prococol mimicking parser object.""" - - server_def: tf.train.ServerDef - target: str + import horovod.tensorflow as HVD __all__ = [ @@ -63,80 +52,47 @@ class TFServerV1(Protocol): ) -def _is_distributed(MPI: "MPI") -> bool: +def _is_distributed(HVD: "HVD") -> bool: """Check if there are more than one MPI processes. Parameters ---------- - MPI : MPI - MPI object + HVD : HVD + Horovod object Returns ------- bool True if we have more than 1 MPI process """ - return MPI.COMM_WORLD.Get_size() > 1 + return HVD.size() > 1 def _distributed_task_config( - MPI: "MPI", - node_name: str, - node_list_: List[str], - gpu_list: Optional[List[int]] = None, - default_port: int = 2222, -) -> Tuple[Dict[str, List[str]], str, int, str, str]: + HVD: "HVD", + gpu_list: Optional[List[int]] = None +) -> Tuple[int, int, str]: """Create configuration for distributed tensorflow session. Parameters ---------- - MPI : mpi4py.MPI - MPI module - node_name : str - the name of current node - node_list_ : List[str] - the list of nodes of the current mpirun + HVD : horovod.tensorflow + Horovod TensorFlow module gpu_list : Optional[List[int]], optional the list of GPUs on each node, by default None - default_port : int, optional - the default port for socket communication, by default 2222 Returns ------- - Tuple[Dict[str, List[str]], str, int, str, str] - cluster specification, job name of this task, index of this task, - hostname:port socket of this task, the device for this task + Tuple[int, int, str] + task count, index of this task, the device for this task """ - # setup cluster - node_list = list(set(node_list_)) - node_list.sort() - node_color = node_list.index(node_name) - world_idx = MPI.COMM_WORLD.Get_rank() - node_comm = MPI.COMM_WORLD.Split(node_color, world_idx) - node_task_idx = node_comm.Get_rank() - node_numb_task = node_comm.Get_size() - - socket_list = [] - for ii in node_list: - for jj in range(node_numb_task): - socket_list.append(f"{ii}:{default_port + jj}") - ps_map = socket_list[0:1] - worker_map = socket_list[1:] - - if node_color == 0 and node_task_idx == 0: - my_job = "ps" - my_socket = ps_map[0] - my_task_idx = ps_map.index(my_socket) - else: - my_job = "worker" - my_socket = f"{node_name}:{default_port - node_task_idx}" - assert my_socket in worker_map - my_task_idx = worker_map.index(my_socket) + my_rank = HVD.rank() + world_size = HVD.size() # setup gpu/cpu devices if gpu_list is not None: numb_gpu = len(gpu_list) - gpu_idx = node_numb_task - node_task_idx - 1 + gpu_idx = HVD.local_rank() if gpu_idx >= numb_gpu: my_device = "cpu:0" # "cpu:%d" % node_task_idx else: @@ -144,8 +100,7 @@ def _distributed_task_config( else: my_device = "cpu:0" # "cpu:%d" % node_task_idx - cluster = {"worker": worker_map, "ps": ps_map} - return cluster, my_job, my_task_idx, my_socket, my_device + return world_size, my_rank, my_device class RunOptions: @@ -153,47 +108,31 @@ class RunOptions: Attributes ---------- - cluster: Optional[Dict[str, List[str]]] - cluster informations as dict - cluster_spec: Optional[tf.train.ClusterSpec] - `tf.train.ClusterSpec` or None if training is serial gpus: Optional[List[int]] list of GPUs if any are present else None is_chief: bool in distribured training it is true for tha main MPI process in serail it is always true - my_job_name: str - name of the training job - my_socket: Optional[str] - communication socket for distributed training - my_task_index: int + world_size: int + total worker count + my_rank: int index of the MPI task nodename: str name of the node - num_ps: Optional[int] - number of ps - num_workers: Optional[int] - number of workers - server: Optional[tf.train.Server] - `tf.train.Server` or `None` for serial training + node_list_ : List[str] + the list of nodes of the current mpirun my_device: str deviice type - gpu or cpu """ - cluster: Optional[Dict[str, List[str]]] - cluster_spec: Optional[tf.train.ClusterSpec] gpus: Optional[List[int]] - is_chief: bool - my_job_name: str - my_socket: Optional[str] - my_task_index: int + world_size: int + my_rank: int nodename: str - num_ps: Optional[int] - num_workers: Optional[int] - server: Optional["TFServerV1"] + nodelist: List[int] my_device: str - _MPI: Optional["MPI"] + _HVD: Optional["HVD"] _log_handles_already_set: bool = False def __init__( @@ -202,15 +141,9 @@ def __init__( restart: Optional[str] = None, log_path: Optional[str] = None, log_level: int = 0, - mpi_log: str = "master", - try_distrib: bool = False + mpi_log: str = "master" ): - # distributed tasks - if try_distrib: - self._try_init_mpi() - else: - self.is_distrib = False - self._init_serial() + self._try_init_distrib() if all((init_model, restart)): raise RuntimeError( @@ -231,16 +164,20 @@ def __init__( self._setup_logger(Path(log_path) if log_path else None, log_level, mpi_log) + @property + def is_chief(self): + """Whether my rank is 0.""" + return self.my_rank == 0 + def print_resource_summary(self): """Print build and current running cluster configuration summary.""" log.info("---Summary of the training---------------------------------------") if self.is_distrib: log.info("distributed") - log.info(f"ps list: {self.cluster['ps']}") - log.info(f"worker list: {self.cluster['worker']}") - log.info(f"chief on: {self.nodename}") - else: - log.info(f"running on: {self.nodename}") + log.info(f"world size: {self.world_size}") + log.info(f"my rank: {self.my_rank}") + log.info(f"node list: {self.nodelist}") + log.info(f"running on: {self.nodename}") if self.gpus is None: log.info(f"CUDA_VISIBLE_DEVICES: unset") else: @@ -270,84 +207,68 @@ def _setup_logger( console only from rank==0. `collect` will write messages from all ranks to one file opened under rank==0 and to console. `workers` will open one log file for each worker designated by its rank, console behaviour is the same - as for `collect`. If this argument is specified than also `MPI` object must - be passed in. by default None + as for `collect`. """ if not self._log_handles_already_set: - if not self._MPI: + if not self._HVD: mpi_log = None - set_log_handles(log_level, log_path, mpi_log=mpi_log, MPI=self._MPI) + set_log_handles(log_level, log_path, mpi_log=mpi_log) self._log_handles_already_set = True log.debug("Log handles were successfully set") else: log.warning( f"Log handles have already been set. It is not advisable to " - f"reset them{', especially when runnig with MPI!' if self._MPI else ''}" + f"reset them{', especially when runnig with MPI!' if self._HVD else ''}" ) - def _try_init_mpi(self): + def _try_init_distrib(self): try: - from mpi4py import MPI + import horovod.tensorflow as HVD + HVD.init() + self.is_distrib = _is_distributed(HVD) except ImportError: - raise RuntimeError( - "cannot import mpi4py module, cannot do distributed simulation" - ) + log.warning("Switch to serial execution due to lack of horovod module.") + self.is_distrib = False + + # Do real intialization + if self.is_distrib: + self._init_distributed(HVD) + self._HVD = HVD else: - self.is_distrib = _is_distributed(MPI) - if self.is_distrib: - self._init_distributed(MPI) - self._MPI = MPI - else: - self._init_serial() - self._MPI = None - - def _init_distributed(self, MPI: "MPI"): + self._init_serial() + self._HVD = None + + def _init_distributed(self, HVD: "HVD"): """Initialize settings for distributed training. Parameters ---------- - MPI : MPI - MPI object + HVD : HVD + horovod object """ nodename, nodelist, gpus = get_resource() self.nodename = nodename + self.nodelist = nodelist self.gpus = gpus ( - self.cluster, - self.my_job_name, - self.my_task_index, - self.my_socket, + self.world_size, + self.my_rank, self.my_device, - ) = _distributed_task_config(MPI, nodename, nodelist, gpus) - self.is_chief = self.my_job_name == "worker" and self.my_task_index == 0 - self.num_ps = len(self.cluster["ps"]) - self.num_workers = len(self.cluster["worker"]) - self.cluster_spec = tf.train.ClusterSpec(self.cluster) - self.server = tf.train.Server( - server_or_cluster_def=self.cluster_spec, - job_name=self.my_job_name, - task_index=self.my_task_index, - ) + ) = _distributed_task_config(HVD, gpus) def _init_serial(self): """Initialize setting for serial training.""" nodename, _, gpus = get_resource() - self.cluster = None - self.cluster_spec = None self.gpus = gpus - self.is_chief = True - self.my_job_name = nodename - self.my_socket = None - self.my_task_index = 0 + self.world_size = 1 + self.my_rank = 0 self.nodename = nodename - self.num_ps = None - self.num_workers = None - self.server = None + self.nodelist = [nodename] if gpus is not None: - self.my_device = "gpu:" + str(gpus[0]) + self.my_device = "gpu:0" else: self.my_device = "cpu:0" - self._MPI = None + self._HVD = None diff --git a/deepmd/train/trainer.py b/deepmd/train/trainer.py index 8f283b61cf..4526c2d469 100644 --- a/deepmd/train/trainer.py +++ b/deepmd/train/trainer.py @@ -6,7 +6,7 @@ import google.protobuf.message import numpy as np from deepmd.env import tf -from deepmd.env import default_tf_session_config +from deepmd.env import get_tf_session_config from deepmd.env import GLOBAL_TF_FLOAT_PRECISION from deepmd.env import GLOBAL_ENER_FLOAT_PRECISION from deepmd.fit import EnerFitting, WFCFitting, PolarFittingLocFrame, PolarFittingSeA, GlobalPolarFittingSeA, DipoleFittingSeA @@ -261,9 +261,9 @@ def _init_param(self, jdata): self.save_ckpt = tr_data.get('save_ckpt', 'model.ckpt') self.display_in_training = tr_data.get('disp_training', True) self.timing_in_training = tr_data.get('time_training', True) - self.profiling = tr_data.get('profiling', False) + self.profiling = self.run_opt.is_chief and tr_data.get('profiling', False) self.profiling_file = tr_data.get('profiling_file', 'timeline.json') - self.tensorboard = tr_data.get('tensorboard', False) + self.tensorboard = self.run_opt.is_chief and tr_data.get('tensorboard', False) self.tensorboard_log_dir = tr_data.get('tensorboard_log_dir', 'log') # self.sys_probs = tr_data['sys_probs'] # self.auto_prob_style = tr_data['auto_prob'] @@ -308,15 +308,9 @@ def build (self, = self.neighbor_stat.get_stat(data) self.descrpt.enable_compression(self.min_nbor_dist, self.model_param['compress']['model_file'], self.model_param['compress']['table_config'][0], self.model_param['compress']['table_config'][1], self.model_param['compress']['table_config'][2], self.model_param['compress']['table_config'][3]) - worker_device = "/job:%s/task:%d/%s" % (self.run_opt.my_job_name, - self.run_opt.my_task_index, - self.run_opt.my_device) - - with tf.device(tf.train.replica_device_setter(worker_device = worker_device, - cluster = self.run_opt.cluster_spec)): - self._build_lr() - self._build_network(data) - self._build_training() + self._build_lr() + self._build_network(data) + self._build_training() def _build_lr(self): @@ -362,14 +356,11 @@ def _build_network(self, data): def _build_training(self): trainable_variables = tf.trainable_variables() - optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate) - if self.run_opt.is_distrib : - optimizer = tf.train.SyncReplicasOptimizer( - optimizer, - replicas_to_aggregate = self.run_opt.cluster_spec.num_tasks("worker"), - total_num_replicas = self.run_opt.cluster_spec.num_tasks("worker"), - name = "sync_replicas") - self.sync_replicas_hook = optimizer.make_session_run_hook(self.run_opt.is_chief) + if self.run_opt.is_distrib: + optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate*self.run_opt.world_size) + optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer) + else: + optimizer = tf.train.AdamOptimizer(learning_rate = self.learning_rate) grads = tf.gradients(self.l2_l, trainable_variables) apply_op = optimizer.apply_gradients (zip (grads, trainable_variables), global_step=self.global_step, @@ -378,76 +369,48 @@ def _build_training(self): self.train_op = tf.group(*train_ops) log.info("built training") - def _init_sess_serial(self) : - self.sess = tf.Session(config=default_tf_session_config) - self.saver = tf.train.Saver() - saver = self.saver - if self.run_opt.init_mode == 'init_from_scratch' : - log.info("initialize model from scratch") - init_op = tf.global_variables_initializer() - run_sess(self.sess, init_op) - fp = open(self.disp_file, "w") - fp.close () - elif self.run_opt.init_mode == 'init_from_model' : - log.info("initialize from model %s" % self.run_opt.init_model) - init_op = tf.global_variables_initializer() - run_sess(self.sess, init_op) - saver.restore (self.sess, self.run_opt.init_model) - run_sess(self.sess, self.global_step.assign(0)) - fp = open(self.disp_file, "w") - fp.close () - elif self.run_opt.init_mode == 'restart' : - log.info("restart from model %s" % self.run_opt.restart) - init_op = tf.global_variables_initializer() - run_sess(self.sess, init_op) - saver.restore (self.sess, self.run_opt.restart) - else : - raise RuntimeError ("unkown init mode") - - def _init_sess_distrib(self): - ckpt_dir = os.path.join(os.getcwd(), self.save_ckpt) - assert(_is_subdir(ckpt_dir, os.getcwd())), "the checkpoint dir must be a subdir of the current dir" - if self.run_opt.init_mode == 'init_from_scratch' : - log.info("initialize model from scratch") - if self.run_opt.is_chief : - if os.path.exists(ckpt_dir): - shutil.rmtree(ckpt_dir) - if not os.path.exists(ckpt_dir) : - os.makedirs(ckpt_dir) + def _init_session(self): + config = get_tf_session_config() + device, idx = self.run_opt.my_device.split(":", 1) + if device == "gpu": + config.gpu_options.allow_growth = True + config.gpu_options.visible_device_list = idx + self.sess = tf.Session(config=config) + + # Initializes or restore global variables + init_op = tf.global_variables_initializer() + if self.run_opt.is_chief: + self.saver = tf.train.Saver() + if self.run_opt.init_mode == 'init_from_scratch' : + log.info("initialize model from scratch") + run_sess(self.sess, init_op) fp = open(self.disp_file, "w") fp.close () - elif self.run_opt.init_mode == 'init_from_model' : - raise RuntimeError("distributed training does not support %s" % self.run_opt.init_mode) - elif self.run_opt.init_mode == 'restart' : - log.info("restart from model %s" % ckpt_dir) - if self.run_opt.is_chief : - assert(os.path.isdir(ckpt_dir)), "the checkpoint dir %s should exists" % ckpt_dir - else : - raise RuntimeError ("unkown init mode") - - saver = tf.train.Saver(max_to_keep = 1) - self.saver = None - # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5) - # config = tf.ConfigProto(allow_soft_placement=True, - # gpu_options = gpu_options, - # intra_op_parallelism_threads=self.run_opt.num_intra_threads, - # inter_op_parallelism_threads=self.run_opt.num_inter_threads) - config = tf.ConfigProto(intra_op_parallelism_threads=self.run_opt.num_intra_threads, - inter_op_parallelism_threads=self.run_opt.num_inter_threads) - # The stop_hook handles stopping after running given steps - # stop_hook = tf.train.StopAtStepHook(last_step = stop_batch) - # hooks = [self.sync_replicas_hook, stop_hook] - hooks = [self.sync_replicas_hook] - scaffold = tf.train.Scaffold(saver=saver) - # Use monitor session for distributed computation - self.sess = tf.train.MonitoredTrainingSession(master = self.run_opt.server.target, - is_chief = self.run_opt.is_chief, - config = config, - hooks = hooks, - scaffold = scaffold, - checkpoint_dir = ckpt_dir) - # , - # save_checkpoint_steps = self.save_freq) + elif self.run_opt.init_mode == 'init_from_model' : + log.info("initialize from model %s" % self.run_opt.init_model) + run_sess(self.sess, init_op) + self.saver.restore (self.sess, self.run_opt.init_model) + run_sess(self.sess, self.global_step.assign(0)) + fp = open(self.disp_file, "w") + fp.close () + elif self.run_opt.init_mode == 'restart' : + log.info("restart from model %s" % self.run_opt.restart) + run_sess(self.sess, init_op) + self.saver.restore (self.sess, self.run_opt.restart) + else : + raise RuntimeError ("unkown init mode") + else: + run_sess(self.sess, init_op) + self.saver = None + + # Ensure variable consistency among tasks when training starts + if self.run_opt.is_distrib: + bcast_op = self.run_opt._HVD.broadcast_global_variables(0) + if self.run_opt.is_chief: + log.info('broadcast global variables to other tasks') + else: + log.info('receive global variables from task#0') + run_sess(self.sess, bcast_op) def train (self, train_data, valid_data=None) : @@ -455,11 +418,9 @@ def train (self, train_data, valid_data=None) : # valid_data = train_data # using training set as validation set. stop_batch = self.stop_batch - if self.run_opt.is_distrib : - self._init_sess_distrib() - else : - self._init_sess_serial() + self._init_session() + # Before data shard is enabled, only cheif do evaluation and record it # self.print_head() fp = None if self.run_opt.is_chief : @@ -478,12 +439,12 @@ def train (self, train_data, valid_data=None) : prf_options = None prf_run_metadata = None - if self.profiling : + if self.profiling: prf_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) prf_run_metadata = tf.RunMetadata() # set tensorboard execution environment - if self.tensorboard : + if self.tensorboard: summary_merged_op = tf.summary.merge_all() # Remove TB old logging directory from previous run try: @@ -510,8 +471,9 @@ def train (self, train_data, valid_data=None) : # first round validation: train_batch = train_data.get_batch() if self.display_in_training and is_first_step: - valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None - self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True) + if self.run_opt.is_chief: + valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batches, print_header=True) is_first_step = False if self.timing_in_training: tic = time.time() @@ -534,25 +496,25 @@ def train (self, train_data, valid_data=None) : if self.display_in_training and (cur_batch % self.disp_freq == 0): if self.timing_in_training: tic = time.time() - valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None - self.valid_on_the_fly(fp, [train_batch], valid_batches) + if self.run_opt.is_chief: + valid_batches = [valid_data.get_batch() for ii in range(self.valid_numb_batch)] if valid_data is not None else None + self.valid_on_the_fly(fp, [train_batch], valid_batches) if self.timing_in_training: toc = time.time() test_time = toc - tic log.info("batch %7d training time %.2f s, testing time %.2f s" % (cur_batch, train_time, test_time)) train_time = 0 - if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.run_opt.is_chief : - if self.saver is not None : - try: - self.saver.save (self.sess, os.getcwd() + "/" + self.save_ckpt) - except google.protobuf.message.DecodeError as e: - raise GraphTooLargeError( - "The graph size exceeds 2 GB, the hard limitation of protobuf." - " Then a DecodeError was raised by protobuf. You should " - "reduce the size of your model." - ) from e - log.info("saved checkpoint %s" % self.save_ckpt) + if self.save_freq > 0 and cur_batch % self.save_freq == 0 and self.saver is not None: + try: + self.saver.save (self.sess, os.getcwd() + "/" + self.save_ckpt) + except google.protobuf.message.DecodeError as e: + raise GraphTooLargeError( + "The graph size exceeds 2 GB, the hard limitation of protobuf." + " Then a DecodeError was raised by protobuf. You should " + "reduce the size of your model." + ) from e + log.info("saved checkpoint %s" % self.save_ckpt) if self.run_opt.is_chief: fp.close () if self.profiling and self.run_opt.is_chief : diff --git a/deepmd/utils/compat.py b/deepmd/utils/compat.py index 861a00439c..e3fd0c3177 100644 --- a/deepmd/utils/compat.py +++ b/deepmd/utils/compat.py @@ -27,8 +27,6 @@ def convert_input_v0_v1( """ output = {} - if "with_distrib" in jdata: - output["with_distrib"] = jdata["with_distrib"] output["model"] = _model(jdata, jdata["use_smooth"]) output["learning_rate"] = _learning_rate(jdata) output["loss"] = _loss(jdata) diff --git a/doc/getting-started.md b/doc/getting-started.md index f167818fdd..7b028d7165 100644 --- a/doc/getting-started.md +++ b/doc/getting-started.md @@ -5,6 +5,7 @@ In this text, we will call the deep neural network that is used to represent the 2. [Train a model](#train-a-model) - [Write the input script](#write-the-input-script) - [Training](#training) + - [Parallel training](#parallel-training) - [Training analysis with Tensorboard](#training-analysis-with-tensorboard) 3. [Freeze a model](#freeze-a-model) 4. [Test a model](#test-a-model) @@ -140,6 +141,51 @@ One can set other environmental variables: | --------------------- | ---------------------- | ------------- | -------------------------- | | DP_INTERFACE_PREC | `high`, `low` | `high` | Control high (double) or low (float) precision of training. | + +### Parallel training + +Currently, parallel training is enabled in a sychoronized way with help of [Horovod](https://github.com/horovod/horovod). DeePMD-kit will decide parallel training or not according to MPI context. Thus, there is no difference in your json/yaml input file. + +Testing `examples/water/se_e2_a` on a 8-GPU host, linear acceleration can be observed with increasing number of cards. +| Num of GPU cards | Seconds every 100 samples | Samples per second | Speed up | +| -- | -- | -- | -- | +| 1 | 1.6116 | 62.05 | 1.00 | +| 2 | 1.6310 | 61.31 | 1.98 | +| 4 | 1.6168 | 61.85 | 3.99 | +| 8 | 1.6212 | 61.68 | 7.95 | + +To experience this powerful feature, please intall Horovod and [mpi4py](https://github.com/mpi4py/mpi4py) first. For better performance on GPU, please follow tuning steps in [Horovod on GPU](https://github.com/horovod/horovod/blob/master/docs/gpus.rst). +```bash +# By default, MPI is used as communicator. +HOROVOD_WITHOUT_GLOO=1 HOROVOD_WITH_TENSORFLOW=1 pip install horovod mpi4py +``` + +Horovod works in the data-parallel mode resulting a larger global batch size. For example, the real batch size is 8 when `batch_size` is set to 2 in the input file and you lauch 4 workers. Thus, `learning_rate` is automatically scaled by the number of workers for better convergence. Technical details of such heuristic rule are discussed at [Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour](https://arxiv.org/abs/1706.02677). + +With dependencies installed, have a quick try! +```bash +# Launch 4 processes on the same host +CUDA_VISIBLE_DEVICES=4,5,6,7 horovodrun -np 4 \ + dp train --mpi-log=workers input.json +``` + +Need to mention, environment variable `CUDA_VISIBLE_DEVICES` must be set to control parallelism on the occupied host where one process is bound to one GPU card. + +What's more, 2 command-line arguments are defined to control the logging behvaior. +``` +optional arguments: + -l LOG_PATH, --log-path LOG_PATH + set log file to log messages to disk, if not + specified, the logs will only be output to console + (default: None) + -m {master,collect,workers}, --mpi-log {master,collect,workers} + Set the manner of logging when running with MPI. + 'master' logs only on main process, 'collect' + broadcasts logs from workers to master and 'workers' + means each process will output its own log (default: + master) +``` + ### Training analysis with Tensorboard If enbled in json/yaml input file DeePMD-kit will create log files which can be diff --git a/source/tests/compat_inputs/water_v0.json b/source/tests/compat_inputs/water_v0.json index 88f868ff47..70eedcf72b 100644 --- a/source/tests/compat_inputs/water_v0.json +++ b/source/tests/compat_inputs/water_v0.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "_comment": " model parameters", "use_smooth": false, "sel_a": [16, 32], diff --git a/source/tests/compat_inputs/water_v1.json b/source/tests/compat_inputs/water_v1.json index e5f2032ea2..e8b1d8a196 100644 --- a/source/tests/compat_inputs/water_v1.json +++ b/source/tests/compat_inputs/water_v1.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "model":{ "descriptor": { "type": "loc_frame", diff --git a/source/tests/compat_inputs/water_v2.json b/source/tests/compat_inputs/water_v2.json index e49add4467..0bb1281f55 100644 --- a/source/tests/compat_inputs/water_v2.json +++ b/source/tests/compat_inputs/water_v2.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "model":{ "descriptor": { "type": "loc_frame", diff --git a/source/tests/data_modifier/dipole.json b/source/tests/data_modifier/dipole.json index 9e968ba98c..5bd8b505f4 100644 --- a/source/tests/data_modifier/dipole.json +++ b/source/tests/data_modifier/dipole.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "_comment": " model parameters", "model":{ "type_map": ["O", "H"], diff --git a/source/tests/polar_se_a.json b/source/tests/polar_se_a.json index 5e831e19d8..7b3362dbe7 100644 --- a/source/tests/polar_se_a.json +++ b/source/tests/polar_se_a.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "_comment": " model parameters", "model":{ "type": "polar", diff --git a/source/tests/test_data_modifier.py b/source/tests/test_data_modifier.py index 829a589d7e..977df9a2b6 100644 --- a/source/tests/test_data_modifier.py +++ b/source/tests/test_data_modifier.py @@ -44,8 +44,7 @@ def _setUp(self): init_model=None, log_path=None, log_level=30, - mpi_log="master", - try_distrib=False + mpi_log="master" ) jdata = j_loader(INPUT) diff --git a/source/tests/test_data_modifier_shuffle.py b/source/tests/test_data_modifier_shuffle.py index bd4ab58132..c14b6dd105 100644 --- a/source/tests/test_data_modifier_shuffle.py +++ b/source/tests/test_data_modifier_shuffle.py @@ -49,8 +49,7 @@ def _setUp(self): init_model=None, log_path=None, log_level=30, - mpi_log="master", - try_distrib=False + mpi_log="master" ) jdata = self._setUp_jdata() self._setUp_data() diff --git a/source/tests/water.json b/source/tests/water.json index b4817fecf0..f4909a0971 100644 --- a/source/tests/water.json +++ b/source/tests/water.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "_comment": " model parameters", "model" :{ "descriptor":{ diff --git a/source/tests/wfc.json b/source/tests/wfc.json index 556ef2a992..ab2ba7fc99 100644 --- a/source/tests/wfc.json +++ b/source/tests/wfc.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "_comment": " model parameters", "model":{ "type": "polar", diff --git a/source/tests/yaml_inputs/water_v1.json b/source/tests/yaml_inputs/water_v1.json index e5f2032ea2..e8b1d8a196 100644 --- a/source/tests/yaml_inputs/water_v1.json +++ b/source/tests/yaml_inputs/water_v1.json @@ -1,5 +1,4 @@ { - "with_distrib": false, "model":{ "descriptor": { "type": "loc_frame", diff --git a/source/tests/yaml_inputs/water_v1.yaml b/source/tests/yaml_inputs/water_v1.yaml index 5121a961b0..9ddbb89f9c 100644 --- a/source/tests/yaml_inputs/water_v1.yaml +++ b/source/tests/yaml_inputs/water_v1.yaml @@ -1,4 +1,3 @@ -with_distrib: false model: descriptor: type: loc_frame