匈牙利算法解决加权二分图问题

匈牙利方法是一种组合优化算法,它在多项式时间内解决了赋值问题,广泛应用于多目标跟踪的关联问题中。
匈牙利算法解决加权二分图问题
图1:(a)二分图,(b)边权重矩阵,(c)边成本的替代表示形式

动机:分配问题

假设有 nn 辆卡车每辆可装载一种产品以及 nn 家商店,这些商店愿意以矩阵 WW 代表的不同价格购买 nn 种不同的产品。分配问题:若商店 yiy_i 提出以 WijW_{ij} 美元从卡车 xix_i 购买产品,我们如何指派每辆卡车 xix_i 去商店 yjy_j,以便在所有可能的任务中最大化合并利润?

  1. 一般问题:给定 X=x1,,xnX = {x_1, \dots , x_n}Y=y1,,ynY = {y_1, \dots , y_n},矩阵 WWWij=weight(xi,yj)W_{ij} = \mathrm{weight}(x_i, y_j ) 是将 xix_i 分配给 yiy_i 的权重,找到将每个 xix_i 分配给某个 yjy_j 的匹配,使得总权重最大化。

  2. 朴素的算法为:遍历所有 n!n! 种可能的分配,选择得分最高的。

  3. 假设:i,j1,,n:Wij0\forall i, j \in {1, \dots , n} : Wij \geq 0

  4. 可以将问题视为完全加权的二分图 G=(V,E)G = (V, E)
    V=XYE=(xi,yj)xiX,yjYweight(xi,yj)=Wij \begin{aligned} &V = X \cup Y \\ &E = {(x_i, y_j )}_{x_i\in X,y_j\in Y} \\ &\mathrm{weight}(x_i, y_j) = W_{ij} \end{aligned}

  5. 分配是完美匹配:问题简化为最大化权重,找到完美匹配。

匈牙利算法的基础(Kuhn-Munkres)

定义

  1. G=(V,E)G = (V, E) 的标记是函数 l:VRl : V \rightarrow R,这样:
    (u,v)E:l(u)+l(v)weight((u,v)) \forall(u, v) \in E : l(u) + l(v) \geq \mathrm{weight}((u, v))

  2. 相等子图为子图 Gl=(V,El)G=(V,E)G_l = (V, E_l) \subseteq G = (V, E),固定标签函数 ll,则

El=(u,v)E:l(u)+l(v)=weight((u,v)) E_l = {(u, v) \in E : l(u) + l(v) = \mathrm{weight}((u, v))}

Kuhn-Munkres 定理

