跳转至

树分治

点分治ψ(`∇´)ψ

泛化ψ(`∇´)ψ

点分治也叫重心剖分(国外这么叫),一般用来处理一些静态的路径相关询问,且通常所需的路径(状态)比较多。

其思想是,我们假设当前处理到一个点,我们把以这个点为根的子树中的所有路径分为两类,一类是经过这个点的,一类是不经过的。

此时利用分治的思想,我们处理前者,递归后者分割子问题最终转化为前者,但为了保证复杂度,我们每次处理前者,需要选择当前子树的重心,以保证子问题规模每次缩小 \(\dfrac{1}{2}\),从而使得总复杂度为 \(O(n \log n)\)

对于前者,也有两种情况,一种是以这个点为端点,另一种是这个点为路径上一点。

为了方便处理,我们将第二种情况转化为第一种处理,这个是容易的。

这是最基础的思想,核心部分就在于,怎么处理第一种情况对于答案的贡献。

下面看几道题来熟悉一下:

Luogu3806 【模板】点分治1ψ(`∇´)ψ

给定一棵有 \(n\) 个点的带边权树,\(m\) 次询问,每次询问给出 \(k\),询问树上距离为 \(k\) 的点对是否存在。

\(n\le 10000,m\le 100,k\le 10000000\)

做法是显然的,我们考虑怎么处理前面提到的“第一种情况”的贡献。

我们记 \(tf(len) = \texttt{true/false}\) 表示,在以当前节点 \(u\) 为子树根的情况下,是否存在一条经过 \(u\) 且长度为 \(len\) 的路径。

更新时枚举子树,以 \(u\) 为根计算一下子树节点到 \(u\) 的距离 \(dis\),然后暴力枚举所有可能的 \(dis\),对于一个 \(d\),我们判断是否存在 \(tr(k - d) = \texttt{true}\) 即可。

之后清空 \(tf\),先计算一下当前要递归的子树的重心 \(G_v\),然后以 \(G_v\) 为根继续计算答案即可。

注意清空 \(tf\) 的时候为了保证复杂度,要标记一下哪些位置被更改了,实现可以用队列。

Code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
// author : black_trees

#include <cmath>
#include <queue>
#include <cstdio>
#include <bitset>
#include <cstring>
#include <iostream>
#include <algorithm>

#define endl '\n'

using namespace std;
using i64 = long long;

const int si = 3e5 + 10;
const int inf = 1e9 + 7;

int n, m, q[si];
int tot = 0, head[si];
struct Edge { int ver, Next, w; } e[si << 1];
inline void add(int u, int v, int w) { e[tot] = (Edge){v, head[u], w}, head[u] = tot++; } 

std::queue<int> rec;
bool tf[10000010], can[si], vis[si];
// tf: 当前子树的可行性。

int cnt = 0, sum = 0;
int maxv[si], rt = 0;
int d[si], dis[si], siz[si];
// d: 当前子树的 节点-根 距离。
void calcsiz(int u, int fa) {
    siz[u] = 1, maxv[u] = 0;
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        calcsiz(v, u);
        maxv[u] = max(maxv[u], siz[v]), siz[u] += siz[v];
    }
    maxv[u] = max(maxv[u], sum - siz[u]); // 注意这里是当前子树的节点个数。
    if(maxv[rt] > maxv[u]) rt = u;
}
void calcdis(int u, int fa) {
    d[++cnt] = dis[u]; // 这里复制是为了枚举的时候不全部枚举,保证复杂度。
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        dis[v] = dis[u] + w, calcdis(v, u);
    }
}
void dfs(int u, int fa) {
    tf[0] = true, rec.push(0), vis[u] = true; // 打 vis 是为了确保在子树中进行操作,不会递归出去。
    // 或者不妨说,我们是利用 vis,将树划分成了一个个联通块来处理,因为它每次都只会标记到重心嘛。
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        dis[v] = w, calcdis(v, u);
        for(int j = 1; j <= cnt; ++j) {
            for(int k = 1; k <= m; ++k) {
                if(q[k] >= d[j]) can[q[k]] |= tf[q[k] - d[j]];
            }
        } // 先判断再添加,不然算的不是除了自己子树的情况,这样会多算。
        for(int j = 1; j <= cnt; ++j) {
            if(d[j] < 10000010) rec.push(d[j]), tf[d[j]] = true;
        }
        cnt = 0;
    }

    while(!rec.empty()) tf[rec.front()] = false, rec.pop();
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        rt = 0, maxv[rt] = inf, sum = siz[v];
        calcsiz(v, u), calcsiz(rt, -1), dfs(rt, u); // 先找重心再递归。
    }
}

int main() {

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    memset(tf, false, sizeof tf);
    memset(head, -1, sizeof head);
    memset(vis, false, sizeof vis);
    memset(can, false, sizeof can);

    cin >> n >> m;
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }
    for(int nw = 1; nw <= m; ++nw) {
        cin >> q[nw];
    }

    rt = 0, maxv[rt] = inf, sum = n;
    calcsiz(1, -1), calcsiz(rt, -1), dfs(rt, -1); // 因为本题需要用到 tf(0) 所以 fa 就用 -1 了。

    for(int nw = 1; nw <= m; ++nw) {
        if(can[q[nw]]) cout << "AYE" << endl;
        else cout << "NAY" << endl;
    }

    return 0;
}

Luogu4178 Treeψ(`∇´)ψ

