回溯算法的理解

定义

在程序设计中,有相当一类求一组解,或求全部解或求最优解的问题,例如读者熟悉的八皇后问题,不是根据某种特定的计算法则,而是利用试探和回溯的搜索技术求解。回溯法也是设计递归过程的一种重要方法,它的求解过程实质上是一个先序遍历一棵"状态树"的过程,只是这棵树不是遍历前预先建立的,而是隐含在遍历过程中。

---数据结构》(严蔚敏)

怎么理解这段话呢?

首先,某种问题的解我们很难去找规律计算出来,没有公式可循,只能列出所有可能的解,然后一个个检查每个解是否符合我们要找的条件,也就是通常说的遍历。而解空间很多是树型的,就是树的遍历。

其次,树的先序遍历,也就是根是先被检查的,二叉树的先序遍历是根,左子树,右子树的顺序被输出。如果把树看做一种特殊的图的话,DFS就是先序遍历。所以,回溯和DFS是联系非常紧密的,可以认为回溯是DFS的一种应用场景。另外,DFS有个好处,它只存储深度,不存储广度。所以空间复杂度较小,而时间复杂度较大。

最后,某些解空间是非常大的,可以认为是一个非常庞大的树,此时完全遍历的时间复杂度是难以忍受的。此时可以在遍历的同时检查一些条件,当遍历某分支的时候,若发现条件不满足,则退回到根节点进入下一个分支的遍历。这就是“回溯”这个词的来源。而根据条件有选择的遍历,叫做剪枝或分枝定界。

DFS

首先看DFS,下面是算法导论上DFS的伪代码,值得一行行的去品味。需要注意染色的过程,因为图有可能是有环的,所以需要记录那些节点被访问过了,那些没有,而树的遍历是没有染色过程的。而且它用 π[m]来记录m的父节点,也就可以记录DFS时的路径。

DFS(G)
1  for each vertex u ∈ V [G]
2       do color[u] ← WHITE
3          π[u] ← NIL
4  time ← 0
5  for each vertex u ∈ V [G]
6       do if color[u] = WHITE
7             then DFS-VISIT(u)
DFS-VISIT(u)
1  color[u] ← GRAY
2  time ← time +1
3  d[u] <-time
4  for each v ∈ Adj[u]
5       do if color[v] = WHITE
6             then π[v] ← u
7                        DFS-VISIT(v)
8  color[u] <-BLACK

例子

例一求幂集问题,就是返回一个集合所有的子集。为什么叫幂集呢?因为一个集合有n个元素,那么它的所有的子集数是2^n个。比如[1,2,3]的子集是[],[1],[2],[3],[1,2],[1,3],[2,3],[1,2,3]。

也就是下面这棵树的叶子节点:

回溯算法的理解

那问题就变成了如何输出一棵树的叶子节点。那就需要知道现在到底遍历到哪一层了。方法有很多,可以用全局变量记录,也可以用递归函数的参数记录。

A)这里是用全局变量记录,在进入函数的时候level++,退出函数的时候level--

[cpp] view plain copy
 print?
  1. int level=0;  
  2. vector<vector<int> > result;  
  3. vector<int> temp;  
  4. void dfs(vector<int>& S){  
  5.     level++;  
  6.     if(level>S.size()){  
  7.         result.push_back(temp);  
  8.         level--;  
  9.         return;  
  10.     }  
  11.     temp.push_back(S[level-1]);  
  12.     dfs(S);  
  13.     temp.pop_back();  
  14.     dfs(S);  
  15.     level--;  
  16.     return;  
  17. }  
  18. vector<vector<int> > subsets(vector<int>& S){  
  19.     sort(S.begin(),S.end());  
  20.     dfs(S);  
  21.     reverse(result.begin(),result.end());  
  22.     return result;  
  23. }  

B)这里记录层数用的是函数参数

[cpp] view plain copy
 print?
  1. vector<vector<int> > result;  
  2. vector<int> temp;  
  3. void dfs(vector<int>& S, int i){  
  4.     if(i==S.size()){  
  5.         result.push_back(temp);  
  6.         return;  
  7.     }  
  8.     temp.push_back(S[i]);  
  9.     dfs(S,i+1);  
  10.     temp.pop_back();  
  11.     dfs(S,i+1);  
  12.     return;  
  13. }  
  14. vector<vector<int> > subsets(vector<int>& S){  
  15.     dfs(S,0);  
  16.     reverse(result.begin(),result.end());  
  17.     return result;  
  18. }  

总结一下,伪代码就是:

