Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 8 additions & 131 deletions deepmd/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import logging
import time
import os
from typing import Dict, TYPE_CHECKING, List, Optional, Any
from typing import Dict, List, Optional, Any

import numpy as np
from deepmd.common import data_requirement, expand_sys_str, j_loader, j_must_have
from deepmd.env import tf
from deepmd.env import reset_default_tf_session_config
from deepmd.infer.data_modifier import DipoleChargeModifier
from deepmd.train.run_options import BUILD, CITATION, WELCOME, RunOptions
from deepmd.train.trainer import DPTrainer
Expand All @@ -21,118 +21,11 @@
from deepmd.utils.sess import run_sess
from deepmd.utils.neighbor_stat import NeighborStat

if TYPE_CHECKING:
from deepmd.run_options import TFServerV1

__all__ = ["train"]

log = logging.getLogger(__name__)


def create_done_queue(
cluster_spec: tf.train.ClusterSpec, task_index: int
) -> tf.FIFOQueue:
"""Create FIFO queue for distributed tasks.

Parameters
----------
cluster_spec : tf.train.ClusterSpec
tf cluster specification object
task_index : int
identifying index of a task

Returns
-------
tf.FIFOQueue
tf distributed FIFI queue
"""
with tf.device(f"/job:ps/task:{task_index:d}"):
queue = tf.FIFOQueue(
cluster_spec.num_tasks("worker"),
tf.int32,
shared_name=f"done_queue{task_index}",
)
return queue


def wait_done_queue(
cluster_spec: tf.train.ClusterSpec,
server: "TFServerV1",
queue: tf.FIFOQueue,
task_index: int,
):
"""Wait until all enqued operation in tf distributed queue are finished.

Parameters
----------
cluster_spec : tf.train.ClusterSpec
tf cluster specification object
server : TFServerV1
tf server specification object
queue : tf.FIFOQueue
tf distributed queue
task_index : int
identifying index of a task
"""
with tf.Session(server.target) as sess:
for i in range(cluster_spec.num_tasks("worker")):
run_sess(sess, queue.dequeue())
log.debug(f"ps:{task_index:d} received done from worker:{i:d}")
log.debug(f"ps:{task_index:f} quitting")


def connect_done_queue(
cluster_spec: tf.train.ClusterSpec, task_index: int
) -> List[tf.Operation]:
"""Create tf FIFO queue filling operations.

Parameters
----------
cluster_spec : tf.train.ClusterSpec
tf cluster specification object
task_index : int
identifying index of a task

Returns
-------
List[tf.Operation]
list of tf operations that will populate the queue
"""
done_ops = []
for i in range(cluster_spec.num_tasks("ps")):
with tf.device(f"/job:ps/task:{i:d}"):
queue = tf.FIFOQueue(
cluster_spec.num_tasks("worker"), tf.int32, shared_name=f"done_queue{i}"
)
done_ops.append(queue.enqueue(task_index))
return done_ops


def fill_done_queue(
cluster_spec: tf.train.ClusterSpec,
server: "TFServerV1",
done_ops: List[tf.Operation],
task_index: int,
):
"""Run specified operations that will fill the tf distributed FIFO queue.

Parameters
----------
cluster_spec : tf.train.ClusterSpec
tf cluster specification object
server : TFServerV1
tf server specification object
done_ops : List[tf.Operation]
a list of tf operations that will fill the queue
task_index : int
identifying index of a task
"""
with tf.Session(server.target) as sess:
for i in range(cluster_spec.num_tasks("ps")):
run_sess(sess, done_ops[i])
log.debug(f"worker:{task_index:d} sending done to ps:{i:d}")