给定一棵有 \(n\) 个点的带权树,给出 \(k\),询问树上距离小于等于 \(k\) 的点对数量。

\(n\le 40000,k\le 20000,w_i\le 1000\)

类似上一题即可,这次我们不维护 \(tf\) 了,直接维护一颗线段树来记录每个长度出现了多少次就可以。

Code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// author : black_trees

#include <cmath>
#include <queue>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

#define endl '\n'
#define int long long

using namespace std;
// using i64 = long long;

const int si = 2e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3fll;

int n, m, V;
int tot = 0, head[si];
struct Edge { int ver, Next, w; } e[si << 1];
inline void add(int u, int v, int w) { e[tot] = (Edge){v, head[u], w}, head[u] = tot++; }

class Segment_Tree {
    private:
        int ls[si << 2], rs[si << 2], val[si << 2];
        int node() { cot++, ls[cot] = rs[cot] = val[cot] = 0; return cot; }
        void pushup(int p) { val[p] = val[ls[p]] + val[rs[p]]; }
    public:
        int rt, cot;
        void modify(int &p, int l, int r, int x, int v) {
            if(!p) p = node();
            if(l == r) return val[p] += v, void();
            int mid = (l + r) >> 1;
            if(x <= mid) modify(ls[p], l, mid, x, v);
            else modify(rs[p], mid + 1, r, x, v);
            pushup(p);
        }
        int query(int p, int l, int r, int ql, int qr) {
            if(!p) return 0;
            if(ql <= l && r <= qr) return val[p];
            int mid = (l + r) >> 1, ret = 0;
            if(ql <= mid) ret += query(ls[p], l, mid, ql, qr);
            if(qr > mid) ret += query(rs[p], mid + 1, r, ql, qr);
            return ret;
        }
} tr;

int rt = 0, maxv[si], sum;
int siz[si], dis[si], d[si], cnt = 0;
bool vis[si];
std::queue<int> rec;
void calcsiz(int u, int fa) {
    siz[u] = 1, maxv[u] = 0;
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        calcsiz(v, u), maxv[u] = max(maxv[u], siz[v]), siz[u] += siz[v];
    }
    maxv[u] = max(maxv[u], sum - siz[u]);
    if(maxv[rt] > maxv[u]) rt = u;
}
void calcdis(int u, int fa) {
    d[++cnt] = dis[u];
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v= e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        dis[v] = dis[u] + w, calcdis(v, u);
    }
}
int ans = 0;
void dfs(int u, int fa) {
    vis[u] = true;
    tr.modify(tr.rt, 1, V, 1, 1), rec.push(1);
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        dis[v] = w, cnt = 0, calcdis(v, u);
        for(int j = 1; j <= cnt; ++j) {
            if(m >= d[j]) 
                ans += tr.query(tr.rt, 1, V, max(0ll, 1 - d[j]) + 1, max(0ll, m - d[j]) + 1);
                // 因为 w >= 0 所以先整体右移一下。
        }
        for(int j = 1; j <= cnt; ++j) {
            tr.modify(tr.rt, 1, V, d[j] + 1, 1), rec.push(d[j] + 1);
        }
    }
    while(!rec.empty()) tr.modify(tr.rt, 1, V, rec.front(), -1), rec.pop();
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        rt = 0, sum = siz[v], maxv[rt] = inf;
        calcsiz(v, u), calcsiz(rt, -1);
        dfs(rt, u);
    }
}

signed main() {

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    tr.cot = tr.rt = 0;
    memset(head, -1, sizeof head);

    cin >> n, V = (int)2e7;
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }
    cin >> m;
    rt = 0, sum = n, maxv[rt] = inf;
    calcsiz(1, -1), calcsiz(rt, -1), dfs(rt, -1);
    cout << ans << endl;

    return 0;
}

