如何有逻辑的实现思路

我最近发现,很多时候我知道了思路,想清楚了做法,却没有动手实现。

很大一部分原因是我并不知道该怎么下手,从哪里开始写,怎么写,写什么。

昨天觉得这样不好,时间长了之后代码能力就完全退化了,所以想了想,应该怎么做。


以这道题为例:

https://loj.ac/p/3646

Solution

这道题就是说从高位开始考虑,看一下这一位都是 \(1\) 的够不够 \(m\) 个,然后够了的话就把不是 \(1\) 的扔掉免得之后不好判定。

如果这一位不够了那只能是 \(0\) 了,然后如果刚好不够 \(m\) 个了就说明这一位也只能是 \(0\),最后就能确定答案。

然后具体怎么维护这些数呢,显然我们暴力取出来复杂度就寄了。但是你考虑这个贪心过程,我们每次相当于是对每一个位置保留至少 \(m\) 个。

可以注意到我们保留最大的,显然不会比使用更小的获取的答案更劣,所以我们不妨,对每一位都只保留 \(m\) 个,数量级在 \(O(m \log V)\) 左右。

这个复杂度就是可以接受的了(实际卡不满),于是我们现在做的事情就是把这些数取出来,这很简单,用我们 Count on the tree 的套路,每个节点维护从根到它的值域线段树,然后利用可持久化维护,做一个树上差分,同步维护四个指针就行了。

现在我们知道了思路,应该怎么实现呢?

首先明确一下,思路的每个细节大致能知道怎么实现,然后如果有细节问题先想清楚,不然的话会浪费大把时间。

比如说贪心部分,大概知道是从高位开始考虑,逐位确定答案。

所以实现上应该是先枚举位数,然后看一看取出来的这些数当中,那些能让当前位为 \(1\),把这些数标记(注意,我们实现的时候不一定是标记,如果按照自己写出来的思路直接实现很多时候容易过于麻烦),check 一下个数够不够 \(m\),更新答案,删掉数就行了。

然后步骤就是,根据题面给出的输入格式,在 main() 当中一行一行写下去。

比如这题是先给出一棵树。

于是我们知道应该存图,如果使用前向星的话需要 memset

然后看一眼数据范围,\(10^6\),要开双向边,于是我们写下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
const int si = 1e6 + 10;

int n;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
}

然后注意到还要有权值,继续写:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
const int si = 1e6 + 10;

int n;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i];
}

接下来读到有 \(Q\) 组询问,需要强制在线解密。

于是写下:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
const int si = 1e6 + 10;

int n, m, q;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];
int lastans = 0;

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i];

    cin >> q;
    for(int nw = 1; nw <= q; ++nw) {
        int x, y;
        cin >> x >> y;
        x = ((x ^ lastans) % n) + 1;
        y = ((y ^ lastans) % n) + 1;
    }
}

之后考虑询问的部分,我们应当依次做什么。

首先我们要做的就是,提取出路径上的一些数,然后贪心。

提取的话需要使用可持久化线段树。要维护每个节点到根的链。

然后通过树上差分得到对应的线段树,在它上面线段树二分,之后把数提取出来。

那么我们首先需要写一个离散化,为了节省线段树的空间(这题空间卡的比较死)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#include <vector>

const int si = 1e6 + 10;

int n, m, q;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];
int lastans = 0;

int val[si];
std::vector<int> v;
int getVal(int value) {
    return lower_bound(v.begin(), v.end(), value) - v.begin() + 1;
}

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i], v.push_back(w[i]);
    sort(v.begin(), v.end()), v.erase(unique(v.begin(), v.end()), v.end());
    for(int i = 1; i <= n; ++i) val[i] = getVal(w[i]);

    cin >> q;
    for(int nw = 1; nw <= q; ++nw) {
        int x, y;
        cin >> x >> y;
        x = ((x ^ lastans) % n) + 1;
        y = ((y ^ lastans) % n) + 1;
    }
}

然后我们需要建出这个线段树,首先需要 dfs,而且树上差分要求 LCA,于是写一个倍增 LCA:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <vector>

const int si = 1e6 + 10;

int n, m, q;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];
int lastans = 0;

