Skip to content

Commit 5a454d5

Browse files
committed
revert changes on using process for no significant improvement
1 parent 68dd750 commit 5a454d5

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

deepmd/pt/utils/dataloader.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import logging
33
import os
44
import time
5-
from functools import (
6-
partial,
7-
)
8-
from multiprocessing import (
5+
6+
from multiprocessing.dummy import (
97
Pool,
108
)
119
from queue import (
@@ -57,13 +55,6 @@ def setup_seed(seed) -> None:
5755
torch.backends.cudnn.deterministic = True
5856

5957

60-
def construct_dataset(system, type_map):
61-
return DeepmdDataSetForLoader(
62-
system=system,
63-
type_map=type_map,
64-
)
65-
66-
6758
class DpLoaderSet(Dataset):
6859
"""A dataset for storing DataLoaders to multiple Systems.
6960
@@ -99,7 +90,11 @@ def __init__(
9990
if len(systems) >= 100:
10091
log.info(f"Constructing DataLoaders from {len(systems)} systems")
10192

102-
construct_dataset_systems = partial(construct_dataset, type_map=type_map)
93+
def construct_dataset(system):
94+
return DeepmdDataSetForLoader(
95+
system=system,
96+
type_map=type_map,
97+
)
10398

10499
with Pool(
105100
os.cpu_count()
@@ -109,7 +104,7 @@ def __init__(
109104
else 1
110105
)
111106
) as pool:
112-
self.systems = pool.map(construct_dataset_systems, systems)
107+
self.systems = pool.map(construct_dataset, systems)
113108

114109
self.sampler_list: list[DistributedSampler] = []
115110
self.index = []
@@ -235,13 +230,12 @@ def __init__(self, iterable) -> None:
235230
self._consumer = BackgroundConsumer(self._queue, self._iterable)
236231
self._consumer.start()
237232
self.last_warning_time = time.time()
238-
self.len = len(iterable)
239233

240234
def __iter__(self):
241235
return self
242236

243237
def __len__(self) -> int:
244-
return self.len
238+
return len(self._iterable)
245239

246240
def __next__(self):
247241
start_wait = time.time()

0 commit comments

Comments
 (0)