跳转至

2021-02-08-2月8日题解

img

题解

先进行LCA算法,计算出每个点的深度、到根节点的距离等,然后对五个点(分别为s1,s2,s3,s4,s5)进行处理:

  1. 将第s1加入图中,如图,增加的权值为dis[s1]

img

  1. 将s2加入图中,假如s1和s2的最小公共祖先为fa1即为lca(s1,s2),则增加的权值为dis[s2]-dis[fa1]

image-20210208133953025

  1. 将s3加入图中,假设s1和s3的最小公共祖先为fa2,s2和s3的最小公共最先为fa1,假设fa1的深度小于fa2的深度,则增加的权值为dis[s3]-dis[fa2]

img

  1. 重复第3过程,将s4,s5两个点加入到图当中,最后权值和减去所有点的最小公共祖先到根节点1的距离即为答案。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include <bits/stdc++.h>
using namespace std;
const int maxn=1e5+500;
struct node
{
    int to,nex,w;
} road[maxn*2];
int n,q,cnt;
int pre[maxn][32],head[maxn],depth[maxn];
int dis[maxn];
void add(int u,int v,int w)
{
    road[cnt].to=v;
    road[cnt].w=w;
    road[cnt].nex=head[u];
    head[u]=cnt++;
}
void dfs(int u,int fa)
{
    pre[u][0]=fa;
    depth[u]=depth[fa]+1;
    for(int i=1; (1<<i)<=depth[u]; i++) //倍增
        pre[u][i]=pre[pre[u][i-1]][i-1];
    for(int i=head[u]; ~i; i=road[i].nex)
    {
        int v=road[i].to;
        if(v!=fa)
        {
            dis[v]=dis[u]+road[i].w;
            dfs(v,u);
        }
    }
}
int lca(int u,int v)
{
    if(depth[u]<depth[v])
    {
        swap(u,v);
    }
    int i=-1,j;
    while((1<<(i+1))<=depth[u])
        i++;
    for(j=i; j>=0; j--)
    {
        if(depth[u]-(1<<j)>=depth[v])
        {
            u=pre[u][j];
        }
    }
    if(u==v)
        return u;
    for(int j=i; j>=0; j--)
    {
        if(pre[u][j]!=pre[v][j])
        {
            u=pre[u][j];
            v=pre[v][j];
        }
    }
    return pre[u][0];
}
int main()
{
    scanf("%d",&n);
    memset(head,-1,sizeof(head));
    memset(depth,0,sizeof(depth));
    cnt=0;
    for(int i=1; i<n; i++)
    {
        int u,v,w;
        scanf("%d %d %d",&u,&v,&w);
        add(u+1,v+1,w);
        add(v+1,u+1,w);
    }
    dis[1]=0;
    dfs(1,0);
    scanf("%d",&q);
    while(q--)
    {
        int point[10]={0};
        for(int i=1;i<=5;i++)
            scanf("%d",&point[i]),point[i]++;
        int ans=0;
        for(int i=1;i<=5;i++)
        {
            int dep=1;
            for(int j=1;j<i;j++)
            {
                int fa=lca(point[i],point[j]);
                if(depth[fa]>depth[dep])
                    dep=fa;
            }
            ans+=dis[point[i]];
            ans-=dis[dep];
        }
        int dep=point[1];
        for(int i=1;i<=5;i++)
            dep=lca(point[i],dep);
        ans-=dis[dep];
        cout<<ans<<endl;
    }
}
img 题解

该题点数较少,可以直接暴力做,但是dfs会超时,改用bfs即可,详细见代码。

