用 Java 训练出一只“不死鸟” (2)

我们采用了采用了 3 个卷积层,4 个 relu 激活函数以及 2 个全连接层的神经网络架构。

layer input shape output shape
conv2d   (batchSize, 4, 80, 80)   (batchSize,4,20,20)  
conv2d   (batchSize, 4, 20 ,20)   (batchSize, 32, 9, 9)  
conv2d   (batchSize, 32, 9, 9)   (batchSize, 64, 7, 7)  
linear   (batchSize, 3136)   (batchSize, 512)  
linear   (batchSize, 512)   (batchSize, 2)  
训练过程

DJL 的 RL 库中提供了非常方便的用于实现强化学习的接口:(RlEnv, RlAgent, ReplayBuffer)。

实现 RlAgent 接口即可构建一个可以进行训练的智能体。

在现有的游戏环境中实现 RlEnv 接口即可生成训练所需的数据。

创建 ReplayBuffer 可以存储并动态更新训练数据。

在实现这些接口后,只需要调用 step 方法:

RlEnv.step(action, training);

这个方法会将 RlAgent 决策出的动作输入到游戏环境中获得反馈。我们可以在 RlEnv 中提供的 runEnviroment 方法中调用 step 方法,然后只需要重复执行 runEnvironment 方法,即可不断地生成用于训练的数据。

public Step[] runEnvironment(RlAgent agent, boolean training) { // run the game NDList action = agent.chooseAction(this, training); step(action, training); if (training) { batchSteps = this.getBatch(); } return batchSteps; }

我们将 ReplayBuffer 可存储的 step 数量设置为 50000,在 observe 周期我们会先向 replayBuffer 中存储 1000 个使用随机动作生成的 step,这样可以使智能体更快地从随机动作中学习。

在 explore 和 training 周期,神经网络会随机从 replayBuffer 中生成训练集并将它们输入到模型中训练。我们使用 Adam 优化器和 MSE 损失函数迭代神经网络。

神经网络输入预处理

首先将图像大小 resize 成 80x80 并转为灰度图,这有助于在不丢失信息的情况下提高训练速度。

public static NDArray imgPreprocess(BufferedImage observation) { return NDImageUtils.toTensor( NDImageUtils.resize( ImageFactory.getInstance().fromImage(observation) .toNDArray(NDManager.newBaseManager(), Image.Flag.GRAYSCALE) ,80,80)); }

然后我们把连续的四帧图像作为一个输入,为了获得连续四帧的连续图像,我们维护了一个全局的图像队列保存游戏线程中的图像,每一次动作后替换掉最旧的一帧,然后把队列里的图像 stack 成一个单独的 NDArray。

public NDList createObservation(BufferedImage currentImg) { NDArray observation = GameUtil.imgPreprocess(currentImg); if (imgQueue.isEmpty()) { for (int i = 0; i < 4; i++) { imgQueue.offer(observation); } return new NDList(NDArrays.stack(new NDList(observation, observation, observation, observation), 1)); } else { imgQueue.remove(); imgQueue.offer(observation); NDArray[] buf = new NDArray[4]; int i = 0; for (NDArray nd : imgQueue) { buf[i++] = nd; } return new NDList(NDArrays.stack(new NDList(buf[0], buf[1], buf[2], buf[3]), 1)); } }

一旦以上部分完成,我们就可以开始训练了。训练优化为了获得最佳的训练性能,我们关闭了 GUI 以加快样本生成速度。并使用 Java 多线程将训练循环和样本生成循环分别在不同的线程中运行。

List<Callable<Object>> callables = new ArrayList<>(numOfThreads); callables.add(new GeneratorCallable(game, agent, training)); if(training) { callables.add(new TrainerCallable(model, agent)); } 总结

这个模型在 NVIDIA T4 GPU 训练了大概 4 个小时,更新了 300 万步。训练后的小鸟已经可以完全自主控制动作灵活穿梭与管道之间。训练后的模型也同样上传到了仓库中供您测试。在此项目中 DJL 提供了强大的训练 API 以及模型库支持,使得在 Java 开发过程中得心应手。

内容版权声明:除非注明,否则皆为本站原创文章。

转载注明出处:https://www.heiqu.com/zwsdyf.html