题面(简化)
不用简化了,这题题面本来就很简短
给定 $k$ 个有 $n$ 个节点的树,对于每个点对 $(i,j),i,j\in[1,n]$ ,请输出在每棵树上的路径经过的点(含端点)的交集大小。
即 $\forall i,j\in [1,n],please~print:[k\in [1,n],k\in \forall i\to j]$
范围:$n,k≤500,n,k\in N^*;1≤u,v≤n$
题解
在做这题时,想到了11月7号做的题目(Link)。虽然这两道题看似风马牛不相及,但是却使用了一个相同的结论:
$x$ 在 $i$ 到 $j$ 的最短路径上的充要条件是:$dis(x,i)+dis(x,j)=dis(i,j)$(结论1);
这个结论非常显然,但却是下面一切的开端。根据它,我们可以推出:
$\forall x\in [1,n],dis(x,i)+dis(x,j)≥dis(i,j)$ (结论2),当且仅当 $x$ 处在 $i\to j$ 的路径上时取等;
这个结论更是简明,因为如果 $x$ 不在由 $i$ 到 $j$ 的路径上,那么路径 $i\to x\to j$ 肯定比从 $i$ 到 $j$ 的最短路径长。
好,现在我们对于 $k$ 棵树将结论2累加,又能推出什么?
$\forall x\in [1,n],\sum\limits_{s=1}^{k}diss+\sum\limits_{s=1}^{k}diss=\sum\limits_{s=1}^{k}diss$ (结论3),其中 $diss$ 表示在第 $s$ 棵树中节点 $m,n$ 之间的最短路径长度,当且仅当 $x$ 一直处在 $i\to j$ 的路径上时取等。
这个结论是本题的突破口。
对于每张图,我们可以 bfs 求出 $dis$ 数组,然后对这个数组求和得到 $powd$ 数组($powd(i,j)=\sum\limits_{s=1}^{k}diss$)。
然后,我们用 $n^2$ 的时间复杂度暴力枚举每对 $(i,j)$ ,再用 $n$ 的时间复杂度暴力枚举每个点,看看这个点是否满足结论3中的取等条件。若满足,则 $ans++$ 。最后输出 $ans$ 即可。
时间复杂度分析如下:
-
$powd$ 数组的预处理为 $O(n^2k)$
-
最后的枚举复杂度为 $O(n^2*n)=O(n^3)$
所以总复杂度为 $O(n^2(n+k))$,可以通过。
其实赛场上写的时候忘记 bfs 全源最短路怎么写了,所以写了个 dijstra ,这样复杂度多了个 $logn$ ,但是卡几发波动也过去了…
注意:$dis$ 数组不能开三维,但其实第三维可以省略,具体见代码。
完整代码(赛后的bfs版)
#include<bits/stdc++.h>
using namespace std;
inline int read()
{
int s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9')
{
if(ch=='-') w=-1;
ch=getchar();
}
while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
return s*w;
}
const int N=505;
int n,kkk,d[N][N],powd[N][N];//d数组就是题解部分说的dis数组
bool vis[N][N];
vector<int> g[N][N];
queue<int> q;
void bfs(int s,int k)
{
memset(vis[k],0,sizeof vis[k]);
for(int i=1;i<=n;i++) d[s][i]=INT_MAX;
d[s][s]=0;
q.push(s);
while(q.size())
{
int now=q.front();
q.pop();
if(vis[k][now]) continue;
vis[k][now]=true;
for(auto v:g[k][now]) if(d[s][v]>d[s][now]+1) d[s][v]=d[s][now]+1,q.push(v);
}
}
int main()
{
freopen("trees.in","r",stdin);
freopen("trees.out","w",stdout);
n=read(),kkk=read();
for(int i=1;i<=kkk;i++) for(int j=1,x,y;j<n;j++) x=read(),y=read(),g[i][x].emplace_back(y),g[i][y].emplace_back(x);
for(int i=1;i<=kkk;i++)
{
for(int j=1;j<=n;j++) bfs(j,i);
for(int j=1;j<=n;j++) for(int k=1;k<=n;k++) powd[j][k]+=d[j][k];
}
for(int i=1;i<=n;i++)
{
for(int j=1;j<=n;j++)
{
int ans=0;
for(int k=1;k<=n;k++) if(powd[k][i]+powd[k][j]==powd[i][j]) ans++;
printf("%d ",ans);
}
putchar('\n');
}
}