之前介绍的回溯法常用于 解空间的搜索 问题,即 找到一个或者所有满足约束条件的解,它通常是将解空间组织成树或者图,然后进行DFS(深度优先遍历)并注意在搜索的时候进行 剪枝操作。
但是状态空间搜索则是需要 找到一条从起始状态到终止状态的路径,其一般需要考虑一下问题:
下面用经典的搜索问题 八数码(九宫格问题) 举例,参考算法竞赛经典入门 第二版。
OJ例题:
题中所说要找到移动步数最少的路径,即我们需要从起始状态到目标状态进行BDS(广度优先搜索),而我们需要考虑以下问题:
如何表示状态
显然,最直接的方式是直接用一个 3X3 的二维矩阵,简化为一个 1X9 的一维数组。
怎么进行状态压缩和记忆化搜索?
我们可以直接申请一个9维数组vis
,然后根据vis[s0][s1][s2][s3][s4][s5][s6][s7][s8]
是否等于1来判重,需要的数组大小为 9 9 = 387 , 420 , 489 9 ^9 = 387,420,489 99=387,420,489 项,太多了,而且实际上最多的结点数也只有 0~8 的全排列 9 ! = 362 , 880 9! = 362,880 9!=362,880 项而已。
所以,如何进行状态压缩,常见右3种思路:
将每一种状态与一个整数编码一一对应起来,然后只开一个一维数组来判重。
而本题就是将 0 ~ 8 的排列数与 0 ~ 对应起来,常见的方式是康拓展开,即 我们将一个排列数与其在所有排列中的字典序一一对应起来,例如 0 <–> 0 , <–> 。
这种方法时间效率高,但是当状态空间的结点总数非常大时,编码也会很大,因为是一一对应的。
这种方法也是将状态映射成整数,但是不必是一一对应的。他可以映射到一个 [ 0 , M − 1 ] [0,M-1] [0,M−1] 范围内的整数,然后开一个 M 大小的数组来存储,相同哈希值的存放在一起,例如使用 链表 连在一起,称为一个 桶(bucket)。这种方法注意三点:
我们可以用一个STL中的<set>
集合来存储访问过的排列数来进行判重,但是,STL底层是基于红黑树的,其插入和查找的复杂度都在 O ( l o g n ) O(logn) O(logn) 而编码和哈希在最好情况下(哈希的冲突为0)是数组的直接索引,复杂度在 O ( 1 ) O(1) O(1)。当然,使用STL的代码比较简洁,我们可以先用它来实现判重,然后再验证程序其他部分的正确性,然后转化为编码或者哈希表。
先给出这个问题的BFS的大致框架:
typedef int State[9]; // 定义状态,九宫格 const int maxn = 0x7fffff; // 最多的可能状态 State st[maxn]; // 存储状态 State goal; // 目标状态 int fa[maxn]; // 存储状态的前一状态 char pre[maxn]; // 存储前一状态变化到当前状态所用的操作 const char op[5] = "udlr"; const int dx[] = {
-1,1,0,0 }; const int dy[] = {
0,0,-1,1 }; // 四个方向,上,下,左,右 void init_lookup_table(); int try_to_insert(int s); void printState(State& s); int bfs() {
// 若成功,则返回目标状态在状态数组中的位置 init_lookup_table(); // 初始化查找表 int front = 1, rear = 2; while (front < rear) {
State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作 //printState(s); if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功 int z; // 0 的位置,即空格 for (z = 0; s[z] != 0; z++); int x = z / 3, y = z % 3; for (int i = 0; i < 4; i++) {
int newx = x + dx[i]; int newy = y + dy[i]; int newz = newx * 3 + newy; if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) {
State& t = st[rear]; // 新状态 memcpy(t, s, sizeof(s)); t[newz] = s[z]; t[z] = s[newz]; fa[rear] = front; pre[rear] = op[i]; if (try_to_insert(rear)) rear++; // 此状态没有出现过 }// if }// for front++; } return 0; }
其中,init_look_table() 和 try_insert()
就是我们的判重操作,即初始化查找表和判断该状态是否已经搜索过。也就是我们上面所说的3种判重方式:
集合判重
set<int> vis; void init_lookup_table() {
vis.clear(); } int try_to_insert(int s) {
// 试图插入一个状态 State& ma = st[s]; int num = 0; // 转换为一个9位数 for (int i = 0; i < 9; i++) num = num * 10 + ma[i]; if (vis.count(num)) return 0; else {
vis.insert(num); return 1; } }
简单但是效率低。
哈希表
const int hashsize = 1e+6 + 3; // 哈希表的大小 int head[hashsize], Next[maxn]; // 哈希链表 void init_lookup_table() {
memset(head, 0, sizeof(head)); } int hashfunc(State& s) {
// 一个状态的哈希函数 int num = 0; for (int i = 0; i < 9; i++) num = num * 10 + s[i]; return num % hashsize; } int try_to_insert(int s) {
// 试图插入一个状态 State& ma = st[s]; int h = hashfunc(ma); int u = head[h]; // 查找状态 while (u) {
if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了 u = Next[u]; } // 头插法插入结点 Next[s] = head[h]; head[h] = s; return 1; } void printState(State& s) {
for (int i = 0; i < 9; i++) {
if(s[i]) printf("%d", s[i]); else printf("X"); if ((i + 1) % 3 == 0) printf("\n"); else printf(" "); } }
编码解码
int vis[], fact[9]; // 判重数组和阶乘 void init_lookup_table() {
memset(vis, 0, sizeof(vis)); fact[0] = 1; for (int i = 1; i < 9; i++) fact[i] = fact[i - 1] * i; } int canto(State& s) {
// 将一个状态转成康拓编码 int code = 0; for (int i = 0; i < 9; i++) {
int cnt = 0; // 计算逆序数 for (int j = i + 1; j < 9; j++) if (s[j] < s[i]) cnt++; code += cnt * fact[8 - i]; } return code; } int try_to_insert(int s) {
// 试图插入一个状态 int code = canto(st[s]); if (vis[code]) return 0; else return vis[code] = 1; }
示例AC代码
/* 八数码问题 BFS中状态空间搜索 */ #include<cstring> #include<cstdio> #include<vector> #include<set> #include<iostream> using namespace std; typedef int State[9]; // 定义状态,九宫格 const int maxn = 0x7fffff; // 最多的可能状态 const int hashsize = 1e+6 + 3; // 哈希表的大小 int head[hashsize], Next[maxn]; // 哈希链表 State st[maxn]; // 存储状态 State goal; // 目标状态 int fa[maxn]; // 存储状态的前一状态 char pre[maxn]; // 存储前一状态变化到当前状态所用的操作 const char op[5] = "udlr"; const int dx[] = {
-1,1,0,0 }; const int dy[] = {
0,0,-1,1 }; // 四个方向,上,下,左,右 void init_lookup_table() {
memset(head, 0, sizeof(head)); } int hashfunc(State& s) {
// 一个状态的哈希函数 int num = 0; for (int i = 0; i < 9; i++) num = num * 10 + s[i]; return num % hashsize; } int try_to_insert(int s) {
// 试图插入一个状态 State& ma = st[s]; int h = hashfunc(ma); int u = head[h]; // 查找状态 while (u) {
if (memcmp(ma, st[u], sizeof(ma)) == 0) return 0; // 已经存在了 u = Next[u]; } // 头插法插入结点 Next[s] = head[h]; head[h] = s; return 1; } void printState(State& s) {
for (int i = 0; i < 9; i++) {
if(s[i]) printf("%d", s[i]); else printf("X"); if ((i + 1) % 3 == 0) printf("\n"); else printf(" "); } } int bfs() {
// 若成功,则返回目标状态在状态数组中的位置 init_lookup_table(); // 初始化查找表 int front = 1, rear = 2; while (front < rear) {
State& s = st[front]; // 使用“引用”指向同一片内存,节省赋值操作 //printState(s); if (memcmp(s, goal, sizeof(s)) == 0) return front; // 成功 int z; // 0 的位置,即空格 for (z = 0; s[z] != 0; z++); int x = z / 3, y = z % 3; for (int i = 0; i < 4; i++) {
int newx = x + dx[i]; int newy = y + dy[i]; int newz = newx * 3 + newy; if (newx >= 0 && newx < 3 && newy >= 0 && newy < 3) {
State& t = st[rear]; // 新状态 memcpy(t, s, sizeof(s)); t[newz] = s[z]; t[z] = s[newz]; fa[rear] = front; pre[rear] = op[i]; if (try_to_insert(rear)) rear++; // 此状态没有出现过 }// if }// for front++; } return 0; } void printPath(int s) {
// 打印路径 if (s == 1) return; printPath(fa[s]); printf("%c", pre[s]); } int main() {
char c; for (int i = 0; i < 9; i++) {
// 初始状态 cin >> c; if (c == 'x') st[1][i] = 0; else st[1][i] = c - '0'; } for (int i = 0; i < 9; i++) goal[i] = i + 1; goal[8] = 0; // 目标状态 int ans = bfs(); if (ans == 0) printf("unsolvable\n"); else {
printPath(ans); printf("\n"); } return 0; } /*2 3 4 1 5 x 7 6 8*/