定理2.1:给定标记 ll,如果 MMGlG_l 上的完美匹配,则 MMGG 的最大权重匹配。

  1. 假设 M0M_0GG 中的任意完美匹配。通过定义标记函数,由于 M0M_0 是完美的,
    weight(M)=(u,v)Mweight((u,v))(u,v)Ml(u)+l(v)=vVl(v) \mathrm{weight}(M') =\sum_{(u,v)\in M'} \mathrm{weight}((u, v)) \leq \sum_{(u,v)\in M'}l(u) + l(v) = \sum_{v \in V}l(v)
  2. 这意味着:vVl(v)\sum_{v\in V}l(v)GG 的任何完美匹配 MM' 的上界。
  3. 现在看看匹配 MM 的权重:
    weight(M)=(u,v)Mweight((u,v))=(u,v)Ml(u)+l(v)=vVl(v) \mathrm{weight}(M) = \sum_{(u,v)\in M} \mathrm{weight}((u, v)) = \sum_{(u,v)\in M}l(u) + l(v)= \sum_{v \in V}l(v)
  4. 通过1. 和3. 中等式可得 GG 的所有完美匹配 MM' 满足:
    weight(M)weight(M) \mathrm{weight}(M) \geq \mathrm{weight}(M')

关键点:通过 Kuhn-Munkres 定理,找到最大权重分配的问题简化为找到正确的标记函数和相应的相等子图上的任何完美匹配。

增加匹配

给定标记 llGl=(V,El)G_l= (V,E_l)GlG_l 上的一些匹配 MM,未匹配的顶点 uV,uMu\in V, u\notin M

  1. 如果路径在 ElME_l-MMM 之间交替,并且路径的第一个和最后一个顶点在 MM 中未匹配,则该路径是 MMGlG_l 上的增广路径。记录从 uu 开始的“近似”增广路径。
  2. 如果我们能找到一个不匹配的顶点 vv,那么我们创建从 uuvv 的增广路径 α\alpha
  3. 通过将 MM 中的边替换为增广路径中 ElME_l-M 中的边来翻转匹配。
  4. 由于起始和终止顶点未匹配,这增加了匹配的大小     M>M\implies | M'|> | M |

改进标签

  1. SXS\subseteq XTYT\subseteq YSSTT “几乎”代表当前匹配 MM 与外部其他边 ElME_{l-M} 之间的增广交替路径。
  2. Let Nl(S) be the neighbors to each node in S along El. Nl(S) = {v|∀u ∈ S : (u, v) ∈ El}
    Nl(S)N_l(S) 为沿 ElE_lSS 中每个节点的邻居。
    Nl(S)={vuS:(u,v)El} N_l(S) = \{ v\mid \forall u \in S : (u, v) \in E_l\}
  3. 如果 Nl(S)=TN_l(S) = T 我们不能增加交替路径并增广,所以我们必须改进标签!
  4. 计算:δl=minuS,vT{l(u)+l(v)weight((u,v))}\delta_l = \min_{u\in S,v\notin T}\{l(u) + l(v) − \mathrm{weight}((u, v))\}
  5. 改进 lll\rightarrow l'
    l(r)={l(r)δlif rSl(r)+δlif rTl(r)otherwise  \begin{aligned} l'(r) = \begin{cases} l(r) − \delta_l &\text{if } r \in S \\ l(r) + \delta_l &\text{if } r \in T \\ l(r) &\text{otherwise } \end{cases} \end{aligned}
  6. 声明:ll' 是有效标签且 ElElE_l\subset E_{l'}
  7. 通过检验元素 uS,uS,vT,vTu \in S, u \notin S, v \in T, v \notin T 所有可能的情况证明。

Kuhn-Munkres 算法

  1. 从一些匹配 MM 开始,并且有效标记
    l::=xX,yY:l(y)=0,l(x)=maxyY(weight(x,y)) l ::= \forall x \in X, y \in Y : l(y) = 0, l(x) = \max_{y'\in Y} (\text{weight}(x, y'))
  2. 执行以下操作直到 MM 完美匹配:
    (a) 寻找增广路径
    (b) 如果不存在增广路径,则改进 lll\rightarrow l' 并转到步骤 (a)。

复杂度

  1. 每个步骤(a)或(b)每轮增加1个匹配的边,因此总轮数为 O(V=2n)=O(n)O(|V | = 2n) = O(n)
  2. 增加 MM:找到正确的顶点(如果存在)需要 O(V)O(|V |) ,翻转匹配为 O(V)O(|V |)
  3. 改进标签:找到 δl\delta_l 并更新标签的计算量为 O(V)O(|V |) 。如果没有找到增广路径,则改进标签可能发生 O(V)O(|V |) 次。所以在一轮中总共为 O(V2)O(|V |^2)
  4. O(V)O(|V|) 轮次且每次执行 O(V2)O(|V |^2),总运行时间为 O(V3)=O(n3)O(|V|^3) = O(n^3)

max_cost_assignment

HungarianAlgorithm
Augment the matching
Improve the labeling

C++SORT 使用 Dlib 中的 max_cost_assignment 函数完成关联匹配。

输入参数cost_是只读的,所以复制一份。
检查矩阵元素是否为整型。

        const_temp_matrix<EXP> cost(cost_);
        typedef typename EXP::type type;
        // This algorithm only works if the elements of the cost matrix can be reliably 
        // compared using operator==. However, comparing for equality with floating point
        // numbers is not a stable operation. So you need to use an integer cost matrix.
        COMPILE_TIME_ASSERT(std::numeric_limits<type>::is_integer);
        DLIB_ASSERT(cost.nr() == cost.nc(),
            "\t std::vector<long> max_cost_assignment(cost)"
            << "\n\t cost.nr(): " << cost.nr()
            << "\n\t cost.nc(): " << cost.nc()
            );

参考链接已打不开,算法复杂度为 O(n3)O(n^3)

        using namespace dlib::impl;
        /*
            I based the implementation of this algorithm on the description of the
            Hungarian algorithm on the following websites:
                http://www.math.uwo.ca/~mdawes/courses/344/kuhn-munkres.pdf
                http://www.topcoder.com/tc?module=Static&d1=tutorials&d2=hungarianAlgorithm
            Note that this is the fast O(n^3) version of the algorithm.
        */

        if (cost.size() == 0)
            return std::vector<long>();

lxl(x)l(x)lyl(y)l(y)xyyx表示 GlG_l 的完美匹配,即匹配 MMSX,TYS \subseteq X, T\subseteq YS,TS, T 表示匹配 MM 和其他边 ElME_l-M 之间的当前“近似”增广交替路径。

        std::vector<type> lx, ly;
        std::vector<long> xy;
        std::vector<long> yx;
        std::vector<char> S, T;
        std::vector<type> slack;
        std::vector<long> slackx;
        std::vector<long> aug_path;

初始时 M=M = \emptyset

        // Initially, nothing is matched. 
        xy.assign(cost.nc(), -1);
        yx.assign(cost.nc(), -1);
        /*
            We maintain the following invariant:
                Vertex x is matched to vertex xy[x] and
                vertex y is matched to vertex yx[y].
                A value of -1 means a vertex isn't matched to anything.  Moreover,
                x corresponds to rows of the cost matrix and y corresponds to the
                columns of the cost matrix.  So we are matching X to Y.
        */

weight(M)=(u,v)Mweight((u,v))(u,v)Ml(u)+l(v)=vVl(v) \mathrm{weight}(M&#x27;) =\sum_{(u,v)\in M&#x27;} \mathrm{weight}((u, v)) \leq \sum_{(u,v)\in M&#x27;}l(u) + l(v) = \sum_{v \in V}l(v)

        // Create an initial feasible labeling.  Moreover, in the following
        // code we will always have: 
        //     for all valid x and y:  lx[x] + ly[y] >= cost(x,y)
        lx.resize(cost.nc());
        ly.assign(cost.nc(),0);
        for (long x = 0; x < cost.nr(); ++x)
            lx[x] = max(rowm(cost,x));

逐顶点搜索。队列q用于存储 BFS 的变量。
每次搜索前重置 S,TS, T,以及slackslackx
slack存储顶点可行的标签。ST记录是否选中顶点。

        // Now grow the match set by picking edges from the equality subgraph until
        // we have a complete matching.
        for (long match_size = 0; match_size < cost.nc(); ++match_size)
        {
            std::deque<long> q;

            // Empty out the S and T sets
            S.assign(cost.nc(), false);
            T.assign(cost.nc(), false);

            // clear out old slack values
            slack.assign(cost.nc(), std::numeric_limits<type>::max());
            slackx.resize(cost.nc());
            /*
                slack and slackx are maintained such that we always
                have the following (once they get initialized by compute_slack() below):
                    - for all y:
                        - let x == slackx[y]
                        - slack[y] == lx[x] + ly[y] - cost(x,y)
            */

            aug_path.assign(cost.nc(), -1);

遍历列,尝试找到一个未匹配的顶点x,调用 compute_slack 函数。这里是不是应该使用cost.nr()

            for (long x = 0; x < cost.nc(); ++x)
            {
                // If x is not matched to anything
                if (xy[x] == -1)
                {
                    q.push_back(x);
                    S[x] = true;

                    compute_slack(x, slack, slackx, cost, lx, ly);
                    break;
                }
            }

q中取出一个x,尝试找到一个未匹配的y

            long x_start = 0;
            long y_start = 0;

            // Find an augmenting path.  
            bool found_augmenting_path = false;
            while (!found_augmenting_path)
            {
                while (q.size() > 0 && !found_augmenting_path)
                {
                    const long x = q.front();
                    q.pop_front();
                    for (long y = 0; y < cost.nc(); ++y)
                    {
                        if (cost(x,y) == lx[x] + ly[y] && !T[y])
                        {
                            // if vertex y isn't matched with anything
                            if (yx[y] == -1) 
                            {
                                y_start = y;
                                x_start = x;
                                found_augmenting_path = true;
                                break;
                            }

对于已匹配的y,沿其路继续搜寻。

                            T[y] = true;
                            q.push_back(yx[y]);

                            aug_path[yx[y]] = x;
                            S[yx[y]] = true;
                            compute_slack(yx[y], slack, slackx, cost, lx, ly);
                        }
                    }
                }

                if (found_augmenting_path)
                    break;

未找到增广路径,改进 lll\rightarrow l&#x27;
l(r)={l(r)δlif rSl(r)+δlif rTl(r)otherwise  \begin{aligned} l&#x27;(r) = \begin{cases} l(r) − \delta_l &amp;\text{if } r \in S \\ l(r) + \delta_l &amp;\text{if } r \in T \\ l(r) &amp;\text{otherwise } \end{cases} \end{aligned}

                // Since we didn't find an augmenting path we need to improve the 
                // feasible labeling stored in lx and ly.  We also need to keep the
                // slack updated accordingly.
                type delta = std::numeric_limits<type>::max();
                for (unsigned long i = 0; i < T.size(); ++i)
                {
                    if (!T[i])
                        delta = std::min(delta, slack[i]);
                }
                for (unsigned long i = 0; i < T.size(); ++i)
                {
                    if (S[i])
                        lx[i] -= delta;

                    if (T[i])
                        ly[i] += delta;
                    else
                        slack[i] -= delta;
                }

改进标签后,如果松弛量为0的边未匹配,则找到增广路径;否则,继续沿其搜索。

                q.clear();
                for (long y = 0; y < cost.nc(); ++y)
                {
                    if (!T[y] && slack[y] == 0)
                    {
                        // if vertex y isn't matched with anything
                        if (yx[y] == -1)
                        {
                            x_start = slackx[y];
                            y_start = y;
                            found_augmenting_path = true;
                            break;
                        }
                        else
                        {
                            T[y] = true;
                            if (!S[yx[y]])
                            {
                                q.push_back(yx[y]);

                                aug_path[yx[y]] = slackx[y];
                                S[yx[y]] = true;
                                compute_slack(yx[y], slack, slackx, cost, lx, ly);
                            }
                        }
                    }
                }
            } // end while (!found_augmenting_path)

沿着增广路径反转边。

            // Flip the edges along the augmenting path.  This means we will add one more
            // item to our matching.
            for (long cx = x_start, cy = y_start, ty; 
                 cx != -1; 
                 cx = aug_path[cx], cy = ty)
            {
                ty = xy[cx];
                yx[cy] = cx;
                xy[cx] = cy;
            }

        }


        return xy;

compute_slack

compute_slack 函数对slack进行更新,计算顶点x的松弛量并记录对应的y
δl=minuS,vTl(u)+l(v)weight((u,v)) \delta_l= \min_{u\in S,v \notin T }{l(u) +l(v)−weight((u,v))}

            for (long y = 0; y < cost.nc(); ++y)
            {
                if (lx[x] + ly[y] - cost(x,y) < slack[y])
                {
                    slack[y] = lx[x] + ly[y] - cost(x,y);
                    slackx[y] = x;
                }
            }

参考资料: