Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 44 additions & 33 deletions pysoundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@ def seek(self, frames, whence=SEEK_SET):

"""
self._check_if_closed()
return _snd.sf_seek(self._file, frames, whence)
position = _snd.sf_seek(self._file, frames, whence)
self._handle_error()
return position

def _check_array(self, array):
# Do some error checking
Expand Down Expand Up @@ -655,6 +657,17 @@ def _read_or_write(self, funcname, array, frames):
self.seek(curr + frames, SEEK_SET) # Update read & write position
return frames

def _check_frames(self, frames, fill_value):
# Check if frames is larger than the remaining frames in the file
if self.seekable():
remaining_frames = self.frames - self.seek(0, SEEK_CUR)
if frames < 0 or (frames > remaining_frames
and fill_value is None):
frames = remaining_frames
elif frames < 0:
raise ValueError("frames must be specified for non-seekable files")
return frames

def read(self, frames=-1, dtype='float64', always_2d=True,
fill_value=None, out=None):
"""Read a number of frames from the file.
Expand Down Expand Up @@ -682,10 +695,7 @@ def read(self, frames=-1, dtype='float64', always_2d=True,

"""
if out is None:
remaining_frames = self.frames - self.seek(0, SEEK_CUR)
if frames < 0 or (frames > remaining_frames and
fill_value is None):
frames = remaining_frames
frames = self._check_frames(frames, fill_value)
out = self._create_empty_array(frames, always_2d, dtype)
else:
if frames < 0 or frames > len(out):
Expand Down Expand Up @@ -722,9 +732,12 @@ def write(self, data):
written = self._read_or_write('sf_writef_', data, len(data))
assert written == len(data)

curr = self.seek(0, SEEK_CUR)
self._info.frames = self.seek(0, SEEK_END)
self.seek(curr, SEEK_SET)
if self.seekable():
curr = self.seek(0, SEEK_CUR)
self._info.frames = self.seek(0, SEEK_END)
self.seek(curr, SEEK_SET)
else:
self._info.frames += written

def blocks(self, blocksize=None, overlap=0, frames=-1, dtype='float64',
always_2d=True, fill_value=None, out=None):
Expand All @@ -747,6 +760,9 @@ def blocks(self, blocksize=None, overlap=0, frames=-1, dtype='float64',
if 'r' not in self.mode and '+' not in self.mode:
raise RuntimeError("blocks() is not allowed in write-only mode")

if overlap != 0 and not self.seekable():
raise ValueError("overlap is only allowed for seekable files")

if out is None:
if blocksize is None:
raise TypeError("One of {blocksize, out} must be specified")
Expand All @@ -756,22 +772,35 @@ def blocks(self, blocksize=None, overlap=0, frames=-1, dtype='float64',
"Only one of {blocksize, out} may be specified")
blocksize = len(out)

remaining_frames = self.frames - self.seek(0, SEEK_CUR)
if frames < 0 or (fill_value is None and frames > remaining_frames):
frames = remaining_frames

frames = self._check_frames(frames, fill_value)
while frames > 0:
if frames < blocksize:
if fill_value is not None and out is None:
out = self._create_empty_array(blocksize, always_2d, dtype)
blocksize = frames
block = self.read(blocksize, dtype, always_2d, fill_value, out)
frames -= blocksize
if frames > 0:
if frames > 0 and self.seekable():
self.seek(-overlap, SEEK_CUR)
frames += overlap
yield block

def _prepare_read(self, start, stop, frames):
# Seek to start frame and calculate length
if start != 0 and not self.seekable():
raise ValueError("start is only allowed for seekable files")
if frames >= 0 and stop is not None:
raise TypeError("Only one of {frames, stop} may be used")

start, stop, _ = slice(start, stop).indices(self.frames)
if stop < start:
stop = start
if frames < 0:
frames = stop - start
if self.seekable():
self.seek(start, SEEK_SET)
return frames


def open(file, mode='r', samplerate=None, channels=None,
subtype=None, endian=None, format=None, closefd=True):
Expand Down Expand Up @@ -813,13 +842,9 @@ def read(file, samplerate=None, channels=None, subtype=None, endian=None,
endian are only needed for 'RAW' files. See open() for details.

"""
if frames >= 0 and stop is not None:
raise TypeError("Only one of {frames, stop} may be used")

with SoundFile(file, 'r', samplerate, channels,
subtype, endian, format, closefd) as f:
start, frames = _get_read_range(start, stop, frames, f.frames)
f.seek(start, SEEK_SET)
frames = f._prepare_read(start, stop, frames)
data = f.read(frames, dtype, always_2d, fill_value, out)
return data, f.samplerate

Expand Down Expand Up @@ -876,28 +901,14 @@ def blocks(file, samplerate=None, channels=None,
the generator's close() method can be called.

"""
if frames >= 0 and stop is not None:
raise TypeError("Only one of {frames, stop} may be used")

with open(file, 'r', samplerate, channels,
subtype, endian, format, closefd) as f:
start, frames = _get_read_range(start, stop, frames, f.frames)
f.seek(start, SEEK_SET)
frames = f._prepare_read(start, stop, frames)
for block in f.blocks(blocksize, overlap, frames,
dtype, always_2d, fill_value, out):
yield block


def _get_read_range(start, stop, frames, total_frames):
# Calculate start frame and length
start, stop, _ = slice(start, stop).indices(total_frames)
if stop < start:
stop = start
if frames < 0:
frames = stop - start
return start, frames


def default_subtype(format):
"""Return default subtype for given format."""
return _default_subtypes.get(str(format).upper())
Expand Down
77 changes: 61 additions & 16 deletions tests/test_pysoundfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
filename_stereo = 'tests/stereo.wav'
filename_mono = 'tests/mono.wav'
filename_raw = 'tests/mono.raw'
filename_new = 'tests/new.wav'
tempfilename = 'tests/delme.please'
filename_new = 'tests/delme.please'


open_variants = 'name', 'fd', 'obj'
Expand Down Expand Up @@ -48,9 +47,9 @@ def _file_new(request, fdarg, objarg=None):


def _file_copy(request, filename, fdarg, objarg=None):
shutil.copy(filename, tempfilename)
request.addfinalizer(lambda: os.remove(tempfilename))
return _file_existing(request, tempfilename, fdarg, objarg)
shutil.copy(filename, filename_new)
request.addfinalizer(lambda: os.remove(filename_new))
return _file_existing(request, filename_new, fdarg, objarg)


@pytest.fixture(params=open_variants)
Expand Down Expand Up @@ -340,19 +339,19 @@ def test_open_with_invalid_mode():

def test_open_with_more_invalid_arguments():
with pytest.raises(TypeError) as excinfo:
sf.open(filename_new, 'w', samplerate=3.1415, channels=2)
sf.open(filename_new, 'w', 3.1415, 2, format='WAV')
assert "integer" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
sf.open(filename_new, 'w', samplerate=44100, channels=3.1415)
sf.open(filename_new, 'w', 44100, 3.1415, format='WAV')
assert "integer" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
sf.open(filename_new, 'w', 44100, 2, format='WAF')
assert "Invalid format string" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
sf.open(filename_new, 'w', 44100, 2, subtype='PCM16')
sf.open(filename_new, 'w', 44100, 2, 'PCM16', format='WAV')
assert "Invalid subtype string" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
sf.open(filename_new, 'w', 44100, 2, endian='BOTH')
sf.open(filename_new, 'w', 44100, 2, endian='BOTH', format='WAV')
assert "Invalid endian-ness" in str(excinfo.value)
with pytest.raises(ValueError) as excinfo:
sf.open(filename_stereo, closefd=False)
Expand All @@ -379,16 +378,15 @@ def test_open_r_and_rplus_with_too_many_arguments():


def test_open_w_and_wplus_with_too_few_arguments():
filename = 'not_existing.xyz'
for mode in 'w', 'w+':
with pytest.raises(TypeError) as excinfo:
sf.open(filename, mode, samplerate=44100, channels=2)
sf.open(filename_new, mode, samplerate=44100, channels=2)
assert "No format specified" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
sf.open(filename, mode, samplerate=44100, format='WAV')
sf.open(filename_new, mode, samplerate=44100, format='WAV')
assert "channels" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
sf.open(filename, mode, channels=2, format='WAV')
sf.open(filename_new, mode, channels=2, format='WAV')
assert "samplerate" in str(excinfo.value)


Expand Down Expand Up @@ -455,8 +453,10 @@ def test_seek_in_read_mode(sf_stereo_r):
assert sf_stereo_r.seek(2) == 2
assert sf_stereo_r.seek(2, sf.SEEK_CUR) == 4
assert sf_stereo_r.seek(-2, sf.SEEK_END) == len(data_stereo) - 2
assert sf_stereo_r.seek(666) == -1
assert sf_stereo_r.seek(-666) == -1
with pytest.raises(RuntimeError):
sf_stereo_r.seek(666)
with pytest.raises(RuntimeError):
sf_stereo_r.seek(-666)


def test_seek_in_write_mode(sf_stereo_w):
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_rplus_append_data(sf_stereo_rplus):
sf_stereo_rplus.seek(0, sf.SEEK_END)
sf_stereo_rplus.write(data_stereo / 2)
sf_stereo_rplus.close()
data, fs = sf.read(tempfilename)
data, fs = sf.read(filename_new)
assert np.all(data[:len(data_stereo)] == data_stereo)
assert np.all(data[len(data_stereo):] == data_stereo / 2)

Expand Down Expand Up @@ -645,3 +645,48 @@ def test_read_raw_files_with_too_few_arguments_should_fail():
sf.open(filename_raw, samplerate=44100, subtype='PCM_16')
with pytest.raises(TypeError): # missing samplerate
sf.open(filename_raw, channels=2, subtype='PCM_16')


# -----------------------------------------------------------------------------
# Test non-seekable files
# -----------------------------------------------------------------------------


def test_write_non_seekable_file():
with sf.open(filename_new, 'w', 44100, 1, format='XI') as f:
assert not f.seekable()
assert f.frames == 0
f.write(data_mono)
assert f.frames == len(data_mono)

with pytest.raises(RuntimeError) as excinfo:
f.seek(2)
assert "unseekable" in str(excinfo.value)

with sf.open(filename_new) as f:
assert not f.seekable()
assert f.frames == len(data_mono)
data = f.read(3, dtype='int16')
assert np.all(data == data_mono[:3])
data = f.read(666, dtype='int16')
assert np.all(data == data_mono[3:])

with pytest.raises(RuntimeError) as excinfo:
f.seek(2)
assert "unseekable" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
f.read()
assert "frames" in str(excinfo.value)

with pytest.raises(ValueError) as excinfo:
list(f.blocks(blocksize=3, overlap=1))
assert "overlap" in str(excinfo.value)

data, fs = sf.read(filename_new, dtype='int16')
assert np.all(data == data_mono)
assert fs == 44100

with pytest.raises(ValueError) as excinfo:
sf.read(filename_new, start=3)
assert "start is only allowed for seekable files" in str(excinfo.value)