diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index ed3f621..99d19cf 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -818,6 +818,101 @@ def control_limits(u): self.assertLess( np.linalg.norm(X[-1] - goal, ord=np.inf), constraints_threshold) + def testRandomShooting1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + zero_input_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + zero_input_obj), 1e-6) + + optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 100, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.random_shooting( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-5.0]), np.array([5.0]), + hyperparams=cem_hyperparams, + ) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 10*optimal_obj) + # Approximately 234 + + def testCEM1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + zero_input_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + zero_input_obj), 1e-6) + + optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 100, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.cem( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-5.0]), np.array([5.0]), + hyperparams=cem_hyperparams, + ) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 10*optimal_obj) + # Objective is Around 171 if __name__ == '__main__': absltest.main() diff --git a/trajax/optimizers.py b/trajax/optimizers.py index 23fc40a..5f210d2 100644 --- a/trajax/optimizers.py +++ b/trajax/optimizers.py @@ -50,6 +50,7 @@ from functools import partial # pylint: disable=g-importing-member import jax +from frozendict import frozendict from jax import custom_derivatives from jax import device_get from jax import hessian @@ -985,13 +986,13 @@ def hess_vec_prod(u, v): def default_cem_hyperparams(): - return { + return frozendict({ 'sampling_smoothing': 0., 'evolution_smoothing': 0.1, 'elite_portion': 0.1, 'max_iter': 10, 'num_samples': 400 - } + }) @partial(jit, static_argnums=(4,)) @@ -1051,7 +1052,7 @@ def body_fun(t, noises): return samples -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=(0, 1, 6, 7)) def cem(cost, dynamics, init_state, @@ -1116,7 +1117,7 @@ def loop_body(_, args): return X, mean, obj -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=(0, 1, 6, 7)) def random_shooting(cost, dynamics, init_state, @@ -1160,12 +1161,12 @@ def random_shooting(cost, obj_fn = partial(_objective, cost, dynamics) controls = gaussian_samples(random_key, mean, stdev, control_low, control_high, hyperparams) - costs = vmap(obj_fn, in_axes=(0, None))(controls, init_state) + costs = vmap(obj_fn, in_axes=(0, None, None, None))(controls, init_state, {}, {}) best_idx = np.argmin(costs) U = controls[best_idx] - X = rollout(dynamics, mean, init_state) - obj = objective(cost, dynamics, mean, init_state) + X = rollout(dynamics, U, init_state) + obj = objective(cost, dynamics, U, init_state) return X, U, obj @@ -1380,4 +1381,4 @@ def continuation_criteria(inputs): continuation_criteria, body, (X, U, dual_equality, dual_inequality, penalty, equality_constraints, inequality_constraints, np.inf, np.inf, - np.full(U.shape, np.inf), 0, 0)) + np.full(U.shape, np.inf), 0, 0)) \ No newline at end of file