XX Open cup GP of Korea, K. Wind of Changeψ(`∇´)ψ

给定 \(n\) 个点,两棵带权树 \(T1, T2\)

定义 \(dist(T, i, j)\) 表示 \(T\) 上的 \(i, j\) 两点之间的距离。

你需要对每个 \(i\),求出 \(\min\limits_{j \not= i} \{dist(T1, i, j) + dist(T2, i, j)\}\)

\(1 \le n \le 2.5 \times 10^5, 1 \le w_i \le 10^9\)

一个比较套路的想法是,我们尝试分开处理两个 dist,或者把某个 dist 尝试合并到另一个 dist 里面。

因为 \(i, j\) 在两棵树上是对应的,加上 \(\min\) 本身没有结合律,所以分开处理并合并的方法是不可行的。

于是我们考虑把一个 dist 合并到另一个 dist 里面处理。

然后,这种树上路径问题,你也只能拿树分治来维护了,没其他的办法,所以我们不妨对 \(T1\) 点分。

设当前重心为 \(g\),那么我们处理的问题就是 \(g\) 的子树当中的子问题了。

现在要做的,相当于是把 \(dist(T1, i, j)\) 拆分成两个部分,一个部分是 \(i \to g\),一个部分是 \(j \to g\)

这样我们就可以把 \(dist(T1, i, j)\) 拆开,分别挂载到 \(T2\) 对应的 \(i, j\) 两个节点上了。

然后问题转化为,给定 \(T2\) 上的一个点集 \(S(g)\),求这个点集中,每个点到点集中另外的点的最小距离。

这部分可以以 \(S(g)\) 为关键点集拖出来建虚树,然后跑一个换根 dp。

还有一个问题是,如果建出来的虚树上有一个节点,它在 \(T1\) 上并不在 \(g\) 的子树当中,它的权值应该是多少?

这个也很简单,因为这样,它就不是 \(i, j\)\(T1\) 上的 LCA,相当于,最极端的情况,\(i, j\) 走到 \(g\) 之后还要再各走一段,这样显然不优秀,所以权值设置成 \(+\infty\) 就好了。

做法受了 jiangly 和 Yunqian 的指点,在这里感谢。

Code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
// author : black_trees

#include <cmath>
#include <cstdio>
#include <vector>
#include <cstring>
#include <utility>
#include <iostream>
#include <algorithm>

#define endl '\n'

using namespace std;
using i64 = long long;

const int si = 2.5e5 + 10;
const int inf = 0x3f3f3f3f;
const i64 infll = 0x3f3f3f3f3f3f3f3fll;

int n;
i64 Ans[si];

namespace TreeII {
    std::vector<int> ev[si];
    std::vector<std::pair<int, int> > ew[si];
    i64 dep[si], val[si], dp[si][2];
    int pa[si], top[si], siz[si], dfn[si], tim = 0;
    void dfs1(int u) {
        if(~pa[u])
            for(auto it = ew[u].begin(); ; ++it)
                if(it -> first == pa[u]) { ew[u].erase(it); break; }
        siz[u] = 1;
        for(auto &tmp : ew[u]) {
            auto [v, w] = tmp;
            pa[v] = u, dep[v] = dep[u] + w;
            dfs1(v), siz[u] += siz[v];
            if(siz[v] > siz[ew[u][0].first]) swap(ew[u][0], tmp);
        }
    }
    void dfs2(int u) {
        dfn[u] = tim++;
        for(auto tmp : ew[u]) {
            auto [v, w] = tmp;
            if(tmp == ew[u][0]) top[v] = top[u];
            else top[v] = v;
            dfs2(v);
        }
    }
    int lca(int u, int v) {
        while(top[u] != top[v]) {
            if(dep[top[u]] < dep[top[v]])
                swap(u, v);
            u = pa[top[u]];
        }
        if(dep[u] > dep[v]) swap(u, v);
        return u;
    }
    void dfs3(int u) {
        dp[u][0] = dp[u][1] = infll;
        for(auto v : ev[u]) {
            dfs3(v);
            i64 x = min(val[v], dp[v][0]) + dep[v] - dep[u];
            if(x < dp[u][0]) swap(x, dp[u][0]);
            dp[u][1] = min(dp[u][1], x);
        }
    }
    void dfs4(int u) {
        Ans[u] = min(Ans[u], val[u] + dp[u][0]);
        for(auto v : ev[u]) {
            i64 x = val[u];
            if(dp[u][0] == min(val[v], dp[v][0]) + dep[v] - dep[u])
                x = min(x, dp[u][1]);
            else x = min(x, dp[u][0]);
            x += dep[v] - dep[u];
            if(x < dp[v][0]) swap(x, dp[v][0]);
            dp[v][1] = min(dp[v][1], x), dfs4(v);
        }
        ev[u].clear(), val[u] = infll;
    }
    bool cmp(int x, int y) { return dfn[x] < dfn[y]; } 
    void build(int m, int node[], i64 dis[]) {
        for(int i = 0; i < m; ++i)
            val[node[i]] = dis[i];
        std::vector<int> st{0};
        st.reserve(m);
        sort(node, node + m, cmp);
        for(int i = 0; i < m; ++i) {
            int u = node[i];
            if(u == st.back()) continue;
            int Lca = lca(u, st.back());
            while(st.size() >= 2 && dep[Lca] <= dep[st[(int)st.size() - 2]]) {
                int v = st.back();
                st.pop_back(), ev[st.back()].emplace_back(v);
            }
            if(Lca != st.back()) ev[Lca].emplace_back(st.back()), st.back() = Lca;
            st.emplace_back(u);
        }
        while((int)st.size() >= 2) {
            int v = st.back();
            st.pop_back(), ev[st.back()].emplace_back(v);
        }
        dfs3(0), dfs4(0);
    }
    void init() {
        pa[0] = -1, dfs1(0), top[0] = 0, dfs2(0);
        for(int i = 0; i < n; ++i) val[i] = infll;
    }
}
namespace TreeI {
    i64 dis[si];
    bool vis[si];
    int siz[si], node[si], nw;
    std::vector<std::pair<int, int> > e[si];
    void dfs1(int u, int fa) {
        siz[u] = 1;
        for(auto [v, w] : e[u]) {
            if(v == fa || vis[v]) continue;
            dfs1(v, u), siz[u] += siz[v];
        }
    }
    int find(int u, int fa, int s) {
        for(auto [v, w] : e[u])
            if(v != fa && !vis[v] && 2 * siz[v] >= s)
                return find(v, u, s);
        return u;
    }
    void dfs2(int u, int fa, int64_t d) {
        node[nw] = u, dis[nw++] = d;
        for(auto [v, w] : e[u])
            if(v != fa && !vis[v]) dfs2(v, u, d + w);
    }
    void solve(int p) {
        dfs1(p, -1), p = find(p, -1, siz[p]);
        nw = 0, dfs2(p, -1, 0), TreeII::build(nw, node, dis), vis[p] = true;
        for(auto [v, w] : e[p]) if(!vis[v]) solve(v);
    }
}

int main() {

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        TreeI::e[u - 1].emplace_back(v - 1, w);
        TreeI::e[v - 1].emplace_back(u - 1, w);
    }
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        TreeII::ew[u - 1].emplace_back(v - 1, w);
        TreeII::ew[v - 1].emplace_back(u - 1, w);
    }
    TreeII::init();
    for(int i = 0; i < n; ++i) 
        Ans[i] = infll;
    TreeI::solve(0);
    for(int i = 0; i < n; ++i)
        cout << Ans[i] << endl;


    return 0;
}

这个问题启发我们,点分治不仅可以直接对路径信息进行统计,它还可以起到拆分的作用,降低答案的信息之间的耦合度。

Luogu5351 Ruri Loves Mascheraψ(`∇´)ψ

