44 Given bed file, return sequence and chromatin info
55"""
66
7+ import logging
78import torch
89import random
910import pyfasta
@@ -27,12 +28,12 @@ class SeqChromLoader():
2728 def __init__ (self , SeqChromDataset ):
2829 self .SeqChromDataset = SeqChromDataset
2930
30- def __call__ (self , * args , batch_size = 512 , num_workers = 1 , shuffle = False , worker_init_fn = worker_init_fn , ** kwargs ):
31+ def __call__ (self , * args , worker_init_fn = worker_init_fn , dataloader_kws :dict = None , ** kwargs ):
32+ # default dataloader kws
33+ wif = dataloader_kws .pop ("worker_init_fn" , worker_init_fn ) if dataloader_kws is not None else worker_init_fn
34+
3135 return DataLoader (self .SeqChromDataset (* args , ** kwargs ),
32- batch_size = batch_size ,
33- num_workers = num_workers ,
34- shuffle = shuffle ,
35- worker_init_fn = worker_init_fn )
36+ worker_init_fn = wif , ** dataloader_kws )
3637
3738def seqChromLoaderCurry (SeqChromDataset ):
3839
@@ -52,25 +53,24 @@ def initialize(self):
5253
5354 def __iter__ (self ):
5455 worker_info = torch .utils .data .get_worker_info ()
56+ pipeline = [
57+ wds .SimpleShardList (self .wds ),
58+ wds .tarfile_to_samples (),
59+ wds .split_by_worker ,
60+ wds .decode (),
61+ wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
62+ ]
5563 if worker_info is None :
56- ds = wds .DataPipeline (
57- wds .SimpleShardList (self .wds ),
58- wds .shuffle (100 , rng = random .Random (1 )),
59- wds .tarfile_to_samples (),
60- wds .shuffle (1000 , rng = random .Random (1 )),
61- wds .decode (),
62- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
63- )
64- else :
65- ds = wds .DataPipeline (
66- wds .SimpleShardList (self .wds ),
67- wds .shuffle (100 , rng = random .Random (1 )),
68- wds .split_by_worker ,
69- wds .tarfile_to_samples (),
70- wds .shuffle (1000 , rng = random .Random (1 )),
71- wds .decode (),
72- wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
73- )
64+ logging .info ("Worker info not found, won't split dataset across subprocesses, are you using custom dataloader?" )
65+ logging .info ("Ignore the message if you are not using multiprocessing on data loading" )
66+ del pipeline [2 ]
67+
68+ # transform
69+ if self .seq_transform is not None : pipeline .extend (self .seq_transform )
70+ if self .chrom_transform is not None : pipeline .extend (self .chrom_transform )
71+ if self .target_transform is not None : pipeline .extend (self .target_transform )
72+
73+ ds = wds .DataPipeline (* pipeline )
7474
7575 return iter (ds )
7676
@@ -107,8 +107,6 @@ def __getitem__(self, idx):
107107 for idx , bigwig in enumerate (self .bigwigs ):
108108 m = (np .nan_to_num (bigwig .values (entry .chrom , entry .start , entry .end ))).astype (np .float32 )
109109 if entry .strand == "-" : m = m [::- 1 ] # reverse if needed
110- if self .scaler_mean and self .scaler_var :
111- m = (m - self .scaler_mean [idx ])/ sqrt (self .scaler_var [idx ])
112110 ms .append (m )
113111 except RuntimeError as e :
114112 print (e )
@@ -204,8 +202,7 @@ def setup(self, stage=None):
204202 self .batch_size_per_rank = int (self .batch_size / world_size )
205203
206204 if stage in ["fit" , "validate" , "test" ] or stage is None :
207-
208- self .train_loader = wds .DataPipeline (
205+ train_pipeline = [
209206 wds .SimpleShardList (self .train_wds ),
210207 wds .shuffle (100 , rng = random .Random (1 )),
211208 split_by_node (global_rank , world_size ),
@@ -214,24 +211,41 @@ def setup(self, stage=None):
214211 wds .shuffle (1000 , rng = random .Random (1 )),
215212 wds .decode (),
216213 wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
217- )
214+ ]
218215
219- self . val_loader = wds . DataPipeline (
216+ val_pipeline = [
220217 wds .SimpleShardList (self .val_wds ),
221218 split_by_node (global_rank , world_size ),
222219 wds .split_by_worker ,
223220 wds .tarfile_to_samples (),
224221 wds .decode (),
225222 wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
226- )
223+ ]
227224
228- self . test_loader = wds . DataPipeline (
225+ test_pipeline = [
229226 wds .SimpleShardList (self .test_wds ),
230227 split_by_node (global_rank , world_size ),
231228 wds .split_by_worker ,
232229 wds .tarfile_to_samples (),
233230 wds .decode (),
234231 wds .to_tuple ("seq.npy" , "chrom.npy" , "target.npy" , "label.npy" ),
232+ ]
233+
234+ if self .transform is not None :
235+ train_pipeline .extend (self .transform )
236+ val_pipeline .extend (self .transform )
237+ test_pipeline .extend (self .transform )
238+
239+ self .train_loader = wds .DataPipeline (
240+ * train_pipeline
241+ )
242+
243+ self .val_loader = wds .DataPipeline (
244+ * val_pipeline
245+ )
246+
247+ self .test_loader = wds .DataPipeline (
248+ * test_pipeline
235249 )
236250
237251 def train_dataloader (self ):
0 commit comments