跳转至

虚树

概述ψ(`∇´)ψ

有一类问题,每次询问会对树上的一个点集进行操作。

在单次操作操作复杂度比较高,但是点集大小总和级别不大的时候,我们就可以考虑使用虚树优化。

具体来说,对于一棵树 \(T = \{V, E\}\),一个点集 \(S \subset V\),点集 \(S\) 的虚树是这样的一个东西:

包含了 \(S\) 中所有点,并且包含 \(S\) 中任意两个节点的 LCA 的一棵树。

大概长这样:

vt-1

其中红色节点为 \(S\) 中的节点,我们称为关键点,蓝色节点则是 LCA(虚树中的非关键点)。

可以注意到,虚树不会改变原树上的祖先关系,也就是说,实际上虚树是,将原树中对于当前询问没有用的一些节点给去掉,得到的一颗新树。

保留 LCA 就是因为,LCA 也会保留一些原问题的信息,当然你也可以把虚树看作“将原树压缩了”,比如我们只关心 \(\delta(u,v)\) 的一些信息,不关心上面的节点。

我们就直接把 \(\delta(u, v)\) 压成一条边 \((u \to v)\) 就行了,(用于节省空间的压缩 01trie 就用了同样的思想)。

然后只需要在新树上处理问题就可以了。

构建ψ(`∇´)ψ

一种做法是直接按照 dfn 排序,然后相邻的两个节点求 LCA,去重。

但是这个比较麻烦,不如使用单调栈的构造方法,所以这里只讲单调栈做法。

这个做法的思想是,每次只维护虚树的一条链

首先把树根 (\(1\) 号节点) 入栈,并且,我们保证单调栈从顶到底,节点的 dfn 单调递减。

然后考虑,当前的栈顶是 \(top\),加入节点是 \(nw\),分两种情况讨论:

  • \(\text{LCA}(top, nw) = top\),证明 \(nw\) 是当前链上的节点,直接加入即可。
  • 否则,考虑栈中次大节点 \(stop\)\(\text{LCA}(top, nw)\) 的关系,显然此时已经维护完了上一条链,然后我们考虑:
    • 如果 \(dfn(stop) > dfn(\text{LCA}(top, nw))\),那么说明,\(\text{LCA}(top, nw)\) 已经在栈中了,此时我们不断弹栈,直到 \(top = \text{LCA}\),弹栈的时候,记得让被弹出的节点和弹出后的栈顶连边(因为它们是父子关系)。
    • 否则证明 \(\text{LCA}(top, nw)\) 还没有入栈,先连 \(top\)\(\text{LCA}\),然后继续弹栈,最后加入 \(\text{LCA}\)\(nw\) 即可。

vt-1.gif

Code
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
int k, a[si];
int stk[si], top = 0;
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
inline void ADD(int u, int v, int w) { E[Tot] = (Edge){v, Head[u], w}, Head[u] = Tot++; }
inline void Add(int u, int v) { int w = dist(u, v); ADD(u, v, w), ADD(v, u, w); }
void build() {
    sort(a + 1, a + 1 + k, cmp);
    stk[top = 1] = 1, Tot = 0, Head[1] = -1; // 这样清空复杂度才是对的。
    for(int i = 1, Lca; i <= k; ++i) {
        if(a[i] == 1) continue;
        Lca = lca(a[i], stk[top]);
        if(Lca != stk[top]) {
            while(dfn[Lca] < dfn[stk[top - 1]])
                Add(stk[top - 1], stk[top]), --top;
            if(dfn[Lca] > dfn[stk[top - 1]])
                Head[Lca] = -1, Add(Lca, stk[top]), stk[top] = Lca;
            else Add(Lca, stk[top--]); // Lca = stk[top - 1].
        }
        Head[a[i]] = -1, stk[++top] = a[i];
    }
    for(int i = 1; i < top; ++i)
        Add(stk[i], stk[i + 1]);
    return;
}