void dfs(层数){

if(条件){

    输出;

}

else{

    左子树的处理;

    dfs(层数+1);

    右子树的处理;

    dfs(层数+1);

}

}

例二:皇后问题,比如8*8的棋盘,能摆放多少个皇后呢?国际象棋规则,皇后在同一行,同一列,同一斜线均可互相攻击。

伪代码如下:

int a[n];
void try(int i)
{
    if(i==n){
        输出结果;
         }
         else
         {
                   for(j = 下界; j <= 上界; j=j+1)  // 枚举i所有可能的路径
                   {
                            if(fun(j))                // 满足限界函数和约束条件
                            {
                                     a[i] = 1;
                                     ...                        // 其他操作
                                     try(i+1);
                                     a[j] = 0;
                            }
                   }
         }
 }

根据伪代码,写出最关键的一段代码如下。其中vector<vector<int> > m是全局变量,用来记录遍历轨迹,遍历前设上值,遍历后去掉。每一次调到output的时候,所有压入栈中的函数返回,都会调到m[level][i]=0;

[cpp] view plain copy
 print?
  1. void dfs(int level){    
  2.     if(level==N){    
  3.         output();    
  4.     }    
  5.     else{    
  6.         for(int i=0;i<N;i++){    
  7.             if(check(level+1,i+1)){    
  8.                 m[level][i]=1;    
  9.                 dfs(level+1);    
  10.                 m[level][i]=0;    
  11.             }    
  12.         }    
  13.     }    
  14. }    

完整代码:

[cpp] view plain copy
 print?
  1. int N;    
  2. vector<vector<int> > m;    
  3. vector<vector<string> > result;    
  4. bool check(int row,int column){    
  5.             if(row==1) return true;    
  6.             int i,j;    
  7.             for(i=0;i<=row-2;i++){    
  8.                 if(m[i][column-1]==1) return false;    
  9.             }    
  10.             i = row-2;    
  11.             j = i-(row-column);    
  12.             while(i>=0&&j>=0){    
  13.                 if(m[i][j]==1) return false;    
  14.                 i--;    
  15.                 j--;    
  16.             }    
  17.             i = row-2;    
  18.             j = row+column-i-2;    
  19.             while(i>=0&&j<=N-1){    
  20.                 if(m[i][j]==1) return false;    
  21.                 i--;    
  22.                 j++;    
  23.             }    
  24.             return true;    
  25.         }    
  26. void output()    
  27. {    
  28.     vector<string> vec;    
  29.     for(int i=0;i<N;i++){    
  30.         string s;    
  31.         for(int j=0;j<N;j++){    
  32.             if(m[i][j]==1)    
  33.                 s.push_back('Q');    
  34.             else    
  35.                 s.push_back('.');    
  36.         }    
  37.         vec.push_back(s);    
  38.     }    
  39.     result.push_back(vec);    
  40. }    
  41. void dfs(int level){    
  42.     if(level==N){    
  43.         output();    
  44.     }    
  45.     else{    
  46.         for(int i=0;i<N;i++){    
  47.             if(check(level+1,i+1)){    
  48.                 m[level][i]=1;    
  49.                 dfs(level+1);    
  50.                 m[level][i]=0;    
  51.             }    
  52.         }    
  53.     }    
  54. }    
  55. vector<vector<string> > solveNQueens(int n) {    
  56.     N=n;    
  57.     for(int i=0;i<n;i++){    
  58.         vector<int> a(n,0);    
  59.         m.push_back(a);    
  60.     }    
  61.     dfs(0);    
  62.     return result;    
  63. }  

例三:数独问题,就是给出一个数独,解决它。

比如给出:

回溯算法的理解

求解:

回溯算法的理解

解空间是这样的:

回溯算法的理解

由于数独都是9*9的,所以解空间有81层,每层有9个分支,我们做的就是遍历这个解空间。

如果只求一个解,那我们可以在得到解之后返回,而标记是否得到解可以用全局变量或返回值来做,

用全局变量的话,代码如下:

