【字符串】AC自动机(kmp+字典树)hdu2222

AC自动机:用公共前缀来减少查询时间,减少无谓的字符串比较时间,此处的公共前缀就是kmp的next数组构造。查询一个文本串中模式串出现的次数,需要构建字典树,而这里如何将两者联系到一起呢?就是将字典树上的next数组构造变成了给每个节点找fail指针。

关键点1:字典树构造

 

【字符串】AC自动机(kmp+字典树)hdu2222

比方说我们有五个模板串:she shr say he her依次插入,根节点没有字母,标为0,插入she,从根节点出发,每个节点代表一个字母,直到e,在e这个节点cnt加一表示这里是一个字符串结束的地方,再插入shr,还是从根节点出发,找找看有没有s,好,找到了s,那么顺着它走下去,再找h又有,那么再顺着往下走,接着发现没有r,那就新开一个节点r,shr走完,在r这个地方cnt加1。剩下的也是这样插入。

代码实现:

	void insert(char *S)
	{
		int n = strlen(S);
		int now = 0;//从根节点出发
		for (int i = 0;i < n;i++)
		{
			char c = S[i];
			if (!statetable[now].next[c - 'a'])
				statetable[now].next[c - 'a'] = size++;//当前节点下没有我们的下一个字符,
					//那么我们就新开一个,size为它在字典树上的标号
			now = statetable[now].next[c - 'a'];//如果有这个字符,那么直接顺着走,当前节点移到下一个
		}
		statetable[now].cnt++;//整个S插完后,在最后一个字符的地方cnt++
	}

关键点2:找fail指针

1.在kmp算法中,我们有一个next数组,当比较到有一个字符失配的时候,通过next数组找到下一个开始匹配的位置,当然这里是单模式的匹配,一个文本串,和仅有一个模式串。

2.在AC自动机中,当比较到有一个字符失配的时候,我们通过一个类似于next数组的叫做fail指针的东西来找到下一个开始匹配的位置,这里是多模式的匹配。一个文本串,和多个模式串。

假设当前节点为father,找child的fail指针时,我们首先要找到father的fail指针t,如果t下面有和child字符相同的字符时,那child的fail就指向这个相同的字符。如果没有那么再找father->fail->fail,直到找到为止,若是一直都没有找到,最后这个fail已经到了-1

那么就让fail指向root,即fail=0

 

【字符串】AC自动机(kmp+字典树)hdu2222

fail指针是用bfs来求的。首先我们让root的fail指向-1,因为没有标为-1这样的点。让root入队。

然后root出队,它下面节点h.s的fail指针直接指向root,让h  s入队。然后h出队,找它下面节点的fail指针,是e,但发现h的fail为root 所以e的fail也指向root,e入队。

s出队,找a情况和之前的e一样,指向root,a入队。找h发现s的fail指向root,而root下面有h,很好,这时我们就让h指向之前这个h,如图中的蓝线,h入队。

e出队,它下面的r也会指向root,r入队。

a出队,y的fail也是指向root,y入队。

h出队,找e的fail,发现h有fail指向一个h,这个h下面恰好也有e,那么就让e指向之前这个e,如图中绿线,e入队。r也是指向root,r入队。

后面的也是如此。

代码如下:

	void build()
	{
		statetable[0].fail = -1;
		que.push(0);//root的fail为-1,root入队

		while (que.size())
		{
			int u = que.front();
			que.pop();
			for (int i = 0;i < 26;i++)
			{
				if (statetable[u].next[i])
				{
					if (u == 0)
						statetable[statetable[u].next[i]].fail = 0;
					else
					{
						int v = statetable[u].fail;
						while (v != -1)
						{
							if (statetable[v].next[i])
							{
								statetable[statetable[u].next[i]].fail = statetable[v].next[i];
								break;
							}
							v = statetable[v].fail;//如果没有这个字符那就一直找father->fail->fail
						}
						if (v == -1)//一直没找到已经到了root,那么这个字符的fail指向root
							statetable[statetable[u].next[i]].fail = 0;
					}
					que.push(statetable[u].next[i]);
				}
			}
		}
	}

关键点3: 查询

现在我们要查询一个S串,那么我们就遍历这个串开始一个一个匹配,分两种情况。

