状态空间搜索法是谁发明的_记忆化搜索和动态规划

(46) 2024-08-27 17:01:01

文章目录

        • 1. 前言
        • 2. 问题举例(九宫格问题)
        • 3. 问题分析
          • 3.1 状态编码与解码
          • 3.2 哈希映射
          • 3.3 集合判重
        • 4. 问题实现
        • 推荐阅读
1. 前言

之前介绍的回溯法常用于 解空间的搜索 问题,即 找到一个或者所有满足约束条件的解,它通常是将解空间组织成树或者图,然后进行DFS(深度优先遍历)并注意在搜索的时候进行 剪枝操作。

但是状态空间搜索则是需要 找到一条从起始状态到终止状态的路径,其一般需要考虑一下问题:

  • 状态的表示,即我们怎样表示一个状态。
  • 状态的转移,即通过研究初始状态和目标状态的 差别 ,我们定义怎样的操作来进行状态的转移。
  • 状态的压缩和记忆化搜索,即我们如何压缩一个状态的表示,使得 我们能够存储已经搜索过的状态的结果。这样能够避免大量重复状态的搜索。
2. 问题举例(九宫格问题)

下面用经典的搜索问题 八数码(九宫格问题) 举例,参考算法竞赛经典入门 第二版。
状态空间搜索法是谁发明的_记忆化搜索和动态规划 (https://mushiming.com/)  第1张
OJ例题:

  • POJ 1077 Eight
  • HDUOJ 1043 需要离线打表
3. 问题分析

题中所说要找到移动步数最少的路径,即我们需要从起始状态到目标状态进行BDS(广度优先搜索),而我们需要考虑以下问题:

如何表示状态
显然,最直接的方式是直接用一个 3X3 的二维矩阵,简化为一个 1X9 的一维数组。

怎么进行状态压缩和记忆化搜索?
我们可以直接申请一个9维数组vis,然后根据vis[s0][s1][s2][s3][s4][s5][s6][s7][s8] 是否等于1来判重,需要的数组大小为 9 9 = 387 , 420 , 489 9 ^9 = 387,420,489 99=387,420,489 项,太多了,而且实际上最多的结点数也只有 0~8 的全排列 9 ! = 362 , 880 9! = 362,880 9!=362,880 项而已。

所以,如何进行状态压缩,常见右3种思路:

3.1 状态编码与解码

将每一种状态与一个整数编码一一对应起来,然后只开一个一维数组来判重。
而本题就是将 0 ~ 8 的排列数与 0 ~ 对应起来,常见的方式是康拓展开,即 我们将一个排列数与其在所有排列中的字典序一一对应起来,例如 0 <–> 0 , <–> 。

这种方法时间效率高,但是当状态空间的结点总数非常大时,编码也会很大,因为是一一对应的。

3.2 哈希映射

这种方法也是将状态映射成整数,但是不必是一一对应的。他可以映射到一个 [ 0 , M − 1 ] [0,M-1] [0,M1] 范围内的整数,然后开一个 M 大小的数组来存储,相同哈希值的存放在一起,例如使用 链表 连在一起,称为一个 桶(bucket)。这种方法注意三点:

  • 哈希表的大小M设置为多少,一般M越大,冲突的概率会比较小。
  • 哈希函数怎么设置,即如何将状态映射成整数。在 哈希表中,哈希函数的作用很关键,一个设置良好的哈希函数应该保证哈希值的冲突尽可能少。这里,我们的哈希函数可以直接将状态映射成一个9位的排列数,然后对M取余。
  • 冲突怎么解决,我们可以将相同哈希值的元素用链表连起来,也可以设置一个规则,如果冲突了,则向后或者向前移动几位等等。
3.3 集合判重

我们可以用一个STL中的<set>集合来存储访问过的排列数来进行判重,但是,STL底层是基于红黑树的,其插入和查找的复杂度都在 O ( l o g n ) O(logn) O(logn) 而编码和哈希在最好情况下(哈希的冲突为0)是数组的直接索引,复杂度在 O ( 1 ) O(1) O(1)。当然,使用STL的代码比较简洁,我们可以先用它来实现判重,然后再验证程序其他部分的正确性,然后转化为编码或者哈希表。

4. 问题实现

先给出这个问题的BFS的大致框架:

typedef int State[9]; // 定义状态,九宫格 const int maxn = 0x7fffff; // 最多的可能状态 State st[maxn]; // 存储状态 State goal; // 目标状态 int fa[maxn]; // 存储状态的前一状态 char pre[maxn]; // 存储前一状态变化到当前状态所用的操作 const char op[5] = "udlr"; const int dx[] = { 
    -1,1,0,0 }; const int dy[] = { 
    0,0,-1,1 }; // 四个方向,上,下,左,右 void init_lookup_table(); int try_to_insert(int s); void printState(State& s); int bfs() { 
    // 若成功,则返回目标状态在状态数组中的位置 init_lookup_table(); // 初始化查找表 int front = 1, rear = 2; while (front < rear) { 
    State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作 //printState(s); if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功 int z; // 0 的位置,即空格 for (z = 0; s[z] != 0; z++); int x = z / 3, y = z % 3; for (int i = 0; i < 4; i++) { 
    int newx = x + dx[i]; int newy = y + dy[i]; int newz = newx * 3 + newy; if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) { 
    State& t = st[rear]; // 新状态 memcpy(t, s, sizeof(s)); t[newz] = s[z]; t[z] = s[newz]; fa[rear] = front; pre[rear] = op[i]; if (try_to_insert(rear)) rear++; // 此状态没有出现过 }// if }// for front++; } return 0; } 

其中,init_look_table() 和 try_insert() 就是我们的判重操作,即初始化查找表和判断该状态是否已经搜索过。也就是我们上面所说的3种判重方式:

集合判重

set<int> vis; void init_lookup_table() { 
    vis.clear(); } int try_to_insert(int s) { 
    // 试图插入一个状态 State& ma = st[s]; int num = 0; // 转换为一个9位数 for (int i = 0; i < 9; i++) num = num * 10 + ma[i]; if (vis.count(num)) return 0; else { 
    vis.insert(num); return 1; } } 

