
线段树作为处理区间问题的万能神器,能轻松搞定区间修改、区间查询等基础操作,但当遇到区间方差、区间 GCD这类和数学紧密结合的问题时,光靠模板就远远不够了。此时的核心解题思路不再是单纯的代码实现,而是先通过数学公式推导,把复杂问题转化为线段树可维护的基础信息,再用线段树完成后续的操作。 这就是 “线段树 + 数学” 的核心魅力 —— 数学推导为线段树指明维护方向,线段树为数学计算提供高效支撑。本文将从两道经典的硬核例题入手,拆解 “公式推导→确定维护信息→线段树实现” 的完整解题流程,让你彻底掌握这种强强联合的解题思路!下面就让我们正式开始吧!
先看两个看似棘手的区间问题:
如果直接硬刚,这两个问题都无从下手:
但通过数学公式的推导和变形,我们可以把这些复杂的待维护量,转化为线段树能轻松维护的基础信息:
简单来说,数学是解决这类问题的钥匙,线段树是执行的工具。没有数学推导,线段树就没有维护的目标;没有线段树,数学推导后的结果无法高效处理多次操作。
方差是概率论中的基础概念,直接计算需要先求平均数,再计算每个数与平均数的差的平方和,最后取平均。但在编程题中,直接计算会引入浮点误差,且无法直接用线段树维护,这时候就需要对方差公式进行代数变形,转化为整数运算。
题目链接:https://www.luogu.com.cn/problem/P5142
给定长度为 n 的序列,支持两种操作:
首先回顾方差的定义:对于区间[l,r],长度为len=r−l+1,区间和为

,平均数为

,则方差d为:

直接计算这个公式有两个问题:一是涉及浮点数,二是(ai−A)2无法通过子区间的信息合并得到父区间的结果。因此我们需要对公式进行展开变形,消去平均数A,转化为区间和与区间平方和的组合。
展开推导过程:

推导结论:要计算区间方差,只需维护两个基础量:
这两个量都能通过线段树的 pushup 函数轻松合并,完美解决了方差的维护问题!
题目要求以分数取模形式输出,而模运算中除法不能直接计算,需要转化为乘法逆元:

最终方差的模运算公式为:

注意:结果可能为负数,需要加上 mod 后再取模,保证结果非负。
线段树的每个节点需要维护区间左右边界 l/r、区间和 sum、区间平方和 qsum:
typedef long long LL;
const int N = 1e5 + 10, mod = 1e9 + 7;
struct node {
int l, r;
LL sum; // 区间和
LL qsum; // 区间平方和
} tr[N << 2];
LL a[N]; // 原始序列(1)pushup:合并左右孩子的 sum 和 qsum,直接相加即可,符合数学推导的合并规则:
void pushup(node& p, node& l, node& r) {
p.sum = (l.sum + r.sum) % mod;
p.qsum = (l.qsum + r.qsum) % mod;
}(2)build:建树时,叶子节点的 sum 为a[i],qsum 为

