Skip to content
25 changes: 11 additions & 14 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,26 +87,23 @@ def __init__(
with h5py.File(systems) as file:
systems = [os.path.join(systems, item) for item in file.keys()]

self.systems: list[DeepmdDataSetForLoader] = []
if len(systems) >= 100:
log.info(f"Constructing DataLoaders from {len(systems)} systems")

def construct_dataset(system):
return DeepmdDataSetForLoader(
system=system,
type_map=type_map,
)

with Pool(
os.cpu_count()
// (
int(os.environ["LOCAL_WORLD_SIZE"])
if dist.is_available() and dist.is_initialized()
else 1
)
) as pool:
self.systems = pool.map(construct_dataset, systems)

self.systems: list[DeepmdDataSetForLoader] = []
global_rank = dist.get_rank() if dist.is_initialized() else 0
if global_rank == 0:
log.info(f"Constructing DataLoaders from {len(systems)} systems")
with Pool(max(1, env.NUM_WORKERS)) as pool:
self.systems = pool.map(construct_dataset, systems)
else:
self.systems = [None] * len(systems) # type: ignore
if dist.is_initialized():
dist.broadcast_object_list(self.systems)
assert self.systems[-1] is not None
self.sampler_list: list[DistributedSampler] = []
self.index = []
self.total_batch = 0
Expand Down