1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
| class ReplayBuffer: def __init__(self, capacity): self.buffer = collections.deque(maxlen=capacity) return
def add(self, state, action, reward, next_state, done): self.buffer.append((state, action, reward, next_state, done)) return
def sample(self, batch_size): transitions = random.sample(self.buffer, batch_size) state, action, reward, next_state, done = zip(*transitions) return np.array(state), action, reward, np.array(next_state), done
def size(self): return len(self.buffer)
class DDPG: def __init__(self, PolicyNet, QValueNet, sigma, actor_lr, critic_lr, tau, gamma, action_dim, device): self.actor = PolicyNet().to(device) self.critic = QValueNet().to(device) self.target_actor = PolicyNet().to(device) self.target_critic = QValueNet().to(device) self.target_critic.load_state_dict(self.critic.state_dict()) self.target_actor.load_state_dict(self.actor.state_dict()) self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr) self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr) self.gamma = gamma self.sigma = sigma self.tau = tau self.action_dim = action_dim self.device = device return
def sample(self, state, pre=False): state = torch.tensor([state], dtype=torch.float).to(self.device) action = self.actor(state).item() action = action if pre else action + self.sigma * np.random.randn(self.action_dim) return action
def soft_update(self, net, target_net): for param_target, param in zip(target_net.parameters(), net.parameters()): param_target.data.copy_(param_target.data * (1.0 - self.tau) + param.data * self.tau) return
def update(self, s, a, r, s_t, dones): states = torch.tensor(s, dtype=torch.float).to(self.device) actions = torch.tensor(a, dtype=torch.float).view(-1, 1).to(self.device) rewards = torch.tensor(r, dtype=torch.float).view(-1, 1).to(self.device) next_states = torch.tensor(s_t, dtype=torch.float).to(self.device) dones = torch.tensor(dones, dtype=torch.float).view(-1, 1).to(self.device)
next_q_values = self.target_critic(next_states, self.target_actor(next_states)) q_targets = rewards + self.gamma * next_q_values * (1 - dones)
critic_loss = torch.mean(F.mse_loss(self.critic(states, actions), q_targets)) self.critic_optimizer.zero_grad() critic_loss.backward() self.critic_optimizer.step()
actor_loss = -torch.mean(self.critic(states, self.actor(states))) self.actor_optimizer.zero_grad() actor_loss.backward() self.actor_optimizer.step()
self.soft_update(self.actor, self.target_actor) self.soft_update(self.critic, self.target_critic) return
|