,非叶子节点递归构建后 pushup:
void build(int p, int l, int r) {
tr[p] = {l, r, a[l] % mod, (a[l] * a[l]) % mod};
if (l == r) return;
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(tr[p], tr[p << 1], tr[p << 1 | 1]);
}(3)modify:单点修改,找到叶子节点后更新 sum 和 qsum,向上回溯 pushup:
void modify(int p, int x, LL k) {
int l = tr[p].l, r = tr[p].r;
if (l == r) {
tr[p].sum = k % mod;
tr[p].qsum = (k * k) % mod;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) modify(p << 1, x, k);
else modify(p << 1 | 1, x, k);
pushup(tr[p], tr[p << 1], tr[p << 1 | 1]);
}(4)query:区间查询,返回包含该区间 sum 和 qsum 的 node 结构体,跨区间时合并左右子树的结果:
node query(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y) return tr[p];
int mid = (l + r) >> 1;
if (y <= mid) return query(p << 1, x, y);
else if (x > mid) return query(p << 1 | 1, x, y);
else {
node L = query(p << 1, x, y);
node R = query(p << 1 | 1, x, y);
node res;
pushup(res, L, R);
return res;
}
}(5)快速幂求逆元:实现费马小定理求逆元的快速幂函数,时间复杂度O(logmod):
LL qpow(LL a, LL b, LL p) {
LL ret = 1;
a %= p;
while (b) {
if (b & 1) ret = (ret * a) % p;
a = (a * a) % p;
b >>= 1;
}
return ret;
}#include <iostream>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 1e5 + 10, mod = 1e9 + 7;
struct node {
int l, r;
LL sum;
LL qsum;
} tr[N << 2];
LL a[N];
void pushup(node& p, node& l, node& r) {
p.sum = (l.sum + r.sum) % mod;
p.qsum = (l.qsum + r.qsum) % mod;
}
void build(int p, int l, int r) {
tr[p] = {l, r, a[l] % mod, (a[l] * a[l]) % mod};
if (l == r) return;
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(tr[p], tr[p << 1], tr[p << 1 | 1]);
}
void modify(int p, int x, LL k) {
int l = tr[p].l, r = tr[p].r;
if (l == r) {
tr[p].sum = k % mod;
tr[p].qsum = (k * k) % mod;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) modify(p << 1, x, k);
else modify(p << 1 | 1, x, k);
pushup(tr[p], tr[p << 1], tr[p << 1 | 1]);
}
node query(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y) return tr[p];
int mid = (l + r) >> 1;
if (y <= mid) return query(p << 1, x, y);
else if (x > mid) return query(p << 1 | 1, x, y);
else {
node L = query(p << 1, x, y);
node R = query(p << 1 | 1, x, y);
node res;
pushup(res, L, R);
return res;
}
}
// 快速幂求逆元(费马小定理)
LL qpow(LL a, LL b, LL p) {
LL ret = 1;
a %= p;
while (b) {
if (b & 1) ret = (ret * a) % p;
a = (a * a) % p;
b >>= 1;
}
return ret;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
build(1, 1, n);
while (m--) {
int op, x, y;
cin >> op >> x >> y;
if (op == 1) {
// 单点修改:将x位置赋值为y
modify(1, x, y);
} else {
// 区间查询:查询[x,y]的方差
node t = query(1, x, y);
LL sum = t.sum, qsum = t.qsum;
LL len = y - x + 1;
LL inv = qpow(len, mod - 2, mod); // 求1/len的逆元
LL A = (sum * inv) % mod; // 平均数的模运算结果
LL part1 = (qsum * inv) % mod; // qsum/len
LL part2 = (A * A) % mod; // (sum/len)^2
LL ans = (part1 - part2 + mod) % mod; // 保证非负
cout << ans << endl;
}
}
return 0;
}GCD(最大公约数)是数论中的基础概念,普通的区间 GCD 查询可以用 ST 表实现,但如果加上区间加修改,ST 表就无能为力了。此时需要利用数论中的差分结论,将区间加操作转化为单点修改,再用线段树维护差分序列的 GCD,从而解决问题。
题目链接:https://www.luogu.com.cn/problem/P10463

给定长度为 n 的序列,支持两种操作:
首先定义差分序列b:对于原序列a,b1=a1,bi=ai−ai−1(i≥2)。原序列可以由差分序列的前缀和还原:

。
关键结论:
原序列区间[l,r]的 GCD,等于差分序列的bl与差分序列区间[l+1,r]的 GCD的最大公约数,即:

结论证明:
我们用数学归纳法和 GCD 的性质证明:

这是因为gcd(x,y)=gcd(x,y−x),该性质可以推广到多个数的 GCD;

结论得证!
原序列的区间加[x,y]+d,对应差分序列的两个单点修改:

:因为ax增加 d,ax−1不变,所以bx=ax−ax−1增加 d;

