book
归档: solution 
flag

做法很巧妙。

Problem

给你一棵树,和数条铁路的起点和终点,求有多少个火车会在路上相遇。

每个火车的速度相同,起终点也算相遇,但经过终点后火车消失。

Solution

首先如果在边上相遇一定是在中点,所以可以给每条边单独开一个点转化问题。

然后将路径分为上行和下行,lca算在上行,求上下相交和上上相交即可。

上上相交可以线段树合并,上下相交可以对每条重链分别处理,以重链的深度为 $x$ 轴,离起点的距离为 $y$ 轴,得到每条路径在每个重链的函数和值域,那么问题转化成求 $k=1/-1$ 的直线的交点数,可以直接离散到 $y$ 轴后扫描线求交点个数。

注意上上相交的起点重合不要漏算。

Code

// Code by ajcxsu
// Problem: correction

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
template<typename T> inline void gn(T &x) {
    char ch=getchar(); x=0;
    while(!isdigit(ch)) ch=getchar();
    while(isdigit(ch)) x=x*10+ch-'0', ch=getchar();
}
template<typename T, typename ...Args> inline void gn(T &x, Args &...args) { gn(x), gn(args...); }

const int N=2e5+10;
int h[N], to[N<<1], nexp[N<<1], p=1;
inline void ins(int a, int b) { nexp[p]=h[a], h[a]=p, to[p]=b, p++; }

int siz[N], fa[N], top[N], son[N], dep[N], dfn[N], len[N], idx;
int bot[N];
void dfs1(int x, int k) {
    dep[x]=k, siz[x]=1;
    for(int u=h[x];u;u=nexp[u])
        if(!dep[to[u]]) {
            fa[to[u]]=x, dfs1(to[u], k+1), siz[x]+=siz[to[u]];
            if(siz[son[x]]<siz[to[u]]) son[x]=to[u];
        }
}
void dfs2(int x, int t) {
    top[x]=t, dfn[x]=++idx;
    len[t]++;
    if(son[x]) dfs2(son[x], t);
    else bot[t]=x;
    for(int u=h[x];u;u=nexp[u])
        if(!dfn[to[u]]) dfs2(to[u], to[u]);
}

int lca(int s, int t, int mode=0) {
    while(top[s]!=top[t]) {
        if(dep[top[s]]<dep[top[t]]) swap(s, t);
        if(mode && fa[top[s]]==t) return top[s];
        s=fa[top[s]];
    }
    if(mode) return dep[s]<dep[t]?son[s]:son[t];
    return dep[s]<dep[t]?s:t;
}

struct Seg { int k, b, l, r; } ;
vector<Seg> td[N];

struct Node *nil;
struct Node {
    int v, t;
    Node *ls, *rs;
    Node () { v=t=0; ls=rs=nil; }
} *nd[N], *nd2[N];
void ini() { nil=new Node(), nil->ls=nil->rs=nil, fill(nd, nd+N, nil), fill(nd2, nd2+N, nil); }
ll updata(Node *&x, int l, int r, int d, int v) {
    if(x==nil) x=new Node();
    x->v+=v; int mid=(l+r)>>1;
    if(l==r) { x->t=l; return x->v-v; }
    if(d<=mid) return updata(x->ls, l, mid, d, v);
    else return updata(x->rs, mid+1, r, d, v);
}
void Merge(Node *&x, Node *a, Node *b, ll &cnt, int mode=0) {
    if(a==nil) { x=b; return; }
    if(b==nil) { x=a; return; }
    if(x==nil) x=new Node();
    if(!mode && a->t) cnt+=1ll*a->v*b->v;
    x->v=a->v+b->v;
    Merge(x->ls, a->ls, b->ls, cnt, mode);
    Merge(x->rs, a->rs, b->rs, cnt, mode);
    delete b;
}

