diff --git a/pysoundfile.py b/pysoundfile.py index a4a2aa5..105f548 100644 --- a/pysoundfile.py +++ b/pysoundfile.py @@ -11,14 +11,10 @@ __version__ = "0.5.0" import numpy as _np +import os as _os from cffi import FFI as _FFI from os import SEEK_SET, SEEK_CUR, SEEK_END -try: - import builtins as _builtins -except ImportError: - import __builtin__ as _builtins # for Python < 3.0 - _ffi = _FFI() _ffi.cdef(""" enum @@ -563,16 +559,14 @@ def __init__(self, file, mode='r', samplerate=None, channels=None, ``seek()`` and ``tell()``). mode : {'r', 'r+', 'w', 'w+', 'x', 'x+'}, optional Open mode. Has to begin with one of these three characters: - ``'r'`` for reading, ``'w'`` for writing (truncates) or - ``'x'`` for writing (but fail if already existing). - Additionally, it may contain ``'+'`` to open a file for both - reading and writing. + ``'r'`` for reading, ``'w'`` for writing (truncates `file`) + or ``'x'`` for writing (raises an error if `file` already + exists). Additionally, it may contain ``'+'`` to open + `file` for both reading and writing. The character ``'b'`` for *binary mode* is implied because all sound files have to be opened in this mode. - - .. note:: The modes containing ``'x'`` are only available - since Python 3.3! - + If `file` is a file descriptor or a file-like object, + ``'w'`` doesn't truncate and ``'x'`` doesn't raise an error. samplerate : int The sample rate of the file. If `mode` contains ``'r'``, this is obtained from the file (except for ``'RAW'`` files). @@ -634,9 +628,9 @@ def __init__(self, file, mode='r', samplerate=None, channels=None, raise ValueError("mode must contain exactly one of 'xrw'") self._mode = mode - if '+' in mode: + if '+' in modes: mode_int = _snd.SFM_RDWR - elif 'r' in mode: + elif 'r' in modes: mode_int = _snd.SFM_READ else: mode_int = _snd.SFM_WRITE @@ -645,13 +639,13 @@ def __init__(self, file, mode='r', samplerate=None, channels=None, self._name = getattr(file, 'name', file) if format is None: format = str(self.name).rsplit('.', 1)[-1].upper() - if format not in _formats and 'r' not in mode: + if format not in _formats and 'r' not in modes: raise TypeError( "No format specified and unable to get format from " "file extension: %s" % repr(self.name)) self._info = _ffi.new("SF_INFO*") - if 'r' not in mode or str(format).upper() == 'RAW': + if 'r' not in modes or str(format).upper() == 'RAW': if samplerate is None: raise TypeError("samplerate must be specified") self._info.samplerate = samplerate @@ -670,11 +664,14 @@ def __init__(self, file, mode='r', samplerate=None, channels=None, raise ValueError("closefd=False only allowed for file descriptors") if isinstance(file, str): - if 'b' not in mode: - mode += 'b' - file = self._filestream = _builtins.open(file, mode, buffering=0) - - if isinstance(file, int): + if _os.path.isfile(file): + if 'x' in modes: + raise OSError("File exists: %s" % repr(file)) + elif modes.issuperset('w+'): + # truncate the file, because SFM_RDWR doesn't: + _os.close(_os.open(file, _os.O_WRONLY | _os.O_TRUNC)) + self._file = _snd.sf_open(file.encode(), mode_int, self._info) + elif isinstance(file, int): self._file = _snd.sf_open_fd(file, mode_int, self._info, closefd) elif all(hasattr(file, a) for a in ('seek', 'read', 'write', 'tell')): self._file = _snd.sf_open_virtual( @@ -719,7 +716,6 @@ def __init__(self, file, mode='r', samplerate=None, channels=None, # avoid confusion if something goes wrong before assigning self._file: _file = None - _filestream = None def __del__(self): self.close() @@ -1007,8 +1003,6 @@ def close(self): self.flush() err = _snd.sf_close(self._file) self._file = None - if self._filestream: - self._filestream.close() self._handle_error_number(err) def _init_virtual_io(self, file): diff --git a/tests/test_pysoundfile.py b/tests/test_pysoundfile.py index bdf302e..1fa6972 100644 --- a/tests/test_pysoundfile.py +++ b/tests/test_pysoundfile.py @@ -3,7 +3,6 @@ import os import shutil import pytest -import sys data_stereo = np.array([[1.0, -1.0], [0.75, -0.75], @@ -397,13 +396,28 @@ def test_open_with_mode_is_none(): assert f.mode == 'rb' -@pytest.mark.skipif(sys.version_info < (3, 3), - reason="mode='x' not supported in Python < 3.3") def test_open_with_mode_is_x(): - with pytest.raises(FileExistsError): + with pytest.raises(OSError) as excinfo: sf.SoundFile(filename_stereo, 'x', 44100, 2) - with pytest.raises(FileExistsError): + assert "exists" in str(excinfo.value) + with pytest.raises(OSError) as excinfo: sf.SoundFile(filename_stereo, 'x+', 44100, 2) + assert "exists" in str(excinfo.value) + + +@pytest.mark.parametrize("mode", ['w', 'w+']) +def test_if_open_with_mode_w_truncates(file_stereo_rplus, mode): + with sf.SoundFile(file_stereo_rplus, mode, 48000, 6, format='AIFF') as f: + pass + with sf.SoundFile(filename_new) as f: + if isinstance(file_stereo_rplus, str): + assert f.samplerate == 48000 + assert f.channels == 6 + assert f.format == 'AIFF' + assert len(f) == 0 + else: + # This doesn't really work for file descriptors and file objects + pass # -----------------------------------------------------------------------------