Skip to content

Commit 4fec94f

Browse files
committed
Bump to 0.2.0, API changes
Modifications to let loaders have the same behavior Keyword argument of dataloder wrapper to be compatible with all dataloder arguments Several bug fixes
1 parent 18622cb commit 4fec94f

File tree

6 files changed

+56
-106
lines changed

6 files changed

+56
-106
lines changed

data/sample.fa

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
>chr19
2+
ACGTNNNATTCGG

seqchromloader/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from .loader import SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
2-
from .writer import get_data_webdataset
2+
from .writer import dump_data_webdataset

seqchromloader/loader.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Given bed file, return sequence and chromatin info
55
"""
66

7+
import logging
78
import torch
89
import random
910
import pyfasta
@@ -27,12 +28,12 @@ class SeqChromLoader():
2728
def __init__(self, SeqChromDataset):
2829
self.SeqChromDataset = SeqChromDataset
2930

30-
def __call__(self, *args, batch_size=512, num_workers=1, shuffle=False, worker_init_fn=worker_init_fn, **kwargs):
31+
def __call__(self, *args, worker_init_fn=worker_init_fn, dataloader_kws:dict=None, **kwargs):
32+
# default dataloader kws
33+
wif = dataloader_kws.pop("worker_init_fn", worker_init_fn) if dataloader_kws is not None else worker_init_fn
34+
3135
return DataLoader(self.SeqChromDataset(*args, **kwargs),
32-
batch_size=batch_size,
33-
num_workers=num_workers,
34-
shuffle=shuffle,
35-
worker_init_fn=worker_init_fn)
36+
worker_init_fn=wif, **dataloader_kws)
3637

3738
def seqChromLoaderCurry(SeqChromDataset):
3839

@@ -52,25 +53,24 @@ def initialize(self):
5253

5354
def __iter__(self):
5455
worker_info = torch.utils.data.get_worker_info()
56+
pipeline = [
57+
wds.SimpleShardList(self.wds),
58+
wds.tarfile_to_samples(),
59+
wds.split_by_worker,
60+
wds.decode(),
61+
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
62+
]
5563
if worker_info is None:
56-
ds = wds.DataPipeline(
57-
wds.SimpleShardList(self.wds),
58-
wds.shuffle(100, rng=random.Random(1)),
59-
wds.tarfile_to_samples(),
60-
wds.shuffle(1000, rng=random.Random(1)),
61-
wds.decode(),
62-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
63-
)
64-
else:
65-
ds = wds.DataPipeline(
66-
wds.SimpleShardList(self.wds),
67-
wds.shuffle(100, rng=random.Random(1)),
68-
wds.split_by_worker,
69-
wds.tarfile_to_samples(),
70-
wds.shuffle(1000, rng=random.Random(1)),
71-
wds.decode(),
72-
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
73-
)
64+
logging.info("Worker info not found, won't split dataset across subprocesses, are you using custom dataloader?")
65+
logging.info("Ignore the message if you are not using multiprocessing on data loading")
66+
del pipeline[2]
67+
68+
# transform
69+
if self.seq_transform is not None: pipeline.extend(self.seq_transform)
70+
if self.chrom_transform is not None: pipeline.extend(self.chrom_transform)
71+
if self.target_transform is not None: pipeline.extend(self.target_transform)
72+
73+
ds = wds.DataPipeline(*pipeline)
7474

7575
return iter(ds)
7676

@@ -107,8 +107,6 @@ def __getitem__(self, idx):
107107
for idx, bigwig in enumerate(self.bigwigs):
108108
m = (np.nan_to_num(bigwig.values(entry.chrom, entry.start, entry.end))).astype(np.float32)
109109
if entry.strand == "-": m = m[::-1] # reverse if needed
110-
if self.scaler_mean and self.scaler_var:
111-
m = (m - self.scaler_mean[idx])/sqrt(self.scaler_var[idx])
112110
ms.append(m)
113111
except RuntimeError as e:
114112
print(e)
@@ -204,8 +202,7 @@ def setup(self, stage=None):
204202
self.batch_size_per_rank = int(self.batch_size/world_size)
205203

206204
if stage in ["fit", "validate", "test"] or stage is None:
207-
208-
self.train_loader = wds.DataPipeline(
205+
train_pipeline = [
209206
wds.SimpleShardList(self.train_wds),
210207
wds.shuffle(100, rng=random.Random(1)),
211208
split_by_node(global_rank, world_size),
@@ -214,24 +211,41 @@ def setup(self, stage=None):
214211
wds.shuffle(1000, rng=random.Random(1)),
215212
wds.decode(),
216213
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
217-
)
214+
]
218215

219-
self.val_loader = wds.DataPipeline(
216+
val_pipeline = [
220217
wds.SimpleShardList(self.val_wds),
221218
split_by_node(global_rank, world_size),
222219
wds.split_by_worker,
223220
wds.tarfile_to_samples(),
224221
wds.decode(),
225222
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
226-
)
223+
]
227224

228-
self.test_loader = wds.DataPipeline(
225+
test_pipeline = [
229226
wds.SimpleShardList(self.test_wds),
230227
split_by_node(global_rank, world_size),
231228
wds.split_by_worker,
232229
wds.tarfile_to_samples(),
233230
wds.decode(),
234231
wds.to_tuple("seq.npy", "chrom.npy", "target.npy", "label.npy"),
232+
]
233+
234+
if self.transform is not None:
235+
train_pipeline.extend(self.transform)
236+
val_pipeline.extend(self.transform)
237+
test_pipeline.extend(self.transform)
238+
239+
self.train_loader = wds.DataPipeline(
240+
*train_pipeline
241+
)
242+
243+
self.val_loader = wds.DataPipeline(
244+
*val_pipeline
245+
)
246+
247+
self.test_loader = wds.DataPipeline(
248+
*test_pipeline
235249
)
236250

237251
def train_dataloader(self):

seqchromloader/writer.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919

2020
from seqchromloader import utils
2121

22-
def get_data_webdataset(coords, genome_fasta, chromatin_tracks,
22+
def dump_data_webdataset(coords, genome_fasta, chromatin_tracks,
2323
tf_bam=None,
24-
nbins=None,
2524
outdir="dataset/", outprefix="seqchrom",
2625
reverse=False, compress=False,
2726
numProcessors=1, chroms_scaler=None):
@@ -37,16 +36,16 @@ def get_data_webdataset(coords, genome_fasta, chromatin_tracks,
3736
# freeze the common parameters
3837
## create a scaler to get statistics for normalizing chromatin marks input
3938
## also create a multiprocessing lock
40-
get_data_worker_freeze = functools.partial(get_data_webdataset_worker,
41-
fasta=genome_fasta, nbins=nbins,
39+
dump_data_worker_freeze = functools.partial(dump_data_webdataset_worker,
40+
fasta=genome_fasta,
4241
bigwig_files=chromatin_tracks,
4342
tf_bam=tf_bam,
4443
reverse=reverse,
4544
compress=compress,
4645
outdir=outdir)
4746

4847
pool = Pool(numProcessors)
49-
res = pool.starmap_async(get_data_worker_freeze, zip(chunks, [outprefix + "_" + str(i) for i in range(num_chunks)]))
48+
res = pool.starmap_async(dump_data_worker_freeze, zip(chunks, [outprefix + "_" + str(i) for i in range(num_chunks)]))
5049
res = res.get()
5150

5251
# fit the scaler if provided
@@ -58,10 +57,10 @@ def get_data_webdataset(coords, genome_fasta, chromatin_tracks,
5857

5958
return files
6059

61-
def get_data_webdataset_worker(coords, outprefix, fasta, bigwig_files,
60+
def dump_data_webdataset_worker(coords, outprefix, fasta, bigwig_files,
6261
tf_bam=None,
6362
outdir="dataset/",
64-
nbins=None, reverse=False, compress=True):
63+
reverse=False, compress=True):
6564
# get handlers
6665
genome_pyfasta = pyfasta.Fasta(fasta)
6766
bigwigs = [pyBigWig.open(bw) for bw in bigwig_files]
@@ -87,10 +86,7 @@ def get_data_webdataset_worker(coords, outprefix, fasta, bigwig_files,
8786
ms = []
8887
try:
8988
for idx, bigwig in enumerate(bigwigs):
90-
m = (np.nan_to_num(bigwig.values(item.chrom, item.start, item.end)))
91-
if nbins:
92-
m = (m.reshape((nbins, -1))
93-
.mean(axis=1, dtype=np.float32))
89+
m = (np.nan_to_num(bigwig.values(item.chrom, item.start, item.end))).astype(np.float32)
9490
if reverse:
9591
m = m[::-1]
9692
ms.append(m)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# eg: 1.0.0, 1.0.1, 3.0.2, 5.0-beta, etc.
2121
# You CANNOT upload two versions of your package with the same version number
2222
# This field is REQUIRED
23-
version="0.1.1",
23+
version="0.2.0",
2424

2525
# The packages that constitute your project.
2626
# For my project, I have only one - "pydash".

tests/unittest.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

0 commit comments

Comments
 (0)