20231110 A题 多树 题解

 

题面(简化)

不用简化了,这题题面本来就很简短

给定 k 个有 n 个节点的树,对于每个点对 (i,j),i,j[1,n] ,请输出在每棵树上的路径经过的点(含端点)的交集大小。

i,j[1,n],please print:[k[1,n],kij]

范围:n,k500,n,kN;1u,vn

题解

在做这题时,想到了11月7号做的题目(Link)。虽然这两道题看似风马牛不相及,但是却使用了一个相同的结论:

xij 的最短路径上的充要条件是:dis(x,i)+dis(x,j)=dis(i,j)(结论1)

这个结论非常显然,但却是下面一切的开端。根据它,我们可以推出:

x[1,n],dis(x,i)+dis(x,j)dis(i,j) (结论2),当且仅当 x 处在 ij 的路径上时取等;

这个结论更是简明,因为如果 x 不在由 ij 的路径上,那么路径 ixj 肯定比从 ij 的最短路径长。

好,现在我们对于 k 棵树将结论2累加,又能推出什么?

$\forall x\in [1,n],\sum\limits_{s=1}^{k}diss+\sum\limits_{s=1}^{k}diss=\sum\limits_{s=1}^{k}diss$ (结论3),其中 $disssm,nxi\to j$ 的路径上时取等。

这个结论是本题的突破口。

对于每张图,我们可以 bfs 求出 dis 数组,然后对这个数组求和得到 powd 数组($powd(i,j)=\sum\limits_{s=1}^{k}diss$)。

然后,我们用 n2 的时间复杂度暴力枚举每对 (i,j) ,再用 n 的时间复杂度暴力枚举每个点,看看这个点是否满足结论3中的取等条件。若满足,则 ans++ 。最后输出 ans 即可。

时间复杂度分析如下:

  1. powd 数组的预处理为 O(n2k)

  2. 最后的枚举复杂度为 O(n2n)=O(n3)

所以总复杂度为 O(n2(n+k)),可以通过。

其实赛场上写的时候忘记 bfs 全源最短路怎么写了,所以写了个 dijstra ,这样复杂度多了个 logn ,但是卡几发波动也过去了…

注意:dis 数组不能开三维,但其实第三维可以省略,具体见代码。

完整代码(赛后的bfs版)

#include<bits/stdc++.h>
using namespace std;
inline int read()
{
	int s=0,w=1;
	char ch=getchar();
	while(ch<'0'||ch>'9')
	{
		if(ch=='-') w=-1;
		ch=getchar();
	}
	while(ch>='0'&&ch<='9') s=(s<<3)+(s<<1)+ch-'0',ch=getchar();
	return s*w;
}
const int N=505;
int n,kkk,d[N][N],powd[N][N];//d数组就是题解部分说的dis数组
bool vis[N][N];
vector<int> g[N][N];
queue<int> q;
void bfs(int s,int k)
{
        memset(vis[k],0,sizeof vis[k]);
	for(int i=1;i<=n;i++) d[s][i]=INT_MAX;
	d[s][s]=0; 
        q.push(s);
	while(q.size()) 
        {
		int now=q.front();
                 q.pop();
		if(vis[k][now]) continue;
		vis[k][now]=true;
		for(auto v:g[k][now]) if(d[s][v]>d[s][now]+1) d[s][v]=d[s][now]+1,q.push(v);
	}
}
int main()
{
    freopen("trees.in","r",stdin);
    freopen("trees.out","w",stdout);
    n=read(),kkk=read();
    for(int i=1;i<=kkk;i++) for(int j=1,x,y;j<n;j++) x=read(),y=read(),g[i][x].emplace_back(y),g[i][y].emplace_back(x);
    for(int i=1;i<=kkk;i++)
    {
        for(int j=1;j<=n;j++) bfs(j,i);
        for(int j=1;j<=n;j++) for(int k=1;k<=n;k++) powd[j][k]+=d[j][k];
    } 
    for(int i=1;i<=n;i++)
    {
        for(int j=1;j<=n;j++)
        {
            int ans=0;
            for(int k=1;k<=n;k++)  if(powd[k][i]+powd[k][j]==powd[i][j]) ans++;
            printf("%d ",ans);
        }
        putchar('\n');
    }
}