loj 6031. 「雅礼集训 2017 Day1」字符串

题意:

loj 6031. 「雅礼集训 2017 Day1」字符串

题解:

因为保证qk&lt;105qk&lt;10^5,所以对于q<k,即q&lt;300q&lt;300的点,直接sam+倍增O(mqlogn)O(mqlogn)然后就有90分了
对于k&lt;300k&lt;300,暴力枚举每个子串,然后在vector上二分,看下出现了多少次即可。
code

#include<map>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
#define mp make_pair
#define LL long long
using namespace std;
struct SAM{
	int fail,max,a[26];
}sam[200010];int tot=1,root=1,tail=1;
int ri[200010];
LL ans;
map<pair<int,int>,int > t;
int cnt=0;
vector<int> vec[100010];
struct node{
	int y,next;
}a[200010];int len=0,last[200010];

void ins(int x,int y)
{
	a[++len].y=y;
	a[len].next=last[x];last[x]=len;
}
void addsam(int c)
{
	int p=tail,np=++tot;sam[np].max=sam[p].max+1;
	ri[np]=1;
	for(;p&&!sam[p].a[c];p=sam[p].fail) sam[p].a[c]=np;
	if(!p) sam[np].fail=root;
	else
	{
		int q=sam[p].a[c];
		if(sam[q].max==sam[p].max+1) sam[np].fail=q;
		else
		{
			int nq=++tot;sam[nq]=sam[q];
			sam[nq].max=sam[p].max+1;
			sam[np].fail=sam[q].fail=nq;
			for(;p&&sam[p].a[c]==q;p=sam[p].fail) sam[p].a[c]=nq;
		}
	}
	tail=np;
}
int n,m,q,k,fa[200010][20],dep[200010];
char s[100010];
void dfs(int x,int Fa)
{
	dep[x]=dep[Fa]+1;fa[x][0]=Fa;
	for(int i=1;(1<<i)<=dep[x];i++)
		fa[x][i]=fa[fa[x][i-1]][i-1];
	for(int i=last[x];i;i=a[i].next)
	{
		int y=a[i].y;
		dfs(y,x);ri[x]+=ri[y];
	}
}
int pre[100010],L[100010];
struct query{
	int l,r;
}op[100010];
void solve(int l,int r)
{
	int len=r-l+1,x=pre[r];
	//printf("ok %d %d\n",sam[x].max,len);
	if(L[r]<len) return;
	for(int i=18;i>=0;i--)
		if((1<<i)<=dep[x]&&sam[fa[x][i]].max>=len) x=fa[x][i];
	//printf("finish %d %d %d %d\n",l,r,ri[x],L[r]);
	ans+=ri[x];
}
int main()
{
	scanf("%d %d %d %d",&n,&m,&q,&k);
	scanf("%s",s+1);
	for(int i=1;i<=m;i++) scanf("%d %d",&op[i].l,&op[i].r),op[i].l++,op[i].r++;
	for(int i=1;i<=n;i++) addsam(s[i]-'a');
	for(int i=2;i<=tot;i++) ins(sam[i].fail,i);
	dep[0]=-1;dfs(1,0);
	if(q<=k)
	{
		while(q--)
		{
			scanf("%s",s+1);
			int l,r;scanf("%d %d",&l,&r);l++;r++;
			int x=root;
			for(int i=1;i<=k;i++)
			{
				int c=s[i]-'a';L[i]=L[i-1];
				while(x!=root&&!sam[x].a[c]) x=sam[x].fail,L[i]=sam[x].max;//printf("x:%d\n",x);
				if(sam[x].a[c]) x=sam[x].a[c],L[i]++;
				pre[i]=x;
			}
			//printf("L:");for(int i=1;i<=k;i++) printf("%d ",L[i]);printf("\n");
			ans=0;
			for(int i=l;i<=r;i++) solve(op[i].l,op[i].r);//printf("now:%d %lld\n",i,ans);
			printf("%lld\n",ans);
		}
	}
	else
	{
		for(int i=1;i<=m;i++)
		{
			int l=op[i].l,r=op[i].r;
			if(!t[mp(l,r)]) t[mp(l,r)]=++cnt;
			int x=t[mp(l,r)];
			vec[x].push_back(i);
		}
		while(q--)
		{
			scanf("%s",s+1);
			int l,r;scanf("%d %d",&l,&r);l++;r++;
			ans=0;
			for(int L=1;L<=k;L++)
				for(int R=L;R<=k;R++)
				{
					if(!t[mp(L,R)]) continue;
					int x=root;bool flag=true;
					for(int i=L;i<=R;i++)
					{
						int c=s[i]-'a';
						if(!sam[x].a[c]) {flag=false;break;}
						x=sam[x].a[c];
					}
					if(!flag) continue;
					int c=t[mp(L,R)];
					vector<int> :: iterator k1=lower_bound(vec[c].begin(),vec[c].end(),l);
					vector<int> :: iterator k2=upper_bound(vec[c].begin(),vec[c].end(),r);
					//printf("vec %d:",c);for(int i=0;i<vec[c].size();i++) printf("%d ",vec[c][i]);printf("\n");
					//printf("%d %d %d %d\n",L,R,ri[x],k2-k1);
					if(k1<k2) ans+=(LL)(k2-k1)*ri[x];
				}
			printf("%lld\n",ans);
		}
	}
}