用Python和TensorFlow训练AI玩贪吃蛇从游戏逻辑到DQN算法实战贪吃蛇这个经典游戏几乎每个人都玩过。但你是否想过让AI来玩这个游戏会是什么样子本文将带你从零开始用Python和TensorFlow构建一个能够自主玩贪吃蛇的AI系统。不同于简单的规则式AI我们将使用深度强化学习中的DQN算法让AI真正学会如何玩这个游戏。1. 项目准备与环境搭建在开始编码之前我们需要准备好开发环境。这个项目需要以下几个主要组件Python 3.7或更高版本Pygame库用于游戏界面TensorFlow 2.x用于构建和训练神经网络NumPy用于数值计算安装这些依赖非常简单只需在命令行中执行以下命令pip install pygame tensorflow numpy对于硬件要求虽然可以在CPU上运行但如果有NVIDIA显卡并安装了CUDA训练速度会显著提升。建议至少4GB内存因为神经网络训练过程会比较消耗资源。项目目录结构建议如下/snake_ai /game __init__.py snake.py # 游戏逻辑 render.py # 游戏渲染 /rl __init__.py dqn.py # DQN算法实现 memory.py # 经验回放缓冲区 config.py # 配置文件 train.py # 训练脚本 play.py # 人类游玩脚本2. 贪吃蛇游戏逻辑实现首先我们需要构建贪吃蛇游戏的基本框架。使用Pygame可以方便地创建游戏窗口和处理用户输入。2.1 游戏核心类设计我们创建三个主要类Snake、Food和Game。下面是Snake类的核心代码class Snake: def __init__(self, block_size20, width800, height600): self.length 3 self.positions [(width // 2, height // 2)] self.direction random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) self.block_size block_size self.width width self.height height self.color (0, 255, 0) # 绿色 def get_head_position(self): return self.positions[0] def turn(self, new_direction): # 防止180度转弯 if (new_direction[0] * -1, new_direction[1] * -1) ! self.direction: self.direction new_direction def move(self): head self.get_head_position() x, y self.direction new_x (head[0] (x * self.block_size)) % self.width new_y (head[1] (y * self.block_size)) % self.height new_position (new_x, new_y) self.positions.insert(0, new_position) if len(self.positions) self.length: self.positions.pop() def reset(self): self.length 3 self.positions [(self.width // 2, self.height // 2)] self.direction random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) def draw(self, surface): for p in self.positions: rect pygame.Rect((p[0], p[1]), (self.block_size, self.block_size)) pygame.draw.rect(surface, self.color, rect) pygame.draw.rect(surface, (0, 0, 0), rect, 1)2.2 游戏主循环游戏主循环负责处理输入、更新游戏状态和渲染画面class Game: def __init__(self, width800, height600, block_size20): pygame.init() self.screen pygame.display.set_mode((width, height)) self.clock pygame.time.Clock() self.snake Snake(block_size, width, height) self.food Food(block_size, width, height) self.width width self.height height self.block_size block_size self.score 0 def run(self): running True while running: for event in pygame.event.get(): if event.type pygame.QUIT: running False elif event.type pygame.KEYDOWN: if event.key pygame.K_UP: self.snake.turn((0, -1)) elif event.key pygame.K_DOWN: self.snake.turn((0, 1)) elif event.key pygame.K_LEFT: self.snake.turn((-1, 0)) elif event.key pygame.K_RIGHT: self.snake.turn((1, 0)) self.snake.move() # 检测是否吃到食物 if self.snake.get_head_position() self.food.position: self.snake.length 1 self.score 1 self.food Food(self.block_size, self.width, self.height) # 检测碰撞 if self.snake.get_head_position() in self.snake.positions[1:]: print(fGame Over! Score: {self.score}) self.snake.reset() self.score 0 # 渲染 self.screen.fill((255, 255, 255)) self.snake.draw(self.screen) self.food.draw(self.screen) pygame.display.update() self.clock.tick(10) # 控制游戏速度 pygame.quit()3. DQN算法原理与实现深度Q网络(DQN)是强化学习中的一种重要算法它结合了Q-learning和深度神经网络的优点。3.1 DQN核心概念DQN的核心思想是使用神经网络来近似Q函数即状态-动作值函数。Q函数表示在某个状态下采取某个动作所能获得的预期回报。DQN有几个关键组件经验回放(Experience Replay)存储智能体的经验(状态,动作,奖励,新状态)在记忆库中训练时从中随机采样打破数据间的相关性。目标网络(Target Network)使用一个独立的网络来计算目标Q值提高训练稳定性。ε-贪婪策略(ε-Greedy Policy)在探索和利用之间取得平衡开始时更多探索逐渐增加利用。3.2 DQN实现代码下面是DQN的核心实现import numpy as np import tensorflow as tf from collections import deque import random class DQNAgent: def __init__(self, state_size, action_size): self.state_size state_size self.action_size action_size self.memory deque(maxlen2000) self.gamma 0.95 # 折扣因子 self.epsilon 1.0 # 探索率 self.epsilon_min 0.01 self.epsilon_decay 0.995 self.learning_rate 0.001 self.model self._build_model() self.target_model self._build_model() self.update_target_model() def _build_model(self): model tf.keras.Sequential() model.add(tf.keras.layers.Dense(24, input_dimself.state_size, activationrelu)) model.add(tf.keras.layers.Dense(24, activationrelu)) model.add(tf.keras.layers.Dense(self.action_size, activationlinear)) model.compile(lossmse, optimizertf.keras.optimizers.Adam(lrself.learning_rate)) return model def update_target_model(self): self.target_model.set_weights(self.model.get_weights()) def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def act(self, state): if np.random.rand() self.epsilon: return random.randrange(self.action_size) act_values self.model.predict(state) return np.argmax(act_values[0]) def replay(self, batch_size): if len(self.memory) batch_size: return minibatch random.sample(self.memory, batch_size) states np.array([i[0] for i in minibatch]) actions np.array([i[1] for i in minibatch]) rewards np.array([i[2] for i in minibatch]) next_states np.array([i[3] for i in minibatch]) dones np.array([i[4] for i in minibatch]) states np.squeeze(states) next_states np.squeeze(next_states) targets rewards self.gamma * (np.amax(self.target_model.predict_on_batch(next_states), axis1)) * (1 - dones) targets_full self.model.predict_on_batch(states) ind np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] targets self.model.fit(states, targets_full, epochs1, verbose0) if self.epsilon self.epsilon_min: self.epsilon * self.epsilon_decay def load(self, name): self.model.load_weights(name) def save(self, name): self.model.save_weights(name)4. 训练AI玩贪吃蛇现在我们将游戏环境和DQN算法结合起来训练AI玩贪吃蛇。4.1 状态表示我们需要定义如何将游戏状态表示为神经网络可以理解的输入。对于贪吃蛇游戏状态可以包括蛇头周围四个方向是否有障碍蛇身或墙壁食物相对于蛇头的位置左/右/上/下蛇当前的移动方向def get_state(self): head self.snake.get_head_position() food self.food.position # 计算四个方向的点 point_l (head[0] - self.block_size, head[1]) point_r (head[0] self.block_size, head[1]) point_u (head[0], head[1] - self.block_size) point_d (head[0], head[1] self.block_size) # 当前移动方向 dir_l self.snake.direction (-1, 0) dir_r self.snake.direction (1, 0) dir_u self.snake.direction (0, -1) dir_d self.snake.direction (0, 1) state [ # 危险直行 (dir_r and self.is_collision(point_r)) or (dir_l and self.is_collision(point_l)) or (dir_u and self.is_collision(point_u)) or (dir_d and self.is_collision(point_d)), # 危险右转 (dir_u and self.is_collision(point_r)) or (dir_d and self.is_collision(point_l)) or (dir_l and self.is_collision(point_u)) or (dir_r and self.is_collision(point_d)), # 危险左转 (dir_d and self.is_collision(point_r)) or (dir_u and self.is_collision(point_l)) or (dir_r and self.is_collision(point_u)) or (dir_l and self.is_collision(point_d)), # 移动方向 dir_l, dir_r, dir_u, dir_d, # 食物位置 food[0] head[0], # 食物在左 food[0] head[0], # 食物在右 food[1] head[1], # 食物在上 food[1] head[1] # 食物在下 ] return np.array(state, dtypeint)4.2 奖励函数设计奖励函数是强化学习中最关键的部分之一它告诉AI什么是好的行为什么是坏的行为。对于贪吃蛇游戏我们可以设计如下奖励吃到食物10撞到自己或墙壁-10靠近食物1远离食物-1每移动一步-0.1鼓励高效def get_reward(self, snake, food, done): if done: return -10 if snake.get_head_position() food.position: return 10 # 计算与食物的距离 head snake.get_head_position() food_pos food.position new_dist abs(head[0] - food_pos[0]) abs(head[1] - food_pos[1]) # 如果距离减小给予正奖励否则负奖励 if new_dist self.prev_distance: reward 1 else: reward -1 self.prev_distance new_dist # 每步的小惩罚 reward - 0.1 return reward4.3 训练过程训练过程主要包括以下步骤初始化环境和智能体获取当前状态智能体选择动作执行动作获取新状态和奖励存储经验到记忆库训练智能体定期更新目标网络def train(): pygame.init() width, height, block_size 800, 600, 20 game Game(width, height, block_size) agent DQNAgent(state_size11, action_size3) # 3动作直行、右转、左转 episodes 1000 batch_size 32 for e in range(episodes): game.reset() state game.get_state() state np.reshape(state, [1, 11]) total_reward 0 while True: action agent.act(state) # 执行动作 if action 0: # 直行 pass elif action 1: # 右转 if game.snake.direction (0, -1): game.snake.turn((1, 0)) elif game.snake.direction (1, 0): game.snake.turn((0, 1)) elif game.snake.direction (0, 1): game.snake.turn((-1, 0)) elif game.snake.direction (-1, 0): game.snake.turn((0, -1)) elif action 2: # 左转 if game.snake.direction (0, -1): game.snake.turn((-1, 0)) elif game.snake.direction (-1, 0): game.snake.turn((0, 1)) elif game.snake.direction (0, 1): game.snake.turn((1, 0)) elif game.snake.direction (1, 0): game.snake.turn((0, -1)) game.snake.move() # 检查游戏状态 done False if game.snake.get_head_position() in game.snake.positions[1:]: done True # 检查是否吃到食物 if game.snake.get_head_position() game.food.position: game.snake.length 1 game.food Food(block_size, width, height) # 获取奖励和新状态 reward game.get_reward(game.snake, game.food, done) total_reward reward next_state game.get_state() next_state np.reshape(next_state, [1, 11]) # 存储经验 agent.remember(state, action, reward, next_state, done) state next_state if done: print(fEpisode: {e}/{episodes}, Score: {game.snake.length}, Total reward: {total_reward}, Epsilon: {agent.epsilon:.2f}) break if len(agent.memory) batch_size: agent.replay(batch_size) # 定期更新目标网络 if e % 10 0: agent.update_target_model() # 定期保存模型 if e % 100 0: agent.save(fsnake_dqn_{e}.h5) agent.save(snake_dqn_final.h5)5. 调优与改进训练过程中你可能会遇到AI表现不佳的情况。以下是几个常见的调优方向5.1 奖励函数调整奖励函数的设计对训练效果影响巨大。可以尝试以下调整增加对长时间存活的奖励调整靠近/远离食物的奖励幅度增加对形成循环移动的惩罚5.2 网络结构优化可以尝试更复杂的网络结构def _build_model(self): model tf.keras.Sequential([ tf.keras.layers.Dense(64, input_dimself.state_size, activationrelu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(64, activationrelu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(self.action_size, activationlinear) ]) model.compile(lossmse, optimizertf.keras.optimizers.Adam(lrself.learning_rate)) return model5.3 训练参数调整关键训练参数包括参数建议值说明γ (gamma)0.9-0.99折扣因子越大表示越重视长期奖励ε (epsilon)1.0→0.01探索率初始高探索逐渐降低ε衰减0.995控制探索率降低速度学习率0.0001-0.001影响权重更新幅度批次大小32-64每次训练的样本数量记忆容量1000-10000经验回放缓冲区大小5.4 高级技巧双DQN(Double DQN)使用两个网络分别选择动作和评估动作减少过高估计问题。优先级经验回放(Prioritized Experience Replay)给重要的经验样本更高采样概率。决斗网络架构(Dueling Network)将Q值分解为状态值和优势函数。实现双DQN只需修改replay方法def replay(self, batch_size): if len(self.memory) batch_size: return minibatch random.sample(self.memory, batch_size) states np.array([i[0] for i in minibatch]) actions np.array([i[1] for i in minibatch]) rewards np.array([i[2] for i in minibatch]) next_states np.array([i[3] for i in minibatch]) dones np.array([i[4] for i in minibatch]) states np.squeeze(states) next_states np.squeeze(next_states) # 双DQN修改部分 next_actions np.argmax(self.model.predict_on_batch(next_states), axis1) q_values_next self.target_model.predict_on_batch(next_states) targets rewards self.gamma * q_values_next[np.arange(batch_size), next_actions] * (1 - dones) targets_full self.model.predict_on_batch(states) ind np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] targets self.model.fit(states, targets_full, epochs1, verbose0) if self.epsilon self.epsilon_min: self.epsilon * self.epsilon_decay
用Python和TensorFlow训练AI玩贪吃蛇:从游戏逻辑到DQN算法实战(附完整代码)
用Python和TensorFlow训练AI玩贪吃蛇从游戏逻辑到DQN算法实战贪吃蛇这个经典游戏几乎每个人都玩过。但你是否想过让AI来玩这个游戏会是什么样子本文将带你从零开始用Python和TensorFlow构建一个能够自主玩贪吃蛇的AI系统。不同于简单的规则式AI我们将使用深度强化学习中的DQN算法让AI真正学会如何玩这个游戏。1. 项目准备与环境搭建在开始编码之前我们需要准备好开发环境。这个项目需要以下几个主要组件Python 3.7或更高版本Pygame库用于游戏界面TensorFlow 2.x用于构建和训练神经网络NumPy用于数值计算安装这些依赖非常简单只需在命令行中执行以下命令pip install pygame tensorflow numpy对于硬件要求虽然可以在CPU上运行但如果有NVIDIA显卡并安装了CUDA训练速度会显著提升。建议至少4GB内存因为神经网络训练过程会比较消耗资源。项目目录结构建议如下/snake_ai /game __init__.py snake.py # 游戏逻辑 render.py # 游戏渲染 /rl __init__.py dqn.py # DQN算法实现 memory.py # 经验回放缓冲区 config.py # 配置文件 train.py # 训练脚本 play.py # 人类游玩脚本2. 贪吃蛇游戏逻辑实现首先我们需要构建贪吃蛇游戏的基本框架。使用Pygame可以方便地创建游戏窗口和处理用户输入。2.1 游戏核心类设计我们创建三个主要类Snake、Food和Game。下面是Snake类的核心代码class Snake: def __init__(self, block_size20, width800, height600): self.length 3 self.positions [(width // 2, height // 2)] self.direction random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) self.block_size block_size self.width width self.height height self.color (0, 255, 0) # 绿色 def get_head_position(self): return self.positions[0] def turn(self, new_direction): # 防止180度转弯 if (new_direction[0] * -1, new_direction[1] * -1) ! self.direction: self.direction new_direction def move(self): head self.get_head_position() x, y self.direction new_x (head[0] (x * self.block_size)) % self.width new_y (head[1] (y * self.block_size)) % self.height new_position (new_x, new_y) self.positions.insert(0, new_position) if len(self.positions) self.length: self.positions.pop() def reset(self): self.length 3 self.positions [(self.width // 2, self.height // 2)] self.direction random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)]) def draw(self, surface): for p in self.positions: rect pygame.Rect((p[0], p[1]), (self.block_size, self.block_size)) pygame.draw.rect(surface, self.color, rect) pygame.draw.rect(surface, (0, 0, 0), rect, 1)2.2 游戏主循环游戏主循环负责处理输入、更新游戏状态和渲染画面class Game: def __init__(self, width800, height600, block_size20): pygame.init() self.screen pygame.display.set_mode((width, height)) self.clock pygame.time.Clock() self.snake Snake(block_size, width, height) self.food Food(block_size, width, height) self.width width self.height height self.block_size block_size self.score 0 def run(self): running True while running: for event in pygame.event.get(): if event.type pygame.QUIT: running False elif event.type pygame.KEYDOWN: if event.key pygame.K_UP: self.snake.turn((0, -1)) elif event.key pygame.K_DOWN: self.snake.turn((0, 1)) elif event.key pygame.K_LEFT: self.snake.turn((-1, 0)) elif event.key pygame.K_RIGHT: self.snake.turn((1, 0)) self.snake.move() # 检测是否吃到食物 if self.snake.get_head_position() self.food.position: self.snake.length 1 self.score 1 self.food Food(self.block_size, self.width, self.height) # 检测碰撞 if self.snake.get_head_position() in self.snake.positions[1:]: print(fGame Over! Score: {self.score}) self.snake.reset() self.score 0 # 渲染 self.screen.fill((255, 255, 255)) self.snake.draw(self.screen) self.food.draw(self.screen) pygame.display.update() self.clock.tick(10) # 控制游戏速度 pygame.quit()3. DQN算法原理与实现深度Q网络(DQN)是强化学习中的一种重要算法它结合了Q-learning和深度神经网络的优点。3.1 DQN核心概念DQN的核心思想是使用神经网络来近似Q函数即状态-动作值函数。Q函数表示在某个状态下采取某个动作所能获得的预期回报。DQN有几个关键组件经验回放(Experience Replay)存储智能体的经验(状态,动作,奖励,新状态)在记忆库中训练时从中随机采样打破数据间的相关性。目标网络(Target Network)使用一个独立的网络来计算目标Q值提高训练稳定性。ε-贪婪策略(ε-Greedy Policy)在探索和利用之间取得平衡开始时更多探索逐渐增加利用。3.2 DQN实现代码下面是DQN的核心实现import numpy as np import tensorflow as tf from collections import deque import random class DQNAgent: def __init__(self, state_size, action_size): self.state_size state_size self.action_size action_size self.memory deque(maxlen2000) self.gamma 0.95 # 折扣因子 self.epsilon 1.0 # 探索率 self.epsilon_min 0.01 self.epsilon_decay 0.995 self.learning_rate 0.001 self.model self._build_model() self.target_model self._build_model() self.update_target_model() def _build_model(self): model tf.keras.Sequential() model.add(tf.keras.layers.Dense(24, input_dimself.state_size, activationrelu)) model.add(tf.keras.layers.Dense(24, activationrelu)) model.add(tf.keras.layers.Dense(self.action_size, activationlinear)) model.compile(lossmse, optimizertf.keras.optimizers.Adam(lrself.learning_rate)) return model def update_target_model(self): self.target_model.set_weights(self.model.get_weights()) def remember(self, state, action, reward, next_state, done): self.memory.append((state, action, reward, next_state, done)) def act(self, state): if np.random.rand() self.epsilon: return random.randrange(self.action_size) act_values self.model.predict(state) return np.argmax(act_values[0]) def replay(self, batch_size): if len(self.memory) batch_size: return minibatch random.sample(self.memory, batch_size) states np.array([i[0] for i in minibatch]) actions np.array([i[1] for i in minibatch]) rewards np.array([i[2] for i in minibatch]) next_states np.array([i[3] for i in minibatch]) dones np.array([i[4] for i in minibatch]) states np.squeeze(states) next_states np.squeeze(next_states) targets rewards self.gamma * (np.amax(self.target_model.predict_on_batch(next_states), axis1)) * (1 - dones) targets_full self.model.predict_on_batch(states) ind np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] targets self.model.fit(states, targets_full, epochs1, verbose0) if self.epsilon self.epsilon_min: self.epsilon * self.epsilon_decay def load(self, name): self.model.load_weights(name) def save(self, name): self.model.save_weights(name)4. 训练AI玩贪吃蛇现在我们将游戏环境和DQN算法结合起来训练AI玩贪吃蛇。4.1 状态表示我们需要定义如何将游戏状态表示为神经网络可以理解的输入。对于贪吃蛇游戏状态可以包括蛇头周围四个方向是否有障碍蛇身或墙壁食物相对于蛇头的位置左/右/上/下蛇当前的移动方向def get_state(self): head self.snake.get_head_position() food self.food.position # 计算四个方向的点 point_l (head[0] - self.block_size, head[1]) point_r (head[0] self.block_size, head[1]) point_u (head[0], head[1] - self.block_size) point_d (head[0], head[1] self.block_size) # 当前移动方向 dir_l self.snake.direction (-1, 0) dir_r self.snake.direction (1, 0) dir_u self.snake.direction (0, -1) dir_d self.snake.direction (0, 1) state [ # 危险直行 (dir_r and self.is_collision(point_r)) or (dir_l and self.is_collision(point_l)) or (dir_u and self.is_collision(point_u)) or (dir_d and self.is_collision(point_d)), # 危险右转 (dir_u and self.is_collision(point_r)) or (dir_d and self.is_collision(point_l)) or (dir_l and self.is_collision(point_u)) or (dir_r and self.is_collision(point_d)), # 危险左转 (dir_d and self.is_collision(point_r)) or (dir_u and self.is_collision(point_l)) or (dir_r and self.is_collision(point_u)) or (dir_l and self.is_collision(point_d)), # 移动方向 dir_l, dir_r, dir_u, dir_d, # 食物位置 food[0] head[0], # 食物在左 food[0] head[0], # 食物在右 food[1] head[1], # 食物在上 food[1] head[1] # 食物在下 ] return np.array(state, dtypeint)4.2 奖励函数设计奖励函数是强化学习中最关键的部分之一它告诉AI什么是好的行为什么是坏的行为。对于贪吃蛇游戏我们可以设计如下奖励吃到食物10撞到自己或墙壁-10靠近食物1远离食物-1每移动一步-0.1鼓励高效def get_reward(self, snake, food, done): if done: return -10 if snake.get_head_position() food.position: return 10 # 计算与食物的距离 head snake.get_head_position() food_pos food.position new_dist abs(head[0] - food_pos[0]) abs(head[1] - food_pos[1]) # 如果距离减小给予正奖励否则负奖励 if new_dist self.prev_distance: reward 1 else: reward -1 self.prev_distance new_dist # 每步的小惩罚 reward - 0.1 return reward4.3 训练过程训练过程主要包括以下步骤初始化环境和智能体获取当前状态智能体选择动作执行动作获取新状态和奖励存储经验到记忆库训练智能体定期更新目标网络def train(): pygame.init() width, height, block_size 800, 600, 20 game Game(width, height, block_size) agent DQNAgent(state_size11, action_size3) # 3动作直行、右转、左转 episodes 1000 batch_size 32 for e in range(episodes): game.reset() state game.get_state() state np.reshape(state, [1, 11]) total_reward 0 while True: action agent.act(state) # 执行动作 if action 0: # 直行 pass elif action 1: # 右转 if game.snake.direction (0, -1): game.snake.turn((1, 0)) elif game.snake.direction (1, 0): game.snake.turn((0, 1)) elif game.snake.direction (0, 1): game.snake.turn((-1, 0)) elif game.snake.direction (-1, 0): game.snake.turn((0, -1)) elif action 2: # 左转 if game.snake.direction (0, -1): game.snake.turn((-1, 0)) elif game.snake.direction (-1, 0): game.snake.turn((0, 1)) elif game.snake.direction (0, 1): game.snake.turn((1, 0)) elif game.snake.direction (1, 0): game.snake.turn((0, -1)) game.snake.move() # 检查游戏状态 done False if game.snake.get_head_position() in game.snake.positions[1:]: done True # 检查是否吃到食物 if game.snake.get_head_position() game.food.position: game.snake.length 1 game.food Food(block_size, width, height) # 获取奖励和新状态 reward game.get_reward(game.snake, game.food, done) total_reward reward next_state game.get_state() next_state np.reshape(next_state, [1, 11]) # 存储经验 agent.remember(state, action, reward, next_state, done) state next_state if done: print(fEpisode: {e}/{episodes}, Score: {game.snake.length}, Total reward: {total_reward}, Epsilon: {agent.epsilon:.2f}) break if len(agent.memory) batch_size: agent.replay(batch_size) # 定期更新目标网络 if e % 10 0: agent.update_target_model() # 定期保存模型 if e % 100 0: agent.save(fsnake_dqn_{e}.h5) agent.save(snake_dqn_final.h5)5. 调优与改进训练过程中你可能会遇到AI表现不佳的情况。以下是几个常见的调优方向5.1 奖励函数调整奖励函数的设计对训练效果影响巨大。可以尝试以下调整增加对长时间存活的奖励调整靠近/远离食物的奖励幅度增加对形成循环移动的惩罚5.2 网络结构优化可以尝试更复杂的网络结构def _build_model(self): model tf.keras.Sequential([ tf.keras.layers.Dense(64, input_dimself.state_size, activationrelu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(64, activationrelu), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(self.action_size, activationlinear) ]) model.compile(lossmse, optimizertf.keras.optimizers.Adam(lrself.learning_rate)) return model5.3 训练参数调整关键训练参数包括参数建议值说明γ (gamma)0.9-0.99折扣因子越大表示越重视长期奖励ε (epsilon)1.0→0.01探索率初始高探索逐渐降低ε衰减0.995控制探索率降低速度学习率0.0001-0.001影响权重更新幅度批次大小32-64每次训练的样本数量记忆容量1000-10000经验回放缓冲区大小5.4 高级技巧双DQN(Double DQN)使用两个网络分别选择动作和评估动作减少过高估计问题。优先级经验回放(Prioritized Experience Replay)给重要的经验样本更高采样概率。决斗网络架构(Dueling Network)将Q值分解为状态值和优势函数。实现双DQN只需修改replay方法def replay(self, batch_size): if len(self.memory) batch_size: return minibatch random.sample(self.memory, batch_size) states np.array([i[0] for i in minibatch]) actions np.array([i[1] for i in minibatch]) rewards np.array([i[2] for i in minibatch]) next_states np.array([i[3] for i in minibatch]) dones np.array([i[4] for i in minibatch]) states np.squeeze(states) next_states np.squeeze(next_states) # 双DQN修改部分 next_actions np.argmax(self.model.predict_on_batch(next_states), axis1) q_values_next self.target_model.predict_on_batch(next_states) targets rewards self.gamma * q_values_next[np.arange(batch_size), next_actions] * (1 - dones) targets_full self.model.predict_on_batch(states) ind np.array([i for i in range(batch_size)]) targets_full[[ind], [actions]] targets self.model.fit(states, targets_full, epochs1, verbose0) if self.epsilon self.epsilon_min: self.epsilon * self.epsilon_decay