用 Java 训练出一只“不死鸟”

用 Java 训练出一只“不死鸟”

作者:Kingyu & Lanking

FlappyBird 是 2013 年推出的一款手机游戏,因其简单的玩法但极度困难的设定迅速走红全网。随着深度学习(DL)与增强学习(RL)等前沿算法的发展,我们可以使 Java 非常方便地训练出一个智能体来控制 Flappy Bird。

用 Java 训练出一只“不死鸟”

故事开始于《GitHub 上的大佬们打完招呼,会聊些什么?》,那么,今天我们就来一起看一下如何 Java 训练出一个不死鸟。游戏项目我们使用了一个仅用 Java 基本类库编写的 FlappyBird 游戏。在训练方面,我们使用 DeepJavaLibrary 一个基于 Java 的深度学习框架来构建增强学习训练网络并进行训练。经过了差不多 300 万步(四小时)的训练后,小鸟已经可以获得最高 8000 多分的成绩,灵活穿梭于水管之间。

用 Java 训练出一只“不死鸟”

在本文中,我们将从原理开始一步一步实现增强学习算法并用它对游戏进行训练。如果任何一个时刻不清楚如何继续进行下去,可以参阅项目的源码。

项目地址:https://github.com/kingyuluk/RL-FlappyBird

增强学习(RL)的架构

在这一节会介绍主要用到的算法以及神经网络,帮助你更好的了解如何进行训练。本项目与 DeepLearningFlappyBird 使用了类似的方法进行训练。算法整体的架构是 Q-Learning + 卷积神经网络(CNN),把游戏每一帧的状态存储起来,即小鸟采用的动作和采用动作之后的效果,这些将作为卷积神经网络的训练数据。

CNN 训练简述

CNN 的输入数据为连续的 4 帧图像,我们将这图像 stack 起来作为小鸟当前的“observation”,图像会转换成灰度图以减少所需的训练资源。图像存储的矩阵形式是 (batch size, 4 (frames), 80 (width), 80 (height)) 数组里的元素就是当前帧的像素值,这些数据将输入到 CNN 后将输出 (batch size, 2) 的矩阵,矩阵的第二个维度就是小鸟 (振翅不采取动作) 对应的收益。

训练数据

在小鸟采取动作后,我们会得到 preObservation and currentObservation 即是两组 4 帧的连续的图像表示小鸟动作前和动作后的状态。然后我们将 preObservation, currentObservation, action, reward, terminal 组成的五元组作为一个 step 存进 replayBuffer 中。它是一个有限大小的训练数据集,他会随着最新的操作动态更新内容。

public void step(NDList action, boolean training) { if (action.singletonOrThrow().getInt(1) == 1) { bird.birdFlap(); } stepFrame(); NDList preObservation = currentObservation; currentObservation = createObservation(currentImg); FlappyBirdStep step = new FlappyBirdStep(manager.newSubManager(), preObservation, currentObservation, action, currentReward, currentTerminal); if (training) { replayBuffer.addStep(step); } if (gameState == GAME_OVER) { restartGame(); } } 训练的三个周期

训练分为 3 个不同的周期以更好地生成训练数据:

Observe(观察) 周期:随机产生训练数据

Explore (探索) 周期:随机与推理动作结合更新训练数据

Training (训练) 周期:推理动作主导产生新数据

通过这种训练模式,我们可以更好的达到预期效果。

处于 Explore 周期时,我们会根据权重选取随机的动作或使用模型推理出的动作来作为小鸟的动作。训练前期,随机动作的权重会非常大,因为模型的决策十分不准确 (甚至不如随机)。在训练后期时,随着模型学习的动作逐步增加,我们会不断增加模型推理动作的权重并最终使它成为主导动作。调节随机动作的参数叫做 epsilon 它会随着训练的过程不断变化。

public NDList chooseAction(RlEnv env, boolean training) { if (training && RandomUtils.random() < exploreRate.getNewValue(counter++)) { return env.getActionSpace().randomAction(); } else return baseAgent.chooseAction(env, training); } 训练逻辑

首先,我们会从 replayBuffer 中随机抽取一批数据作为作为训练集。然后将 preObservation 输入到神经网络得到所有行为的 reward(Q)作为预测值:

NDList QReward = trainer.forward(preInput); NDList Q = new NDList(QReward.singletonOrThrow() .mul(actionInput.singletonOrThrow()) .sum(new int[]{1}));

postObservation 同样会输入到神经网络,根据马尔科夫决策过程以及贝尔曼价值函数计算出所有行为的 reward(targetQ)作为真实值:

// 将 postInput 输入到神经网络中得到 targetQReward 是 (batchsize,2) 的矩阵。根据 Q-learning 的算法,每一次的 targetQ 需要根据当前环境是否结束算出不同的值,因此需要将每一个 step 的 targetQ 单独算出后再将 targetQ 堆积成 NDList。 NDList targetQReward = trainer.forward(postInput); NDArray[] targetQValue = new NDArray[batchSteps.length]; for (int i = 0; i < batchSteps.length; i++) { if (batchSteps[i].isTerminal()) { targetQValue[i] = batchSteps[i].getReward(); } else { targetQValue[i] = targetQReward.singletonOrThrow().get(i) .max() .mul(rewardDiscount) .add(rewardInput.singletonOrThrow().get(i)); } } NDList targetQBatch = new NDList(); Arrays.stream(targetQValue).forEach(value -> targetQBatch.addAll(new NDList(value))); NDList targetQ = new NDList(NDArrays.stack(targetQBatch, 0));

在训练结束时,计算 Q 和 targetQ 的损失值,并在 CNN 中更新权重。

卷积神经网络模型(CNN)

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zwsdyf.html