diff --git a/src/qinfer/parallel.py b/src/qinfer/parallel.py index 7ccf820..2f5c269 100644 --- a/src/qinfer/parallel.py +++ b/src/qinfer/parallel.py @@ -68,6 +68,10 @@ def expparams_dtype(self): @property def is_n_outcomes_constant(self): return self._serial_model.is_n_outcomes_constant + + @property + def n_engines(self): + return len(self._dv) ## METHODS ## @@ -80,12 +84,24 @@ def n_outcomes(self, expparams): def likelihood(self, outcomes, modelparams, expparams): # By calling the superclass implementation, we can consolidate # call counting there. - super(DirectViewParallelizedModel, self).likelihood(outcomes, modelparams, expparams) - + super(DirectViewParallelizedModel, self).likelihood(outcomes, modelparams, expparams) + + # Need to decorate with interactive to overcome namespace issues with + # remote engines. + @IPython.parallel.interactive + def serial_likelihood(mps, sm, os, eps): + return sm.likelihood(os, mps, eps) + + # TODO: check whether there's a better way to pass the extra parameters + # that doesn't use so much memory. + # The trick is that serial_likelihood will be pickled, so we need to be + # careful about closures. L = self._dv.map_sync( - lambda mps, sm=self._serial_model, os=outcomes, eps=expparams: - sm.likelihood(os, mps, eps), - np.array_split(modelparams, len(self._dv), axis=0) + serial_likelihood, + np.array_split(modelparams, self.n_engines, axis=0), + [self._serial_model] * self.n_engines, + [outcomes] * self.n_engines, + [expparams] * self.n_engines, ) return np.concatenate(L, axis=1)