「学习笔记」树链剖分

树链剖分用于将树分割成若干条链的形式,以维护树上路径的信息。
具体来说,将整棵树剖分为若干条链,使它组合成线性结构,然后用其他的数据结构维护信息。

树链剖分有很多种形式,本文要讲的是其中的轻重链剖分。

树链剖分本质上就是把链从树上砍下来,然后放到树状数组或线段树上来维护。

轻重链剖分

我们给出一些定义:

定义 重子节点 表示其子节点中 子树最大 的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。

定义 轻子节点 表示剩余的所有子结点。

从这个结点到重子节点的边为 重边

到其他轻子节点的边为 轻边

若干条首尾衔接的重边构成 重链

把落单的结点也当作重链,那么整棵树就被剖分成若干条重链。

如图:

「学习笔记」树链剖分

图片来自 (texttt{OI-Wiki})

实现

要进行树链剖分,得先有一棵树,通过一个 dfs 来得到所有的子树大小、重儿子以及每个结点的深度。

void dfs(int u, int fat) {     siz[u] = 1;     fa[u] = fat;     dep[u] = dep[fat] + 1;     for (int v : e[u]) {         if (v == fat)   continue ;         dfs(v, u);         siz[u] += siz[v];         if (siz[v] > siz[son[u]]) {             son[u] = v;         }     } } 

得到这些信息后,我们就要把树分成好多的链,通过另一个 dfs 来确定好链顶,有时我们要用线段树或树状数组来维护这些链,所以可能还要维护在线段树上的位置(dfs 序)等信息。

