From 315285e178b878a91fe98f3b4bb31d67590743b3 Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Fri, 5 Jun 2020 14:48:04 -0400 Subject: [PATCH 1/5] better support for externally defined models --- crema/models/base.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/crema/models/base.py b/crema/models/base.py index 7be91e7..e78a75f 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -12,7 +12,12 @@ from .. import layers +CORE_CUSTOM_OBJECTS = {k: layers.__dict__[k] for k in layers.__all__} + + class CremaModel(object): + model_root = None + custom_objects = {} def predict(self, filename=None, y=None, sr=None, outputs=None): '''Predict annotations @@ -89,30 +94,26 @@ def transform(self, filename=None, y=None, sr=None): '''Feature transformation''' raise NotImplementedError + def _get_resource(self, *fname): + return ( + os.path.join(self.model_root, *fname) if self.model_root + else resource_filename(__name__, os.path.join(*fname))) + def _instantiate(self, rsc): # First, load the pump - with open(resource_filename(__name__, - os.path.join(rsc, 'pump.pkl')), - 'rb') as fd: + with open(self._get_resource(rsc, 'pump.pkl'), 'rb') as fd: self.pump = pickle.load(fd) # Now load the model - with open(resource_filename(__name__, - os.path.join(rsc, 'model_spec.pkl')), - 'rb') as fd: + custom_objects = dict(CORE_CUSTOM_OBJECTS, **self.custom_objects) + with open(self._get_resource(rsc, 'model_spec.pkl'), 'rb') as fd: spec = pickle.load(fd) - self.model = model_from_config(spec, - custom_objects={k: layers.__dict__[k] - for k in layers.__all__}) + self.model = model_from_config(spec, custom_objects=custom_objects) # And the model weights - self.model.load_weights(resource_filename(__name__, - os.path.join(rsc, - 'model.h5'))) + self.model.load_weights(self._get_resource(rsc, 'model.h5')) # And the version number - with open(resource_filename(__name__, - os.path.join(rsc, 'version.txt')), - 'r') as fd: + with open(self._get_resource(rsc, 'version.txt'), 'r') as fd: self.version = fd.read().strip() From 3c708bec952a5b3e3a9ca7cfb92c3da2d220b0ac Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Fri, 5 Jun 2020 14:58:17 -0400 Subject: [PATCH 2/5] instantiate if name attr is set --- crema/models/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crema/models/base.py b/crema/models/base.py index e78a75f..c7513d8 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -16,9 +16,14 @@ class CremaModel(object): + name = None model_root = None custom_objects = {} + def __init__(self): + if self.name: + self._instantiate(self.name) + def predict(self, filename=None, y=None, sr=None, outputs=None): '''Predict annotations From c12b97643bff85bed62d9387a2e41b437d9ed8a5 Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Fri, 5 Jun 2020 16:55:04 -0400 Subject: [PATCH 3/5] rename models_dir and add cls file util --- crema/models/base.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/crema/models/base.py b/crema/models/base.py index c7513d8..ea6baf3 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -17,7 +17,7 @@ class CremaModel(object): name = None - model_root = None + models_dir = None custom_objects = {} def __init__(self): @@ -99,26 +99,27 @@ def transform(self, filename=None, y=None, sr=None): '''Feature transformation''' raise NotImplementedError - def _get_resource(self, *fname): + @classmethod + def resource_file(cls, *fname): return ( - os.path.join(self.model_root, *fname) if self.model_root + os.path.join(cls.models_dir, *fname) if cls.models_dir is not None else resource_filename(__name__, os.path.join(*fname))) def _instantiate(self, rsc): # First, load the pump - with open(self._get_resource(rsc, 'pump.pkl'), 'rb') as fd: + with open(self.resource_file(rsc, 'pump.pkl'), 'rb') as fd: self.pump = pickle.load(fd) # Now load the model custom_objects = dict(CORE_CUSTOM_OBJECTS, **self.custom_objects) - with open(self._get_resource(rsc, 'model_spec.pkl'), 'rb') as fd: + with open(self.resource_file(rsc, 'model_spec.pkl'), 'rb') as fd: spec = pickle.load(fd) self.model = model_from_config(spec, custom_objects=custom_objects) # And the model weights - self.model.load_weights(self._get_resource(rsc, 'model.h5')) + self.model.load_weights(self.resource_file(rsc, 'model.h5')) # And the version number - with open(self._get_resource(rsc, 'version.txt'), 'r') as fd: + with open(self.resource_file(rsc, 'version.txt'), 'r') as fd: self.version = fd.read().strip() From ef5a2aeb682593acfb3ff361b254fc6da494269f Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Mon, 20 Jul 2020 18:46:01 -0400 Subject: [PATCH 4/5] handle init args --- crema/models/base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/crema/models/base.py b/crema/models/base.py index ea6baf3..cf935c5 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -20,7 +20,9 @@ class CremaModel(object): models_dir = None custom_objects = {} - def __init__(self): + def __init__(self, name=None, models_dir=None): + self.name = name or self.name + self.models_dir = models_dir or self.models_dir if self.name: self._instantiate(self.name) @@ -99,11 +101,11 @@ def transform(self, filename=None, y=None, sr=None): '''Feature transformation''' raise NotImplementedError - @classmethod - def resource_file(cls, *fname): + def resource_file(self, f='', *fname): return ( - os.path.join(cls.models_dir, *fname) if cls.models_dir is not None - else resource_filename(__name__, os.path.join(*fname))) + os.path.join(self.models_dir, f, *fname) + if self.models_dir is not None else + resource_filename(__name__, os.path.join(f, *fname))) def _instantiate(self, rsc): From 492d50357df6c4e96392ee5094b60a8469265f84 Mon Sep 17 00:00:00 2001 From: Bea Steers Date: Mon, 20 Jul 2020 18:46:13 -0400 Subject: [PATCH 5/5] add __str__ --- crema/models/base.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/crema/models/base.py b/crema/models/base.py index cf935c5..cac54d3 100644 --- a/crema/models/base.py +++ b/crema/models/base.py @@ -26,6 +26,17 @@ def __init__(self, name=None, models_dir=None): if self.name: self._instantiate(self.name) + def __str__(self): + if not hasattr(self, 'model'): # not instantiated + return super().__str__() + return ( + '<{} version={} resources={}\n' + '--------\n' + '* {}\n* {}\n' + '--------->').format( + self.__class__.__name__, self.version, + self.resource_file(), self.pump, self.model) + def predict(self, filename=None, y=None, sr=None, outputs=None): '''Predict annotations