给你一棵树,定义 \(val(\delta(u, v))\) 表示 \(\delta(u, v)\) 上的 \(\max\{w_i\}\)

你需要对于 \(\forall (u, v), len(\delta(u, v)) \in [L, R]\),求 \(\sum val(\delta(u, v))\)\(len\) 是路径上的边数。

\(1\le L \le R < n \le 10^5, 1\le w_i \le 10^5\)

本题钦定,\(\delta(u, v) \not= \delta(v, u)\)

可以很快速的想到点分治处理这个路径。

我们把 \(T\) 点分,然后 \(val\) 就转化为了 \(u \to g\)\(v \to g\) 两部分。

于是我们记 \(mx(u)\) 表示 \(\delta(u, g)\) 上的 \(\max\{w_i\}\)

那么一条路径 \(\delta(u, v)\) 的贡献就是 \(\max(mx(u), mx(v))\)

然后路径有序这个东西可以很套路的转化一下,我们把贡献丢到大的那一边算,最后答案翻个倍就可以了。

假设做贡献的是 \(mx(u)\),可以注意到它的贡献是 \(2\times mx(u) \times cnt\),其中 \(cnt\) 是满足 \(mx(v) \le mx(u), len(\delta(u, v)) \in [L, R]\)\(v\) 的数量,且 \(u, v\) 不在 \(g\) 的同一个子树当中,后面这句是点分的限制,不加要算重。

我们先不管 \([L, R]\) 的限制,考虑更简单的情况。

那么此时就是,对 \(u\) 要算不同子树的 \(v\),这个是一维数点,枚举一下是很好做的。

然后考虑加上另一个条件,这个区间本质就是前缀和,所以转化为二维数点。

定义 \(dis(u)\) 表示,\(u\)\(g\) 的路径条数

我们可以对 \(mx\) 排序,要求的就是满足 \(L - dis(u) \le dis(v) \le R - dis(u)\)\(v\)

于是树状数组维护一下 \(dis\) 即可。

但是注意到这样复杂度是 \(O(n^2\log n)\) 的,因为考虑每个子树的时候我们都要重新排序一次,不然你会考虑到自己子树的部分,会算错。

这很不好维护,或者说,二维数点本身是静态的,我们这里因为要分别考虑子树,所以本质上是个动态问题。

