diff --git a/example/reinforcement-learning/ddpg/strategies.py b/example/reinforcement-learning/ddpg/strategies.py index d73ad060cc87..eb22ddf59728 100644 --- a/example/reinforcement-learning/ddpg/strategies.py +++ b/example/reinforcement-learning/ddpg/strategies.py @@ -61,7 +61,7 @@ def reset(self): def get_action(self, obs, policy): # get_action accepts a 2D tensor with one row - obs = obs.reshape((1, -1)) + obs = obs.reshape((1, -1)) action = policy.get_action(obs) increment = self.evolve_state() @@ -94,5 +94,3 @@ def __init__(self): plt.plot(states) plt.show() - -