void getpos(int u, int top) {     dfn[u] = ++ tim; // 在线段树上的位置 (dfs 序)     pos[tim] = u; // 线段树这个位置所代表的节点     tp[u] = top; // 链顶     if (!son[u])    return ;     getpos(son[u], top); // 优先跑重儿子     for (int v : e[u]) {         if (v == fa[u] || v == son[u])  continue ;         getpos(v, v); // 轻儿子再单独分链     } } 

到这里,处理部分就完成了,接下来就是根据题目来对这些链进行处理了,可以使用线段树、树状数组等数据结构来维护链的信息,也可以利用跳链顶的优秀复杂度求 LCA 等。

例题

P3384 【模板】重链剖分/树链剖分 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

模板题。

要求我们对链上节点的权值进行修改、求和,以及对子树内的所有节点进行修改、求和,我们使用线段树来进行维护。

对于链上的操作,由于我们已经把链都摆在了线段树上,所以只需要对线段树进行区间操作即可。对于子树的操作,根据 dfs 序的性质,一棵子树在 dfs 序上的范围是 ([dfn_{rt}, dfn_{rt} + siz_{rt} - 1]),对这个区间进行区间操作即可。

void Modify(int x, int y, ll z) {     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         modify(1, 1, n, dfn[tp[x]], dfn[x], z);         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     modify(1, 1, n, dfn[x], dfn[y], z);     return ; }  ll Query(int x, int y) {     ll ans = 0;     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         ans = (ans + query(1, 1, n, dfn[tp[x]], dfn[x])) % mod;         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % mod;     return ans; } 

这两段代码就是跳链的过程,可以这样理解一下,如果两个节点的链顶不相同,说明他们不在同一条链中,我们让链顶深度大的节点向上跳(这样可以防止跳过头),在跳之前,先对这段链的信息进行修改维护,也就是 Modify 中的 modify 函数和 Query 中的 query 函数,然后,跳到这个链顶的父亲,离开这条链,以此继续,直到这两个节点链顶一样时(即在同一条链上时),对这两个节点之间的链进行操作,退出函数。

// The code was written by yifan, and yifan is neutral!!!  #include <bits/stdc++.h> using namespace std; typedef long long ll; #define bug puts("NOIP rp ++!"); #define rep(i, a, b, c) for (int i = (a); i <= (b); i += (c)) #define per(i, a, b, c) for (int i = (a); i >= (b); i -= (c)) #define Mod(x) ((x) >= mod ? (x) %= mod : (x))  template<typename T> inline T read() {     T x = 0;     bool fg = 0;     char ch = getchar();     while (ch < '0' || ch > '9') {         fg |= (ch == '-');         ch = getchar();     }     while (ch >= '0' && ch <= '9') {         x = (x << 3) + (x << 1) + (ch ^ 48);         ch = getchar();     }     return fg ? ~x + 1 : x; }  const int N = 1e5 + 5;  int n, m, rt, mod, tim; int val[N], siz[N], son[N], fa[N]; int dfn[N], pos[N], tp[N], dep[N]; vector<int> e[N];  struct seg {     int len;     ll val, tag; } t[N << 2];  void dfs(int u, int fat) {     siz[u] = 1;     fa[u] = fat;     dep[u] = dep[fat] + 1;     for (int v : e[u]) {         if (v == fat)   continue ;         dfs(v, u);         siz[u] += siz[v];         if (siz[v] > siz[son[u]]) {             son[u] = v;         }     } }  void getpos(int u, int top) {     dfn[u] = ++ tim;     pos[tim] = u;     tp[u] = top;     if (!son[u])    return ;     getpos(son[u], top);     for (int v : e[u]) {         if (v == fa[u] || v == son[u])  continue ;         getpos(v, v);     } }  #define ls (u << 1) #define rs (u << 1 | 1) #define mid ((l + r) >> 1)  void pushup(int u) {     t[u].val = (t[ls].val + t[rs].val) % mod; }  void pushdown(int u, int l, int r) {     if (!t[u].tag)  return ;     if (l == r) {         t[u].tag = 0;         return ;     }     t[ls].tag = (t[ls].tag + t[u].tag) % mod;     t[ls].val = (t[ls].val + t[u].tag * t[ls].len % mod) % mod;     t[rs].tag = (t[rs].tag + t[u].tag) % mod;     t[rs].val = (t[rs].val + t[u].tag * t[rs].len % mod) % mod;     t[u].tag = 0;     return ; }  void build(int u, int l, int r) {     t[u].tag = 0;     t[u].len = r - l + 1;     if (l == r) {         t[u].val = val[pos[l]];         return ;     }     build(ls, l, mid);     build(rs, mid + 1, r);     pushup(u); }  void modify(int u, int l, int r, int lr, int rr, ll v) {     if (lr <= l && r <= rr) {         t[u].tag = (t[u].tag + v) % mod;         t[u].val = (t[u].val + (t[u].len * v) % mod) % mod;         return ;     }     pushdown(u, l, r);     if (lr <= mid) {         modify(ls, l, mid, lr, rr, v);     }     if (rr > mid) {         modify(rs, mid + 1, r, lr, rr, v);     }     pushup(u); }  void Modify(int x, int y, ll z) {     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         modify(1, 1, n, dfn[tp[x]], dfn[x], z);         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     modify(1, 1, n, dfn[x], dfn[y], z);     return ; }  ll query(int u, int l, int r, int lr, int rr) {     if (lr <= l && r <= rr) {         return t[u].val;     }     pushdown(u, l, r);     ll ans = 0;     if (lr <= mid) {         ans = (ans + query(ls, l, mid, lr, rr)) % mod;     }     if (rr > mid) {         ans = (ans + query(rs, mid + 1, r, lr, rr)) % mod;     }     return ans; }  ll Query(int x, int y) {     ll ans = 0;     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         ans = (ans + query(1, 1, n, dfn[tp[x]], dfn[x])) % mod;         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % mod;     return ans; }  #undef ls #undef rs #undef mid  int main() {     n = read<int>(), m = read<int>();     rt = read<int>(), mod = read<int>();     rep (i, 1, n, 1) {         val[i] = read<int>();     }     int x, y;     rep (i, 1, n - 1, 1) {         x = read<int>(), y = read<int>();         e[x].emplace_back(y);         e[y].emplace_back(x);     }     dfs(rt, 0);     getpos(rt, rt);     build(1, 1, n);     int op, z;     rep (i, 1, m, 1) {         op = read<int>(), x = read<int>();         if (op == 1) {             y = read<int>(), z = read<ll>();             Modify(x, y, z);         }         if (op == 2) {             y = read<int>();             cout << Query(x, y) % mod << 'n';         }         if (op == 3) {             z = read<ll>();             modify(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z);         }         if (op == 4) {             cout << query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1) % mod << 'n';         }     }     return 0; } 

P4211 [LNOI2014] LCA - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

询问 LCA 的深度,其实就是在树上差分中,一个节点权值 (+ 1),另一个点求该节点到根节点的路径和。在树链剖分中,就是将根节点到一个节点这条链上所有的点 (+ 1),另一个节点求该节点到根节点的路径的权值和。

可以经询问进行离线处理,离线来完成这道题,具体看代码。

//The code was written by yifan, and yifan is neutral!!!  #include <bits/stdc++.h> using namespace std; typedef long long ll; #define bug puts("NOIP rp ++!"); #define lowbit(x) (x & (-x)) #define ls (u << 1) #define rs (u << 1 | 1) #define mid ((l + r) >> 1)  template<typename T> inline T read() {     T x = 0;     bool fg = 0;     char ch = getchar();     while (ch < '0' || ch > '9') {         fg |= (ch == '-');         ch = getchar();     }     while (ch >= '0' && ch <= '9') {         x = (x << 3) + (x << 1) + (ch ^ 48);         ch = getchar();     }     return fg ? ~x + 1 : x; }  const int N = 5e4 + 5; const int mod = 201314;  int n, m; int fa[N], dep[N], siz[N], son[N], tp[N], pos[N]; int dfn[N]; vector<int> e[N];  struct ask {     int be, R, z;     bool fg;      int operator < (const ask& b) const {         return R < b.R;     } } xunwen[N << 1];  struct seg {     int val, tag;      seg operator + (const seg& b) {         seg& a = *this, res;         res.val = (a.val + b.val) % mod;         return res;     } } t[N << 2];  struct ANS {     ll ans1, ans2; } ans[N << 1];  void dfs(int u) {     siz[u] = 1;     dep[u] = dep[fa[u]] + 1;     for (int v : e[u]) {         if (v == fa[u]) continue ;         dfs(v);         siz[u] += siz[v];         if (siz[v] > siz[son[u]]) {             son[u] = v;         }     } }  void gettop(int u, int Top) {     static int t = 0;     tp[u] = Top;     dfn[u] = ++ t;     pos[t] = u;     if (!son[u])    return ;     gettop(son[u], Top);     for (int v : e[u]) {         if (v == fa[u] || v == son[u]) continue ;         gettop(v, v);     } }  void color(int u, int l, int r, int co) {     t[u].val = (t[u].val + (r - l + 1) * co) % mod;     if (l < r) {         t[u].tag = (t[u].tag + co) % mod;     } }  void pushdown(int u, int l, int r) {     if (t[u].tag && l < r) {         color(ls, l, mid, t[u].tag);         color(rs, mid + 1, r, t[u].tag);     }     t[u].tag = 0; }  void modify(int u, int l, int r, int lr, int rr) {     if (lr <= l && r <= rr) {         color(u, l, r, 1);         return ;     }     pushdown(u, l, r);     if (lr <= mid) {         modify(ls, l, mid, lr, rr);     }     if (rr > mid) {         modify(rs, mid + 1, r, lr, rr);     }     t[u] = t[ls] + t[rs]; }  void Modify(int x, int y) {     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         modify(1, 1, n, dfn[tp[x]], dfn[x]);         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     modify(1, 1, n, dfn[x], dfn[y]); }  ll query(int u, int l, int r, int lr, int rr) {     if (lr <= l && r <= rr) {         return t[u].val;     }     pushdown(u, l, r);     ll ans = 0;     if (lr <= mid) {         ans += query(ls, l, mid, lr, rr);     }     if (rr > mid) {         ans += query(rs, mid + 1, r, lr, rr);     }     return ans % mod; }  ll Query(int x, int y) {     ll ans = 0;     while (tp[x] != tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         ans += query(1, 1, n, dfn[tp[x]], dfn[x]);         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     ans += query(1, 1, n, dfn[x], dfn[y]);     return ans % mod; }  int main() {     n = read<int>(), m = read<int>();     for (int i = 2; i <= n; ++ i) {         fa[i] = read<int>() + 1;         e[fa[i]].emplace_back(i);         e[i].emplace_back(fa[i]);     }     for (int i = 1; i <= m; ++ i) {         int l = read<int>(), r = read<int>() + 1, z = read<int>() + 1;         xunwen[i] = ask{i, l, z, 0};         xunwen[i + m] = ask{i, r, z, 1};     }     dep[1] = 1;     dfs(1);     gettop(1, 1);     m <<= 1;     sort(xunwen + 1, xunwen + m + 1, [](ask& a, ask& b) {         return a.R < b.R;     });     int now = 0;     for (int i = 1; i <= m; ++ i) {         while (now < xunwen[i].R) {             Modify(1, ++ now);         }         int j = xunwen[i].be;         if (xunwen[i].fg) {             ans[j].ans1 = Query(1, xunwen[i].z);         } else {             ans[j].ans2 = Query(1, xunwen[i].z);         }     }     m >>= 1;     for (int i = 1; i <= m; ++ i) {         printf("%lldn", (ans[i].ans1 - ans[i].ans2 + mod) % mod);     }     return 0; } 

P4216 [SCOI2015] 情报传递 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

这个题就是树链剖分与树状数组的搭配,由于风险值会随着时间变化,风险值有一个限度,我们可以利用当前时间减去风险值来得到一个时间节点,在这个时间节点之前开始搜集情报的人就会产生威胁。

由于只是对一条链产生威胁,所以可以使用差分,用树状数组来维护。

//The code was written by yifan, and yifan is neutral!!!  #include <bits/stdc++.h> using namespace std; typedef long long ll; #define bug puts("NOIP rp ++!"); #define lowbit(x) (x & (-x))  template<typename T> inline T read() {     T x = 0;     bool fg = 0;     char ch = getchar();     while (ch < '0' || ch > '9') {         fg |= (ch == '-');         ch = getchar();     }     while (ch >= '0' && ch <= '9') {         x = (x << 3) + (x << 1) + (ch ^ 48);         ch = getchar();     }     return fg ? ~x + 1 : x; }  const int N = 2e5 + 5;  int n, q, rt, tot; int fa[N], dep[N], siz[N], son[N], tp[N]; int opt[N], X[N], Y[N], L[N], R[N]; ll t[N], ans[N]; vector<int> e[N], tim[N];  void dfs(int u) {     L[u] = ++ tot;     dep[u] = dep[fa[u]] + 1;     siz[u] = 1;     for (int &v : e[u]) {         dfs(v);         siz[u] += siz[v];         if (siz[v] > siz[son[u]]) {             son[u] = v;         }     }     R[u] = tot; }  void gettop(int u, int Top) {     tp[u] = Top;     if (!son[u]) return ;     gettop(son[u], Top);     for (int v : e[u]) {         if (v == son[u])    continue ;         gettop(v, v);     } }  int Lca(int x, int y) {     while (tp[x] ^ tp[y]) {         if (dep[tp[x]] < dep[tp[y]]) {             swap(x, y);         }         x = fa[tp[x]];     }     if (dep[x] > dep[y]) {         swap(x, y);     }     return x; }  void modify(int x, int v) {     while (x <= n) {         t[x] += v;         x += lowbit(x);     } }  ll query(int x) {     ll ans = 0;     while (x) {         ans += t[x];         x -= lowbit(x);     }     return ans; }  ll Query(int x, int y) {     int lca = Lca(x, y);     return query(L[x]) + query(L[y]) - query(L[lca]) - query(L[fa[lca]]); }  int dis(int x, int y) {     int lca = Lca(x, y);     return dep[x] + dep[y] - 2 * dep[lca] + 1; }  int main() {     n = read<int>();     for (int i = 1; i <= n; ++ i) {         fa[i] = read<int>();         e[fa[i]].emplace_back(i);     }     for (rt = 1; fa[rt]; rt = fa[rt]);     dfs(rt);     gettop(rt, rt);     q = read<int>();     for (int i = 1; i <= q; ++ i) {         opt[i] = read<int>();         if (opt[i] == 1) {             X[i] = read<int>(), Y[i] = read<int>();             int c = read<int>();             if (c < i) {                 tim[i - c - 1].emplace_back(i);             }         } else {             X[i] = read<int>();         }     }     for (int i = 1; i <= q; ++ i) {         if (opt[i] == 2) {             modify(L[X[i]], 1);             modify(R[X[i]] + 1, -1);         }          for (int &j : tim[i]) {             ans[j] = Query(X[j], Y[j]);         }         if (opt[i] == 1) {             printf("%d %lldn", dis(X[i], Y[i]), ans[i]);         }     }     return 0; } 

发表评论

相关文章