likes
comments
collection
share

Pytorch强化学习初学者笔记 -- 实战 -- 使用强化学习(Reinforcement Learning)来让程序自动学习玩乒乓小游戏

作者站长头像
站长
· 阅读数 23

必备前提

(【注】: 在实战前,首先要确保基础环境能够正常执行模型训练。若环境准备不当,实战代码便无法验证。以下文章链接提供了相关的环境准备指南)

前言

  • 本文暂时并未完全实现,我还在努力学习中,这篇文章后续会随着我的学习进度,再来改进
  • 目前只是记录下暂时写下的代码而已
  • (边工作边学习个人爱好,所以经常会暂时性写下一些未完成的笔记,有意者看到的话请见谅哈)
    • (等此案例完全完成之后,我会把源码放到 github 上开源分享出来,欢迎点个关注)
  • *(若期待本文最终结果的话,可以持续关注)

案例功能描述

  • 有一个简易的打乒乓小游戏是这样的:一个小球会从顶部随机位置开始往下掉落,底部有个一定长度的横线,需要通过左右方向按键左右移动这顶部的横线来接住小球,如果没接住小球就算游戏失败,失败了再重新开始。
  • 我们可以使用强化学习(Reinforcement Learning)来让程序自动学习玩这个简单的小游戏。 Pytorch强化学习初学者笔记 -- 实战 -- 使用强化学习(Reinforcement Learning)来让程序自动学习玩乒乓小游戏

1、小游戏准备:html+js+css+canvas 的乒乓小游戏代码在此!

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Ping Pong Game</title>
    <style>
        body {
            background-color: rgb(255, 255, 255);
        }
        canvas {
            border: 5px solid black;
            display: block;
            margin: 0 auto;
            background-color: rgb(255, 255, 255);
        }

        .modal {
            display: none;
            position: fixed;
            z-index: 1;
            left: 0;
            top: 0;
            width: 100%;
            height: 100%;
            overflow: auto;
            background-color: rgba(0, 0, 0, 0.4);
        }

        .modal-content {
            background-color: #fefefe;
            margin: 15% auto;
            padding: 20px;
            border: 2px solid black;
            width: 30%;
            text-align: center;
        }

        .instructions {
            margin-top: 20px;
            text-align: center;
        }
    </style>
</head>

