「笔记」可持久化权值线段树三题
「笔记」可持久化权值线段树三题
有人能给我续一秒吗|主播晚上在吃合味道晕过去了
可持久化权值线段树是什么呢,小编也很想知道。于是小编找到了三道神秘的题目来假装自己学会了这个东西。
首先我们需要知道什么是权值线段树,顾名思义就是维护一个权值区间信息的线段树。然而从某种意义上来说我们更多的是需要取维护一个区间权值信息的,于是聪明的先人们就发明了可持久化权值线段树。具体的原理呢就是维护若干个版本的线段树,然后查询某个区间的时候用两个版本的线段树减一下就好了。显然直接维护若干个版本的线段树空间会爆炸,于是可以通过一些方式来优化一下空间。
于是我们就得到了下面的模板:
struct HJT {
int cntNodes, root[N];
struct Node {
int l, r;
int cnt;
} tr[32 * N];
void modify(int &u, int v, int l, int r, int x) {
u = ++cntNodes;
tr[u] = tr[v];
tr[u].cnt++;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) {
modify(tr[u].l, tr[v].l, l, mid, x);
} else {
modify(tr[u].r, tr[v].r, mid + 1, r, x);
}
}
int query(int u, int v, int l, int r, int ql, int qr) {
if (ql <= l and r <= qr)
return tr[v].cnt - tr[u].cnt;
int mid = (l + r) / 2;
int res = 0;
if (ql <= mid)
res += query(tr[u].l, tr[v].l, l, mid, ql, qr);
if (qr > mid)
res += query(tr[u].r, tr[v].r, mid + 1, r, ql, qr);
return res;
}
int kth(int u, int v, int l, int r, int k) {
if (l == r) return l;
int res = tr[tr[v].l].cnt - tr[tr[u].l].cnt;
int mid = (l + r) / 2;
if (k <= res)
return kth(tr[u].l, tr[v].l, l, mid, k);
return kth(tr[u].r, tr[v].r, mid + 1, r, k - res);
}
} hjt;因为这个不太像我的风格所以就没有塞进 MT Folder 里面了。
虽然好像 01 Trie 那篇也不是很像我的风格啊。
显然我们上面的模板代码就是可以用来直接解决这道例题的。
考虑区间第 小值可以通过在频域上二分得到(我在说什么鬼话),于是我们可以建立一个可持久化权值线段树维护区间 内的每一个值的出现次数(当然这里要离散化一下)。于是我们就做完了这道题。
代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 2e5 + 10;
struct HJT {
int cntNodes, root[N];
struct Node {
int l, r;
int cnt;
} tr[4 * N + 17 * N];
void modify(int &u, int v, int l, int r, int x) {
u = ++cntNodes;
tr[u] = tr[v];
tr[u].cnt++;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) {
modify(tr[u].l, tr[v].l, l, mid, x);
} else {
modify(tr[u].r, tr[v].r, mid + 1, r, x);
}
}
int kth(int u, int v, int l, int r, int k) {
if (l == r) return l;
int res = tr[tr[v].l].cnt - tr[tr[u].l].cnt;
int mid = (l + r) / 2;
if (k <= res)
return kth(tr[u].l, tr[v].l, l, mid, k);
return kth(tr[u].r, tr[v].r, mid + 1, r, k - res);
}
};
int n, m;
int a[N];
vector<int> b;
HJT hjt;
int getIndex(int x) {
return lower_bound(b.begin(), b.end(), x) - b.begin() + 1;
}
signed main() {
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
b.emplace_back(a[i]);
}
sort(b.begin(), b.end());
int len = b.size();
for (int i = 1; i <= n; i++) {
int idx = getIndex(a[i]);
hjt.modify(hjt.root[i], hjt.root[i - 1], 1, len, idx);
}
while (m--) {
int l, r, k;
cin >> l >> r >> k;
cout << b[hjt.kth(hjt.root[l - 1], hjt.root[r], 1, len, k) - 1] << "\n";
}
}
auto FAST_IO = cin.tie(nullptr)->sync_with_stdio(false);这道题一个众所周知的解法就是离线后排序做,比如这篇文章(显然我是随便找的一篇)。然而使用可持久化权值线段树我们完全可以在线做这道题:
我们考虑用一个数组 pre[i] 来维护 a[i] 上次出现的下标,然后用我们的神奇可持久化权值线段树来查询区间 内所有小于 的下标的次数和。不难发现这刚好是种类数。然后我们就做完了。
具体的理解需要吃一些灵感菇通灵得到。
代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e6 + 10;
struct HJT {
int cntNodes, root[N];
struct Node {
int l, r;
int cnt;
} tr[22 * N];
void modify(int &u, int v, int l, int r, int x) {
u = ++cntNodes;
tr[u] = tr[v];
tr[u].cnt++;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) {
modify(tr[u].l, tr[v].l, l, mid, x);
} else {
modify(tr[u].r, tr[v].r, mid + 1, r, x);
}
}
int query(int u, int v, int l, int r, int ql, int qr) {
if (ql <= l and r <= qr)
return tr[v].cnt - tr[u].cnt;
int mid = (l + r) / 2;
int res = 0;
if (ql <= mid)
res += query(tr[u].l, tr[v].l, l, mid, ql, qr);
if (qr > mid)
res += query(tr[u].r, tr[v].r, mid + 1, r, ql, qr);
return res;
}
int kth(int u, int v, int l, int r, int k) {
if (l == r) return l;
int res = tr[tr[v].l].cnt - tr[tr[u].l].cnt;
int mid = (l + r) / 2;
if (k <= res)
return kth(tr[u].l, tr[v].l, l, mid, k);
return kth(tr[u].r, tr[v].r, mid + 1, r, k - res);
}
} hjt;
int n, m;
int a[N];
int pre[N];
unordered_map<int, int> lst;
signed main() {
cin >> n;
for (int i = 1; i <= n; i++) {
cin >> a[i];
if (lst.contains(a[i]))
pre[i] = lst[a[i]];
else
pre[i] = 0;
lst[a[i]] = i;
}
for (int i = 1; i <= n; i++) {
hjt.modify(hjt.root[i], hjt.root[i - 1], 0, n, pre[i]);
}
cin >> m;
while (m--) {
int l, r;
cin >> l >> r;
cout << hjt.query(hjt.root[l - 1], hjt.root[r], 0, n, 0, l - 1) << "\n";
}
}
auto FAST_IO = cin.tie(nullptr)->sync_with_stdio(false);看题目就很神秘。实际上也很神秘。因为这个转换就很神秘。
我们考虑一段区间 ,同时能表示出来的数为 。对于区间中每一个数,当且仅当 存在于区间内答案才会被更新。
因为:
如果 :
我们原来能拼出 。现在有了 ,我们还能拼出 (通过 加上 里的所有数)。
因为 ,所以新区间 与原区间 是相邻或重叠的。
两个区间合并,我们现在能连续拼出的新范围变成了 。如果 :
我们原来能拼出 。我们能拼出的下一个最小的数是 。
这意味着 这个数,我们永远也拼不出来。
因此,最小不能表示的数就是 。
于是我们可以不断的更新这个 的值,直到不能再更新为止。显然我们在这里需要用可持久化权值线段树来维护区间求和。这里直接把次数换成了区间和,算是一种偷懒的写法。这里 query 出来的是区间 中, 之内的总和。反正显然这个总和就是我们能表示的一个数字嘛。只有当它满足上面的条件的时候才会被更新。
换句话说:
这个算法的核心思想是迭代地扩展“当前能连续拼出的最大值”,我们称这个值为 sum,初始 sum = 0。
while 循环在做的事情是:
- 用
hjt.query(...)查出区间 中,所有值 的元素总和,记为res。 - 如果
res > sum:这说明我们找到了新的可用元素(它们的值在 之间)。根据贪心,我们把这些新元素全部“吸收”,能连续拼出的最大值被扩展到了res。于是,我们更新sum = res并继续循环。 - 如果
res == sum:这说明在区间 中,所有 的元素总和 就是sum。这意味着,不存在能帮我们“接上”sum + 1的新元素了,所有未被计入sum的元素都 。因此,sum + 1就是最小拼不出的数,循环结束。
代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 1e5 + 10;
constexpr int INF = 1e9;
struct HJT {
int cntNodes, root[N];
struct Node {
int l, r;
int cnt;
} tr[32 * N];
void modify(int &u, int v, int l, int r, int x) {
u = ++cntNodes;
tr[u] = tr[v];
tr[u].cnt += x;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) {
modify(tr[u].l, tr[v].l, l, mid, x);
} else {
modify(tr[u].r, tr[v].r, mid + 1, r, x);
}
}
int query(int u, int v, int l, int r, int ql, int qr) {
if (ql <= l and r <= qr)
return tr[v].cnt - tr[u].cnt;
int mid = (l + r) / 2;
int res = 0;
if (ql <= mid)
res += query(tr[u].l, tr[v].l, l, mid, ql, qr);
if (qr > mid)
res += query(tr[u].r, tr[v].r, mid + 1, r, ql, qr);
return res;
}
int kth(int u, int v, int l, int r, int k) {
if (l == r) return l;
int res = tr[tr[v].l].cnt - tr[tr[u].l].cnt;
int mid = (l + r) / 2;
if (k <= res)
return kth(tr[u].l, tr[v].l, l, mid, k);
return kth(tr[u].r, tr[v].r, mid + 1, r, k - res);
}
} hjt;
int n, m;
int a[N];
signed main() {
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= n; i++) {
hjt.modify(hjt.root[i], hjt.root[i - 1], 1, INF, a[i]);
}
cin >> m;
while (m--) {
int l, r;
cin >> l >> r;
int sum = 0;
while (true) {
int res = hjt.query(hjt.root[l - 1], hjt.root[r], 1, INF, 1, sum + 1);
if (res == sum) {
break;
} else {
sum = res;
}
}
cout << sum + 1 << "\n";
}
}
auto FAST_IO = cin.tie(nullptr)->sync_with_stdio(false);练习一
洛谷:P7416 [USACO21FEB] No Time to Dry P
假设存在一个全是 的区间 ,若要把它覆盖成 ,每次覆盖都能选中连续的一段子数组,把颜色设为 ,但是要求「颜色小的不能覆盖颜色大的」,对于每一个查询 求最小覆盖次数。
首先我们肯定要思考一个最优的覆盖策略,否则一切都无从谈起。
我们要想办法快速求出「对于 来说,上一个能把它染色的节点 」,记上一个值为 的节点为 , 显然是满足不存在 使得 的某个 节点。然后为了判断这段区间内有没有比 更小的数,这里使用单调栈维护。最后我们不难得出答案就是满足 的节点数量。
于是在这里我们把问题转化为了求区间内小于某个数的节点数量,看过上一期的同学们想必非常熟悉这个东西。我们在 HH 的项链 那里也有查询区间 内所有小于 的下标的次数和的操作。
显然我们发现,可持久化权值线段树可以完美地做到查询区间内小于某个数值的元素的次数和,于是我们直接把那题的代码抄一下改一下就做完了。
代码
#include <bits/stdc++.h>
using namespace std;
constexpr int N = 2e5 + 10;
struct HJT {
int cntNodes, root[N];
struct Node {
int l, r;
int cnt;
} tr[32 * N];
void modify(int &u, int v, int l, int r, int x) {
u = ++cntNodes;
tr[u] = tr[v];
tr[u].cnt++;
if (l == r) return;
int mid = (l + r) / 2;
if (x <= mid) {
modify(tr[u].l, tr[v].l, l, mid, x);
} else {
modify(tr[u].r, tr[v].r, mid + 1, r, x);
}
}
int query(int u, int v, int l, int r, int ql, int qr) {
if (ql <= l and r <= qr)
return tr[v].cnt - tr[u].cnt;
int mid = (l + r) / 2;
int res = 0;
if (ql <= mid)
res += query(tr[u].l, tr[v].l, l, mid, ql, qr);
if (qr > mid)
res += query(tr[u].r, tr[v].r, mid + 1, r, ql, qr);
return res;
}
int kth(int u, int v, int l, int r, int k) {
if (l == r) return l;
int res = tr[tr[v].l].cnt - tr[tr[u].l].cnt;
int mid = (l + r) / 2;
if (k <= res)
return kth(tr[u].l, tr[v].l, l, mid, k);
return kth(tr[u].r, tr[v].r, mid + 1, r, k - res);
}
} hjt;
int a[N];
int p[N];
int lst[N];
signed main() {
int n, q;
cin >> n >> q;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
stack<int> stk;
for (int i = 1; i <= n; i++) {
while (not stk.empty() and a[stk.top()] >= a[i])
stk.pop();
int l = lst[a[i]];
if (l > 0 and (stk.empty() or stk.top() < l))
p[i] = l;
stk.push(i);
lst[a[i]] = i;
}
for (int i = 1; i <= n; i++) {
hjt.modify(hjt.root[i], hjt.root[i - 1], 0, n, p[i]);
}
while (q--) {
int l, r;
cin >> l >> r;
cout << hjt.query(hjt.root[l - 1], hjt.root[r], 0, n, 0, l - 1) << "\n";
}
}
auto FAST_IO = cin.tie(nullptr)->sync_with_stdio(false);以上。