Expected diameter of a tree | Codeforces #411

コメントつきコード

// 頂点の数、辺の数、連結成分のID、最も遠い点までの距離、連結成分の直径
int N, M, Q, cmp[100001], far[100001], nc, dia[100001];
// グラフ、連結成分ごとに最も遠い点までの距離をすべてもってソートしたもの
vi G[100001], F[100001];
// その部分和
vll S[100001];
// メモ
map<pii, ll> DS;

int bfs(int root, int c) {
    queue<tuple<int,int,int> > q;
    q.push({ root,-1, 0});
    int ma = -1, fa = 0;
    while (sz(q)) {
        int u, p, d;
        tie(u, p, d) = q.top();
        q.pop();
        cmp[u] = c;
        if (d > ma) {
            fa = u;
            ma = d;
        }
        /*
        頂点uから最も遠い点までの距離を求める。
        */
        smax(far[u], d);
        each(v, G[u])if (v != p) {
            q.push({ v,u,d + 1 });
        }
    }
    return fa;
}

ll solve(int U, int V) {
    if (DS.count({ U,V }))return DS[{U, V}];
    int madi = max(dia[U], dia[V]);
    ll &res = DS[{U, V}];
    /*
    ここの最悪ケースを考えてみよう。
    すべての連結成分の大きさが√nのとき
    √n個のUについてxは√N個あり、相手のVがc√N個
    中で二部探索しているので
    O(n^(3/2)log(n))
    **/
    each(x, F[U]) {
        /*
        madi > x + y
        連結しても直径はmadiのまま
        madi <= x + yのとき
        連結すると直径は
        x + y + 1
        y ∈ F[V]
        最小のyを求める
        madi - x <= y
        */
        int k = (int)distance(F[V].begin(), lower_bound(all(F[V]), madi - x));
        res += (ll)k*madi;
        // madi<=x+y
        // x+y+1のx+1の部分を足す
        res += (ll)(x + 1)*(sz(F[V]) - k);
        // x+y+1のyを足す
        res += S[V].back() - S[V][k];
    }
    return res;
}

int main(){
    cin >> N >> M >> Q;
    rep(i, M) {
        int u, v;
        scanf("%d%d", &u, &v);
        --u; --v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    MEM(cmp, -1);

    /*
    各木について直径を求める。
    また、頂点vから最も遠い点までの距離をすべて求める。
    */
    rep(i, N) if (cmp[i] == -1) {
        // iから最も遠い頂点はx
        int x = bfs(i, nc);
        // xから最も遠い頂点はy
        int y = bfs(x, nc);
        // (x, y)の距離が直径
        dia[nc] = far[y];
        // 任意の頂点についてxまたはyが最も遠い頂点の1つである。
        bfs(y, nc);
        nc++;
    }
    
    /*
    最も遠い点までの距離たちを連結成分ごとにまとめてからソートする
    */
    rep(i, N)F[cmp[i]].push_back(far[i]);
    rep(i, nc) {
        sort(all(F[i]));
        S[i].resize(sz(F[i]) + 1);
        rep(j, sz(F[i])) {
            S[i][j + 1] = F[i][j] + S[i][j];
        }
    }

    while (Q--) {
        int u, v;
        scanf("%d%d", &u, &v);
        --u;
        --v;
        int U = cmp[u], V = cmp[v];
        if (U == V) {
            printf("%d\n", -1);
            continue;
        }
        if (sz(F[U]) > sz(F[V]))swap(U, V);
        ll x = solve(U, V);
        double ans = (double)x / sz(F[U]) / sz(F[V]);
        printf("%0.10f\n", ans);
    }
}