20231110 A题 多树 题解

 

题面(简化)

不用简化了,这题题面本来就很简短

给定 $k$ 个有 $n$ 个节点的树,对于每个点对 $(i,j),i,j\in[1,n]$ ,请输出在每棵树上的路径经过的点(含端点)的交集大小。

即 $\forall i,j\in [1,n],please~print:[k\in [1,n],k\in \forall i\to j]$

范围:$n,k≤500,n,k\in N^*;1≤u,v≤n$

题解

在做这题时,想到了11月7号做的题目(Link)。虽然这两道题看似风马牛不相及,但是却使用了一个相同的结论:

$x$ 在 $i$ 到 $j$ 的最短路径上的充要条件是:$dis(x,i)+dis(x,j)=dis(i,j)$(结论1)

这个结论非常显然,但却是下面一切的开端。根据它,我们可以推出:

$\forall x\in [1,n],dis(x,i)+dis(x,j)≥dis(i,j)$ (结论2),当且仅当 $x$ 处在 $i\to j$ 的路径上时取等;

这个结论更是简明,因为如果 $x$ 不在由 $i$ 到 $j$ 的路径上,那么路径 $i\to x\to j$ 肯定比从 $i$ 到 $j$ 的最短路径长。

好,现在我们对于 $k$ 棵树将结论2累加,又能推出什么?

$\forall x\in [1,n],\sum\limits_{s=1}^{k}diss+\sum\limits_{s=1}^{k}diss=\sum\limits_{s=1}^{k}diss$ (结论3),其中 $diss$ 表示在第 $s$ 棵树中节点 $m,n$ 之间的最短路径长度,当且仅当 $x$ 一直处在 $i\to j$ 的路径上时取等。

这个结论是本题的突破口。

对于每张图,我们可以 bfs 求出 $dis$ 数组,然后对这个数组求和得到 $powd$ 数组($powd(i,j)=\sum\limits_{s=1}^{k}diss$)。

然后,我们用 $n^2$ 的时间复杂度暴力枚举每对 $(i,j)$ ,再用 $n$ 的时间复杂度暴力枚举每个点,看看这个点是否满足结论3中的取等条件。若满足,则 $ans++$ 。最后输出 $ans$ 即可。

时间复杂度分析如下:

  1. $powd$ 数组的预处理为 $O(n^2k)$

  2. 最后的枚举复杂度为 $O(n^2*n)=O(n^3)$

所以总复杂度为 $O(n^2(n+k))$,可以通过。

其实赛场上写的时候忘记 bfs 全源最短路怎么写了,所以写了个 dijstra ,这样复杂度多了个 $logn$ ,但是卡几发波动也过去了…

注意:$dis$ 数组不能开三维,但其实第三维可以省略,具体见代码。

完整代码(赛后的bfs版)

#include<bits/stdc++.h>
using namespace std;
inline int read()
{
	int s=0,w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
	return s*w;
}
const int N=505;
int n,kkk,d[N][N],powd[N][N];//d数组就是题解部分说的dis数组
bool vis[N][N];
vector<int> g[N][N];
queue<int> q;
void bfs(int s,int k)
{
        memset(vis[k],0,sizeof vis[k]);
	for(int i=1;i<=n;i++) d[s][i]=INT_MAX;
	d[s][s]=0; 
        q.push(s);
	while(q.size()) 
        {
		int now=q.front();
                 q.pop();
		if(vis[k][now]) continue;
		vis[k][now]=true;
		for(auto v:g[k][now]) if(d[s][v]>d[s][now]+1) d[s][v]=d[s][now]+1,q.push(v);
	}
}
int main()
{
    freopen("trees.in","r",stdin);
    freopen("trees.out","w",stdout);
    n=read(),kkk=read();
    for(int i=1;i<=kkk;i++) for(int j=1,x,y;j<n;j++) x=read(),y=read(),g[i][x].emplace_back(y),g[i][y].emplace_back(x);
    for(int i=1;i<=kkk;i++)
    {
        for(int j=1;j<=n;j++) bfs(j,i);
        for(int j=1;j<=n;j++) for(int k=1;k<=n;k++) powd[j][k]+=d[j][k];
    } 
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            int ans=0;
            for(int k=1;k<=n;k++)  if(powd[k][i]+powd[k][j]==powd[i][j]) ans++;
            printf("%d ",ans);
        }
        putchar('\n');
    }
}