1.当前字符匹配:那就顺着走,如果遇到某一个字符它是某一个模式串的结尾,那就找这个结尾的fail指针,一直遍历到根节点,把它们的cnt都加上。

2.当前字符不匹配:就找当前节点的fail指针,一直失配就一直遍历到根。

代码如下:

	int get(int u)//比方说现在有一个she已经在文本串上匹配成功,
	{           //那么再去找比she短但有公共后缀的模式串,像he肯定已经成功了,所以一直加就好。
		int res = 0;
		while (u)
		{
			res += statetable[u].cnt;
			statetable[u].cnt = 0;
			u = statetable[u].fail;
		}
		return res;
}
	int match(char*S)
	{
		int n = strlen(S);
		int res = 0, now = 0;
		for (int i = 0;i < n;i++)
		{
			char c = S[i];
			if (statetable[now].next[c - 'a'])
				now = statetable[now].next[c - 'a'];
			else
			{
				int p = statetable[now].fail;
				while (p != -1 && statetable[p].next[c - 'a'] == 0)
					p = statetable[p].fail;
				if (p == -1)
					now = 0;
				else
					now = statetable[p].next[c - 'a'];
			}
			if (statetable[now].cnt)
				res = res + get(now);
		}
		return res;
	}

整个AC自动机就是这样啦!一共就四个函数 init insert build match&get。 

完整代码:

#include<iostream>
#include<queue>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define MAX_N 1000006
#define MAX_tot 500005

using namespace std;
struct aho {
	struct state {
		int next[26];
		int fail, cnt;
	}statetable[MAX_tot];
	int size;
	std::queue<int>que;

	void init()
	{
		while (que.size()) que.pop();
		for (int i = 0;i < MAX_tot;i++)
		{
			memset(statetable[i].next, 0, sizeof(statetable[i].next));
			statetable[i].fail = 0;
			statetable[i].cnt = 0;
		}
		size = 1;
	}
	void insert(char*S)
	{
		int n = strlen(S);
		int now = 0;
		for (int i = 0;i < n;i++)
		{
			char c = S[i];
			if (!statetable[now].next[c - 'a'])
				statetable[now].next[c - 'a'] = size++;
			now = statetable[now].next[c - 'a'];
		}
		statetable[now].cnt++;
	}


	void build()
	{
		statetable[0].fail = -1;
		que.push(0);

		while (que.size())
		{
			int u = que.front();
			que.pop();
			for (int i = 0;i < 26;i++)
			{
				if (statetable[u].next[i])
				{
					if (u == 0)
						statetable[statetable[u].next[i]].fail = 0;
					else 
					{
						int v = statetable[u].fail;
						while (v != -1)
						{
							if (statetable[v].next[i])
							{
								statetable[statetable[u].next[i]].fail = statetable[v].next[i];
								break;
							}
							v = statetable[v].fail;
						}
						if (v == -1)
							statetable[statetable[u].next[i]].fail = 0;
					}
					que.push(statetable[u].next[i]);
				}
			}
		}
	}

	int get(int u)//比方说现在有一个she已经在文本串上匹配成功,
	{           //那么再去找比she短但有公共后缀的模式串,像he肯定已经成功了,所以一直加就好。
		int res = 0;
		while (u)
		{
			res += statetable[u].cnt;
			statetable[u].cnt = 0;
			u = statetable[u].fail;
		}
		return res;
}
	int match(char*S)
	{
		int n = strlen(S);
		int res = 0, now = 0;
		for (int i = 0;i < n;i++)
		{
			char c = S[i];
			if (statetable[now].next[c - 'a'])
				now = statetable[now].next[c - 'a'];
			else
			{
				int p = statetable[now].fail;
				while (p != -1 && statetable[p].next[c - 'a'] == 0)
					p = statetable[p].fail;
				if (p == -1)
					now = 0;
				else
					now = statetable[p].next[c - 'a'];
			}
			if (statetable[now].cnt)
				res = res + get(now);
		}
		return res;
	}
}aho;

int t;
int n;
char S[MAX_N];
int main()
{
	scanf("%d", &t);
	while (t--)
	{
		aho.init();
		scanf("%d", &n);
		for (int i = 0;i < n;i++)
		{
			scanf("%s", S);
			aho.insert(S);
		}
		aho.build();
		scanf("%s", S);
		printf("%d\n", aho.match(S));
	}
	return 0;
}