[Luogu P3301] [BZOJ 3129] [SDOI2013]方程

洛谷传送门

BZOJ传送门

题目描述

给定方程

X1+X2+...+Xn=MX_1+X_2+... +X_n=M

我们对第1..n11..n_1个变量进行一些限制:

X1A1X_{1} \le A_{1}
X2A2X_{2} \le A_{2}

Xn1An1X_{n1} \le A_{n1}

我们对第n1+1..n1+n2n_1 + 1..n_1+n_2个变量进行一些限制:

Xn1+lAn1+1X_{n1+l} \ge A_{n1+1}
Xn1+2An1+2X_{n1+2} \ge A_{n1+2}

Xn1+n2An1+n2X_{n1+n2} \ge A_{n1+n2}

求:在满足这些限制的前提下,该方程正整数解的个数。答案可能很大,请输出对pp取模后的答案,也即答案除以pp的余数。

输入输出格式

输入格式:

输入含有多组数据。

第一行两个正整数TTppTT表示这个测试点内的数据组数,pp的含义见题目描述。

对于每组数据,第一行四个非负整数nnn1n_1n2n_2mm
第二行n1+n2n_1+n_2个正整数,表示A1...An1+n2A_1...A_{n_1+n_2}。请注意,如果n1+n2n_1+n_2等于00,那么这一行会成为一个空行。

输出格式:

TT行,每行一个正整数表示取模后的答案。

输入输出样例

输入样例#1:

3 10007
3 1 1 6
3 3
3 0 0 5
3 1 1 3
3 3

输出样例#1:

3
6
0

说明

【样例说明】 对于第一组数据,三组解为(1,3,2)(1,3,2)(1,4,1)(1,4,1),(2,3,1)(2,3,1) 对于第二组数据,六组解为(1,1,3)(1,1,3)(1,2,2)(1,2,2),(1,3,1)(1,3,1),(2,1,2)(2,1,2),(2,2,1)(2,2,1),(3,1,1)(3,1,1)

[Luogu P3301] [BZOJ 3129] [SDOI2013]方程

n109,n18,n28,m109,p437367875n \le 10^9 , n_1 \le 8 , n_2 \le 8 , m \le 10^9 ,p\le 437367875

对于100%的测试数据: T5,1A1..n1,n2m,n1+n2nT \le 5,1 \le A_1..n_1,n_2 \le m,n1+n2 \le n

解题分析

后面n2n_2个限制很好弄, 直接预先给AiA_i个就好了。

关键是前n1n_1个限制, 我们可以利用容斥原理, 先算出至少不满足第11个, 第22个…要求的, 然后算有两个要求不满足的… 最后得到不满足要求的方案数, 然后用总数去减就好了。

注意题目要求的是正整数解, 所以所有位置都要预留11, 然后再分配。

模数不为质数, 组合数取模需要用exLucasexLucas

代码如下:

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <cstdlib>
#include <algorithm>
#define R register
#define IN inline
#define W while
#define gc getchar()
#define ll long long
#define MX 1000500
int pri[36], prk[36];
ll seg[MX];
int tot, T, n, n1, n2, mod, m;
int lim1[10], lim2[10];
template <class T>
IN void in(T &x)
{
	x = 0; R char c = gc;
	for (; !isdigit(c); c = gc);
	for (;  isdigit(c); c = gc)
	x = (x << 1) + (x << 3) + c - 48;
}
/*IN ll fmul(ll a, ll b, ll MOD)
{
	static __int128 c, d;
	c = a, d = b;
	return (ll)(c * d % MOD);
}*/
IN ll fmul(ll a, ll b, ll MOD)
{
	ll ret = 0;
	if (a > b) std::swap(a, b);
	W (a)
	{
		if (a & 1) ret = (ret + b) % MOD;
		b = (b + b) % MOD, a >>= 1;
	}
	return ret;
}
IN ll fpow(ll base, R int tim, ll MOD)
{
	ll ret = 1;
	W (tim)
	{
		if (tim & 1) ret = fmul(ret, base, MOD);
		base = fmul(base, base, MOD), tim >>= 1;
	}
	return ret;
}
void exgcd(ll a, ll b, ll &x, ll &y)
{
	if (!b) return x = 1, y = 0, void();
	exgcd(b, a % b, x, y);
	ll buf = x; x = y, y = buf - a / b * y;
}
IN ll getinv(ll n, ll pk)
{
	ll x, y;
	exgcd(n, pk, x, y);
	return (x % pk + pk) % pk;
}
ll calc(ll n, ll p, ll pk)
{
	if (n <= 1) return 1;
	ll ret = 1;
	if (n >= pk) ret = fpow(seg[pk - 1], n / pk, pk);
	ret = ret * seg[n % pk] % pk;
	return ret * calc(n / p, p, pk) % pk;
}
IN ll C(ll n, ll m, R int p, R int pk)
{
	if (n < m) return 0;
	seg[0] = seg[1] = 1;
	for (R int i = 2; i < pk; ++i)
	{
		seg[i] = seg[i - 1];
		if (i % p) seg[i] = seg[i] * i % pk;
	}
	ll up = calc(n, p, pk);
	ll down1 = calc(m, p, pk), down2 = calc(n - m, p, pk);
	ll tim = 0;
	for (ll i = n / p; i; i /= p) tim += i;
	for (ll i = m / p; i; i /= p) tim -= i;
	for (ll i = (n - m) / p; i; i /= p) tim -= i;
	return up * getinv(down1, pk) % pk * getinv(down2, pk) % pk * fpow(p, tim, pk) % pk;
}
IN void init()
{
	ll tmp = mod, bd = std::sqrt(mod);
	for (R int i = 2; i <= bd; ++i)
	{
		if (!(tmp % i))
		{
			pri[++tot] = i;
			prk[tot] = 1;
			W (!(tmp % i)) prk[tot] *= i, tmp /= i;
		}
	}
	if (tmp > 1) pri[++tot] = tmp, prk[tot] = tmp;
}
IN ll Exlucas(ll n, ll m)
{
	ll ans = 0, mi;
	if (n < m) return 0;
	if (n == m || m == 0) return 1;
	for (R int i = 1; i <= tot; ++i)
	{
		mi = mod / prk[i];
		(ans += C(n, m, pri[i], prk[i]) * getinv(mi, prk[i]) % mod * mi % mod) %= mod;
	}
	return ans;
}
int main(void)
{
	int all, typ, need, tar, tmpm;
	ll ans;
	in(T), in(mod);
	init();
	W (T--)
	{
		in(n), in(n1), in(n2), in(m); ans = 0; tmpm = m;
		for (R int i = 1; i <= n1; ++i) in(lim1[i]);
		for (R int i = 1; i <= n2; ++i) in(lim2[i]), m -= lim2[i] - 1;
		m -= n;
		if (m < 0) {puts("0"); continue;}
		if (m == 0) {puts("1"); continue;}
		all = (1 << n1) - 1;
		for (R int i = 1; i <= all; ++i)
		{
			need = 0;
			typ = __builtin_popcount(i);
			for (R int j = 1; j <= n1; ++j)
			if ((1 << j - 1) & i) need += lim1[j];
			if (need > m) continue;
			tar = m - need;
			if (typ & 1) (ans += Exlucas(n + tar - 1, n - 1)) %= mod;
			else (ans -= Exlucas(n + tar - 1, n - 1)) %= mod;
		}
		printf("%lld\n", (Exlucas(n + m - 1, n - 1) - ans + mod) % mod);
	}
}