int val[si];
std::vector<int> v;
int getVal(int value) {
    return lower_bound(v.begin(), v.end(), value) - v.begin() + 1;
}

int dep[si], f[si][24];
void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1, f[u][0] = fa;
    for(int i = 1; i <= 22; ++i) {
        f[u][i] = f[f[u][i - 1]][i - 1];
    }
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa) continue;
        dfs(v, u);
    }
}
int getLca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 22; i >= 0; --i) {
        if(dep[f[u][i]] >= dep[v]) u = f[u][i];
    }
    if(u == v) return u;
    for(int i = 22; i >= 0; --i) {
        if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
    }
    return f[u][0];
}

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i], v.push_back(w[i]);
    sort(v.begin(), v.end()), v.erase(unique(v.begin(), v.end()), v.end());
    for(int i = 1; i <= n; ++i) val[i] = getVal(w[i]);

    cin >> q;
    for(int nw = 1; nw <= q; ++nw) {
        int x, y;
        cin >> x >> y;
        x = ((x ^ lastans) % n) + 1;
        y = ((y ^ lastans) % n) + 1;
    }
}

然后就写个可持久化线段树就行了,记得在 dfs 里面加一行来建树。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
#include <vector>

const int si = 1e6 + 10;

int n, m, q;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];
int lastans = 0;

int val[si];
std::vector<int> v;
int getVal(int value) {
    return lower_bound(v.begin(), v.end(), value) - v.begin() + 1;
}

int V = 0;
int cnt = 0;
int root[si];
int dat[si << 2];
int ls[si << 2], rs[si << 2];
int build(int l, int r) {
    int p = ++cnt;
    if(l == r) return p;
    int mid = (l + r) >> 1;
    ls[p] = build(l, mid), rs[p] = build(mid + 1, r);
    return p;
}
int insert(int lst, int l, int r, int x) {
    int p = ++cnt;
    ls[p] = ls[lst], rs[p] = rs[lst], dat[p] = dat[lst];
    if(l == r) return p;
    int mid = (l + r) >> 1;
    if(x <= mid) ls[p] = insert(ls[lst], l, mid, x);
    else rs[p] = insert(rs[lst], mid + 1, r, x);
    return p;
}
int limit = 0;
int chain[si], cur = 0;
int temp[si], tcur = 0;
void getChain(int p, int q, int u, int v, int l, int r) {
    if(!limit || !(dat[p] + dat[q] - dat[u] - dat[v])) return;
    if(l == r) {
        int tmp = min(dat[p] + dat[q] - dat[u] - dat[v], limit);
        while(tmp--) chain[++cur] = w[p], --limit;
        return;
    }
    int mid = (l + r) >> 1;
    getChain(ls[p], ls[q], ls[u], ls[v], l, mid);
    getChain(rs[p], rs[q], rs[u], rs[v], mid + 1, r);
}

int dep[si], f[si][24];
void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1, f[u][0] = fa;
    insert(root[fa], 1, V, val[u]);
    for(int i = 1; i <= 22; ++i) {
        f[u][i] = f[f[u][i - 1]][i - 1];
    }
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa) continue;
        dfs(v, u);
    }
}
int getLca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 22; i >= 0; --i) {
        if(dep[f[u][i]] >= dep[v]) u = f[u][i];
    }
    if(u == v) return u;
    for(int i = 22; i >= 0; --i) {
        if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
    }
    return f[u][0];
}

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i], v.push_back(w[i]);
    sort(v.begin(), v.end()), v.erase(unique(v.begin(), v.end()), v.end());
    for(int i = 1; i <= n; ++i) val[i] = getVal(w[i]);

    V = (int)v.size();
    root[0] = build(1, V);

    dfs(1, 0);

    cin >> q;
    for(int nw = 1; nw <= q; ++nw) {
        int x, y;
        cin >> x >> y;
        x = ((x ^ lastans) % n) + 1;
        y = ((y ^ lastans) % n) + 1;
    }
}

剩下的事情就是写一个贪心了,先提取出要用的数。

然后按位考虑,每一位选一些数出来看看够不够,如果够就删掉不满足条件的。

最后输出答案就行了。

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <vector>

