硬着学AI之搜索算法
什么是搜索算法
如果不知道搜索算法,可以先看下面的视频
为什么学AI要学搜索算法
什么? 不生成漂亮小姐姐的也配叫AI?不讲ChatGPT, Midjourney, Stable Diffusion也配叫AI? 是的,目前以大语言模型、文生图/视频为代表的生成式AI正在风口上,但如果你跟我一样好奇,想要弄明白这些技术背后的基础知识,想要系统的学习AI,想要知道神经网络、深度学习、大模型、机器怎么推理、怎么产生智能等等等等,这篇文章是我的一个开始,也可能是你的。
应用场景
搜索算法被用在AI的各个子领域中,比如搜索算法让曾经红极一时的Alpha狗能够下棋。
在自动驾驶领域,机器自动寻找起点和终点间的最优的路径。
当然在大语言模型中也会运用到,比如GPT在预测下一个token的时候,会采用束搜索(Beam Search)算法。
引言
这篇文章中,我们将通过编写一个拼图游戏,和一个简单的下棋游戏,通过动手练习的方式学习搜索算法。当然没有办法在文章中涵盖所有的搜索算法,但你会有一个基本认识,以及会为你实际编程时,如何建模提供一个基本的思维框架。
拼图游戏
我们要做什么?
我们有这样一个拼图游戏:
先随机打乱100次,变成了这样:
我们要做的就是让机器通过搜索算法找到复原这个拼图的步骤。当然不能把拼图块抠出来哈,只能每次移动空白附近的拼图块。
往下看之前可以自己先尝试实现一下,至少先思考一下要怎么实现。
基本概念
状态(State)
我们很容易想到,可以将一个拼图游戏采用二维数组进行编码:
比如用如下代码来表示:
[
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24]
]
其中24
代表空白,在拼图游戏中,我们通过二维数组来编码,编码的这个东西有个更通用的名字,叫做 「状态」。对于不同的问题,状态的编码可能不一样,但是对于所有搜索问题,都可以去定义状态,比如下棋,迷宫游戏等状态的形式都不一样。
随机打乱后,是一个混乱的状态,比如有可能是这样:
[
[0, 6, 1, 3, 4],
[10, 5, 2, 7, 8],
[15, 11, 12, 9, 13],
[24, 16, 17, 18, 14],
[20, 21, 22, 23, 19]
]
我们的目标就是要从一个打乱的状态,让其恢复成其初始状态。
在Typescript中通过如下的方式来定义一个State
export default class State {
COL = 5;
ROW = 5; //定义拼图的行数和列数
perfect: number[][] = [
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
] //复原的拼图的状态
constructor(inital?: State) {
if (inital) {
this._copy(inital._state); //通过另外一个状态初始化的时候复制一份
} else {
this._copy(this.perfect);
}
}
private _copy(target: number[][]) {
for (let i = 0; i < this.ROW; i++) {
for (let j = 0; j < this.COL; j++) {
this._state[i][j] = target[i][j];
}
}
}
private _state: number[][] = [
[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19],
[20, 21, 22, 23, 24],
]; //实际表示状态的数组
}
动作(Action)
状态通过行为发生改变,比如拼图游戏中,我们有 上、下、左、右移动空白四种动作,每次移动就是把空白和对应的非空白图块交换,每次交换过后就会产生一个新的状态。
为了体现State之间的关系,我们将State包装在一个叫做Nodee(多了个e是因为跟Node冲突了,其实就是节点)的结构中,Node作为一个辅助类,记录上一个状态通过什么action到了目前的状态:
export class Nodee {
parent?: Nodee; //当前节点的上一个节点
state: State; //当前节点的状态
action?: (node: Nodee) => Nodee; //到达当前节点的Action
constructor(state: State, parent?: Nodee, action?: (node: Nodee) => Nodee) {
this.state = state;
this.parent = parent;
this.action = action;
}
matches(node: Nodee): boolean {
return this.state.identity() == node.state.identity();
} //用于判断两个节点是否是同一个节点
}
所以我们的四个Action就对应着如下四个函数:
export const up = (node: Nodee): Nodee => {
const s = new State(node.state);
s.move(-1, 0); //将节点中空白上移
return new Nodee(s, node, up); //返回新的状态
}
export const down = (node: Nodee): Nodee => {
const s = new State(node.state);
s.move(1, 0); //将节点中空白下移
return new Nodee(s, node, down); //返回新的状态
}
export const left = (node: Nodee): Nodee => {
const s = new State(node.state);
s.move(0, -1); //将节点中空白左移
return new Nodee(s, node, left); //返回新的状态
}
export const right = (node: Nodee): Nodee => {
const s = new State(node.state);
s.move(0, 1); //将节点中空白右移
return new Nodee(s, node, right); //返回新的状态
}
State中的move函数如下:
move(x: number, y: number) {
const { i, j } = this.spacePosition(); //找到当前空白位置在哪
this._move(i, j, x, y); //将空白位置移动
}
spacePosition(): { i: number, j: number } {
for (let i = 0; i < this.ROW; i++) {
for (let j = 0; j < this.COL; j++) {
if (this._state[i][j] == 24) {
return { i, j };
}
}
}
return { i: this.COL - 1, j: this.ROW - 1 };
}
_move(i: number, j: number, x: number, y: number) {
const newI = i + x;
const newJ = j + y;
if (newI >= this.ROW || newJ >= this.COL) {
throw Error("unable to move outside of the puzzle");
}
//交换新旧位置的值
const val = this._state[i][j];
this._state[i][j] = this._state[newI][newJ];
this._state[newI][newJ] = val;
}
我们可以编写一个actions函数来判断一个具体的节点可以采取的action, 比如边缘的地方只有3个action,角落的地方只有两个。
export const actions = (node: Nodee): ((node: Nodee) => Nodee)[] => {
const { i, j } = node.state.spacePosition();
const actions: ((node: Nodee) => Nodee)[] = [];
if (i != 0) {
actions.push(up);
}
if (i != node.state.COL - 1) {
actions.push(down);
}
if (j != 0) {
actions.push(left);
}
if (j != node.state.ROW - 1) {
actions.push(right);
}
return actions;
}
针对每个特定的状态,通过调用actions方法得到最多4个action,根据这些action,能够衍生出最多4个新的状态,因此存在一颗巨大的状态树(下图是其中一部分):
因此我们的目标就是搜索这棵树,从一个状态出发,沿着所有的分支,找到一个最终状态(Terminal State),这个最终状态就是打乱前的状态,也就复原了拼图。
深度优先搜索(Depth-first Search)
几乎所有程序员都知道深度优先搜索,沿着子树,一直搜索到最深,一直到叶子节点或者找到最终状态为止,然后再搜索下一棵子树。
通过一个栈(先进后出的结构),我们可以实现深度优先搜索。具体流程如下:
注意这里把遍历过的记录在了一个叫做Explored的节点中,主要为了防止遍历的时候走回头路。
广度优先搜索(Breadth-first Search)
广度优先搜索,按照当前节点的所有子节点进行遍历,然后再遍历对应子节点的子节点。
采用队列(先进先出的结构),我们可以实现广度优先搜索。具体流程如下:
算法优化
一个5x5的拼图,总共有1.55 * 1025种状态,如果采用朴素的深度优先搜索或者广度优先搜索,运气不好的话,可能需要遍历所有状态才能找得到最终状态,明显不可取。
那应该如何优化呢?不管是深度优先还是广度优先,搜索的方向是盲目的,但其实没有必要盲目搜索,如果在遍历的时候,我们能知道接下来要遍历的状态的一些信息,借助这些信息我们可以选择方向,而不是盲目的向前。
启发函数
启发函数是一种有效的提供信息的手段,算法优化的程度很多时候也取决于启发函数的选择,在拼图游戏中,我们可以选择一个状态到最终状态的曼哈顿距离(Manhattan Distance) 作为启发函数。比如下图中,我们知道24代表的空白在最终状态中,应该在二维数组的最后一个元素上,那这个单一的图块的曼哈顿距离就是二维坐标中的横轴加竖轴的距离:
整个状态到最终状态的曼哈顿距离就是当前状态状态中所有位置的曼哈顿距离之和。在State类中,我们可以通过如下函数返回一个状态的曼哈顿距离:
manhattanDistance(): number {
let distance = 0;
for (let i = 0; i < this.ROW; i++) {
for (let j = 0; j < this.COL; j++) {
if (this._state[i][j] !== this.perfect[i][j]) { //perfect为最终状态
const val = this._state[i][j];
const correctI = val / this.ROW;
const correctJ = val % this.COL;
distance += Math.abs(correctI - i) + Math.abs(correctJ - j);
}
}
}
return distance;
}
贪婪最佳优先算法(Greddy best-first Search)
贪婪最佳优先算法就是在之前深度优先或者广度优先搜索的时候,采用贪心策略,每次都朝着启发函数最小的方向搜索,如果我们采用的曼哈顿距离作为启发函数,那就是每次都朝曼哈顿距离最小的状态搜索,因为曼哈顿距离越小,意味着更接近于最终状态,离最终状态越近的概率也就越高。
实现贪婪优先最佳,我们可以使用优先队列(有序的队列)来实现,其具体工作流程如下:
为了方便代码实现,我们先文字描述版算法(也可用伪代码表述):
- 初始状态:将起点节点放入到优先队列(启发函数值越小的在队首)中。
- 如果队列不为空,从队首取一个状态,为空则返回结束搜索。
- 检查是否是最终状态,如果是则返回,如果不是则到步骤4。
- 把当前状态的所有子状态放到队列中。
- 重复2-4。
代码实现如下:
const expored: string[] = []; //用于存放已经遍历过的,防止走回头路
const frontier: Nodee[] = []; //用于优先队列
frontier.push(node); //初始的时候放入当前节点(打乱后的节点)
await resolve(frontier, expored); //调用算法流程
const resolve = async (frontier: Nodee[], expored: string[]) => {
while (frontier.length > 0) { //确保队列不为空
frontier.sort((n1, n2) => {
return n2.hueristic() - n1.hueristic();
}); //按启发函数排序,模拟优先队列
const fNode = frontier.pop(); //取除队首节点
if (fNode == undefined) {
return;
}
if (fNode.state.isTerminal()) { //判断是否是最终状态
await delay(10); //为了界面显示可以看到动画,与算法无关
setNode(fNode); //为了界面显示可以看到动画,记录结果node,也可以返回。
return;
}
expored.push(fNode.state.identity()); //放入explored集合中,仅存id
await delay(10); //为了界面显示可以看到动画,与算法无关
setNode(fNode); //为了界面显示可以看到动画,与算法无关
const nodes = actions(fNode).map(action => {
return action(fNode);
}); //获取当前节点的子节点,当前节点的所有action调用当前节点后生成的新状态
nodes.forEach(n => {
if (!expored.includes(n.state.identity()) //没有遍历过
//且当前不在队列中
&& !frontier.map(f => { return f.state.identity() }).includes(n.state.identity())) {
frontier.push(n); //放入队列
}
});
}
}
A星算法(A* Search)
采用贪婪最佳优先算法并不能保证在所有情况下都能有最佳的效果,有时候,需要引入另外一个评估函数来参与到启发的过程中,比如评估函数可以是一个状态距离当前搜索状态在树中的距离,具体实现可以在放入队列的时候将距离递增进行记录。通过评估函数和启发函数一起参与计算,能够进一步提高搜索的方向的准确性。
小结
找到最终状态的节点后,通过节点中记录的parent,我们可以进行爬树,从而确定一条最优路径。
下棋
拼图游戏中,只有一个参与方,像下棋这样的对弈游戏里面,通常有多个参与方,一个参与方尝试打败另一个参与方,这种情况下,搜索算法怎么帮助计算机下棋呢?
我们通过一个简单的下棋游戏来学习,这个游戏的名字叫做TicTakToe,游戏规则很简单,两个选手,交替出棋,谁棋连成三个谁就赢。
先看一段VCR:
我们要做什么
我们通过制作一个简易的TicTakToe游戏,人机博弈,计算机通过搜索算法来跟我们下棋,从而学习另外一种搜索算法,叫做Minimax算法,用于这种有两个参与方,而且互相影响的情况下。
学会了拼图游戏后,对搜索算法有了基本认识了,进一步往下阅读前,自己尝试先利用上面的思维框架,尝试自己编写一下,至少尝试去定义其中的相关概念,比如状态、动作等。
基本概念定义
状态(State)
同样的,我们需要将棋盘的状态编码表示出来(我们定义一个选手为1=MAX, 另一个选手为-1=MIN,未落子的地方为0):
所以一个棋盘也可以通过二维数组来表示:
export const MAX = 1;
export const MIN = -1;
export const ROWS = 3;
export const COLS = 3;
export default class State {
private board: number[][] = []; //表示棋盘的二维数组
constructor() {
//初始化为0
for (let i = 0; i < ROWS; i++) {
const arr = [];
for (let j = 0; j < COLS; j++) {
arr.push(0);
}
this.board.push(arr);
}
}
}
拼图游戏中,我们有一个最终状态,来决定搜索是否结束,同样的,我们需要定义一个最终状态,或者叫做结束状态,这里结束状态就是任意一个选手连成了3个,或者棋盘被填满了打成平局:
terminal(): boolean {
//检查行是否连成一条线
for (let i = 0; i < ROWS; i++) {
let res = this.allNonZeroEq(this.board[i]); //判断是否全相等(且不为0)
if (res) {
return res;
}
}
//检查列是连成一条线
for (let j = 0; j < COLS; j++) {
const a = [];
for (let i = 0; i < ROWS; i++) {
a.push(this.board[i][j]);
}
let res = this.allNonZeroEq(a);
if (res) {
return res;
}
}
//检查斜线(左)是否连成一条线
const leftCross = [];
for (let i = 0, j = 0; i < ROWS && j < COLS; i++, j++) {
leftCross.push(this.board[i][j]);
}
if (this.allNonZeroEq(leftCross)) {
return true;
}
//检查斜线(右)是否连成一条线
const rightCross = [];
for (let i = 0, j = COLS - 1; i < ROWS && j >= 0; i++, j--) {
rightCross.push(this.board[i][j]);
}
if (this.allNonZeroEq(rightCross)) {
return true;
}
//检查棋盘是否填满
return this.count(0) == 0;
}
//判断是否全相等(且不为0)
allNonZeroEq(arr: number[]): boolean {
let res = arr[0] == arr[1];
for (let i = 2; i < arr.length; i++) {
res = res && arr[i] == arr[i - 1];
}
if (!res) {
return res;
}
return res && arr[0] != 0;
}
知道最终状态后,我们还需要知道谁赢了,所以我们需要给状态打分,如果先连成一条线的是1=MAX,则状态为1分,如果-1=MIN先连成一条线,则状态为-1分,如果平局,则状态为0分。对于非最终状态的额分数如何计算,稍后我们会涉及。
所以MAX选手的目的就是让棋盘的分数最大化,MIN选手的目标就是让棋盘分数最小化,这也是Minimax算法名字的由来。
计算分数的函数如下:
score(): number {
for (let i = 0; i < ROWS; i++) {
let res = this.allNonZeroEq(this.board[i]);
if (res) {
return this.board[i][0]; //返回连续行的第一个值
}
}
for (let j = 0; j < COLS; j++) {
const a = [];
for (let i = 0; i < ROWS; i++) {
a.push(this.board[i][j]);
}
let res = this.allNonZeroEq(a);
if (res) {
return a[0]; //返回连续列的第一个值
}
}
const leftCross = [];
for (let i = 0, j = 0; i < ROWS && j < COLS; i++, j++) {
leftCross.push(this.board[i][j]);
}
if (this.allNonZeroEq(leftCross)) {
return leftCross[0]; //返回连续斜线(左)的第一个值
}
const rightCross = [];
for (let i = 0, j = COLS - 1; i < ROWS && j >= 0; i++, j--) {
rightCross.push(this.board[i][j]);
}
if (this.allNonZeroEq(rightCross)) {
return rightCross[0]; //返回连续斜线(右)的第一个值
}
return 0; //其他情况分数为0
}
动作(Action)
棋盘中选手落子的动作会产生新的状态:
所以动作的定义就是落子动作,落子有两个关键因素,位置和谁落的子(选手代表的值,1还是-1)。所以可以用如下结构:
export default class Action {
i: number;
j: number; //i,j代表位置
player: number; //player代表1还是-1
constructor(x: number, y: number, p: number) {
this.i = x;
this.j = y;
this.player = p;
}
}
类似的,给定一个状态,我们可以通过一个actions函数获取到所有可能得状态:
export function actions(state: State): Action[] {
const p = player(state); //判断下一步该谁出棋
return state.empties() //获取所有空白地方
.map(e => new Action(e[0], e[1], p)); //每个空白地方都可以是下一步的Action
}
export function player(state: State): number {
return state.whosTurn(); //获取谁是下一步
}
State {
empties(): number[][] {
let es = [];
for (let i = 0; i < ROWS; i++) {
for (let j = 0; j < COLS; j++) {
if (this.board[i][j] == 0) {
es.push([i, j]);
}
}
}
return es;
}
whosTurn(): number {
let countMin = 0;
let countMax = 0;
for (let i = 0; i < ROWS; i++) {
for (let j = 0; j < COLS; j++) {
if (this.board[i][j] == MIN) {
countMin++;
}
if (this.board[i][j] == MAX) {
countMax++;
}
}
}
return countMin < countMax ? MIN : MAX;
}
}
Minimax算法
我们采用一种稍微不一样的方式来描绘状态树:
用向上的三角形表示当前是MaxPlayer, 向下的箭头表示当前是Min Player,叶子节点(最终状态)上会有每个状态的分数,对于下三角形(Min Player)则选择最小分数作为自己的分数,上三角形(Max Player)则选择最大分数作为自己的分数,这样对于每一个未结束的状态也有一个分数。
因此,在每一步的时候,不同选手就有了不同的搜索方向,也就知道了下一步落子的位置(保证自己赢的概率最大)。
对于棋盘比较大的时候,一个棋盘的状态树也是天文数字,所以不可能无限制的搜索到叶子层,因此需要限制树的深度:
限制深度带来的问题是,我们没法通过叶子节点递归算出非叶子节点的分数,所以这时候我们需要引入类似A星算法的评估函数,评估函数评估当前谁更接近连成一条线来确定分数。
//State中
hueristicScore(): number {
const hMax = this.maxCountsPattern(MAX); //计算最大连续的MAX个数
const hMin = this.maxCountsPattern(MIN); //计算最大连续的MIN个数
if (hMax == hMin) {
return 0;
}
return hMax > hMin ? 1 : -1; //看谁大
}
maxCountsPattern(target: number) {
let max = 0;
for (let i = 0; i < ROWS; i++) {
let res = this.countsIn(this.board[i], target);
if (res > max) {
max = res;
}
}
for (let j = 0; j < COLS; j++) {
const a = [];
for (let i = 0; i < ROWS; i++) {
a.push(this.board[i][j]);
}
let res = this.countsIn(a, target);
if (res > max) {
max = res;
}
}
const leftCross = [];
for (let i = 0, j = 0; i < ROWS && j < COLS; i++, j++) {
leftCross.push(this.board[i][j]);
}
const mL = this.countsIn(leftCross, target);
if (mL > max) {
max = mL;
}
const rightCross = [];
for (let i = 0, j = COLS - 1; i < ROWS && j >= 0; i++, j--) {
rightCross.push(this.board[i][j]);
}
const mR = this.countsIn(rightCross, target);
if (mR > max) {
max = mR;
}
return max;
}
所以Minimax的实现如下:
export function min(state: State): Action { //用于min player
const acts = actions(state); //获取当前棋盘能走的所有Action
const depth = 4; //指定一个搜索深度
let act = acts[0];
const s = state.clone();
let v = maxValue(result(s, act), depth); //下一个是MAX,选他的最大值里面的最小值,result函数根据当前状态,应员工一个action,生成一个新的状态
for (let i = 1; i < acts.length; i++) {
const s = state.clone();
const m = maxValue(result(s, acts[i]), depth);
if (m < v) { //寻找最小值
v = m;
act = acts[i];
}
}
return act;
}
export function max(state: State): Action { //用于max player
const acts = actions(state); //获取当前棋盘能走的所有Action
const depth = 4; //指定一个搜索深度
let act = acts[0];
const s = state.clone();
let v = minValue(result(s, act), depth); //下一个是MIN,选他的最小值里面的最大值,result函数根据当前状态,应员工一个action,生成一个新的状态
for (let i = 1; i < acts.length; i++) {
const s = state.clone();
const m = minValue(result(s, acts[i]), depth);
if (m > v) { //寻找最大值
v = m;
act = acts[i];
}
}
return act;
}
export function maxValue(state: State, depth: number): number {
if (state.terminal()) {
return state.score();
}
if (depth == 0) {
return state.hueristicScore();
}
let v = -100;
const acts = actions(state);
for (const act of acts) {
const s = state.clone();
const m = minValue(result(s, act), depth - 1);
if (m > v) {
v = m;
}
}
return v;
}
export function minValue(state: State, depth: number): number {
if (state.terminal()) {
return state.score();
}
if (depth == 0) {
return state.hueristicScore();
}
let v = 100;
const acts = actions(state);
for (const act of acts) {
const s = state.clone();
const m = maxValue(result(s, act), depth - 1);
if (m < v) {
v = m;
}
}
return v;
}
export function result(state: State, action: Action): State {
const newState = state.clone();
newState.play(action.i, action.j, action.player);
return newState;
}
总结
通过两个简单的游戏实现,我们学习了搜索算法的基础,搜索算法被大量应用在AI的各个子领域,当然不是AI的地方搜索算法也有用武之地。
文中出于对算法工作过程理解的目的,没有考虑其他方面的优化,比如空间复杂度等等。
下一篇文章我们按照这样的方式学习机器如何产生进行逻辑推理。(先给自己立个flag),文中不足之处,欢迎在评论区指正和探讨。
转载自:https://juejin.cn/post/7357739240753086504