代码
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
vector<ll>v[100500]; //存图用
ll vis[5050]={0};    ///记录距离用
ll to[5050]={0};     ///到该点的道路个数
ll sum_vis[5050]={0};  ///bfs统计用
long double val_all[5050]={0};  ///统计权值用
ll s,e;  ///开始结尾
ll sum_size=0;
struct node
{
    ll pos,val;
};
void bfs()
{
    queue<node>que;
    que.push({s,0});
    vis[s]=0;
    to[s]=1;
    while(!que.empty())
    {
        node now=que.front();
        que.pop();
        vis[now.pos]=min(vis[now.pos],now.val);
        for(ll i:v[now.pos])
        {
            if(vis[i]>now.val+1)
            {
                vis[i]=now.val+1;
                to[i]=to[now.pos];
                que.push({i,now.val+1});
            }
            else if(vis[i]==now.val+1)
            {
                to[i]+=to[now.pos];
            }
        }
    }
    sum_size=to[e];
}
void bfs1()
{
    queue<int>que;
    que.push(e);
    while(!que.empty())
    {
        int now=que.front();
        que.pop();
        sum_vis[vis[now]]+=to[now];
        for(int i:v[now])
        {
            if(vis[i]==vis[now]-1)
                que.push(i);
        }
    }
    que.push(e);
    while(!que.empty())
    {
        int now=que.front();
        que.pop();
        val_all[now]+=to[now]*1.0/sum_vis[vis[now]];   ///到该点路的条数占相同长度条数的比例
        for(int i:v[now])
        {
            if(vis[i]==vis[now]-1)
            {
                que.push(i);
            }
        }
    }
}
ll dfs(ll k,ll sum)
{
    if(k==e)
    {
        if(sum==vis[k])
            sum_size++;
        else if(sum<vis[k])
            vis[k]=sum,sum_size=1;
    }
    vis[k]=min(vis[k],sum);
    for(ll i:v[k])
    {
        if(vis[i]>=sum+1)
            dfs(i,sum+1);
    }
    return 0;
}
long double dfs1(ll k,ll sum)
{
    if(sum>vis[e])
        return 0;
    if(k==e)
    {
        val_all[k]+=1.0/sum_size;
        return 1.0/sum_size;
    }
    long double val=0;
    for(ll i:v[k])
    {
        if(vis[i]==vis[k]+1)
            val+=dfs1(i,sum+1);
    }
    val_all[k]+=val;
    return val;
}
int main()
{
    ll n,m;
    scanf("%lld%lld",&n,&m);
    for(ll i=1;i<=m;i++)
    {
        scanf("%lld%lld",&s,&e);
        v[s].push_back(e);
        v[e].push_back(s);
    }
    ll k;
    scanf("%lld",&k);
    for(ll i=1;i<=k;i++)
    {
        scanf("%lld%lld",&s,&e);
        memset(vis,inf,sizeof(vis));
        memset(to,0,sizeof(to));
        memset(sum_vis,0,sizeof(sum_vis));
        bfs(); //dfs(s,0)    ///这里dfs也可以得出正确答案,但是时间会超,qwq
        bfs1();//dfs1(s,0);   ///bfs和dfs均可,bfs更快一点,dfs更好理解一点,这里用dfs时间不会超
    }
    ll max1=0;
    for(ll i=0;i<n;i++)
    {
        if(val_all[i]>val_all[max1])
            max1=i;
    }
    cout<<max1<<endl;
}

image-20210208135020340

题解

该题目参照博客https://m-sea-blog.com/archives/2139,附一张图:

我们可以发现取余后的结果为周期为a的周期函数,任取一点k,如果(2 * k)%a <= k%a,则一定有k < a <= 2 * k,预先缩小范围然后二分即可得到答案。

代码
#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
using namespace std;
typedef long long ll;
char query(int x,int y)
{
    cout<<"? "<<x<<" "<<y<<endl;
    cout.flush();
    char ans;
    cin>>ans;
    return ans;
}
int main()
{
    while(1)
    {
        string s;
        cin>>s;
        if(s=="end")
            break;
        ll l,r;
        for(ll i=1;;i=i*2)
        {
            l=i,r=min(i*2,(ll)2e9);
            char ans=query(l,r);
            if(ans=='x')
            {
                l=i,r=min(i*2,(ll)1e9);
                break;
            }
        }
        while(l<r)
        {
            int mid=(l+r)/2;
            char ans=query(mid*2,mid);
            if(ans=='x')
                r=mid;
            else
                l=mid+1;
        }
        if(l==2)
        {
            char ans=query(2,1);
            if(ans=='x')
            {
                cout<<"! 1"<<endl;
                cout.flush();
                continue;
            }
            else
            {
                cout<<"! 2"<<endl;
                cout.flush();
                continue;
            }
        }
        cout<<"! "<<l<<endl;
        cout.flush();
    }
}