diff --git a/examples/tutorials/tuto_plot_custom_emissivity.py b/examples/tutorials/tuto_plot_custom_emissivity.py index dc8676037..2b63672da 100644 --- a/examples/tutorials/tuto_plot_custom_emissivity.py +++ b/examples/tutorials/tuto_plot_custom_emissivity.py @@ -31,6 +31,7 @@ # Now, we define an emissivity function that depends on r and z coordinates. # We can plot its profile in the (0, X, Z) plane. + def emissivity(pts, t=None, vect=None): """Custom emissivity as a function of geometry. diff --git a/tofu/geom/_core.py b/tofu/geom/_core.py index 03f76e79c..13357a8c9 100644 --- a/tofu/geom/_core.py +++ b/tofu/geom/_core.py @@ -5770,6 +5770,8 @@ def get_inspector(self, ff): return na, kw def check_ff(self, ff, t=None, ani=None): + # Initialization of function wrapper + wrapped_ff = ff # Define unique error message giving all info in a concise way # Optionnally add error-specific line afterwards @@ -5838,10 +5840,17 @@ def check_ff(self, ff, t=None, ani=None): msg += "\n\n => ff must take a ff(pts, t=t) !" raise Exception(msg) - if not (isinstance(out, np.ndarray) and out.shape == (nt, npts)): + if not (isinstance(out, np.ndarray) and (out.shape == (nt, npts) + or out.shape == (npts,))): msg += "\n\n => wrong output (always 2d np.ndarray) !" raise Exception(msg) - return is_ani + + if nt == 1 and out.shape == (npts,): + def wrapped_ff(*args, **kwargs): + res_ff = ff(*args, **kwargs) + return np.reshape(res_ff, (1, -1)) + + return is_ani, wrapped_ff def _calc_signal_preformat(self, ind=None, DL=None, t=None, out=object, Brightness=True): @@ -6049,7 +6058,7 @@ def calc_signal( # Launch # NB : find a way to exclude cases with DL[0,:]>=DL[1,:] !! # Exclude Rays not seeing the plasma if newcalc: - ani = self.check_ff(func, t=t, ani=ani) + ani, func = self.check_ff(func, t=t, ani=ani) s = _GG.LOS_calc_signal( func, Ds,