2021-02-08-2月8日题解
题解¶
先进行LCA算法,计算出每个点的深度、到根节点的距离等,然后对五个点(分别为s1,s2,s3,s4,s5)进行处理:
- 将第s1加入图中,如图,增加的权值为dis[s1]
- 将s2加入图中,假如s1和s2的最小公共祖先为fa1即为lca(s1,s2),则增加的权值为dis[s2]-dis[fa1]
- 将s3加入图中,假设s1和s3的最小公共祖先为fa2,s2和s3的最小公共最先为fa1,假设fa1的深度小于fa2的深度,则增加的权值为dis[s3]-dis[fa2]
- 重复第3过程,将s4,s5两个点加入到图当中,最后权值和减去所有点的最小公共祖先到根节点1的距离即为答案。
代码¶
#include<cstdio>
#include<cstring>
#include<algorithm>
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+500;
struct node
{
int to,nex,w;
} road[maxn*2];
int n,q,cnt;
int pre[maxn][32],head[maxn],depth[maxn];
int dis[maxn];
void add(int u,int v,int w)
{
road[cnt].to=v;
road[cnt].w=w;
road[cnt].nex=head[u];
head[u]=cnt++;
}
void dfs(int u,int fa)
{
pre[u][0]=fa;
depth[u]=depth[fa]+1;
for(int i=1; (1<<i)<=depth[u]; i++) //倍增
pre[u][i]=pre[pre[u][i-1]][i-1];
for(int i=head[u]; ~i; i=road[i].nex)
{
int v=road[i].to;
if(v!=fa)
{
dis[v]=dis[u]+road[i].w;
dfs(v,u);
}
}
}
int lca(int u,int v)
{
if(depth[u]<depth[v])
{
swap(u,v);
}
int i=-1,j;
while((1<<(i+1))<=depth[u])
i++;
for(j=i; j>=0; j--)
{
if(depth[u]-(1<<j)>=depth[v])
{
u=pre[u][j];
}
}
if(u==v)
return u;
for(int j=i; j>=0; j--)
{
if(pre[u][j]!=pre[v][j])
{
u=pre[u][j];
v=pre[v][j];
}
}
return pre[u][0];
}
int main()
{
scanf("%d",&n);
memset(head,-1,sizeof(head));
memset(depth,0,sizeof(depth));
cnt=0;
for(int i=1; i<n; i++)
{
int u,v,w;
scanf("%d %d %d",&u,&v,&w);
add(u+1,v+1,w);
add(v+1,u+1,w);
}
dis[1]=0;
dfs(1,0);
scanf("%d",&q);
while(q--)
{
int point[10]={0};
for(int i=1;i<=5;i++)
scanf("%d",&point[i]),point[i]++;
int ans=0;
for(int i=1;i<=5;i++)
{
int dep=1;
for(int j=1;j<i;j++)
{
int fa=lca(point[i],point[j]);
if(depth[fa]>depth[dep])
dep=fa;
}
ans+=dis[point[i]];
ans-=dis[dep];
}
int dep=point[1];
for(int i=1;i<=5;i++)
dep=lca(point[i],dep);
ans-=dis[dep];
cout<<ans<<endl;
}
}
题解¶
该题点数较少,可以直接暴力做,但是dfs会超时,改用bfs即可,详细见代码。
代码¶
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
vector<ll>v[100500]; //存图用
ll vis[5050]={0}; ///记录距离用
ll to[5050]={0}; ///到该点的道路个数
ll sum_vis[5050]={0}; ///bfs统计用
long double val_all[5050]={0}; ///统计权值用
ll s,e; ///开始结尾
ll sum_size=0;
struct node
{
ll pos,val;
};
void bfs()
{
queue<node>que;
que.push({s,0});
vis[s]=0;
to[s]=1;
while(!que.empty())
{
node now=que.front();
que.pop();
vis[now.pos]=min(vis[now.pos],now.val);
for(ll i:v[now.pos])
{
if(vis[i]>now.val+1)
{
vis[i]=now.val+1;
to[i]=to[now.pos];
que.push({i,now.val+1});
}
else if(vis[i]==now.val+1)
{
to[i]+=to[now.pos];
}
}
}
sum_size=to[e];
}
void bfs1()
{
queue<int>que;
que.push(e);
while(!que.empty())
{
int now=que.front();
que.pop();
sum_vis[vis[now]]+=to[now];
for(int i:v[now])
{
if(vis[i]==vis[now]-1)
que.push(i);
}
}
que.push(e);
while(!que.empty())
{
int now=que.front();
que.pop();
val_all[now]+=to[now]*1.0/sum_vis[vis[now]]; ///到该点路的条数占相同长度条数的比例
for(int i:v[now])
{
if(vis[i]==vis[now]-1)
{
que.push(i);
}
}
}
}
ll dfs(ll k,ll sum)
{
if(k==e)
{
if(sum==vis[k])
sum_size++;
else if(sum<vis[k])
vis[k]=sum,sum_size=1;
}
vis[k]=min(vis[k],sum);
for(ll i:v[k])
{
if(vis[i]>=sum+1)
dfs(i,sum+1);
}
return 0;
}
long double dfs1(ll k,ll sum)
{
if(sum>vis[e])
return 0;
if(k==e)
{
val_all[k]+=1.0/sum_size;
return 1.0/sum_size;
}
long double val=0;
for(ll i:v[k])
{
if(vis[i]==vis[k]+1)
val+=dfs1(i,sum+1);
}
val_all[k]+=val;
return val;
}
int main()
{
ll n,m;
scanf("%lld%lld",&n,&m);
for(ll i=1;i<=m;i++)
{
scanf("%lld%lld",&s,&e);
v[s].push_back(e);
v[e].push_back(s);
}
ll k;
scanf("%lld",&k);
for(ll i=1;i<=k;i++)
{
scanf("%lld%lld",&s,&e);
memset(vis,inf,sizeof(vis));
memset(to,0,sizeof(to));
memset(sum_vis,0,sizeof(sum_vis));
bfs(); //dfs(s,0) ///这里dfs也可以得出正确答案,但是时间会超,qwq
bfs1();//dfs1(s,0); ///bfs和dfs均可,bfs更快一点,dfs更好理解一点,这里用dfs时间不会超
}
ll max1=0;
for(ll i=0;i<n;i++)
{
if(val_all[i]>val_all[max1])
max1=i;
}
cout<<max1<<endl;
}
题解¶
该题目参照博客https://m-sea-blog.com/archives/2139,附一张图:
我们可以发现取余后的结果为周期为a的周期函数,任取一点k,如果(2 * k)%a <= k%a,则一定有k < a <= 2 * k,预先缩小范围然后二分即可得到答案。
代码¶
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
char query(int x,int y)
{
cout<<"? "<<x<<" "<<y<<endl;
cout.flush();
char ans;
cin>>ans;
return ans;
}
int main()
{
while(1)
{
string s;
cin>>s;
if(s=="end")
break;
ll l,r;
for(ll i=1;;i=i*2)
{
l=i,r=min(i*2,(ll)2e9);
char ans=query(l,r);
if(ans=='x')
{
l=i,r=min(i*2,(ll)1e9);
break;
}
}
while(l<r)
{
int mid=(l+r)/2;
char ans=query(mid*2,mid);
if(ans=='x')
r=mid;
else
l=mid+1;
}
if(l==2)
{
char ans=query(2,1);
if(ans=='x')
{
cout<<"! 1"<<endl;
cout.flush();
continue;
}
else
{
cout<<"! 2"<<endl;
cout.flush();
continue;
}
}
cout<<"! "<<l<<endl;
cout.flush();
}
}