Actor-Critic算法实战用PyTorch搞定CartPole平衡问题附完整代码在强化学习领域Actor-Critic算法因其独特的双网络结构和稳定的训练特性而备受关注。本文将带您从零开始通过PyTorch框架实现一个完整的Actor-Critic模型解决经典的CartPole平衡问题。不同于简单的算法介绍我们将重点关注工程实现中的关键细节和常见陷阱。1. 环境准备与算法基础CartPole是OpenAI Gym中最经典的测试环境之一目标是通过控制小车的左右移动来保持杆子竖直不倒。这个看似简单的任务实际上包含了强化学习的核心挑战如何在连续的状态空间中做出离散的动作决策。Actor-Critic算法的核心思想是将策略学习和价值评估分离Actor演员负责根据当前状态选择动作Critic评论家评估当前状态的价值指导Actor的更新与纯策略梯度方法相比Actor-Critic的主要优势在于通过Critic的引导减少方差加速收敛可以处理连续动作空间适用于部分可观测环境import gym import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # 初始化环境 env gym.make(CartPole-v1) state_dim env.observation_space.shape[0] action_dim env.action_space.n2. 网络架构设计2.1 Actor网络实现Actor网络输出动作的概率分布我们使用简单的两层全连接网络class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size64): super(Actor, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, action_dim) def forward(self, state): x F.relu(self.fc1(state)) return F.softmax(self.fc2(x), dim-1)注意输出层使用softmax确保动作概率和为12.2 Critic网络实现Critic网络评估状态价值输出单个标量值class Critic(nn.Module): def __init__(self, state_dim, hidden_size64): super(Critic, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, 1) def forward(self, state): x F.relu(self.fc1(state)) return self.fc2(x)3. 核心算法实现3.1 TD误差计算时间差分(TD)误差是Actor-Critic算法的核心计算公式为TD_error r γ * V(s) - V(s)其中r即时奖励γ折扣因子(0.9-0.99)V(s)下一状态价值V(s)当前状态价值def compute_td_error(reward, state_value, next_state_value, gamma0.99, doneFalse): if done: td_error reward - state_value else: td_error reward gamma * next_state_value - state_value return td_error3.2 策略梯度更新Actor的更新使用策略梯度定理∇J(θ) E[∇logπ(a|s) * A(s,a)]其中A(s,a)是优势函数这里用TD误差近似。def update_actor(actor, critic, optimizer, state, action, td_error): optimizer.zero_grad() action_probs actor(state) selected_action_prob action_probs.gather(1, action.unsqueeze(1)) loss -torch.log(selected_action_prob) * td_error.detach() loss.mean().backward() optimizer.step()关键点必须对td_error使用detach()避免Critic梯度影响Actor更新4. 完整训练流程4.1 超参数设置参数推荐值说明LR_ACTOR1e-3Actor学习率LR_CRITIC1e-2Critic学习率GAMMA0.99折扣因子MAX_EPISODES1000最大训练回合数MAX_STEPS500每回合最大步数4.2 训练主循环def train(): actor Actor(state_dim, action_dim).to(device) critic Critic(state_dim).to(device) actor_optimizer torch.optim.Adam(actor.parameters(), lrLR_ACTOR) critic_optimizer torch.optim.Adam(critic.parameters(), lrLR_CRITIC) for episode in range(MAX_EPISODES): state env.reset() episode_reward 0 for step in range(MAX_STEPS): # 选择动作 state_tensor torch.FloatTensor(state).unsqueeze(0).to(device) action_probs actor(state_tensor) action torch.multinomial(action_probs, 1).item() # 执行动作 next_state, reward, done, _ env.step(action) episode_reward reward # 转换为tensor next_state_tensor torch.FloatTensor(next_state).unsqueeze(0).to(device) reward_tensor torch.FloatTensor([reward]).to(device) # 计算TD误差 state_value critic(state_tensor) next_state_value critic(next_state_tensor) td_error compute_td_error(reward_tensor, state_value, next_state_value, GAMMA, done) # 更新Critic critic_loss td_error.pow(2).mean() critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() # 更新Actor action_tensor torch.LongTensor([action]).to(device) update_actor(actor, critic, actor_optimizer, state_tensor, action_tensor, td_error) state next_state if done: break5. 调试技巧与性能优化在实际项目中Actor-Critic算法可能会遇到以下问题训练不稳定尝试降低学习率使用学习率衰减策略增加批处理大小收敛速度慢调整网络结构增加层数或神经元数量尝试不同的激活函数使用正交初始化实现中的常见错误忘记对td_error使用detach()Critic输出范围不合适没有正确处理回合结束状态# 正交初始化示例 def weights_init(m): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) nn.init.constant_(m.bias.data, 0.01) actor.apply(weights_init) critic.apply(weights_init)在CartPole环境中一个训练良好的Actor-Critic模型通常能在200-300回合内达到最大奖励500分。如果您的模型性能不佳可以尝试以下调整将Critic的学习率设为Actor的10倍使用更深的网络结构如3层引入经验回放机制添加熵正则化项防止过早收敛
Actor-Critic算法实战:用PyTorch搞定CartPole平衡问题(附完整代码)
Actor-Critic算法实战用PyTorch搞定CartPole平衡问题附完整代码在强化学习领域Actor-Critic算法因其独特的双网络结构和稳定的训练特性而备受关注。本文将带您从零开始通过PyTorch框架实现一个完整的Actor-Critic模型解决经典的CartPole平衡问题。不同于简单的算法介绍我们将重点关注工程实现中的关键细节和常见陷阱。1. 环境准备与算法基础CartPole是OpenAI Gym中最经典的测试环境之一目标是通过控制小车的左右移动来保持杆子竖直不倒。这个看似简单的任务实际上包含了强化学习的核心挑战如何在连续的状态空间中做出离散的动作决策。Actor-Critic算法的核心思想是将策略学习和价值评估分离Actor演员负责根据当前状态选择动作Critic评论家评估当前状态的价值指导Actor的更新与纯策略梯度方法相比Actor-Critic的主要优势在于通过Critic的引导减少方差加速收敛可以处理连续动作空间适用于部分可观测环境import gym import torch import torch.nn as nn import torch.nn.functional as F import numpy as np # 初始化环境 env gym.make(CartPole-v1) state_dim env.observation_space.shape[0] action_dim env.action_space.n2. 网络架构设计2.1 Actor网络实现Actor网络输出动作的概率分布我们使用简单的两层全连接网络class Actor(nn.Module): def __init__(self, state_dim, action_dim, hidden_size64): super(Actor, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, action_dim) def forward(self, state): x F.relu(self.fc1(state)) return F.softmax(self.fc2(x), dim-1)注意输出层使用softmax确保动作概率和为12.2 Critic网络实现Critic网络评估状态价值输出单个标量值class Critic(nn.Module): def __init__(self, state_dim, hidden_size64): super(Critic, self).__init__() self.fc1 nn.Linear(state_dim, hidden_size) self.fc2 nn.Linear(hidden_size, 1) def forward(self, state): x F.relu(self.fc1(state)) return self.fc2(x)3. 核心算法实现3.1 TD误差计算时间差分(TD)误差是Actor-Critic算法的核心计算公式为TD_error r γ * V(s) - V(s)其中r即时奖励γ折扣因子(0.9-0.99)V(s)下一状态价值V(s)当前状态价值def compute_td_error(reward, state_value, next_state_value, gamma0.99, doneFalse): if done: td_error reward - state_value else: td_error reward gamma * next_state_value - state_value return td_error3.2 策略梯度更新Actor的更新使用策略梯度定理∇J(θ) E[∇logπ(a|s) * A(s,a)]其中A(s,a)是优势函数这里用TD误差近似。def update_actor(actor, critic, optimizer, state, action, td_error): optimizer.zero_grad() action_probs actor(state) selected_action_prob action_probs.gather(1, action.unsqueeze(1)) loss -torch.log(selected_action_prob) * td_error.detach() loss.mean().backward() optimizer.step()关键点必须对td_error使用detach()避免Critic梯度影响Actor更新4. 完整训练流程4.1 超参数设置参数推荐值说明LR_ACTOR1e-3Actor学习率LR_CRITIC1e-2Critic学习率GAMMA0.99折扣因子MAX_EPISODES1000最大训练回合数MAX_STEPS500每回合最大步数4.2 训练主循环def train(): actor Actor(state_dim, action_dim).to(device) critic Critic(state_dim).to(device) actor_optimizer torch.optim.Adam(actor.parameters(), lrLR_ACTOR) critic_optimizer torch.optim.Adam(critic.parameters(), lrLR_CRITIC) for episode in range(MAX_EPISODES): state env.reset() episode_reward 0 for step in range(MAX_STEPS): # 选择动作 state_tensor torch.FloatTensor(state).unsqueeze(0).to(device) action_probs actor(state_tensor) action torch.multinomial(action_probs, 1).item() # 执行动作 next_state, reward, done, _ env.step(action) episode_reward reward # 转换为tensor next_state_tensor torch.FloatTensor(next_state).unsqueeze(0).to(device) reward_tensor torch.FloatTensor([reward]).to(device) # 计算TD误差 state_value critic(state_tensor) next_state_value critic(next_state_tensor) td_error compute_td_error(reward_tensor, state_value, next_state_value, GAMMA, done) # 更新Critic critic_loss td_error.pow(2).mean() critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() # 更新Actor action_tensor torch.LongTensor([action]).to(device) update_actor(actor, critic, actor_optimizer, state_tensor, action_tensor, td_error) state next_state if done: break5. 调试技巧与性能优化在实际项目中Actor-Critic算法可能会遇到以下问题训练不稳定尝试降低学习率使用学习率衰减策略增加批处理大小收敛速度慢调整网络结构增加层数或神经元数量尝试不同的激活函数使用正交初始化实现中的常见错误忘记对td_error使用detach()Critic输出范围不合适没有正确处理回合结束状态# 正交初始化示例 def weights_init(m): if isinstance(m, nn.Linear): nn.init.orthogonal_(m.weight.data) nn.init.constant_(m.bias.data, 0.01) actor.apply(weights_init) critic.apply(weights_init)在CartPole环境中一个训练良好的Actor-Critic模型通常能在200-300回合内达到最大奖励500分。如果您的模型性能不佳可以尝试以下调整将Critic的学习率设为Actor的10倍使用更深的网络结构如3层引入经验回放机制添加熵正则化项防止过早收敛