diff --git a/crema/models/base.py b/crema/models/base.py index 7be91e7..cac54d3 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -12,7 +12,30 @@ from .. import layers +CORE_CUSTOM_OBJECTS = {k: layers.__dict__[k] for k in layers.__all__} + + class CremaModel(object): + name = None + models_dir = None + custom_objects = {} + + def __init__(self, name=None, models_dir=None): + self.name = name or self.name + self.models_dir = models_dir or self.models_dir + if self.name: + self._instantiate(self.name) + + def __str__(self): + if not hasattr(self, 'model'): # not instantiated + return super().__str__() + return ( + '<{} version={} resources={}\n' + '--------\n' + '* {}\n* {}\n' + '--------->').format( + self.__class__.__name__, self.version, + self.resource_file(), self.pump, self.model) def predict(self, filename=None, y=None, sr=None, outputs=None): '''Predict annotations @@ -89,30 +112,27 @@ def transform(self, filename=None, y=None, sr=None): '''Feature transformation''' raise NotImplementedError + def resource_file(self, f='', *fname): + return ( + os.path.join(self.models_dir, f, *fname) + if self.models_dir is not None else + resource_filename(__name__, os.path.join(f, *fname))) + def _instantiate(self, rsc): # First, load the pump - with open(resource_filename(__name__, - os.path.join(rsc, 'pump.pkl')), - 'rb') as fd: + with open(self.resource_file(rsc, 'pump.pkl'), 'rb') as fd: self.pump = pickle.load(fd) # Now load the model - with open(resource_filename(__name__, - os.path.join(rsc, 'model_spec.pkl')), - 'rb') as fd: + custom_objects = dict(CORE_CUSTOM_OBJECTS, **self.custom_objects) + with open(self.resource_file(rsc, 'model_spec.pkl'), 'rb') as fd: spec = pickle.load(fd) - self.model = model_from_config(spec, - custom_objects={k: layers.__dict__[k] - for k in layers.__all__}) + self.model = model_from_config(spec, custom_objects=custom_objects) # And the model weights - self.model.load_weights(resource_filename(__name__, - os.path.join(rsc, - 'model.h5'))) + self.model.load_weights(self.resource_file(rsc, 'model.h5')) # And the version number - with open(resource_filename(__name__, - os.path.join(rsc, 'version.txt')), - 'r') as fd: + with open(self.resource_file(rsc, 'version.txt'), 'r') as fd: self.version = fd.read().strip()