22import logging
33import os
44import time
5- from functools import (
6- partial ,
7- )
8- from multiprocessing import (
5+
6+ from multiprocessing .dummy import (
97 Pool ,
108)
119from queue import (
@@ -57,13 +55,6 @@ def setup_seed(seed) -> None:
5755 torch .backends .cudnn .deterministic = True
5856
5957
60- def construct_dataset (system , type_map ):
61- return DeepmdDataSetForLoader (
62- system = system ,
63- type_map = type_map ,
64- )
65-
66-
6758class DpLoaderSet (Dataset ):
6859 """A dataset for storing DataLoaders to multiple Systems.
6960
@@ -99,7 +90,11 @@ def __init__(
9990 if len (systems ) >= 100 :
10091 log .info (f"Constructing DataLoaders from { len (systems )} systems" )
10192
102- construct_dataset_systems = partial (construct_dataset , type_map = type_map )
93+ def construct_dataset (system ):
94+ return DeepmdDataSetForLoader (
95+ system = system ,
96+ type_map = type_map ,
97+ )
10398
10499 with Pool (
105100 os .cpu_count ()
@@ -109,7 +104,7 @@ def __init__(
109104 else 1
110105 )
111106 ) as pool :
112- self .systems = pool .map (construct_dataset_systems , systems )
107+ self .systems = pool .map (construct_dataset , systems )
113108
114109 self .sampler_list : list [DistributedSampler ] = []
115110 self .index = []
@@ -235,13 +230,12 @@ def __init__(self, iterable) -> None:
235230 self ._consumer = BackgroundConsumer (self ._queue , self ._iterable )
236231 self ._consumer .start ()
237232 self .last_warning_time = time .time ()
238- self .len = len (iterable )
239233
240234 def __iter__ (self ):
241235 return self
242236
243237 def __len__ (self ) -> int :
244- return self .len
238+ return len ( self ._iterable )
245239
246240 def __next__ (self ):
247241 start_wait = time .time ()
0 commit comments