[cpp] view plain copy
 print?
  1. bool flag= false;  
  2. bool check(int k, vector<vector<char> > &board){  
  3.         int x=k/9;  
  4.         int y=k%9;  
  5.         for (int i = 0; i < 9; i++)  
  6.             if (i != x && board[i][y] == board[x][y])  
  7.                 return false;  
  8.         for (int j = 0; j < 9; j++)  
  9.             if (j != y && board[x][j] == board[x][y])  
  10.                 return false;  
  11.         for (int i = 3 * (x / 3); i < 3 * (x / 3 + 1); i++)  
  12.             for (int j = 3 * (y / 3); j < 3 * (y / 3 + 1); j++)  
  13.                 if (i != x && j != y && board[i][j] == board[x][y])  
  14.                     return false;  
  15.         return true;  
  16.     }  
  17. void dfs(int num,vector<vector<char> > &board){  
  18.     if(num==81){  
  19.         flag=true;  
  20.         return;  
  21.     }  
  22.     else{  
  23.         int x=num/9;  
  24.         int y=num%9;  
  25.         if(board[x][y]=='.'){  
  26.             for(int i=1;i<=9;i++){  
  27.                 board[x][y]=i+'0';  
  28.                 if(check(num,board)){  
  29.                     dfs(num+1,board);  
  30.                     if(flag)  
  31.                         return;  
  32.                 }  
  33.             }  
  34.             board[x][y]='.';  
  35.         }  
  36.         else{  
  37.             dfs(num+1,board);  
  38.         }  
  39.     }  
  40. }  
  41. void solveSudoku(vector<vector<char> > &board) {  
  42.     dfs(0,board);  
  43. }  
用返回值的话,关键部分做一下修改就可以了:

[cpp] view plain copy
 print?
  1. bool f(int i, vector<vector<char> > &board){  
  2.        if(i==n*m)  
  3.            return true;  
  4.        if(board[i/n][i%m]=='.'){  
  5.            for(int k=1;k<=9;k++){  
  6.                board[i/n][i%m]=k+'0';  
  7.                    if(check(i,board) && f(i+1,board))  
  8.                            return true;  
  9.            }  
  10.            board[i/n][i%m]='.';  
  11.            return false;  
  12.        }  
  13.        else  
  14.            return f(i+1,board);  
  15.    }  

要求得到所有解的话,可以在解出现的时候存下来:

[cpp] view plain copy
 print?
  1. vector<vector<vector<char> >> sum;  
  2. bool check(int k, vector<vector<char> > &board){  
  3.         int x=k/9;  
  4.         int y=k%9;  
  5.         for (int i = 0; i < 9; i++)  
  6.             if (i != x && board[i][y] == board[x][y])  
  7.                 return false;  
  8.         for (int j = 0; j < 9; j++)  
  9.             if (j != y && board[x][j] == board[x][y])  
  10.                 return false;  
  11.         for (int i = 3 * (x / 3); i < 3 * (x / 3 + 1); i++)  
  12.             for (int j = 3 * (y / 3); j < 3 * (y / 3 + 1); j++)  
  13.                 if (i != x && j != y && board[i][j] == board[x][y])  
  14.                     return false;  
  15.         return true;  
  16.     }  
  17. void dfs(int num,vector<vector<char> > &board){  
  18.     if(num==81){  
  19.         sum.push_back(board);  
  20.         return;  
  21.     }  
  22.     else{  
  23.         int x=num/9;  
  24.         int y=num%9;  
  25.         if(board[x][y]=='.'){  
  26.             for(int i=1;i<=9;i++){  
  27.                 board[x][y]=i+'0';  
  28.                 if(check(num,board)){  
  29.                     dfs(num+1,board);  
  30.                     //if(flag)  
  31.                       //  return;  
  32.                 }  
  33.             }  
  34.             board[x][y]='.';  
  35.         }  
  36.         else{  
  37.             dfs(num+1,board);  
  38.         }  
  39.     }  
  40. }  
  41. void solveSudoku(vector<vector<char> > &board) {  
  42.     dfs(0,board);  
  43. }  
  44. int main()  
  45. {  
  46.     vector<string> myboard({"...748...","7........",".2.1.9...","..7...24.",".64.1.59.",".98...3..","...8.3.2.","........6","...2759.."});  
  47.     vector<char> temp(9,'.');  
  48.     vector<vector<char> > board(9,temp);  
  49.     for(int i=0;i<myboard.size();i++){  
  50.         for(int j=0;j<myboard[i].length();j++){  
  51.             board[i][j]=myboard[i][j];  
  52.         }  
  53.     }  
  54.     solveSudoku(board);  
  55.     for(int k=0;k<sum.size();k++){  
  56.     for(int i=0;i<sum[k].size();i++){  
  57.         for(int j=0;j<sum[k][i].size();j++){  
  58.             cout<<sum[k][i][j]<<" ";  
  59.         }  
  60.         cout<<endl;  
  61.     }  
  62.     cout<<"######"<<endl;  
  63.     }  
  64.     cout<<"sum is "<<sum.size()<<endl;  
  65.     cout << "Hello world!" << endl;  
  66.     return 0;  
  67. }  

最终,我们得到了8个解。

wiki上有一张图片形象的表达了这个回溯的过程:

回溯算法的理解