简单但是效率低。

哈希表

const int hashsize = 1e+6 + 3; // 哈希表的大小 int head[hashsize], Next[maxn]; // 哈希链表 void init_lookup_table() { 
    memset(head, 0, sizeof(head)); } int hashfunc(State& s) { 
    // 一个状态的哈希函数 int num = 0; for (int i = 0; i < 9; i++) num = num * 10 + s[i]; return num % hashsize; } int try_to_insert(int s) { 
    // 试图插入一个状态 State& ma = st[s]; int h = hashfunc(ma); int u = head[h]; // 查找状态 while (u) { 
    if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了 u = Next[u]; } // 头插法插入结点 Next[s] = head[h]; head[h] = s; return 1; } void printState(State& s) { 
    for (int i = 0; i < 9; i++) { 
    if(s[i]) printf("%d", s[i]); else printf("X"); if ((i + 1) % 3 == 0) printf("\n"); else printf(" "); } } 

编码解码

int vis[], fact[9]; // 判重数组和阶乘 void init_lookup_table() { 
    memset(vis, 0, sizeof(vis)); fact[0] = 1; for (int i = 1; i < 9; i++) fact[i] = fact[i - 1] * i; } int canto(State& s) { 
    // 将一个状态转成康拓编码 int code = 0; for (int i = 0; i < 9; i++) { 
    int cnt = 0; // 计算逆序数 for (int j = i + 1; j < 9; j++) if (s[j] < s[i]) cnt++; code += cnt * fact[8 - i]; } return code; } int try_to_insert(int s) { 
    // 试图插入一个状态 int code = canto(st[s]); if (vis[code]) return 0; else return vis[code] = 1; } 

示例AC代码

/* 八数码问题 BFS中状态空间搜索 */ #include<cstring> #include<cstdio> #include<vector> #include<set> #include<iostream> using namespace std; typedef int State[9]; // 定义状态,九宫格 const int maxn = 0x7fffff; // 最多的可能状态 const int hashsize = 1e+6 + 3; // 哈希表的大小 int head[hashsize], Next[maxn]; // 哈希链表 State st[maxn]; // 存储状态 State goal; // 目标状态 int fa[maxn]; // 存储状态的前一状态 char pre[maxn]; // 存储前一状态变化到当前状态所用的操作 const char op[5] = "udlr"; const int dx[] = { 
    -1,1,0,0 }; const int dy[] = { 
    0,0,-1,1 }; // 四个方向,上,下,左,右 void init_lookup_table() { 
    memset(head, 0, sizeof(head)); } int hashfunc(State& s) { 
    // 一个状态的哈希函数 int num = 0; for (int i = 0; i < 9; i++) num = num * 10 + s[i]; return num % hashsize; } int try_to_insert(int s) { 
    // 试图插入一个状态 State& ma = st[s]; int h = hashfunc(ma); int u = head[h]; // 查找状态 while (u) { 
    if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了 u = Next[u]; } // 头插法插入结点 Next[s] = head[h]; head[h] = s; return 1; } void printState(State& s) { 
    for (int i = 0; i < 9; i++) { 
    if(s[i]) printf("%d", s[i]); else printf("X"); if ((i + 1) % 3 == 0) printf("\n"); else printf(" "); } } int bfs() { 
    // 若成功,则返回目标状态在状态数组中的位置 init_lookup_table(); // 初始化查找表 int front = 1, rear = 2; while (front < rear) { 
    State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作 //printState(s); if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功 int z; // 0 的位置,即空格 for (z = 0; s[z] != 0; z++); int x = z / 3, y = z % 3; for (int i = 0; i < 4; i++) { 
    int newx = x + dx[i]; int newy = y + dy[i]; int newz = newx * 3 + newy; if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) { 
    State& t = st[rear]; // 新状态 memcpy(t, s, sizeof(s)); t[newz] = s[z]; t[z] = s[newz]; fa[rear] = front; pre[rear] = op[i]; if (try_to_insert(rear)) rear++; // 此状态没有出现过 }// if }// for front++; } return 0; } void printPath(int s) { 
    // 打印路径 if (s == 1) return; printPath(fa[s]); printf("%c", pre[s]); } int main() { 
    char c; for (int i = 0; i < 9; i++) { 
   // 初始状态 cin >> c; if (c == 'x') st[1][i] = 0; else st[1][i] = c - '0'; } for (int i = 0; i < 9; i++) goal[i] = i + 1; goal[8] = 0; // 目标状态 int ans = bfs(); if (ans == 0) printf("unsolvable\n"); else { 
    printPath(ans); printf("\n"); } return 0; } /*2 3 4 1 5 x 7 6 8*/ 
推荐阅读
  • 《算法竞赛入门经典 》7.5 路径搜索问题
  • 八数码的八种境界
THE END

发表回复