diff --git a/docs/examples/example_moving_eddies.py b/docs/examples/example_moving_eddies.py index 24d47f2b96..25fafa9501 100644 --- a/docs/examples/example_moving_eddies.py +++ b/docs/examples/example_moving_eddies.py @@ -97,7 +97,14 @@ def cosd(x): data = {'U': U, 'V': V, 'P': P} dimensions = {'lon': lon, 'lat': lat, 'time': time} - return FieldSet.from_data(data, dimensions, transpose=True, mesh=mesh) + + fieldset = FieldSet.from_data(data, dimensions, transpose=True, mesh=mesh) + + # setting some constants for AdvectionRK45 kernel + fieldset.RK45_min_dt = 1e-3 + fieldset.RK45_max_dt = 1e2 + fieldset.RK45_tol = 1e-5 + return fieldset def moving_eddies_example(fieldset, outfile, npart=2, mode='jit', verbose=False, diff --git a/parcels/application_kernels/advection.py b/parcels/application_kernels/advection.py index ea76311457..c661b4610d 100644 --- a/parcels/application_kernels/advection.py +++ b/parcels/application_kernels/advection.py @@ -64,10 +64,10 @@ def AdvectionRK45(particle, fieldset, time): 1e-5 * dt by default. Note that this kernel requires a Particle Class that has an extra Variable 'next_dt' + and a FieldSet with constants 'RK45_tol' (in meters), 'RK45_min_dt' (in seconds) + and 'RK45_max_dt' (in seconds). """ - particle.dt = particle.next_dt - rk45tol = 1e-5 - min_dt = 1e-3 + particle.dt = min(particle.next_dt, fieldset.RK45_max_dt) c = [1./4., 3./8., 12./13., 1., 1./2.] A = [[1./4., 0., 0., 0., 0.], [3./32., 9./32., 0., 0., 0.], @@ -99,11 +99,11 @@ def AdvectionRK45(particle, fieldset, time): lon_5th = (u1 * b5[0] + u2 * b5[1] + u3 * b5[2] + u4 * b5[3] + u5 * b5[4] + u6 * b5[5]) * particle.dt lat_5th = (v1 * b5[0] + v2 * b5[1] + v3 * b5[2] + v4 * b5[3] + v5 * b5[4] + v6 * b5[5]) * particle.dt - kappa2 = math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2) - if kappa2 <= math.pow(math.fabs(particle.dt * rk45tol), 2) or particle.dt < min_dt: + kappa = math.sqrt(math.pow(lon_5th - lon_4th, 2) + math.pow(lat_5th - lat_4th, 2)) + if (kappa <= fieldset.RK45_tol) or (math.fabs(particle.dt) < math.fabs(fieldset.RK45_min_dt)): particle_dlon += lon_4th # noqa particle_dlat += lat_4th # noqa - if kappa2 <= math.pow(math.fabs(particle.dt * rk45tol / 10), 2): + if (kappa <= fieldset.RK45_tol) / 10 and (math.fabs(particle.dt*2) <= math.fabs(fieldset.RK45_max_dt)): particle.next_dt *= 2 else: particle.next_dt /= 2 diff --git a/parcels/compilation/codegenerator.py b/parcels/compilation/codegenerator.py index 45d5e2d8b9..19c602079d 100644 --- a/parcels/compilation/codegenerator.py +++ b/parcels/compilation/codegenerator.py @@ -925,13 +925,16 @@ def generate(self, funcname, field_args, const_args, kernel_ast, c_include): # ==== statement clusters use to compose 'body' variable and variables 'time_loop' and 'part_loop' ==== ## sign_dt = c.Assign("sign_dt", "dt > 0 ? 1 : -1") + # ==== check if next_dt is in the particle type ==== # + dtname = "next_dt" if "next_dt" in [v.name for v in self.ptype.variables] else "dt" + # ==== main computation body ==== # body = [] body += [c.Value("double", "pre_dt")] body += [c.Statement("pre_dt = particles->dt[pnum]")] body += [c.If("sign_dt*particles->time_nextloop[pnum] >= sign_dt*(endtime)", c.Statement("break"))] - body += [c.If("fabs(endtime - particles->time_nextloop[pnum]) < fabs(particles->dt[pnum])-1e-6", - c.Statement("particles->dt[pnum] = fabs(endtime - particles->time_nextloop[pnum]) * sign_dt"))] + body += [c.If(f"fabs(endtime - particles->time_nextloop[pnum]) < fabs(particles->{dtname}[pnum])-1e-6", + c.Statement(f"particles->{dtname}[pnum] = fabs(endtime - particles->time_nextloop[pnum]) * sign_dt"))] body += [c.Assign("particles->state[pnum]", f"{funcname}(particles, pnum, {fargs_str})")] body += [c.If("particles->state[pnum] == SUCCESS", c.Block([c.If("sign_dt*particles->time[pnum] < sign_dt*endtime", diff --git a/parcels/kernel.py b/parcels/kernel.py index 38b8dec6e5..1c9758440a 100644 --- a/parcels/kernel.py +++ b/parcels/kernel.py @@ -23,7 +23,11 @@ MPI = None import parcels.rng as ParcelsRandom # noqa -from parcels.application_kernels.advection import AdvectionAnalytical, AdvectionRK4_3D +from parcels.application_kernels.advection import ( + AdvectionAnalytical, + AdvectionRK4_3D, + AdvectionRK45, +) from parcels.compilation.codegenerator import KernelGenerator, LoopGenerator from parcels.field import Field, NestedField, VectorField from parcels.grid import GridCode @@ -321,6 +325,18 @@ def check_fieldsets_in_kernels(self, pyfunc): raise NotImplementedError('Analytical Advection only works with C-grids') if self._fieldset.U.grid.gtype not in [GridCode.CurvilinearZGrid, GridCode.RectilinearZGrid]: raise NotImplementedError('Analytical Advection only works with Z-grids in the vertical') + elif pyfunc is AdvectionRK45: + if not hasattr(self.fieldset, 'RK45_tol'): + logger.info("Setting RK45 tolerance to 10 m. Use fieldset.add_constant('RK45_tol', [distance]) to change.") + self.fieldset.add_constant('RK45_tol', 10) + if self.fieldset.U.grid.mesh == 'spherical': + self.fieldset.RK45_tol /= (1852 * 60) # TODO does not account for zonal variation in meter -> degree conversion + if not hasattr(self.fieldset, 'RK45_min_dt'): + logger.info("Setting RK45 minimum timestep to 1 s. Use fieldset.add_constant('RK45_min_dt', [timestep]) to change.") + self.fieldset.add_constant('RK45_min_dt', 1) + if not hasattr(self.fieldset, 'RK45_max_dt'): + logger.info("Setting RK45 maximum timestep to 1 day. Use fieldset.add_constant('RK45_max_dt', [timestep]) to change.") + self.fieldset.add_constant('RK45_max_dt', 60*60*24) def check_kernel_signature_on_version(self): numkernelargs = 0 @@ -623,8 +639,12 @@ def evaluate_particle(self, p, endtime): if sign_dt*p.time_nextloop >= sign_dt*endtime: return p - if abs(endtime - p.time_nextloop) < abs(p.dt)-1e-6: - p.dt = abs(endtime - p.time_nextloop) * sign_dt + try: # Use next_dt from AdvectionRK45 if it is set + if abs(endtime - p.time_nextloop) < abs(p.next_dt)-1e-6: + p.next_dt = abs(endtime - p.time_nextloop) * sign_dt + except KeyError: + if abs(endtime - p.time_nextloop) < abs(p.dt)-1e-6: + p.dt = abs(endtime - p.time_nextloop) * sign_dt res = self._pyfunc(p, self._fieldset, p.time_nextloop) if res is None: diff --git a/tests/test_advection.py b/tests/test_advection.py index 7e11338452..50b9cd3320 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -166,6 +166,26 @@ def SubmergeParticle(particle, fieldset, time): assert len(pset) == 0 +@pytest.mark.parametrize('mode', ['scipy', 'jit']) +@pytest.mark.parametrize('rk45_tol', [10, 100]) +def test_advection_RK45(lon, lat, mode, rk45_tol, npart=10): + data2D = {'U': np.ones((lon.size, lat.size), dtype=np.float32), + 'V': np.zeros((lon.size, lat.size), dtype=np.float32)} + dimensions = {'lon': lon, 'lat': lat} + fieldset = FieldSet.from_data(data2D, dimensions, mesh='spherical', transpose=True) + fieldset.add_constant('RK45_tol', rk45_tol) + + dt = delta(seconds=30).total_seconds() + RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt) + pset = ParticleSet(fieldset, pclass=RK45Particles, + lon=np.zeros(npart) + 20., + lat=np.linspace(0, 80, npart)) + pset.execute(AdvectionRK45, runtime=delta(hours=2), dt=dt) + assert (np.diff(pset.lon) > 1.e-4).all() + assert np.isclose(fieldset.RK45_tol, rk45_tol/(1852*60)) + print(fieldset.RK45_tol) + + def periodicfields(xdim, ydim, uvel, vvel): dimensions = {'lon': np.linspace(0., 1., xdim+1, dtype=np.float32)[1:], # don't include both 0 and 1, for periodic b.c. 'lat': np.linspace(0., 1., ydim+1, dtype=np.float32)[1:]} @@ -279,7 +299,12 @@ def fieldset_stationary(xdim=100, ydim=100, maxtime=delta(hours=6)): 'time': time} data = {'U': np.ones((xdim, ydim, 1), dtype=np.float32) * u_0 * np.cos(f * time), 'V': np.ones((xdim, ydim, 1), dtype=np.float32) * -u_0 * np.sin(f * time)} - return FieldSet.from_data(data, dimensions, mesh='flat', transpose=True) + fieldset = FieldSet.from_data(data, dimensions, mesh='flat', transpose=True) + # setting some constants for AdvectionRK45 kernel + fieldset.RK45_min_dt = 1e-3 + fieldset.RK45_max_dt = 1e2 + fieldset.RK45_tol = 1e-5 + return fieldset @pytest.mark.parametrize('mode', ['scipy', 'jit'])