但是这个贡献是具有可减性的(因为贡献是,数量乘上权值的形式,乘法有分配律的嘛),所以一个套路的想法是,算一下 \(g\) 整体的贡献,分别减去子树的贡献即可.

然后有一个要注意的点,算子树的时候不要忘记,\(dis\) 是到 \(g\) 的距离,不是到选定子树树根的距离。

Code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
// author : black_trees

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

#define endl '\n'
#define int long long 

using namespace std;
using i64 = long long;

const int si = 1e5 + 10;
const int inf = 0x3f3f3f3f3f3f3f3fll;

int n, L, R;
int tot = 0, head[si];
struct Edge { int ver, Next, w; } e[si << 1];
inline void add(int u, int v, int w) { e[tot] = (Edge){v, head[u], w}, head[u] = tot++; }

class Fenwick {
    private:
        int t[si * 2], V;
        int lowbit(int x) { return x & -x; }
    public:
        void init(int x) { for(int i = 0; i <= x; ++i) t[i] = 0; V = x; }
        void add(int x, int v) { while(x <= V) t[x] += v, x += lowbit(x); }
        void sub(int x, int v) { while(x <= V) t[x] -= v, x += lowbit(x); } 
        int que(int x) { int ret = 0; while(x > 0) ret += t[x], x -= lowbit(x); return ret; }
} tr;

bool vis[si];
int sum, rt, siz[si], maxv[si];
void calcsiz(int u, int fa) {
    siz[u] = 1, maxv[u] = 0;
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        calcsiz(v, u), siz[u] += siz[v];
        maxv[u] = max(maxv[u], siz[v]);
    }
    maxv[u] = max(maxv[u], sum - siz[u]);
    if(maxv[u] < maxv[rt]) rt = u;
}
struct Node {
    int val, dep;
    bool operator < (const Node &rhs) const {
        return val < rhs.val;
    }
} q[si];
int cnt;
void calcpath(int u, int fa, int dep, int value) {
    q[++cnt] = (Node){value, dep};
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        calcpath(v, u, dep + 1, max(value, w));
    }
} // do not change the original infomation if copies can works, or you have to know what you are doing!
int solve(int u, int fa, int edge) {
    int ret = 0;
    cnt = 0;
    if(edge == 0) calcpath(u, fa, 0, edge);
    else calcpath(u, fa, 1, edge); // notice here!!
    sort(q + 1, q + 1 + cnt);
    for(int i = 1; i <= cnt; ++i) {
        auto [val, dep] = q[i];
        ret += val * 
            (tr.que(R - dep + 1) - tr.que(L - dep));    
        tr.add(dep + 1, 1);
    }
    for(int i = 1; i <= cnt; ++i) tr.sub(q[i].dep + 1, 1);
    return ret;
}
int ans = 0;
void dfs(int u, int fa) {
    vis[u] = true;
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa || vis[v]) continue;
        ans -= solve(v, u, w);
    }
    ans += solve(u, fa, 0);
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        rt = 0, maxv[rt] = inf, sum = siz[v];
        calcsiz(v, u), calcsiz(rt, 0), dfs(rt, u);
    }
}

signed main() {

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    memset(head, -1, sizeof head);

    cin >> n >> L >> R, tr.init(n + 10);
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }
    rt = 0, maxv[rt] = inf, sum = n;
    calcsiz(1, 0), calcsiz(rt, 0), dfs(rt, 0);

    cout << ans * 2 << endl;
    return 0;
}

本题提供了一种,当贡献要区分子树计算时的处理方法,也就是考虑减去重复贡献,这种常见于一类算贡献的问题。

还有一种在下面的暴力写挂那题会提到,是对子树“染色”,当然这种常见于 dp。