习题ψ(`∇´)ψ

「SDOI2011」消耗战

题目描述

在一场战争中,战场由 \(n\) 个岛屿和 \(n-1\) 个桥梁组成,保证每两个岛屿间有且仅有一条路径可达。现在,我军已经侦查到敌军的总部在编号为 \(1\) 的岛屿,而且他们已经没有足够多的能源维系战斗,我军胜利在望。已知在其他 \(k\) 个岛屿上有丰富能源,为了防止敌军获取能源,我军的任务是炸毁一些桥梁,使得敌军不能到达任何能源丰富的岛屿。由于不同桥梁的材质和结构不同,所以炸毁不同的桥梁有不同的代价,我军希望在满足目标的同时使得总代价最小。

侦查部门还发现,敌军有一台神秘机器。即使我军切断所有能源之后,他们也可以用那台机器。机器产生的效果不仅仅会修复所有我军炸毁的桥梁,而且会重新随机资源分布(但可以保证的是,资源不会分布到 \(1\) 号岛屿上)。不过侦查部门还发现了这台机器只能够使用 \(m\) 次,所以我们只需要把每次任务完成即可。

输入格式

第一行一个整数 \(n\),代表岛屿数量。

接下来 n-1 行,每行三个整数 \(u,v,w\),代表 \(u\) 号岛屿和 \(v\) 号岛屿由一条代价为 \(c\) 的桥梁直接相连,保证 \(1\le u,v\le n\)\(1\le c\le 10^5\)

\(n+1\) 行,一个整数 \(m\),代表敌方机器能使用的次数。

接下来 \(m\) 行,每行一个整数 \(k_i\),代表第 \(i\) 次后,有 \(k_i\) 个岛屿资源丰富,接下来 \(k\) 个整数 \(h_1,h_2,\cdots ,h_k\),表示资源丰富岛屿的编号。

输出格式

输出有 \(m\) 行,分别代表每次任务的最小代价。

数据范围

对于 \(100\%\) 的数据,\(2\le n\le 2.5\times 10^5,m\ge 1,\sum k_i\le 5\times 10^5,1\le k_i\le n-1\)

考虑设 \(dp(u)\) 表示,使得 \(u\) 和它的子树中任意一个关键点不相连的最小代价。

复杂度是 \(O(nq)\) 的,注意到 \(\sum k_i \le 5 \times 10^5\),于是我们直接把关键点拖出来建虚树,在虚树上 dp 即可。

Code
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// author : black_trees

#include <cmath>
#include <cstdio>
#include <cstring>
#include <utility>
#include <iostream>
#include <algorithm>

#define endl '\n'
#define int long long 

using namespace std;
using i64 = long long;

const int si = 5e5 + 10;

int n, m;
int tot = 0, Tot = 0;
int head[si], Head[si];
struct Edge { int ver, Next, w; } e[si << 1], E[si << 1];
inline void add(int u, int v, int w) { e[tot] = (Edge){v, head[u], w}, head[u] = tot++; }

int tim = 0, dfn[si];
int f[si][20], dep[si], dis[si][20];
void dfs1(int u, int fa) {
    dfn[u] = ++tim, dep[u] = dep[fa] + 1, f[u][0] = fa;
    for(int i = 1; i <= 19; ++i) 
        f[u][i] = f[f[u][i - 1]][i - 1],
        dis[u][i] = min(dis[u][i - 1], dis[f[u][i - 1]][i - 1]);
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver, w = e[i].w;
        if(v == fa) continue;
        dis[v][0] = w, dfs1(v, u);
    }
}
int lca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 19; i >= 0; --i)
        if(dep[f[u][i]] >= dep[v]) u = f[u][i];
    if(u == v) return u;
    for(int i = 19; i >= 0; --i)
        if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
    return f[u][0];
}
int dist(int u, int v) {
    int ans = 1e9 + 8;
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 19; i >= 0; --i) {
        if(dep[f[u][i]] >= dep[v]) 
            ans = min(ans, dis[u][i]), u = f[u][i];
        if(u == v) return ans;
    }
    return ans;
}

bool vis[si]; // is key point or not.
int k, a[si];
int stk[si], top = 0;
bool cmp(int x, int y) { return dfn[x] < dfn[y]; }
inline void ADD(int u, int v, int w) { E[Tot] = (Edge){v, Head[u], w}, Head[u] = Tot++; }
inline void Add(int u, int v) { int w = dist(u, v); ADD(u, v, w), ADD(v, u, w); }
void build() {
    sort(a + 1, a + 1 + k, cmp);
    stk[top = 1] = 1, Tot = 0, Head[1] = -1;
    for(int i = 1, Lca; i <= k; ++i) {
        if(a[i] == 1) continue;
        Lca = lca(a[i], stk[top]);
        if(Lca != stk[top]) {
            while(dfn[Lca] < dfn[stk[top - 1]])
                Add(stk[top - 1], stk[top]), --top;
            if(dfn[Lca] > dfn[stk[top - 1]])
                Head[Lca] = -1, Add(Lca, stk[top]), stk[top] = Lca;
            else Add(Lca, stk[top--]);
        }
        Head[a[i]] = -1, stk[++top] = a[i];
    }
    for(int i = 1; i < top; ++i)
        Add(stk[i], stk[i + 1]);
    return;
}

int dp[si];
void dfs2(int u, int fa) {
    dp[u] = 0;
    for(int i = Head[u]; ~i; i = E[i].Next) {
        int v = E[i].ver, w = E[i].w;
        if(v == fa) continue;
        dfs2(v, u);
        if(vis[v]) dp[u] = dp[u] + w;
        else dp[u] = dp[u] + min(dp[v], w);
    }
}

signed main() {

    // freopen("test.txt", "r", stdin);

    cin.tie(0) -> sync_with_stdio(false);
    cin.exceptions(cin.failbit | cin.badbit);

    memset(head, -1, sizeof head);  

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v, w;
        cin >> u >> v >> w;
        add(u, v, w), add(v, u, w);
    }
    dfs1(1, 0);

    cin >> m;
    for(int nw = 1; nw <= m; ++nw) {
        cin >> k;
        for(int i = 1; i <= k; ++i)
            cin >> a[i], vis[a[i]] = true;
        build(), dfs2(1, 0);
        cout << dp[1] << endl;
        for(int i = 1; i <= k; ++i)
            vis[a[i]] = false;
    }
    return 0;
}

最后更新: May 9, 2023