<body>
    <canvas id="gameCanvas" width="500" height="500"></canvas>
    <div id="gameOverModal" class="modal">
        <div class="modal-content">
            <h2 id="gameOverMessage">Game Over</h2>
            <p id="finalScore">Final Score: 0</p>
            <button id="restartButton">Restart</button>
        </div>
    </div>
    <div id="startModal" class="modal">
        <div class="modal-content">
            <h2>Start Game</h2>
            <button id="startButton">Start</button>
        </div>
    </div>
    <div class="instructions">通过键盘 A/D 或 ←/→ 按键控制乒乓板左右移动</div>
    <div class="instructions">Enter 回车键 开始/重新开始 游戏</div>
    <script>
        const canvas = document.getElementById('gameCanvas');
        const ctx = canvas.getContext('2d');
        const restartButton = document.getElementById('restartButton');
        const startButton = document.getElementById('startButton');
        const gameOverModal = document.getElementById('gameOverModal');
        const startModal = document.getElementById('startModal');
        const gameOverMessage = document.getElementById('gameOverMessage');
        const finalScoreElement = document.getElementById('finalScore');

        let ballX, ballY, ballRadius = 8;
        let barX = canvas.width / 2 - 15, barY = canvas.height - 55, barWidth = 30, barHeight = 5;
        let score = 0;
        let ballSpeedY = 2;
        let barSpeed = 6;
        let isInitGame = false;
        let isGameOver = false;
        let isGameStarted = false;

        function drawBall() {
            ctx.beginPath();
            ctx.arc(ballX, ballY, ballRadius, 0, 2 * Math.PI);
            ctx.fillStyle = 'red';
            ctx.fill();
            ctx.closePath();
        }

        function drawBar() {
            ctx.fillStyle = 'blue';
            ctx.fillRect(barX, barY, barWidth, barHeight);
        }

        function drawScore() {
            ctx.font = '16px Arial';
            ctx.fillStyle = 'black';
            ctx.fillText(`Score: ${score}`, 10, 20);
        }

        function drawBorder() {
            ctx.strokeStyle = 'black';
            ctx.lineWidth = 5;
            ctx.strokeRect(0, 0, canvas.width, canvas.height);
        }

        function updateGame() {
            if (isGameStarted && !isGameOver) {
                ctx.clearRect(0, 0, canvas.width, canvas.height);
                drawBorder();
                ballY += ballSpeedY;
                drawBall();
                drawBar();
                drawScore();

                if (ballY + ballRadius >= barY && ballY - ballRadius <= barY + barHeight && ballX >= barX && ballX <= barX + barWidth) {
                    score += 10;
                    ballY = 0;
                    ballX = Math.floor(Math.random() * (canvas.width - 2 * ballRadius)) + ballRadius;
                } else if (ballY + ballRadius >= canvas.height) {
                    isGameOver = true;
                    gameOverModal.style.display = 'block';
                    gameOverMessage.textContent = 'Game Over';
                    finalScoreElement.textContent = `Final Score: ${score}`;
                }

                requestAnimationFrame(updateGame);
            }
        }

        document.addEventListener('keydown', (event) => {
            if (isGameStarted && !isGameOver) {
                if (['a','ArrowLeft'].includes(event.key) && barX > 0) {
                    barX -= barSpeed;
                } else if (['d','ArrowRight'].includes(event.key) && barX + barWidth < canvas.width) {
                    barX += barSpeed;
                }
            } else if (isInitGame && ['Space', 'Enter'].includes(event.key)) {
                startGame();
            } else if (isGameOver && ['Space', 'Enter'].includes(event.key)) {
                restartGame();
            }
        });

        function restartGame() {
            score = 0;
            isGameOver = false;
            ballX = Math.floor(Math.random() * (canvas.width - 2 * ballRadius)) + ballRadius;
            ballY = 0;
            barX = canvas.width / 2 - 15;
            gameOverModal.style.display = 'none';
            updateGame();
        }

        restartButton.addEventListener('click', restartGame);

        function startGame() {
            isInitGame = false;
            startModal.style.display = 'none';
            isGameStarted = true;
            ballX = Math.floor(Math.random() * (canvas.width - 2 * ballRadius)) + ballRadius;
            ballY = 0;
            updateGame();
        }

        startButton.addEventListener('click', startGame);

        function initGame() {
            isInitGame = true;
            startModal.style.display = 'block';
        }

        initGame();
    </script>
</body>

</html>
  • 小游戏运行效果 Pytorch强化学习初学者笔记 -- 实战 -- 使用强化学习(Reinforcement Learning)来让程序自动学习玩乒乓小游戏

2、目标识别模型训练:让程序认识小游戏中的重要目标

【附】2.1、一键转移 labelimg 标签文件的脚本

  • 目标识别模型训练中,为了方便把 labelimg 工具输出的标签文件和原始图片文件放到 data/ 目录下,我写下了这个一键转移文件脚本 Pytorch强化学习初学者笔记 -- 实战 -- 使用强化学习(Reinforcement Learning)来让程序自动学习玩乒乓小游戏
# 执行一件转移脚本
python copy_labelimgfiles_to_datadir.py
  • copy_labelimgfiles_to_datadir.py
import os
import shutil
import random

cur_dir = os.path.dirname(__file__).replace('\\', '/')
print(cur_dir)

source_img_dir = cur_dir + '/source_img'
labelimg_output_dir = cur_dir + '/labelimg_output'

