-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_ddpg.py
More file actions
57 lines (47 loc) · 1.63 KB
/
main_ddpg.py
File metadata and controls
57 lines (47 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import gym
import numpy as np
import random
from agent_ddpg import DDPGAgent
import os
import torch
import time
import pygame
# 初始化环境
env = gym.make(id='Pendulum-v1', render_mode="human")
STATE_DIM = env.observation_space.shape[0]
ACTION_DIM = env.action_space.shape[0]
agent =DDPGAgent(STATE_DIM, ACTION_DIM)
#超参数
NUM_EPISODE = 100
NUM_STEP = 200
EPSILON_START = 1
EPSILON_END = 0.02
EPSILON_DECAY = 10000
REWARD_BUFFER = np.empty(shape=NUM_EPISODE)
for episode_i in range(NUM_EPISODE):
state, others = env.reset()
episode_reward = 0
for step_i in range(NUM_STEP):
epsilon = np.interp(x=episode_i*NUM_STEP+step_i, xp=[0,EPSILON_DECAY], fp=[EPSILON_START,EPSILON_END])
random_sample = random.random()
if random_sample <=epsilon:
action = np.random.uniform(low=-2, high=2,size=ACTION_DIM)
else:
action = agent.get_action(state)
next_state, reward, done, truncation, info = env.step(action)
agent.replay_buffer.add_memo(state, action, reward, next_state,done)
state = next_state
episode_reward += reward
agent.update() #每个时间步更新
if done:
break
REWARD_BUFFER[episode_i] = episode_reward
print(f"Episode: {episode_i+1}, Reward: {round(episode_reward,2)}")
#模型保存地址
current_path = os.path.dirname(os.path.realpath(__file__))
model = current_path + '/models/'
timestamp = time.strftime("%Y%m%d%H%M%S")
#保存模型
torch.save(agent.actor.state_dict(), model + f"ddpg_actor_{timestamp}.path")
torch.save(agent.critic.state_dict(), model + f"ddpg_critic_{timestamp}.path")
env.close()