int n;
void modify(int x, int k, int l, int r, int ry) {
    int b=ry-(r-dep[x])*k;
    td[x].push_back({k, b, l-dep[x], r-dep[x]});
}
void modifyup(int s, int t, ll &ans) {
    if(dep[s]<dep[t]) swap(s, t);
    /* ↑ */
    ans+=updata(nd[s], 1, n, dep[s], 1);
    updata(nd2[t], 1, n, dep[s], -1);
    int bg=dep[s];
    /* ↓ */
    while(top[s]!=top[t]) {
        if(dep[top[s]]<dep[top[t]]) swap(s, t);
        modify(top[s], -1, dep[top[s]], dep[s], bg-dep[s]);
        s=fa[top[s]];
    }
    if(dep[s]>dep[t]) swap(s, t);
    modify(top[s], -1, dep[s], dep[t], bg-dep[t]);
}
void modifydown(int s, int t, int rua) {
    int bg=min(dep[s], dep[t]);
    while(top[s]!=top[t]) {
        if(dep[top[s]]<dep[top[t]]) swap(s, t);
        modify(top[s], 1, dep[top[s]], dep[s], dep[s]-bg+rua);
        s=fa[top[s]];
    }
    if(dep[s]>dep[t]) swap(s, t);
    modify(top[s], 1, dep[s], dep[t], dep[t]-bg+rua);
}

void dfs(int x, ll &cnt) {
    for(int u=h[x];u;u=nexp[u])
        if(to[u]!=fa[x]) {
            dfs(to[u], cnt);
            Merge(nd[x], nd[x], nd[to[u]], cnt);
        }
    Merge(nd[x], nd[x], nd2[x], cnt, 1);
}


#define lowbit(x) x&-x
namespace BIT {
    const int V=N<<2;
    int C[V], stk[V][2], t;
    void updata(int x, int v, int mode=0) {
        if(!mode) stk[++t][0]=x, stk[t][1]=v;
        while(x<V) C[x]+=v, x+=lowbit(x);
    }
    int query(int x) {
        int ret=0;
        while(x) ret+=C[x], x-=lowbit(x);
        return ret;
    }
    void clr() {
        while(t) updata(stk[t][0], -stk[t][1], 1), t--;
    }
}

struct Query { int t, v, x, l, r; } ;
Query tmp[N<<1]; int t;
bool cmp(const Query &a, const Query &b) { return a.x==b.x?a.t<b.t:a.x<b.x; }
ll count(int x) {
    t=0;
    for(Seg y:td[x]) if(y.k==1) {
        tmp[++t]={0, 1, 2*y.l+y.b, y.b+(N<<1)};
        tmp[++t]={2, -1, 2*y.r+y.b, y.b+(N<<1)};
    }
    else {
        tmp[++t]={1, 0, y.b, y.b-2*y.l+(N<<1), y.b-2*y.r+(N<<1)};
    }
    sort(tmp+1, tmp+1+t, cmp);
    ll ret=0; BIT::clr();
    for(int i=1; i<=t; i++) if(tmp[i].t==1)
        ret+=BIT::query(tmp[i].l)-BIT::query(tmp[i].r-1);
    else 
        BIT::updata(tmp[i].l, tmp[i].v);
    return ret;
}

int main() {
    int u, v;
    gn(n);
    for(int i=1; i<n; i++) {
        gn(u, v);
        ins(u, i+n), ins(i+n, u);
        ins(v, i+n), ins(i+n, v);
    }
    ini(); dfs1(1, 1), dfs2(1, 1);
    ll ans=0;
    int m; gn(m);
    for(int i=1; i<=m; i++) {
        gn(u, v); int l=lca(u, v);
        if(l==v) modifyup(u, v, ans);
        else {
            int l2=lca(l, v, 1);
            modifyup(u, l, ans);
            modifydown(v, l2, dep[u]-dep[l]+1);
        }
    }
    dfs(1, ans);
    for(int i=1; i<=2*n-1; i++) if(top[i]==i) ans+=count(i);
    printf("%lld\n", ans);
    return 0;
}
navigate_before navigate_next