匈牙利算法解决加权二分图问题
匈牙利方法是一种组合优化算法,它在多项式时间内解决了赋值问题,广泛应用于多目标跟踪的关联问题中。
图1:(a)二分图,(b)边权重矩阵,(c)边成本的替代表示形式
动机:分配问题
假设有 辆卡车每辆可装载一种产品以及 家商店,这些商店愿意以矩阵 代表的不同价格购买 种不同的产品。分配问题:若商店 提出以 美元从卡车 购买产品,我们如何指派每辆卡车 去商店 ,以便在所有可能的任务中最大化合并利润?
-
一般问题:给定 ,,矩阵 中 是将 分配给 的权重,找到将每个 分配给某个 的匹配,使得总权重最大化。
-
朴素的算法为:遍历所有 种可能的分配,选择得分最高的。
-
假设:。
-
可以将问题视为完全加权的二分图 :
-
分配是完美匹配:问题简化为最大化权重,找到完美匹配。
匈牙利算法的基础(Kuhn-Munkres)
定义
-
图 的标记是函数 ,这样:
-
相等子图为子图 ,固定标签函数 ,则
Kuhn-Munkres 定理
定理2.1:给定标记 ,如果 是 上的完美匹配,则 是 的最大权重匹配。
- 假设 是 中的任意完美匹配。通过定义标记函数,由于 是完美的,
- 这意味着: 是 的任何完美匹配 的上界。
- 现在看看匹配 的权重:
- 通过1. 和3. 中等式可得 的所有完美匹配 满足:
关键点:通过 Kuhn-Munkres 定理,找到最大权重分配的问题简化为找到正确的标记函数和相应的相等子图上的任何完美匹配。
增加匹配
给定标记 ,, 上的一些匹配 ,未匹配的顶点 。
- 如果路径在 和 之间交替,并且路径的第一个和最后一个顶点在 中未匹配,则该路径是 在 上的增广路径。记录从 开始的“近似”增广路径。
- 如果我们能找到一个不匹配的顶点 ,那么我们创建从 到 的增广路径 。
- 通过将 中的边替换为增广路径中 中的边来翻转匹配。
- 由于起始和终止顶点未匹配,这增加了匹配的大小
改进标签
- , 且 、 “几乎”代表当前匹配 与外部其他边 之间的增广交替路径。
- Let Nl(S) be the neighbors to each node in S along El. Nl(S) = {v|∀u ∈ S : (u, v) ∈ El}
设 为沿 的 中每个节点的邻居。
- 如果 我们不能增加交替路径并增广,所以我们必须改进标签!
- 计算:
- 改进 :
- 声明: 是有效标签且 。
- 通过检验元素 所有可能的情况证明。
Kuhn-Munkres 算法
- 从一些匹配 开始,并且有效标记
- 执行以下操作直到 完美匹配:
(a) 寻找增广路径
(b) 如果不存在增广路径,则改进 并转到步骤 (a)。
复杂度
- 每个步骤(a)或(b)每轮增加1个匹配的边,因此总轮数为 。
- 增加 :找到正确的顶点(如果存在)需要 ,翻转匹配为 。
- 改进标签:找到 并更新标签的计算量为 。如果没有找到增广路径,则改进标签可能发生 次。所以在一轮中总共为 。
- 轮次且每次执行 ,总运行时间为 。
max_cost_assignment
C++SORT 使用 Dlib 中的 max_cost_assignment 函数完成关联匹配。
输入参数cost_
是只读的,所以复制一份。
检查矩阵元素是否为整型。
const_temp_matrix<EXP> cost(cost_);
typedef typename EXP::type type;
// This algorithm only works if the elements of the cost matrix can be reliably
// compared using operator==. However, comparing for equality with floating point
// numbers is not a stable operation. So you need to use an integer cost matrix.
COMPILE_TIME_ASSERT(std::numeric_limits<type>::is_integer);
DLIB_ASSERT(cost.nr() == cost.nc(),
"\t std::vector<long> max_cost_assignment(cost)"
<< "\n\t cost.nr(): " << cost.nr()
<< "\n\t cost.nc(): " << cost.nc()
);
参考链接已打不开,算法复杂度为 。
using namespace dlib::impl;
/*
I based the implementation of this algorithm on the description of the
Hungarian algorithm on the following websites:
http://www.math.uwo.ca/~mdawes/courses/344/kuhn-munkres.pdf
http://www.topcoder.com/tc?module=Static&d1=tutorials&d2=hungarianAlgorithm
Note that this is the fast O(n^3) version of the algorithm.
*/
if (cost.size() == 0)
return std::vector<long>();
lx
即 ,ly
即 。xy
和yx
表示 的完美匹配,即匹配 。, 表示匹配 和其他边 之间的当前“近似”增广交替路径。
std::vector<type> lx, ly;
std::vector<long> xy;
std::vector<long> yx;
std::vector<char> S, T;
std::vector<type> slack;
std::vector<long> slackx;
std::vector<long> aug_path;
初始时 。
// Initially, nothing is matched.
xy.assign(cost.nc(), -1);
yx.assign(cost.nc(), -1);
/*
We maintain the following invariant:
Vertex x is matched to vertex xy[x] and
vertex y is matched to vertex yx[y].
A value of -1 means a vertex isn't matched to anything. Moreover,
x corresponds to rows of the cost matrix and y corresponds to the
columns of the cost matrix. So we are matching X to Y.
*/
// Create an initial feasible labeling. Moreover, in the following
// code we will always have:
// for all valid x and y: lx[x] + ly[y] >= cost(x,y)
lx.resize(cost.nc());
ly.assign(cost.nc(),0);
for (long x = 0; x < cost.nr(); ++x)
lx[x] = max(rowm(cost,x));
逐顶点搜索。队列q
用于存储 BFS 的变量。
每次搜索前重置 ,以及slack
、slackx
。slack
存储顶点可行的标签。S
和T
记录是否选中顶点。
// Now grow the match set by picking edges from the equality subgraph until
// we have a complete matching.
for (long match_size = 0; match_size < cost.nc(); ++match_size)
{
std::deque<long> q;
// Empty out the S and T sets
S.assign(cost.nc(), false);
T.assign(cost.nc(), false);
// clear out old slack values
slack.assign(cost.nc(), std::numeric_limits<type>::max());
slackx.resize(cost.nc());
/*
slack and slackx are maintained such that we always
have the following (once they get initialized by compute_slack() below):
- for all y:
- let x == slackx[y]
- slack[y] == lx[x] + ly[y] - cost(x,y)
*/
aug_path.assign(cost.nc(), -1);
遍历列,尝试找到一个未匹配的顶点x
,调用 compute_slack 函数。这里是不是应该使用cost.nr()
?
for (long x = 0; x < cost.nc(); ++x)
{
// If x is not matched to anything
if (xy[x] == -1)
{
q.push_back(x);
S[x] = true;
compute_slack(x, slack, slackx, cost, lx, ly);
break;
}
}
从q
中取出一个x
,尝试找到一个未匹配的y
。
long x_start = 0;
long y_start = 0;
// Find an augmenting path.
bool found_augmenting_path = false;
while (!found_augmenting_path)
{
while (q.size() > 0 && !found_augmenting_path)
{
const long x = q.front();
q.pop_front();
for (long y = 0; y < cost.nc(); ++y)
{
if (cost(x,y) == lx[x] + ly[y] && !T[y])
{
// if vertex y isn't matched with anything
if (yx[y] == -1)
{
y_start = y;
x_start = x;
found_augmenting_path = true;
break;
}
对于已匹配的y
,沿其路继续搜寻。
T[y] = true;
q.push_back(yx[y]);
aug_path[yx[y]] = x;
S[yx[y]] = true;
compute_slack(yx[y], slack, slackx, cost, lx, ly);
}
}
}
if (found_augmenting_path)
break;
未找到增广路径,改进 :
// Since we didn't find an augmenting path we need to improve the
// feasible labeling stored in lx and ly. We also need to keep the
// slack updated accordingly.
type delta = std::numeric_limits<type>::max();
for (unsigned long i = 0; i < T.size(); ++i)
{
if (!T[i])
delta = std::min(delta, slack[i]);
}
for (unsigned long i = 0; i < T.size(); ++i)
{
if (S[i])
lx[i] -= delta;
if (T[i])
ly[i] += delta;
else
slack[i] -= delta;
}
改进标签后,如果松弛量为0的边未匹配,则找到增广路径;否则,继续沿其搜索。
q.clear();
for (long y = 0; y < cost.nc(); ++y)
{
if (!T[y] && slack[y] == 0)
{
// if vertex y isn't matched with anything
if (yx[y] == -1)
{
x_start = slackx[y];
y_start = y;
found_augmenting_path = true;
break;
}
else
{
T[y] = true;
if (!S[yx[y]])
{
q.push_back(yx[y]);
aug_path[yx[y]] = slackx[y];
S[yx[y]] = true;
compute_slack(yx[y], slack, slackx, cost, lx, ly);
}
}
}
}
} // end while (!found_augmenting_path)
沿着增广路径反转边。
// Flip the edges along the augmenting path. This means we will add one more
// item to our matching.
for (long cx = x_start, cy = y_start, ty;
cx != -1;
cx = aug_path[cx], cy = ty)
{
ty = xy[cx];
yx[cy] = cx;
xy[cx] = cy;
}
}
return xy;
compute_slack
compute_slack 函数对slack
进行更新,计算顶点x
的松弛量并记录对应的y
。
for (long y = 0; y < cost.nc(); ++y)
{
if (lx[x] + ly[y] - cost(x,y) < slack[y])
{
slack[y] = lx[x] + ly[y] - cost(x,y);
slackx[y] = x;
}
}