data_dir = cur_dir + '/data'
data_images_dir = data_dir + '/images'
data_labels_dir = data_dir + '/labels'
data_train_dir = data_dir + '/train'
data_train_images_dir = data_train_dir + '/images'
data_train_labels_dir = data_train_dir + '/labels'
data_val_dir = data_dir + '/val'
data_val_images_dir = data_val_dir + '/images'
data_val_labels_dir = data_val_dir + '/labels'
data_test_dir = data_dir + '/test'
data_test_images_dir = data_test_dir + '/images'
data_test_labels_dir = data_test_dir + '/labels'
data_yoloyaml_file_path = data_dir + '/yolov8nconfig.yaml'

def read_full_path_list(dir_path):
    full_path_list = []
    files = os.listdir(dir_path)
    for f in files:
        full_path_list.append(dir_path + '/' + f)
    return full_path_list

def create_dir(dir_path):
    if not os.path.exists(dir_path):
        os.mkdir(dir_path)

def read_labelimg_files():
    # 判断读取的目录是否存在
    if not os.path.exists(source_img_dir):
        print(f'【labelimg_output_dir】: {source_img_dir} not exists')
        return
    
    if not os.path.exists(labelimg_output_dir):
        print(f'【labelimg_output_dir】: {labelimg_output_dir} not exists')
        return

    # 读取 source_img_dir 目录下的所有文件
    source_img_files_fullpath_list = read_full_path_list(source_img_dir)
    # 读取 labelimg_output_dir 目录下的所有文件
    labelimg_output_files_fullpath_list = read_full_path_list(labelimg_output_dir)
    labelimg_txt_files_fullpath_list = []
    labelimg_classestxt_file_fullpath = ''
    for f in labelimg_output_files_fullpath_list:
        # 提取出 classes.txt 的文件
        if f.endswith('classes.txt'):
            labelimg_classestxt_file_fullpath = f
        # 提取出 txt 文件
        elif f.endswith('.txt'):
            labelimg_txt_files_fullpath_list.append(f)

    # print(f'source_img_files_fullpath_list: {source_img_files_fullpath_list}')
    # print(f'labelimg_txt_files_fullpath_list: {labelimg_txt_files_fullpath_list}')
    # print(f'labelimg_classestxt_file_fullpath: {labelimg_classestxt_file_fullpath}')
    
    return source_img_files_fullpath_list, labelimg_txt_files_fullpath_list, labelimg_classestxt_file_fullpath
    
def copy_labelimg_files_to_data_dir(source_img_files_fullpath_list, labelimg_txt_files_fullpath_list, labelimg_classestxt_file_fullpath):
    # 从 labelimg_classestxt_file_fullpath 文件中逐行读取出类名
    label_classes = []
    with open(labelimg_classestxt_file_fullpath, 'r', encoding='utf-8') as f:
        for line in f.readlines():
            line = line.strip()
            if line != '':
                label_classes.append(line)
    # 写入 yolo yaml 配置文件
    yoloymal_config_contents = [
        'train: ' + data_train_dir,
        '\n',
        'val: ' + data_val_dir,
        '\n',
        'test: ' + data_test_dir,
        '\n',
        '\n',
        'nc: ' + str(len(label_classes)),
        '\n',
        '\n',
        'names: ' + str(label_classes)
    ]
    # 根据 yoloymal_config_contents 数组中的内容,逐行写入到 data_yoloyaml_file_path 文件中
    with open(data_yoloyaml_file_path, 'w', encoding='utf-8') as f:
        f.writelines(yoloymal_config_contents)


    # 把参数中的文件复制到 data_dir 目录下对应的目录中
    for simg in source_img_files_fullpath_list:
        shutil.copy2(simg, data_images_dir)
        shutil.copy2(simg, data_train_images_dir)

    for ltxt in labelimg_txt_files_fullpath_list:
        shutil.copy2(ltxt, data_labels_dir)
        shutil.copy2(ltxt, data_train_labels_dir)

    # data/val 和 data/test 目录下不需要放所有的文件,只需随机取一部分的文件即可
    simg_files_pre_parts = []
    ltxt_files_pre_parts = []
    simg_files_end_parts = []
    ltxt_files_end_parts = []
    # 取文件列表一部分的几个文件,放到 data/val 和 data/test 目录下
    files_pre_parts_i_list = []
    files_end_parts_i_list = []
    files_len = len(source_img_files_fullpath_list)
    half_files_len = int(files_len/2)
    for i in range(half_files_len):
        files_pre_parts_i_list.append(random.randint(0, half_files_len))
        files_end_parts_i_list.append(random.randint(half_files_len+1, files_len-1))
    for i in files_pre_parts_i_list:
        simg_files_pre_parts.append(source_img_files_fullpath_list[i])
        ltxt_files_pre_parts.append(labelimg_txt_files_fullpath_list[i])
    for i in files_end_parts_i_list:
        simg_files_end_parts.append(source_img_files_fullpath_list[i])
        ltxt_files_end_parts.append(labelimg_txt_files_fullpath_list[i])
    # 把前面随机提取出来的文件,分别复制到 data/val 和 data/test 目录下
    for f in simg_files_pre_parts:
            shutil.copy2(f, data_val_images_dir)
    for f in ltxt_files_pre_parts:
            shutil.copy2(f, data_val_labels_dir)
    for f in simg_files_end_parts:
            shutil.copy2(f, data_test_images_dir)
    for f in ltxt_files_end_parts:
            shutil.copy2(f, data_test_labels_dir)
           
def remove_data_dir():
    if os.path.exists(data_dir):
        # 删除 data_dir 目录及其下面的所有子目录和文件
        shutil.rmtree(data_dir)

def create_data_dir():
    remove_data_dir()

    create_dir(data_dir)
    create_dir(data_images_dir)
    create_dir(data_labels_dir)
    create_dir(data_train_dir)
    create_dir(data_train_images_dir)    
    create_dir(data_train_labels_dir)
    create_dir(data_val_dir)
    create_dir(data_val_images_dir)
    create_dir(data_val_labels_dir)
    create_dir(data_test_dir)
    create_dir(data_test_images_dir)
    create_dir(data_test_labels_dir)

if __name__ == '__main__':
    source_img_files_fullpath_list, labelimg_txt_files_fullpath_list, labelimg_classestxt_file_fullpath = read_labelimg_files()
    create_data_dir()
    copy_labelimg_files_to_data_dir(source_img_files_fullpath_list, labelimg_txt_files_fullpath_list, labelimg_classestxt_file_fullpath)

yolov8 目标检测训练结果

  • 通过来回几轮的目标检测训练之后,最终得到了如下效果(可以精确识别所有目标物体了)
  • 大致训练过程是:
    • (1)第1轮给出几张关键的完整游戏画面截图,并打标签标注;
    • (2)从第2轮开始,每次都多增加几张不太相同的游戏截图,并打标签标注,标注内容同上(每轮训练给到 data/val 和 data/train 目录下的标注图片都尽量不相同,我上面的 copy_labelimgfiles_to_datadir.py 脚本中已经做到了这点,每次都是随机取一部分图片放到 data/val 和 data/train 目录下);
    • (3)如此类推,游戏这个游戏案例比较简单,所以经过几轮的训练之后,就已经能够精确识别到游戏中的所有目标了
    • 标注记录如下
# 得分 球 兵乓板 开始按钮 游戏开始对话框内的标题 重新开始按钮 游戏结束对话框中的本局得分 游戏结束对话框中的标题
['score', 'ball', 'bar', 'btn_start', 'dialog_title_start', 'btn_restart', 'final_socre', 'dialog_title_gameover'])

Pytorch强化学习初学者笔记 -- 实战 -- 使用强化学习(Reinforcement Learning)来让程序自动学习玩乒乓小游戏

深度学习代码在此!!!

  • 接下来我们要实现的目标就是:让程序自动训练从不会到掌握这个小游戏
  • 下面是一个基于 PyTorch 和 OpenAI Gym 的示例实现
    • (【注】:以下代码暂时记录而已,我还没为实践证实,只是个思路,其中一定存在很多问题,待后续抽空学习并实践之后再来修改)
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

