Skip to content

Commit ed32b9c

Browse files
committed
Bug fixes
Add unit tests and fix the bugs that cause unexpected bahaviors during test
1 parent 9cc67dd commit ed32b9c

File tree

6 files changed

+44
-31
lines changed

6 files changed

+44
-31
lines changed

seqchromloader/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from .loader import SeqChromDatasetByDataFrame, SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
22
from .writer import dump_data_webdataset, convert_data_webdataset
3-
from .utils import filter_chromosomes, make_random_shift, make_flank, chop_genome, dna2OneHot, rev_comp
3+
from .utils import filter_chromosomes, make_random_shift, make_flank, chop_genome, dna2OneHot, rev_comp, get_genome_sizes, random_coords

seqchromloader/loader.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
import random
1010
import pysam
11-
import pyfasta
11+
import pyfaidx
1212
import pyBigWig
1313
import numpy as np
1414
import pandas as pd
@@ -120,7 +120,7 @@ def __init__(self,
120120

121121
self.dataframe = dataframe
122122
self.genome_fasta = genome_fasta
123-
self.genome_pyfasta = None
123+
self.genome_pyfaidx = None
124124
self.bigwig_filelist = bigwig_filelist
125125
self.bigwigs = None
126126
self.target_bam = target_bam
@@ -133,7 +133,7 @@ def __init__(self,
133133
def initialize(self):
134134
# create the stream handler after child processes spawned to enable parallel reading
135135
# this function will be called by worker_init_function in DataLoader
136-
self.genome_pyfasta = pyfasta.Fasta(self.genome_fasta)
136+
self.genome_pyfaidx = pyfaidx.Fasta(self.genome_fasta)
137137
self.bigwigs = [pyBigWig.open(bw) for bw in self.bigwig_filelist]
138138
if self.target_bam is not None:
139139
self.target_pysam = pysam.AlignmentFile(self.target_bam)
@@ -149,7 +149,7 @@ def __getitem__(self, idx):
149149
item.start,
150150
item.end,
151151
item.label,
152-
genome_pyfasta=self.genome_pyfasta,
152+
genome_pyfaidx=self.genome_pyfaidx,
153153
bigwigs=self.bigwigs,
154154
target_bam=self.target_pysam,
155155
strand=item.strand,

seqchromloader/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def get_genome_sizes(gs=None, genome=None, to_filter=None, to_keep=None):
2828
elif genome:
2929
genome_sizes = (pd.DataFrame(chromsizes(genome))
3030
.T
31-
.rename(columns={0:"chrom", 1:"len"}))
31+
.reset_index()
32+
.rename(columns={"index":"chrom", 0:"start", 1:"end"})
33+
.assign(length=lambda x: x["end"] - x["start"]))[["chrom", "length"]]
3234
else:
3335
raise Exception("Either gs or genome should be provided!")
3436

@@ -74,7 +76,7 @@ def filter_chromosomes(coords, to_filter=None, to_keep=None):
7476
corods_out = coords
7577
return corods_out
7678

77-
def make_random_shift(coords, L, buffer=25):
79+
def make_random_shift(coords, L, buffer=0):
7880
"""
7981
This function takes as input a set of bed coordinates dataframe
8082
It finds the mid-point for each record or Interval in the bed file,
@@ -152,8 +154,8 @@ def random_coords(gs:str=None, genome:str=None, incl:BedTool=None, excl:BedTool=
152154
else:
153155
raise Exception("Either gs or genome should be provided!")
154156

155-
if incl: shuffle_kwargs.update({"incl": incl})
156-
if excl: shuffle_kwargs.update({"excl": excl})
157+
if incl: shuffle_kwargs.update({"incl": incl.fn})
158+
if excl: shuffle_kwargs.update({"excl": excl.fn})
157159

158160
return (BedTool()
159161
.random(l=l, n=n, **random_kwargs)
@@ -194,7 +196,7 @@ def intervals_loop(chrom, start, stride, l, size):
194196

195197
genome_sizes = get_genome_sizes(gs=gs, genome=genome, to_keep=chroms)
196198

197-
genome_chops = pd.concat([intervals_loop(i.Index, 0, stride, l, i.len)
199+
genome_chops = pd.concat([intervals_loop(i.chrom, 0, stride, l, i.length)
198200
for i in genome_sizes.itertuples()])
199201
genome_chops_bdt = BedTool.from_dataframe(genome_chops)
200202

@@ -282,8 +284,8 @@ def __init__(self, chrom, start, end, *args):
282284
def __str__(self) -> str:
283285
return f'Chromatin Info Inaccessible in region {self.chrom}:{self.start}-{self.end}'
284286

285-
def extract_info(chrom, start, end, label, genome_pyfasta, bigwigs, target_bam, strand="+", transforms:dict=None):
286-
seq = genome_pyfasta[chrom][int(start):int(end)]
287+
def extract_info(chrom, start, end, label, genome_pyfaidx, bigwigs, target_bam, strand="+", transforms:dict=None):
288+
seq = genome_pyfaidx[chrom][int(start):int(end)].seq
287289
if strand=="-":
288290
seq = rev_comp(seq)
289291
seq_array = dna2OneHot(seq)

seqchromloader/writer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from collections import defaultdict
1313
from multiprocessing import Pool
1414

15-
import pyfasta
15+
import pyfaidx
1616
import pysam
1717
import pyBigWig
1818
import webdataset as wds
@@ -53,7 +53,7 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
5353
compress=True,
5454
numProcessors=1,
5555
transforms=None,
56-
braceexpand=True,
56+
braceexpand=False,
5757
DALI=False):
5858
"""
5959
Given coordinates dataframe, extract the sequence and chromatin signal, save in webdataset format
@@ -85,7 +85,7 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
8585
# split coordinates and assign chunks to workers
8686
num_chunks = math.ceil(len(coords) / 7000)
8787
chunks = np.array_split(coords, num_chunks)
88-
88+
8989
# freeze the common parameters
9090
## create a scaler to get statistics for normalizing chromatin marks input
9191
## also create a multiprocessing lock
@@ -99,18 +99,19 @@ def dump_data_webdataset(coords, genome_fasta, bigwig_filelist,
9999
DALI=DALI)
100100

101101
count_of_digits = 0
102-
while num_chunks > 0:
103-
num_chunks = int(num_chunks/10)
102+
nc = num_chunks
103+
while nc > 0:
104+
nc = int(nc/10)
104105
count_of_digits += 1
105106

106107
pool = Pool(numProcessors)
107108
res = pool.starmap_async(dump_data_worker_freeze, zip(chunks, [outprefix + "_" + format(i, f'0{count_of_digits}d') for i in range(num_chunks)]))
108109
files = res.get()
109110

110111
if braceexpand:
111-
begin = f'0{count_of_digits}d'.format(0)
112-
end = f'0{count_of_digits}d'.format(range(num_chunks)[-1])
113-
return f"outprefix_{{{begin}...{end}}}.tar.gz" if compress else f"outprefix_{{{begin}...{end}}}.tar"
112+
begin = format(0, f'0{count_of_digits}d')
113+
end = format(range(num_chunks)[-1], f'0{count_of_digits}d')
114+
return os.path.join(outdir, f"{outprefix}_{{{begin}..{end}}}.tar.gz" if compress else f"{outprefix}_{{{begin}...{end}}}.tar")
114115
else:
115116
return files
116117

@@ -124,7 +125,7 @@ def dump_data_webdataset_worker(coords,
124125
transforms=None,
125126
DALI=False):
126127
# get handlers
127-
genome_pyfasta = pyfasta.Fasta(fasta)
128+
genome_pyfaidx = pyfaidx.Fasta(fasta)
128129
bigwigs = [pyBigWig.open(bw) for bw in bigwig_files]
129130
target_pysam = pysam.AlignmentFile(target_bam) if target_bam is not None else None
130131

@@ -141,7 +142,7 @@ def dump_data_webdataset_worker(coords,
141142
item.start,
142143
item.end,
143144
item.label,
144-
genome_pyfasta=genome_pyfasta,
145+
genome_pyfaidx=genome_pyfaidx,
145146
bigwigs=bigwigs,
146147
target_bam=target_pysam,
147148
strand=item.strand,

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# eg: 1.0.0, 1.0.1, 3.0.2, 5.0-beta, etc.
2121
# You CANNOT upload two versions of your package with the same version number
2222
# This field is REQUIRED
23-
version="0.5.1",
23+
version="0.5.2",
2424

2525
# The packages that constitute your project.
2626
# For my project, I have only one - "pydash".
@@ -37,7 +37,7 @@
3737
'numpy',
3838
'pandas',
3939
'webdataset>=0.2.0',
40-
'pyfasta>=0.5.0',
40+
'pyfaidx>=0.7.0',
4141
'pybedtools>=0.9.0',
4242
'pysam>=0.19.0',
4343
'pybigwig>=0.3.0',

tests/test_writer_loader.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import os
22
import sys
3-
sys.path.insert(0, "./")
43
import numpy as np
54
import pandas as pd
65
from seqchromloader import SeqChromDatasetByDataFrame, SeqChromDatasetByBed, SeqChromDatasetByWds, SeqChromDataModule
@@ -68,7 +67,7 @@ def test_random_coords(self):
6867
coords_incl = random_coords(genome="mm10", incl=interval)
6968
coords_excl = random_coords(genome="mm10", excl=interval)
7069

71-
self.assertTrue(BedTool().from_dataframe(coords_incl).intersect(interval).count()==coords_incl.size)
70+
self.assertTrue(BedTool().from_dataframe(coords_incl).intersect(interval).count()==len(coords_incl))
7271
self.assertTrue(BedTool().from_dataframe(coords_excl).intersect(interval).count()==0)
7372

7473
def test_chop_genome(self):
@@ -77,9 +76,10 @@ def test_chop_genome(self):
7776
'end': [50000, 20000]}))
7877
coords_incl = chop_genome(chroms=["chr2", "chr12"], genome="mm10", stride=1000, l=500, incl=interval)
7978
coords_excl = chop_genome(chroms=["chr2", "chr12"], genome="mm10", stride=1000, l=500, excl=interval)
80-
self.assertTrue(np.all([coords_incl.start.iloc[i] - coords_incl.start.iloc[i-1] for i in range(1, len(coords_incl))]==1000))
81-
self.assertTrue(np.all([coords_excl.start.iloc[i] - coords_excl.start.iloc[i-1] for i in range(1, len(coords_excl))]==1000))
82-
self.assertTrue(BedTool().from_dataframe(coords_incl).intersect(interval).count()==coords_incl.size)
79+
for c in ['chr2', 'chr12']:
80+
df = coords_incl[coords_incl.chrom==c]
81+
self.assertTrue(np.all([df.start.iloc[i] - df.start.iloc[i-1] == 1000 for i in range(1, len(df))]))
82+
self.assertTrue(BedTool().from_dataframe(coords_incl).intersect(interval).count()==len(coords_incl))
8383
self.assertTrue(BedTool().from_dataframe(coords_excl).intersect(interval).count()==0)
8484

8585
def test_writer(self):
@@ -99,8 +99,18 @@ def test_writer(self):
9999
outdir=self.tempdir,
100100
outprefix='test',
101101
compress=True,
102-
numProcessors=5)
102+
numProcessors=2)
103103
self.assertIsFile(os.path.join(self.tempdir, "test_0.tar.gz"))
104+
wds_files = dump_data_webdataset(huge_coords,
105+
genome_fasta='data/sample.fa',
106+
bigwig_filelist=['data/sample.bw'],
107+
target_bam='data/sample.bam',
108+
outdir=self.tempdir,
109+
outprefix='test',
110+
compress=True,
111+
numProcessors=2,
112+
braceexpand=True)
113+
self.assertTrue(wds_files == os.path.join(self.tempdir, "test_{0..1}.tar.gz"))
104114

105115
ds = wds.DataPipeline(
106116
wds.SimpleShardList([os.path.join(self.tempdir, "test_0.tar.gz")]),
@@ -251,4 +261,4 @@ def test_target_transform(target):
251261
return target * 3
252262

253263
if __name__ == "__main__":
254-
unittest.main()
264+
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)