Luogu2305 [NOI2014]购票ψ(`∇´)ψ

给你一颗以 \(1\) 为根的有根树。

现在你要从任意一个节点向上走,走到 \(1\)

节点有一些属性:\((p_u, q_u, lim_u)\),表示,\(u\) 每次只能向上走到,距离 \(d\) 不超过 \(lim_u\) 的节点 \(v\)

并且这一次移动的花费是 \(d\times p_u + q_u\),你需要对所有节点求出移动到 \(1\) 的最小花费。

\(1\le n \le 2\times 10^5, p_u \in [0, 10^6], q_u \in [0, 10^{12}], lim_u \in (0, 10^{12}]\)

暴力的 \(dp\) 是很好设计的:

\(dp(u)\) 表示从 \(u \to 1\) 的最小代价,则有转移方程:

\(dp(u) = \min\limits_{v \in anc(u), dep(u) - dep(v) \le lim(u)}\{dp(v) + (dep(u) - dep(v)) \times p(u) + q(u)\}\)

其中 \(anc(u)\) 表示 \(u\) 的祖先集合,\(dep(u)\)\(u\) 的带权深度。

显然可以想到斜率优化或者李超树的做法,在这里只提斜率优化做法(因为用得到树分治)

套用斜率优化的办法,我们拆开一下:

\(dp(u) - dep(u)\times p(u) - q(u) = dp(v) - dep(v)\times p(u)\)

可以发现 \(y = dp(v), x = dep(v), k = -p(u)\),此时因为取 \(\min\) 所以要维护一个下凸包(虽然这个画出来和斜率优化博客里的不太一样,但因为斜率是负数,所以画出来还是维护了一个单调递增的形式,二分可以直接同理)。

然后注意到,\(k\) 是不单调的,需要二分,\(x\) 关于 \(v\) 不一定单增,哪怕把 \(v\) 替换成 \(dfn(v)\) 也不是,所以还需要支持在任意位置插入,删除决策。

转移方程里还有一个 \(lim\) 的限制,这个也要求在任意位置删除。

所以我们应当使用 CDQ 分治来支持以上操作。

但是问题在于,我们现在是在树上,不是在序列上,相当于是,我们对于每个节点 \(u\),都要对 \(anc(u)\) 使用 CDQ 维护一次凸包,这十分浪费。

因为本质上,CDQ 的过程是静态的,我们这里需要“加入”子树,这部分是动态的。

原理类似上一题,我们需要想办法来做一些转化或者优化。

所以有什么办法不用支持撤销和加入的李超树来维护这个过程呢?

当然有!我会树上 CDQ 分治!

正常的 CDQ 分治的思想其实是,通过分治来批量处理一系列决策,而不是处理一系列状态来达到优化的目的。

对于决策下标范围的限制是通过分治处理掉的,然后决策点不单调则是利用,对前一段计算好的决策排序来处理的,斜率不单调则可以二分。

我们现在的问题就在于怎么分治,怎么确定批量处理哪些决策,要对什么排序。

我们可以类比正常的 CDQ,现在先考虑怎么处理“距离限制”,也可以说成是决策下标的限制。

这部分可以利用一个叫“有根树点分治”的做法来处理。

与正常的点分治不同,有根树点分治的路径不再是两头“挂在”重心上了,而是有一条延伸出去,向联通块的树根方向延伸。

这么做的原因是,本题的决策中有 \(anc(u)\) 的限制,所以,树根方向,对于当前节点的决策也是有影响的。

我们不妨把 \(g\) 上方的联通块当作 “已经计算好了” 的部分(等价于 \([l, mid]\)),先递归分治处理,计算出它们“最终” 的 dp 值。

然后我们考虑怎么计算 \(g\) 下面的 dp 值(对应普通 CDQ 中的 \([mid + 1, r]\)),显然,上面子树能对它们造成贡献的,实质上是过 \(g\) 向联通块的根 \(rt\) 延伸的一条路径。

我们相当于是,分治处理,每次的决策是,过 \(g\)\(rt\) 方向走的一系列决策(所以才叫做有根树点分治)。

于是 \(anc(u)\) 的限制得到了解决,我们只需要从 \(g\to rt\) 上的状态转移过来即可。

然后考虑怎么处理距离限制,这部分和普通 CDQ 不太相同,我们注意到,距离限制实际上和决策点不单调的限制是捆绑的,因为每个节点能向上走的距离不一样,从而导致在 \(g \to rt\) 的路径上的决策集合不一样,所以它是不单调变化的。

所以我们不妨对这些节点按照 \(lim - dep\) 升序排序,这样决策集合就是单调增加的,并且不需要支持弹出操作,于是我们用一个栈维护,每次从下往上加入决策,并在栈上二分,就可以解决本题了。

可以看一张图来加深理解:

1.png

其中紫色路径就是决策集合的位置,绿色是优先处理的部分,黄色是后处理的部分。

这里绿色部分并没有包括 \(g\),所以我们还需要暴力处理一次 \(dp(g)\),因为点分了所以复杂度是对的。

当然你也可以直接把 \(g\) 扔到绿色部分算掉,不过我觉得这可能要判一些边界,比较麻烦。

本题的实质是,通过点分治,给 CDQ 分治找到了一个分治的指导思想,从而达到批量计算决策而不是状态的目的,避免了动态地加入子树。

本题可以看作是 Cash 一题的树上扩展版本。

Code

这份代码还不能通过本题,只是大体思路是正确的。

我个人猜测可能是二分的边界能否取到,维护的紫色路径是否维护完了,或者是联通块边界的一些处理挂了。

也有可能是点分的时候重复计算了,不过这题从 5-21 调到 5-23,确实不想动了,之后有缘再调好了。

这里还有一份 cmd 大爷的代码,仅作参考。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// author : black_trees

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

#define endl '\n'
#define int long long

using namespace std;
using i64 = long long;
using ldb = long double;

const ldb eps = 1e-6;
const int si = 2e5 + 10;
const ldb infdb = 1e18 + 1;
const int inf = 0x3f3f3f3f3f3f3f3fll;

int n, _t;
int tot = 0, head[si];
struct Edge { int ver, Next, w; } e[si << 1];
inline void add(int u, int v, int w) { e[tot] = (Edge){v, head[u], w}, head[u] = tot++; }

bool vis[si];
int pa[si], s[si], p[si], q[si], lim[si];
int dep[si], siz[si], rt = 0, sum = 0, maxv[si];

void calcsiz(int u, int fa) {
    siz[u] = 1, maxv[u] = 0;
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        calcsiz(v, u), siz[u] += siz[v];
        maxv[u] = max(maxv[u], siz[v]);
    }
    maxv[u] = max(maxv[u], sum - siz[u]);
    if(maxv[u] < maxv[rt]) rt = u;
}
struct Node {
    int id, val;
    bool operator < (const Node &rhs) const {
        return val < rhs.val;
    }
} t[si];
int cnt = 0, udep;
void build(int u, int fa) {
    if(lim[u] >= dep[u] - udep)
        t[++cnt] = (Node){u, lim[u] - dep[u] + udep}; 
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa || vis[v]) continue;
        build(v, u);
    }
}
int Q[si], cur = 0;
int dp[si];
ldb calc(int i, int j) { return (ldb)(dp[j] - dp[i]) / (ldb)(dep[j] - dep[i]); }
int find(ldb slope) {
    int l = 1, r = cur;
    while(l < r) {
        int mid = (l + r) >> 1;
        if(calc(Q[mid - 1], Q[mid]) < slope - eps) 
            r = mid;
        else l = mid + 1;
    }
    return Q[l];
}
void solve(int u, int top) {
    int nw = pa[u], dis = s[u];
    while(dis <= lim[u] && nw != pa[top]) { // top 似乎也需要被包含进去。
        dp[u] = min(dp[u], dp[nw] + dis * p[u] + q[u]);
        dis += s[nw], nw = pa[nw];
    }
    cnt = 0, udep = dep[u]; 
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(vis[v]) continue; // pa[u] 已经 vis 过了。
        build(v, u);
    }
    sort(t + 1, t + 1 + cnt);
    nw = u, dis = 0, cur = 0;
    for(int i = 1; i <= cnt; ++i) {
        auto [v, val] = t[i];
        while(dis <= val && nw != pa[top]) {
            while(cur > 1 && calc(Q[cur - 1], Q[cur]) < calc(Q[cur - 1], nw) + eps)
                --cur;
            Q[++cur] = nw;
            dis += s[nw], nw = pa[nw];
        }
        if(cur > 0) {
            int opt = find(-p[v]);
            dp[v] = min(dp[v], dep[v] * p[v] + q[v] + dp[opt] - dep[opt] * p[v]);
        }
    }
}
void dfs(int u) { // 因为 vis 的缘故其实没有必要在点分的时候存 Father 了,除非统计信息要用
    int top = u;
    while(!vis[top]) top = pa[top];
    vis[u] = true;
    if(!vis[pa[u]]) {
        calcsiz(pa[u], 0);
        rt = 0, sum = siz[pa[u]], maxv[rt] = 0;
        calcsiz(pa[u], 0), dfs(pa[u]); // 算上面以 rt 为根的 siz 是没有必要的。
    }
    solve(u, top);
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(vis[v]) continue;
        calcsiz(v, 0);
        rt = 0, sum = siz[v], maxv[rt] = 0;
        calcsiz(v, 0), dfs(v);
    }
}

signed main() {

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    memset(head, -1, sizeof head);

    cin >> n >> _t;
    pa[1] = 0, s[1] = dep[1] = 0, p[1] = q[1] = 0, lim[1] = inf;
    for(int i = 2; i <= n; ++i) {
        dp[i] = inf;
        cin >> pa[i] >> s[i];
        dep[i] = dep[pa[i]] + s[i]; 
        cin >> p[i] >> q[i] >> lim[i];
        add(i, pa[i], s[i]), add(pa[i], i, s[i]);
    }
    rt = 0, sum = n, maxv[rt] = inf;
    dp[1] = 0, vis[0] = true, calcsiz(1, 0), dfs(rt);

    for(int i = 2; i <= n; ++i)
        cout << dp[i] << endl;

    return 0;
}
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#include<algorithm>
#include<cstdio>
#include<vector>
#define ll long long
#define eps 1e-12
#define MaxN 200500
using namespace std;
inline ll read()
{
  register ll X=0;
  register char ch=0;
  while(ch<48||ch>57)ch=getchar();
  while(ch>=48&&ch<=57)X=X*10+(ch^48),ch=getchar();
  return X;
}
bool vis[MaxN];
vector<int> g[MaxN];
int rt,sum,siz[MaxN],maxp[MaxN];
void cnt(int u,int fa)
{
  sum++;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v])
      cnt(v,u);
}
void getrt(int u,int fa)
{
  siz[u]=1;maxp[u]=0;
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v]){
      getrt(v,u);
      siz[u]+=siz[v];
      maxp[u]=max(maxp[u],siz[v]);
    }
  maxp[u]=max(maxp[u],sum-siz[u]);
  if (maxp[rt]>maxp[u])rt=u;
}
ll dep[MaxN],pf[MaxN];
int fa[MaxN],top;
struct Line
{
  ll k,b; double t;
  ll get(ll x){return k*x+b;}
}q[MaxN];
double inter(const Line &A,const Line &B)
{return 1.00*(B.b-A.b)/(A.k-B.k);}
struct Data
{int p;long long lim;}t[MaxN];
int tot; ll lim[MaxN],udep;
void dfs(int u,int fa)
{
  if (lim[u]+udep>=dep[u])
    t[++tot]=(Data){u,lim[u]-dep[u]+udep};
  for (int i=0,v;i<g[u].size();i++)
    if ((v=g[u][i])!=fa&&!vis[v])
      dfs(v,u);
}
bool cmp(const Data &A,const Data &B)
{return A.lim<B.lim;}
ll p[MaxN],c[MaxN],F[MaxN];
void calc(int u,int bar)
{
  tot=0;udep=dep[u];
  for (int i=0,v;i<g[u].size();i++)
    if (!vis[v=g[u][i]])
      dfs(v,u);
  sort(t+1,t+tot+1,cmp);

  int tp=fa[u];long long dis=pf[u];
  while (dis<=lim[u]&&tp!=bar){
    F[u]=min(F[u],F[tp]+p[u]*dis+c[u]);
    dis+=pf[tp];tp=fa[tp];
  }
  tp=u;dis=0;top=0;
  for (int i=1;i<=tot;i++){
    while (dis<=t[i].lim&&tp!=bar){
      Line sav=(Line){-dep[tp],F[tp]};
      while(top>1&&q[top].t<eps+inter(q[top-1],sav))
        top--;
      q[++top]=sav;
      if (top>1)
        q[top].t=inter(q[top-1],q[top]);
      else q[top].t=1e18;
      dis+=pf[tp];tp=fa[tp];
    }
    if (top){
      int v=t[i].p,l=1,r=top,mid;
      while(l<r){
        mid=(l+r+1)>>1;
        if (q[mid].t+eps>p[v])l=mid;
        else r=mid-1;
      }
      F[v]=min(F[v],
        q[l].get(p[v])+c[v]+p[v]*dep[v]
      );
    }
  }
}
void solve(int u)
{
  int bar=u;
  while(!vis[bar])bar=fa[bar];
  vis[u]=1;
  if (!vis[fa[u]]){
    rt=0;sum=0;cnt(fa[u],0);
    getrt(fa[u],0);
    solve(rt);
  }calc(u,bar);
  for (int i=0,v;i<g[u].size();i++)
    if (!vis[v=g[u][i]]){
      rt=0;sum=0;cnt(v,0);
      getrt(v,0);
      solve(rt);
    }
}
int n,_t;
int main()
{
  scanf("%d%d",&n,&_t);
  for (int i=2;i<=n;i++){
    fa[i]=read();
    g[fa[i]].push_back(i);
    g[i].push_back(fa[i]);
    dep[i]=dep[fa[i]]+(pf[i]=read());
    p[i]=read();c[i]=read();
    lim[i]=read();
    F[i]=1ll<<60;
  }
  maxp[0]=sum=n;getrt(1,0);
  F[1]=0;vis[0]=1;solve(rt);
  for (int i=2;i<=n;i++)
    printf("%lld\n",F[i]);
  return 0;
}

Cmd 大爷的实现也给了我一些启示。

其实点分治的那层 dfs 大部分时候不需要记录 fa,因为 vis 本身就起了分割联通块的作用。

然后有一些不必要的 Calcsiz 其实可以不写()

Luogu4886 快递员ψ(`∇´)ψ

[TODO]

边分治ψ(`∇´)ψ

泛化ψ(`∇´)ψ

[TODO]

Luogu4565 [CTSC2018]暴力写挂ψ(`∇´)ψ

[TODO]

[本题存在边分树合并/点(边)分+虚树两种做法]

Luogu4220 [WC2018]通道ψ(`∇´)ψ

[TODO]

[本题存在边分树合并/点(边)分+虚树两种做法]

点分树 & 边分树ψ(`∇´)ψ

泛化ψ(`∇´)ψ

[TODO]

Luogu4565 [CTSC2018]暴力写挂ψ(`∇´)ψ

[TODO]

[本题存在边分树合并/点(边)分+虚树两种做法]

Luogu4220 [WC2018]通道ψ(`∇´)ψ

[TODO]

[本题存在边分树合并/点(边)分+虚树两种做法]

Referenceψ(`∇´)ψ


最后更新: September 6, 2023