# 定义游戏环境
class BallCatchingEnv(gym.Env):
    def __init__(self):
        super(BallCatchingEnv, self).__init__()
        self.action_space = gym.spaces.Discrete(3)  # 左移、不动、右移
        self.observation_space = gym.spaces.Box(low=np.array([0, 0]), high=np.array([100, 100]), dtype=np.float32)
        self.bar_position = 50
        self.ball_position = np.random.randint(0, 100)
        self.ball_velocity = np.random.uniform(-5, 5)

    def step(self, action):
        # 根据动作更新bar位置
        if action == 0:
            self.bar_position = max(0, self.bar_position - 5)
        elif action == 1:
            pass
        else:
            self.bar_position = min(100, self.bar_position + 5)

        # 更新球位置
        self.ball_position += self.ball_velocity
        if self.ball_position <= 0 or self.ball_position >= 100:
            self.ball_velocity = -self.ball_velocity

        # 检查是否接住球
        reward = 0
        done = False
        if self.ball_position >= self.bar_position and self.ball_position <= self.bar_position + 10:
            reward = 1
        else:
            reward = -1
            done = True

        return np.array([self.bar_position, self.ball_position]), reward, done, {}

    def reset(self):
        self.bar_position = 50
        self.ball_position = np.random.randint(0, 100)
        self.ball_velocity = np.random.uniform(-5, 5)
        return np.array([self.bar_position, self.ball_position])

# 定义 Q 网络
class QNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, output_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 定义 DQN 代理
class DQNAgent:
    def __init__(self, env, lr=0.001, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.01):
        self.env = env
        self.q_network = QNetwork(env.observation_space.shape[0], env.action_space.n)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

    def select_action(self, state):
        if np.random.rand() < self.epsilon:
            return self.env.action_space.sample()
        else:
            state = torch.from_numpy(state).float().unsqueeze(0)
            q_values = self.q_network(state)
            return q_values.argmax().item()

    def train(self, num_episodes):
        for episode in range(num_episodes):
            state = self.env.reset()
            done = False
            while not done:
                action = self.select_action(state)
                next_state, reward, done, _ = self.env.step(action)
                target = reward + self.gamma * self.q_network(torch.from_numpy(next_state).float().unsqueeze(0)).max()
                loss = (self.q_network(torch.from_numpy(state).float().unsqueeze(0))[0][action] - target) ** 2
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                state = next_state
            self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
            print(f"Episode {episode}, Epsilon: {self.epsilon:.2f}")

# 训练代理
env = BallCatchingEnv()
agent = DQNAgent(env)
agent.train(1000)
  • 这个示例使用了深度 Q 网络(DQN)算法来学习玩这个简单的小游戏。我们首先定义了游戏环境 BallCatchingEnv,然后创建了一个 Q 网络 QNetwork 和一个 DQN 代理 DQNAgent

在训练过程中,代理会探索环境并学习如何选择最佳动作来接住下落的小球。它使用 epsilon-greedy 策略,在训练初期更倾向于探索(随机选择动作),而在训练后期更倾向于利用(选择预测的最佳动作)。

通过运行 agent.train(1000),代理会在 1000 个回合中学习如何玩这个游戏。你可以观察代理的表现,并根据需要调整训练参数,如学习率、折扣因子和 epsilon 衰减率等。

总结

  • 这只是一个简单的示例,我感觉是比较适合初学者入门的。我也是为了逐步迈入机器学习这个领域而想到的这个小案例,如果我能现在掌握这个小案例,再逐步扩展的话,这样学习会更有趣更有动力更合理一些;毕竟一下子去实现太大的目标或者直接去看一整套乏味的学习资料或教学视频的话,会很容易产生迷茫甚至厌烦;
  • 后续可以根据需求再进行慢慢扩展和优化,比如添加更复杂的游戏规则、调整网络结构、尝试其他强化学习算法等。通过实践和探索,可以慢慢掌握深度学习在游戏AI中的应用
转载自:https://juejin.cn/post/7355393008965058610
评论
请登录