Skip to content
3 changes: 2 additions & 1 deletion interfaces/cython/cantera/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

import os
import sys
from pathlib import Path
import warnings

warnings.filterwarnings('default', module='cantera')
add_directory(os.path.join(os.path.dirname(__file__), 'data'))
add_directory(Path(__file__).parent / "data")
add_directory('.') # Move current working directory to the front of the path

# Python interpreter used for converting mechanisms
Expand Down
3 changes: 3 additions & 0 deletions interfaces/cython/cantera/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# at https://cantera.org/license.txt for license and copyright information.

from collections import defaultdict as _defaultdict
from pathlib import PurePath

cdef class _SolutionBase:
def __cinit__(self, infile='', name='', adjacent=(), origin=None,
Expand Down Expand Up @@ -53,6 +54,8 @@ cdef class _SolutionBase:
self.transport = NULL

# Parse inputs
if isinstance(infile, PurePath):
infile = str(infile)
if infile.endswith('.yml') or infile.endswith('.yaml') or yaml:
self._init_yaml(infile, name, adjacent, yaml)
elif infile or source:
Expand Down
3 changes: 2 additions & 1 deletion interfaces/cython/cantera/ck2yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,9 @@ def load_extra_file(self, path):
"""
Load YAML-formatted entries from ``path`` on disk.
"""
yaml_ = yaml.YAML()
with open(path, 'rt', encoding="utf-8") as stream:
yml = yaml.round_trip_load(stream)
yml = yaml_.load(stream)

# do not overwrite reserved field names
reserved = {'generator', 'input-files', 'cantera-version', 'date',
Expand Down
4 changes: 2 additions & 2 deletions interfaces/cython/cantera/onedim.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1411,7 +1411,7 @@ cdef class Sim1D:
... description='solution with energy eqn. disabled')

"""
self.sim.save(stringify(filename), stringify(name),
self.sim.save(stringify(str(filename)), stringify(name),
stringify(description), loglevel)

def restore(self, filename='soln.xml', name='solution', loglevel=2):
Expand All @@ -1427,7 +1427,7 @@ cdef class Sim1D:

>>> s.restore(filename='save.xml', name='energy_off')
"""
self.sim.restore(stringify(filename), stringify(name), loglevel)
self.sim.restore(stringify(str(filename)), stringify(name), loglevel)
self._initialized = True

def restore_time_stepping_solution(self):
Expand Down
6 changes: 3 additions & 3 deletions interfaces/cython/cantera/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import os
from pathlib import Path
import cantera

from .test_composite import *
Expand All @@ -15,5 +15,5 @@
from .test_transport import *
from .test_utils import *

cantera.add_directory(os.path.join(os.path.dirname(__file__), 'data'))
cantera.add_directory(os.path.join(os.path.dirname(__file__), '..', 'examples', 'surface_chemistry'))
cantera.add_directory(Path(__file__) / "data")
cantera.add_directory(Path(__file__).parents[1] / "examples" / "surface_chemistry")
53 changes: 21 additions & 32 deletions interfaces/cython/cantera/test/test_composite.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,7 @@
from os.path import join as pjoin
import os
import sys

import numpy as np
from collections import OrderedDict
import warnings

try:
import ruamel_yaml as yaml
except ImportError:
from ruamel import yaml


import cantera as ct
from cantera.composite import _h5py, _pandas
Expand All @@ -22,9 +13,8 @@ class TestModels(utilities.CanteraTest):
@classmethod
def setUpClass(cls):
utilities.CanteraTest.setUpClass()
cls.yml_file = pjoin(cls.test_data_dir, "thermo-models.yaml")
with open(cls.yml_file, 'rt', encoding="utf-8") as stream:
cls.yml = yaml.safe_load(stream)
cls.yml_file = cls.test_data_path / "thermo-models.yaml"
cls.yml = utilities.load_yaml(cls.yml_file)
Comment thread
ischoegl marked this conversation as resolved.

def test_load_thermo_models(self):
for ph in self.yml['phases']:
Expand Down Expand Up @@ -150,7 +140,7 @@ def test_write_csv(self):
states.TPX = np.linspace(300, 1000, 7), 2e5, 'H2:0.5, O2:0.4'
states.equilibrate('HP')

outfile = pjoin(self.test_work_dir, 'solutionarray.csv')
outfile = self.test_work_path / "solutionarray.csv"
Comment thread
ischoegl marked this conversation as resolved.
states.write_csv(outfile)

data = np.genfromtxt(outfile, names=True, delimiter=',')
Expand All @@ -167,7 +157,7 @@ def test_write_csv(self):
def test_write_csv_str_column(self):
states = ct.SolutionArray(self.gas, 3, extra={'spam': 'eggs'})

outfile = pjoin(self.test_work_dir, 'solutionarray.csv')
outfile = self.test_work_path / "solutionarray.csv"
states.write_csv(outfile)

b = ct.SolutionArray(self.gas, extra={'spam'})
Expand All @@ -177,7 +167,7 @@ def test_write_csv_str_column(self):
def test_write_csv_multidim_column(self):
states = ct.SolutionArray(self.gas, 3, extra={'spam': np.zeros((3, 5,))})

outfile = pjoin(self.test_work_dir, 'solutionarray.csv')
outfile = self.test_work_path / "solutionarray.csv"
with self.assertRaisesRegex(NotImplementedError, 'not supported'):
states.write_csv(outfile)

Expand All @@ -194,9 +184,10 @@ def test_to_pandas(self):

@utilities.unittest.skipIf(isinstance(_h5py, ImportError), "h5py is not installed")
def test_write_hdf(self):
outfile = pjoin(self.test_work_dir, 'solutionarray.h5')
if os.path.exists(outfile):
os.remove(outfile)
outfile = self.test_work_path / "solutionarray.h5"
# In Python >= 3.8, this can be replaced by the missing_ok argument
if outfile.is_file():
outfile.unlink()

extra = {'foo': range(7), 'bar': range(7)}
meta = {'spam': 'eggs', 'hello': 'world'}
Expand Down Expand Up @@ -239,9 +230,10 @@ def test_write_hdf(self):

@utilities.unittest.skipIf(isinstance(_h5py, ImportError), "h5py is not installed")
def test_write_hdf_str_column(self):
outfile = pjoin(self.test_work_dir, 'solutionarray.h5')
if os.path.exists(outfile):
os.remove(outfile)
outfile = self.test_work_path / "solutionarray.h5"
# In Python >= 3.8, this can be replaced by the missing_ok argument
if outfile.is_file():
outfile.unlink()

states = ct.SolutionArray(self.gas, 3, extra={'spam': 'eggs'})
states.write_hdf(outfile, mode='w')
Expand All @@ -252,9 +244,10 @@ def test_write_hdf_str_column(self):

@utilities.unittest.skipIf(isinstance(_h5py, ImportError), "h5py is not installed")
def test_write_hdf_multidim_column(self):
outfile = pjoin(self.test_work_dir, 'solutionarray.h5')
if os.path.exists(outfile):
os.remove(outfile)
outfile = self.test_work_path / "solutionarray.h5"
# In Python >= 3.8, this can be replaced by the missing_ok argument
if outfile.is_file():
outfile.unlink()

states = ct.SolutionArray(self.gas, 3, extra={'spam': [[1, 2], [3, 4], [5, 6]]})
states.write_hdf(outfile, mode='w')
Expand Down Expand Up @@ -478,8 +471,7 @@ def test_yaml_simple(self):
gas.equilibrate('HP')
gas.TP = 1500, ct.one_atm
gas.write_yaml('h2o2-generated.yaml')
with open('h2o2-generated.yaml', 'r') as infile:
generated = yaml.safe_load(infile)
generated = utilities.load_yaml("h2o2-generated.yaml")
for key in ('generator', 'date', 'phases', 'species', 'reactions'):
self.assertIn(key, generated)
self.assertEqual(generated['phases'][0]['transport'], 'mixture-averaged')
Expand All @@ -503,10 +495,8 @@ def test_yaml_outunits(self):
gas.TP = 1500, ct.one_atm
units = {'length': 'cm', 'quantity': 'mol', 'energy': 'cal'}
gas.write_yaml('h2o2-generated.yaml', units=units)
with open('h2o2-generated.yaml') as infile:
generated = yaml.safe_load(infile)
with open(pjoin(self.cantera_data, "h2o2.yaml")) as infile:
original = yaml.safe_load(infile)
generated = utilities.load_yaml("h2o2-generated.yaml")
original = utilities.load_yaml(self.cantera_data_path / "h2o2.yaml")
self.assertEqual(generated['units'], units)

for r1, r2 in zip(original['reactions'], generated['reactions']):
Expand All @@ -529,8 +519,7 @@ def test_yaml_surface(self):
surf.coverages = np.ones(surf.n_species)
surf.write_yaml('ptcombust-generated.yaml')

with open('ptcombust-generated.yaml') as infile:
generated = yaml.safe_load(infile)
generated = utilities.load_yaml("ptcombust-generated.yaml")
for key in ('phases', 'species', 'gas-reactions', 'Pt_surf-reactions'):
self.assertIn(key, generated)
self.assertEqual(len(generated['gas-reactions']), gas.n_reactions)
Expand Down
Loading