[BZOJ4558] [JLoi2016]方
JLOI2016 方
题目描述
Description
上帝说,不要圆,要方,于是便有了这道题。由于我们应该方,而且最好能够尽量方,所以上帝派我们来找正方形
上帝把我们派到了一个有N行M列的方格图上,图上一共有(N+1)×(M+1)个格点,我们需要做的就是找出这些格点形
成了多少个正方形(换句话说,正方形的四个顶点都是格点)。但是这个问题对于我们来说太难了,因为点数太多
了,所以上帝删掉了这(N+1)×(M+1)中的K个点。既然点变少了,问题也就变简单了,那么这个时候这些格点组成
了多少个正方形呢?
Input
第一行三个整数 N, M, K, 代表棋盘的行数、 列数和不能选取的顶点个数。 保证 N, M >= 1, K <=(N + 1) ×
(M + 1)。约定每行的格点从上到下依次用整数 0 到 N 编号,每列的格点依次用 0到 M 编号。接下来 K 行,每
行两个整数 x,y 代表第 x 行第 y 列的格点被删掉了。保证 0 <=x <=N<=10^6, 0 <=y<=M<=10^6,K<=2*1000且不
会出现重复的格点。
Output
仅一行一个正整数, 代表正方形个数对 100000007( 10^8 + 7) 取模之后的值
Sample Input
2 2 4
1 0
1 2
0 1
2 1
Sample Output
1
题解
题意
N行M列的网格图上,被禁止使用了 K 个不重复的交叉点,可以使用其余交叉点,询问有多少种方案,选出四个点,可以构成正方形。
正方形的边可以不与网格线平行,即可以倾斜。
N,M <= 1e6, K <= 1e3
解法
容斥原理。
此处称边与坐标轴平行的正方形为直正方形;
边与坐标轴不平行的正方形为斜正方形。
记x为行号,y为列号
- 如图,无禁止点,所有点都可以使用时,直正方形个数有规律。
- 枚举正方形边长k,对于合法的左上角顶点(x,y),0<=x 且 x+k<=n,0<=y 且 y+k<=m。
则边长为k的直正方形个数为(n-k+1) * (m-k+1) -
如图,任意斜正方形,必被包含于一个大直正方形。
四个端点皆在大正方形上,记相邻端点坐标的差 x=|x1-x2|, y=|y1-y2|,
大正方形边长k=x+y;正方形种类(包含正、斜)共k种,不同正方形的端点不重合。 -
忽略禁止点的限制,可以得知使用了>=0个禁止交叉点的正方形方案数ans0,
ans0=sigma((n-k+1) * (m-k+1) * k ) (k=1…min(n,m)) -
可以根据容斥原理计算答案,
记ansk=使用了>=k个禁止交叉点的正方形方案数
使用0个禁止点的合法方案数 = ans0 - ans1 + ans2 - ans3 + ans4 -
计算ans1
如图,由于一个固定端点的直正方形中,边界上固定的一点,只属于一种唯一的斜(直)正方形,即一个直正方形贡献为1。所以,只需统计以红色点为正方形边界上一点的直正方形个数。
不妨分跨两个象限(绿),属于单个象限(蓝)两类。
可以分四面(边)、四角进行统计。
不好处理必过某条直线的限制,不妨忽略,最后减去未过直线的方案数,
即四角的方案数。
边:h<=l+r 计算等差数列
边:h>l 或 h>r 减去多余部分,计算等差数列
角:
合并答案,角必被统计2次 - 计算ans2,ans3,ans4
只要枚举两个禁止点,即可得知包含该两点的三种正方形:
AB边长2种
AB对角线1种
可以二分查找或做哈希,从而判定点是否存在于先前元素的集合中,统计正方形不可法端点数,从而更新ans3,ans4。
枚举禁止点时保证编号有序,防止重复。
由于3点间的3条边,ans3重复了3次
由于4点间的6条边,ans4重复了6次
代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<queue>
#include<vector>
#include<map>
using namespace std;
#define rep(i,a,b) for(int i=a;i<=b;i++)
#define dow(i,a,b) for(int i=a;i>=b;i--)
#define tab(i,u) for(int i=head[u];i!=-1;i=e[i].next)
#define cls(a,x) memset(a,x,sizeof(a))
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const int K = 2e3 + 10, mod = 1e8 + 7;
int n,m,k;
struct abcd {
int x,y;
abcd(int _x=0,int _y=0):x(_x),y(_y){}
} a[K];
bool operator < (abcd a,abcd b) { if(a.x!=b.x) return a.x<b.x; else return a.y<b.y; }
bool operator == (abcd a,abcd b) { return a.x==b.x && a.y==b.y; }
void out(abcd p) { printf("%d %d\n",p.x,p.y); }
bool find(abcd key) {
int l=1, r=k, mid;
while(l<=r) {
mid=(l+r)>>1;
if(a[mid]==key) return true;
if(key<a[mid]) r=mid-1;
else l=mid+1;
}
return false;
}
ll ans0,ans1,ans2,ans3,ans4,ans;
void cal(int l,int r,int h) {
int z=min(l+r,h);
ans1+=(ll)(2 + z+1)*z/2; // 2 + 3 + ... + z+1
if(l<z) ans1-=(ll)(z-l)*(z-l+1)/2; // 1 + 2 + ... + z-l
if(r<z) ans1-=(ll)(z-r)*(z-r+1)/2; // 1 + 2 + ... + z-r
ans1%=mod;
}
void calc1(int u,int d,int l,int r) {
cal(u,d,l); cal(u,d,r);
cal(l,r,u); cal(l,r,d);
ans1-=min(u,l); ans1-=min(u,r);
ans1-=min(d,l); ans1-=min(d,r);
ans1%=mod;
}
bool inlaw(abcd p) {
return 0<=p.x && p.x<=n && 0<=p.y && p.y<=m;
}
void check(abcd p3, abcd p4) {
if(!inlaw(p3) || !inlaw(p4)) return;
int cnt=0;
if(find(p3)) ++cnt;
if(find(p4)) ++cnt;
++ans2;
if(cnt>0) ++ans3;
if(cnt>1) ++ans4, ++ans3;
}
void Solve() {
scanf("%d%d%d",&n,&m,&k);
rep(i,1,k) scanf("%d%d",&a[i].x,&a[i].y);
sort(a+1,a+1+k);
// rep(i,1,k) out(a[i]); puts("");
ans0=0;
int len=min(n,m);
rep(d,1,len) {
int nn=n-d+1;
int mm=m-d+1;
ans0 += (ll)nn * mm % mod * d % mod;
ans0 %= mod;
}
ans1=0;
rep(i,1,k) calc1(a[i].x,n-a[i].x,a[i].y,m-a[i].y);
ans2=0; ans3=0; ans4=0;
rep(i,1,k) {
rep(j,i+1,k) {
abcd p0,p1,p2,p3,p4;
int x,y;
p1=a[i]; p2=a[j];
x=p2.x-p1.x; y=p2.y-p1.y;
p3=abcd(p2.x+y, p2.y-x);
p4=abcd(p1.x+y, p1.y-x);
check(p3,p4);
// out(p1); out(p2); out(p3); out(p4); puts("");
// if(inlaw(p3) && inlaw(p4)) { out(p1); out(p2); out(p3); out(p4); puts(""); }
p1=a[j]; p2=a[i];
x=p2.x-p1.x; y=p2.y-p1.y;
p3=abcd(p2.x+y, p2.y-x);
p4=abcd(p1.x+y, p1.y-x);
check(p3,p4);
// out(p1); out(p2); out(p3); out(p4); puts("");
// if(inlaw(p3) && inlaw(p4)) { out(p1); out(p2); out(p3); out(p4); puts(""); }
p1=a[i]; p2=a[j];
int dx=p1.x-p2.x, dy=p1.y-p2.y;
if (!((abs(dx) + abs(dy)) & 1)) {
x = (dx - dy) >> 1, y = (dx + dy) >> 1;
p3=abcd(p1.x-x,p1.y-y);
p4=abcd(p2.x+x,p2.y+y);
check(p3,p4);
}
}
}
// printf("%lld %lld %lld %lld %lld\n",ans0,ans1,ans2,ans3/3,ans4/6);
ans=ans0-ans1+ans2-ans3/3+ans4/6;
ans%=mod; if(ans<0) ans+=mod;
printf("%lld\n",ans);
}
int main() {
freopen("square.in","r",stdin);
freopen("square.out","w",stdout);
Solve();
return 0;
}