(若y+1≤n):因为ay增加 d,a(y+1)不变,所以b(y+1)=a(y+1)−ay减少 d。
转化原理:差分序列的本质是相邻元素的差,区间加操作只会改变区间起点和终点下一个位置的差,中间位置的差保持不变。
通过数论结论和差分转化,原本的 “区间加 + 区间 GCD 查询” 问题,被转化为差分序列上的两个简单操作:
因此,线段树需要维护差分序列的区间和 sum(用于求原序列的ax)和区间 GCD gcd(用于求差分序列的区间 GCD)。
线段树的每个节点维护区间左右边界 l/r、区间和 sum、区间 GCD gcd:
typedef long long LL;
const int N = 5e5 + 10;
struct node {
int l, r;
LL sum; // 差分序列的区间和
LL gcd; // 差分序列的区间GCD
} tr[N << 2];
LL b[N]; // 差分序列(1)GCD 基础函数:实现求两个数的 GCD,注意处理负数(取绝对值):
LL gcd(LL a, LL b) {
a = abs(a), b = abs(b);
return b == 0 ? a : gcd(b, a % b);
}(2)pushup:合并左右孩子的 sum 和 gcd,sum 直接相加,gcd 取左右孩子 gcd 的最大公约数:
void pushup(int p) {
tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum;
tr[p].gcd = gcd(tr[p << 1].gcd, tr[p << 1 | 1].gcd);
}(3)build:建树时,叶子节点的 sum 和 gcd 均为差分序列的b[i],非叶子节点递归构建后 pushup:
void build(int p, int l, int r) {
tr[p] = {l, r, b[l], b[l]};
if (l == r) return;
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}(4)modify:单点修改,找到叶子节点后更新 sum 和 gcd,向上回溯 pushup:
void modify(int p, int x, LL d) {
int l = tr[p].l, r = tr[p].r;
if (l == r) {
tr[p].sum += d;
tr[p].gcd += d;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) modify(p << 1, x, d);
else modify(p << 1 | 1, x, d);
pushup(p);
}(5)query_sum:查询差分序列的区间和,用于求原序列的ax=sum(b1∼bx):
LL query_sum(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y) return tr[p].sum;
int mid = (l + r) >> 1;
LL res = 0;
if (x <= mid) res += query_sum(p << 1, x, y);
if (y > mid) res += query_sum(p << 1 | 1, x, y);
return res;
}(6)query_gcd:查询差分序列的区间 GCD,注意处理查询区间为空的情况(返回 0):
LL query_gcd(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x > y) return 0; // 区间为空,返回0
if (x <= l && r <= y) return tr[p].gcd;
int mid = (l + r) >> 1;
LL res = 0;
if (x <= mid) res = gcd(res, query_gcd(p << 1, x, y));
if (y > mid) res = gcd(res, query_gcd(p << 1 | 1, x, y));
return res;
}#include <iostream>
#include <cstdlib> // 用于abs
#include <string>
using namespace std;
typedef long long LL;
const int N = 5e5 + 10;
struct node {
int l, r;
LL sum;
LL gcd;
} tr[N << 2];
LL b[N]; // 差分序列
LL a[N]; // 原始序列
// 求两个数的GCD,处理负数
LL gcd(LL a, LL b) {
a = abs(a), b = abs(b);
return b == 0 ? a : gcd(b, a % b);
}
void pushup(int p) {
tr[p].sum = tr[p << 1].sum + tr[p << 1 | 1].sum;
tr[p].gcd = gcd(tr[p << 1].gcd, tr[p << 1 | 1].gcd);
}
void build(int p, int l, int r) {
tr[p] = {l, r, b[l], b[l]};
if (l == r) return;
int mid = (l + r) >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
pushup(p);
}
void modify(int p, int x, LL d) {
int l = tr[p].l, r = tr[p].r;
if (l == r) {
tr[p].sum += d;
tr[p].gcd += d;
return;
}
int mid = (l + r) >> 1;
if (x <= mid) modify(p << 1, x, d);
else modify(p << 1 | 1, x, d);
pushup(p);
}
// 查询区间和
LL query_sum(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x <= l && r <= y) return tr[p].sum;
int mid = (l + r) >> 1;
LL res = 0;
if (x <= mid) res += query_sum(p << 1, x, y);
if (y > mid) res += query_sum(p << 1 | 1, x, y);
return res;
}
// 查询区间GCD
LL query_gcd(int p, int x, int y) {
int l = tr[p].l, r = tr[p].r;
if (x > y) return 0;
if (x <= l && r <= y) return tr[p].gcd;
int mid = (l + r) >> 1;
LL res = 0;
if (x <= mid) res = gcd(res, query_gcd(p << 1, x, y));
if (y > mid) res = gcd(res, query_gcd(p << 1 | 1, x, y));
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> a[i];
}
// 构建差分序列
b[1] = a[1];
for (int i = 2; i <= n; i++) {
b[i] = a[i] - a[i-1];
}
build(1, 1, n);
while (m--) {
string op;
int l, r;
LL d;
cin >> op >> l >> r;
if (op == "C") {
// 区间加:[l,r] + d
cin >> d;
modify(1, l, d);
if (r + 1 <= n) {
modify(1, r + 1, -d);
}
} else if (op == "Q") {
// 区间查询GCD:[l,r]
LL al = query_sum(1, 1, l); // 原序列a[l] = 差分序列[1,l]的和
LL g = query_gcd(1, l + 1, r); // 差分序列[l+1,r]的GCD
LL ans = gcd(al, g);
cout << ans << endl;
}
}
return 0;
}通过以上两道例题,我们可以总结出 “线段树 + 数学” 解决硬核区间问题的三步通用解题框架,无论遇到哪种结合数学的线段树问题,都可以按这个思路推导:
明确题目要求维护的复杂量(如方差、GCD),回忆该概念的数学定义、公式和相关性质,这是推导的基础。例如:
这是最核心的一步,通过公式变形、数论结论、数据转化(如差分)等方式,将复杂的待维护量,转化为线段树可维护的基础量,要求基础量满足可合并性(即父节点的基础量能由左右孩子的基础量通过简单运算得到)。例如:
根据推导得到的基础量,设计线段树的结构体和核心函数(pushup、build、modify、query),实现基础量的维护;最后根据数学推导的公式,将线段树查询到的基础量还原为题目要求的结果(如方差的模运算、GCD 的组合计算)。
“线段树 + 数学” 的问题,代码实现本身并不复杂,难点在于数学推导和细节处理,以下是五大高频易错点,一定要避开!
这是最致命的错误,直接导致后续的线段树维护方向错误。解决方法:手动推导公式时,一步一步写清楚,不要跳步;推导完成后,用简单的测试用例验证公式的正确性。
涉及分数取模的问题,容易忘记逆元的计算,或忽略结果为负数的情况。解决方法:
GCD、差分序列等问题中容易出现负数,而 GCD 的定义是正整数,解决方法:计算 GCD 前,对所有数取绝对值。
区间加操作转化为差分序列的单点修改时,容易忘记 **y+1≤n** 的判断,导致数组越界。解决方法:修改by+1前,必须判断y+1是否在序列范围内。
线段树的 pushup 函数是基础量合并的核心,容易出现合并规则错误(如将 GCD 的合并写成相加)。解决方法:根据数学推导的结论,明确基础量的合并规则,用简单的测试用例验证 pushup 函数的正确性。
线段树 + 数学的问题,看似硬核,实则有章可循。它考察的不是单纯的代码能力,而是数学思维和问题转化能力—— 能否将陌生的复杂问题,转化为熟悉的基础问题。 很多同学遇到这类问题会直接放弃,其实只要沉下心来做数学推导,把复杂量转化为线段树可维护的基础量,剩下的就是模板化的代码实现。希望本文能让你掌握这种解题思路,在面对线段树 + 数学的问题时,不再畏惧,从容应对! 创作不易,如果本文对你有帮助,欢迎点赞、收藏、关注三连~