From 0e413b22edc293deb9b3e9f8a5feec9f61c400d1 Mon Sep 17 00:00:00 2001 From: Kwesi Rutledge Date: Tue, 18 Jul 2023 19:19:41 -0400 Subject: [PATCH 1/2] Incorporated Small Changes (Using FrozenDict + Providing Default empty dictionaries to costs vmap) That Seem To Make Random Shooting Happy --- tests/optimizers_test.py | 45 ++++++++++++++++++++++++++++++++++++++++ trajax/optimizers.py | 11 +++++----- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index ed3f621..6fa0ed3 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -818,6 +818,51 @@ def control_limits(u): self.assertLess( np.linalg.norm(X[-1] - goal, ord=np.inf), constraints_threshold) + def testRandomShooting1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + true_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + true_obj), 1e-6) + + # optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 100, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.random_shooting( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-10.0]), np.array([10.0]), + hyperparams=cem_hyperparams, + ) + self.assertAlmostEqual(obj, true_obj, places=4) if __name__ == '__main__': absltest.main() diff --git a/trajax/optimizers.py b/trajax/optimizers.py index 23fc40a..04ea8b8 100644 --- a/trajax/optimizers.py +++ b/trajax/optimizers.py @@ -50,6 +50,7 @@ from functools import partial # pylint: disable=g-importing-member import jax +from frozendict import frozendict from jax import custom_derivatives from jax import device_get from jax import hessian @@ -985,13 +986,13 @@ def hess_vec_prod(u, v): def default_cem_hyperparams(): - return { + return frozendict({ 'sampling_smoothing': 0., 'evolution_smoothing': 0.1, 'elite_portion': 0.1, 'max_iter': 10, 'num_samples': 400 - } + }) @partial(jit, static_argnums=(4,)) @@ -1116,7 +1117,7 @@ def loop_body(_, args): return X, mean, obj -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=(0, 1, 6, 7)) def random_shooting(cost, dynamics, init_state, @@ -1160,7 +1161,7 @@ def random_shooting(cost, obj_fn = partial(_objective, cost, dynamics) controls = gaussian_samples(random_key, mean, stdev, control_low, control_high, hyperparams) - costs = vmap(obj_fn, in_axes=(0, None))(controls, init_state) + costs = vmap(obj_fn, in_axes=(0, None, None, None))(controls, init_state, {}, {}) best_idx = np.argmin(costs) U = controls[best_idx] @@ -1380,4 +1381,4 @@ def continuation_criteria(inputs): continuation_criteria, body, (X, U, dual_equality, dual_inequality, penalty, equality_constraints, inequality_constraints, np.inf, np.inf, - np.full(U.shape, np.inf), 0, 0)) + np.full(U.shape, np.inf), 0, 0)) \ No newline at end of file From 72826417ba84103cf311f763b0bd983c8f7d8649 Mon Sep 17 00:00:00 2001 From: Kwesi Rutledge Date: Tue, 18 Jul 2023 19:38:35 -0400 Subject: [PATCH 2/2] Added Small Changes Which Make CEM Run + Fixed a Bug in random_shooting --- tests/optimizers_test.py | 60 ++++++++++++++++++++++++++++++++++++---- trajax/optimizers.py | 6 ++-- 2 files changed, 58 insertions(+), 8 deletions(-) diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index 6fa0ed3..99d19cf 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -839,14 +839,14 @@ def cost(x, u, t, params): x0 = np.zeros(4) U = np.zeros((T, 1)) params = np.array([1000.0, 0.1, 0.01]) - true_obj = 4959.476212 + zero_input_obj = 4959.476212 self.assertLess( np.abs( optimizers.objective( functools.partial(cost, params=params), dynamics, U, x0) - - true_obj), 1e-6) + zero_input_obj), 1e-6) - # optimal_obj = 51.0 + optimal_obj = 51.0 cem_hyperparams = frozendict({ 'sampling_smoothing': 0.2, 'evolution_smoothing': 0.1, @@ -859,10 +859,60 @@ def cost(x, u, t, params): dynamics, x0, U, - np.array([-10.0]), np.array([10.0]), + np.array([-5.0]), np.array([5.0]), + hyperparams=cem_hyperparams, + ) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 10*optimal_obj) + # Approximately 234 + + def testCEM1(self): + """ + test_CEM1 + Description: + Attempts to use the Cross Entropy Method to solve the acrobot problem from "testAcrobotSolve" + """ + + T = 50 + goal = np.array([np.pi, 0.0, 0.0, 0.0]) + dynamics = euler(acrobot, dt=0.1) + + def cost(x, u, t, params): + delta = x - goal + terminal_cost = 0.5 * params[0] * np.dot(delta, delta) + stagewise_cost = 0.5 * params[1] * np.dot( + delta, delta) + 0.5 * params[2] * np.dot(u, u) + return np.where(t == T, terminal_cost, stagewise_cost) + + x0 = np.zeros(4) + U = np.zeros((T, 1)) + params = np.array([1000.0, 0.1, 0.01]) + zero_input_obj = 4959.476212 + self.assertLess( + np.abs( + optimizers.objective( + functools.partial(cost, params=params), dynamics, U, x0) - + zero_input_obj), 1e-6) + + optimal_obj = 51.0 + cem_hyperparams = frozendict({ + 'sampling_smoothing': 0.2, + 'evolution_smoothing': 0.1, + 'elite_portion': 0.1, + 'max_iter': 100, + 'num_samples': 20_000 + }) + X_opt, U_opt, obj = optimizers.cem( + functools.partial(cost, params=params), + dynamics, + x0, + U, + np.array([-5.0]), np.array([5.0]), hyperparams=cem_hyperparams, ) - self.assertAlmostEqual(obj, true_obj, places=4) + self.assertLessEqual(obj, zero_input_obj) + self.assertLessEqual(obj, 10*optimal_obj) + # Objective is Around 171 if __name__ == '__main__': absltest.main() diff --git a/trajax/optimizers.py b/trajax/optimizers.py index 04ea8b8..5f210d2 100644 --- a/trajax/optimizers.py +++ b/trajax/optimizers.py @@ -1052,7 +1052,7 @@ def body_fun(t, noises): return samples -@partial(jit, static_argnums=(0, 1)) +@partial(jit, static_argnums=(0, 1, 6, 7)) def cem(cost, dynamics, init_state, @@ -1165,8 +1165,8 @@ def random_shooting(cost, best_idx = np.argmin(costs) U = controls[best_idx] - X = rollout(dynamics, mean, init_state) - obj = objective(cost, dynamics, mean, init_state) + X = rollout(dynamics, U, init_state) + obj = objective(cost, dynamics, U, init_state) return X, U, obj