const int si = 1e6 + 10;

int n, m, q;
int tot = 0, head[si];
struct Edge { int ver, Next; } e[si << 1];
inline void add(int u, int v) { e[tot] = (Edge){v, head[u]}, head[u] = tot++; }

int w[si];
int lastans = 0;

int val[si];
std::vector<int> v;
int getVal(int value) {
    return lower_bound(v.begin(), v.end(), value) - v.begin() + 1;
}

int V = 0;
int cnt = 0;
int root[si];
int dat[si << 2];
int ls[si << 2], rs[si << 2];
int build(int l, int r) {
    int p = ++cnt;
    if(l == r) return p;
    int mid = (l + r) >> 1;
    ls[p] = build(l, mid), rs[p] = build(mid + 1, r);
    return p;
}
int insert(int lst, int l, int r, int x) {
    int p = ++cnt;
    ls[p] = ls[lst], rs[p] = rs[lst], dat[p] = dat[lst];
    if(l == r) return p;
    int mid = (l + r) >> 1;
    if(x <= mid) ls[p] = insert(ls[lst], l, mid, x);
    else rs[p] = insert(rs[lst], mid + 1, r, x);
    return p;
}
int limit = 0;
int chain[si], cur = 0;
int temp[si], tcur = 0;
void getChain(int p, int q, int u, int v, int l, int r) {
    if(!limit || !(dat[p] + dat[q] - dat[u] - dat[v])) return;
    if(l == r) {
        int tmp = min(dat[p] + dat[q] - dat[u] - dat[v], limit);
        while(tmp--) chain[++cur] = w[p], --limit;
        return;
    }
    int mid = (l + r) >> 1;
    getChain(ls[p], ls[q], ls[u], ls[v], l, mid);
    getChain(rs[p], rs[q], rs[u], rs[v], mid + 1, r);
}

int dep[si], f[si][24];
void dfs(int u, int fa) {
    dep[u] = dep[fa] + 1, f[u][0] = fa;
    insert(root[fa], 1, V, val[u]);
    for(int i = 1; i <= 22; ++i) {
        f[u][i] = f[f[u][i - 1]][i - 1];
    }
    for(int i = head[u]; ~i; i = e[i].Next) {
        int v = e[i].ver;
        if(v == fa) continue;
        dfs(v, u);
    }
}
int getLca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 22; i >= 0; --i) {
        if(dep[f[u][i]] >= dep[v]) u = f[u][i];
    }
    if(u == v) return u;
    for(int i = 22; i >= 0; --i) {
        if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
    }
    return f[u][0];
}

int main() {
    memset(head, -1, sizeof head);

    cin >> n;
    for(int i = 1; i < n; ++i) {
        int u, v;
        cin >> u >> v;
        add(u, v), add(v, u);
    }
    for(int i = 1; i <= n; ++i) cin >> w[i], v.push_back(w[i]);
    sort(v.begin(), v.end()), v.erase(unique(v.begin(), v.end()), v.end());
    for(int i = 1; i <= n; ++i) val[i] = getVal(w[i]);

    V = (int)v.size();
    root[0] = build(1, V);

    dfs(1, 0);

    cin >> q;
    for(int nw = 1; nw <= q; ++nw) {
        int x, y;
        cin >> x >> y;
        x = ((x ^ lastans) % n) + 1;
        y = ((y ^ lastans) % n) + 1;
        int Lca = getLca(x, y), Fat = f[Lca][0];    
        cur = 0, limit = 198;
        getChain(root[x], root[y], root[Lca], root[Fat], 1, V);
        int ans = 0;
        for(int i = 60; i >= 0; --i) {
            ans |= (1 << i), tcur = 0;
            for(int j = 1; j <= cur; ++j) {
                if((chain[j] & ans) == ans)
                    temp[++tcur] = chain[j];
            }
            if(tcur < m) ans ^= (1 << i);
            else {
                for(int j = 1; j <= tcur; ++j) {
                    chain[j] = temp[j];
                }
                cur = tcur;
            }
        }
        cout << (lastans = ans) << endl;
    }
}

剩下的就是调试了(这个代码不保证能过)


最后更新: August 19, 2023