def train(
*,
INPUT: str,
Expand Down Expand Up @@ -186,34 +79,14 @@ def train(
restart=restart,
log_path=log_path,
log_level=log_level,
mpi_log=mpi_log,
try_distrib=jdata.get("with_distrib", False),
mpi_log=mpi_log
)

for message in WELCOME + CITATION + BUILD:
log.info(message)

run_opt.print_resource_summary()

if run_opt.is_distrib:
# distributed training
if run_opt.my_job_name == "ps":
queue = create_done_queue(run_opt.cluster_spec, run_opt.my_task_index)
wait_done_queue(
run_opt.cluster_spec, run_opt.server, queue, run_opt.my_task_index
)
# server.join()
elif run_opt.my_job_name == "worker":
done_ops = connect_done_queue(run_opt.cluster_spec, run_opt.my_task_index)
_do_work(jdata, run_opt)
fill_done_queue(
run_opt.cluster_spec, run_opt.server, done_ops, run_opt.my_task_index
)
else:
raise RuntimeError("unknown job name")
else:
# serial training
_do_work(jdata, run_opt)
_do_work(jdata, run_opt)


def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
Expand All @@ -234,6 +107,10 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions):
# make necessary checks
assert "training" in jdata

# avoid conflict of visible gpus among multipe tf sessions in one process
if run_opt.is_distrib and len(run_opt.gpus or []) > 1:
reset_default_tf_session_config(cpu_only=True)

# init the model
model = DPTrainer(jdata, run_opt=run_opt)
rcut = model.model.get_rcut()
Expand Down
19 changes: 18 additions & 1 deletion deepmd/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,31 @@ def get_tf_session_config() -> Any:
"""
set_tf_default_nthreads()
intra, inter = get_tf_default_nthreads()
return tf.ConfigProto(
config = tf.ConfigProto(
intra_op_parallelism_threads=intra, inter_op_parallelism_threads=inter
)
return config


default_tf_session_config = get_tf_session_config()


def reset_default_tf_session_config(cpu_only: bool):
"""Limit tensorflow session to CPU or not.

Parameters
----------
cpu_only : bool
If enabled, no GPU device is visible to the TensorFlow Session.
"""
global default_tf_session_config
if cpu_only:
default_tf_session_config.device_count['GPU'] = 0
else:
if 'GPU' in default_tf_session_config.device_count:
del default_tf_session_config.device_count['GPU']


def get_module(module_name: str) -> "ModuleType":
"""Load force module.

Expand Down
18 changes: 9 additions & 9 deletions deepmd/loggers/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ def setStream(self, stream):
def set_log_handles(
level: int,
log_path: Optional["Path"] = None,
mpi_log: Optional[str] = None,
MPI: Optional["MPI"] = None,
mpi_log: Optional[str] = None
):
"""Set desired level for package loggers and add file handlers.

Expand All @@ -154,16 +153,13 @@ def set_log_handles(
only from rank==0. `collect` will write messages from all ranks to one file
opened under rank==0 and to console. `workers` will open one log file for each
worker designated by its rank, console behaviour is the same as for `collect`.
If this argument is specified than also `MPI` object must be passed in.
by default None
MPI : Optional[MPI, optional]
`MPI` communicator object, must be specified if `mpi_log` is specified,
If this argument is specified, package 'mpi4py' must be already installed.
by default None

Raises
------
RuntimeError
if only one of the arguments `mpi_log`, `MPI` is specified
If the argument `mpi_log` is specified, package `mpi4py` is not installed.

References
----------
Expand Down Expand Up @@ -204,8 +200,12 @@ def set_log_handles(
root_log.removeHandler(hdlr)

# check if arguments are present
if (mpi_log and not MPI) or (not mpi_log and MPI):
raise RuntimeError("You cannot specify only one of 'mpi_log', 'MPI' arguments")
MPI = None
if mpi_log:
try:
from mpi4py import MPI
except ImportError as e:
raise RuntimeError("You cannot specify 'mpi_log' when mpi4py not installed") from e

# * add console handler ************************************************************
ch = logging.StreamHandler()
Expand Down
Loading