跳至主要內容

线段树

大约 9 分钟

线段树

题目难度
218. 天际线问题open in new window困难区间修改(upd),区间最值(max)
307. 区域和检索 - 数组可修改open in new window中等区间修改(upd),区间求和
699. 掉落的方块open in new window困难区间修改(upd),区间最值(max)(同 218)
731. 我的日程安排表 IIopen in new window中等区间加,区间最值
LCP 05. 发 LeetCoinopen in new window困难区间加,区间求和
715. Range 模块open in new window困难区间修改(upd),区间求和
2276. 统计区间中的整数数目open in new window困难区间修改,区间求和(同 715)

最大子数组和

题目难度
53. 最大子数组和open in new window简单最大子数组和 MaxSubArraySegTree
2382. 删除操作后的最大子段和open in new window困难最大子数组和 MaxSubArraySegTree
CF1692Hopen in new windowrating 1700最大子数组和 MaxSubArraySegTree

maintain

题目难度
2213. 由单个字符重复的最长子字符串open in new window困难maintain T4
2916. 子数组不同元素数目的平方和 IIopen in new window困难formula T4

矩阵线段树

题目难度
魔术师【算法赛】open in new windowLQ231223T7

区间赋值,单点查询

题目难度
CF292Eopen in new windowrating 1900
CF522Dopen in new windowrating 2000

线段树基础

注意点:

  • oi-wiki 版本中,s, t1,nl, r 是查询参数。而个人更喜欢 lc2916 @TsReaper 的写法:if (ql <= l && r <= qr) { ... }int mid = l + (r - l) / 2;
  • p << 1 等价于 p * 2,左子树;
  • p << 1 | 1 等价于 p * 2 + 1,右子树;
  • 时间复杂度:O(nlogn)
  • 空间复杂度:O(4n)

模板

  • SegTreeUpd_max.java
  • SegTreeUpd_sum.java
  • DynamicSegTreeUpd_max.java
  • DynamicSegTreeUpd_sum.java

线段树维护最大子数组和

详情
  1. 删除操作后的最大子段和
import java.util.Arrays;

public class Solution2382 {
    private static final long INF = (long) (1e5 * 1e9 + 10);

    // 137ms
    public long[] maximumSegmentSum(int[] nums, int[] removeQueries) {
        int n = nums.length;
        MaxSubArraySegTree seg = new MaxSubArraySegTree(nums);

        int q = removeQueries.length;
        long[] ans = new long[q];
        for (int i = 0; i < q - 1; i++) {
            seg.modify(removeQueries[i] + 1, -INF);
            ans[i] = seg.query(1, n);
        }
        return ans;
    }

    static class MaxSubArraySegTree {
        Node[] tree;
        static final int INF = (int) 1e9;

        static class Node {
            // 分别表示 [l,r] 区间:前缀最大子段和,后缀最大子段和,最大子段和,区间和
            long maxL, maxR, maxSum, sum;

            public Node(long maxL, long maxR, long maxSum, long sum) {
                this.maxL = maxL;
                this.maxR = maxR;
                this.maxSum = maxSum;
                this.sum = sum;
            }
        }

        int[] nums;
        int n;

        public MaxSubArraySegTree(int[] nums) {
            this.nums = nums;
            this.n = nums.length;
            tree = new Node[4 * n];
            Arrays.setAll(tree, e -> new Node(0, 0, 0, 0));

            build(1, 1, n);
        }

        void build(int p, int l, int r) {
            if (l == r) {
                int val = nums[l - 1];
                tree[p].maxL = tree[p].maxR = tree[p].maxSum = tree[p].sum = val;
                return;
            }
            int mid = l + (r - l) / 2;
            build(p << 1, l, mid);
            build(p << 1 | 1, mid + 1, r);
            tree[p] = pushUp(tree[p << 1], tree[p << 1 | 1]);
        }

        // nums[pos] 修改为 val
        void modify(int pos, long val) {
            modify(1, 1, n, pos, val);
        }

        void modify(int p, int l, int r, int pos, long val) {
            if (l > pos || r < pos) {
                return;
            }
            if (l == pos && r == pos) {
                tree[p].maxL = tree[p].maxR = tree[p].maxSum = tree[p].sum = val;
                return;
            }
            int mid = l + (r - l) / 2;
            modify(p << 1, l, mid, pos, val);
            modify(p << 1 | 1, mid + 1, r, pos, val);
            tree[p] = pushUp(tree[p * 2], tree[p * 2 + 1]);
        }

        // 查询 [l,r] 区间最大子段和
        long query(int ql, int qr) {
            return query(1, 1, n, ql, qr).maxSum;
        }

        Node query(int p, int l, int r, int ql, int qr) {
            if (l > qr || r < ql) {
                return new Node(-INF, -INF, -INF, 0);
            }
            if (ql <= l && r <= qr) {
                return tree[p];
            }
            int mid = l + (r - l) / 2;
            Node ls = query(p << 1, l, mid, ql, qr);
            Node rs = query(p << 1 | 1, mid + 1, r, ql, qr);
            return pushUp(ls, rs);
        }

        Node pushUp(Node ls, Node rs) {
            long maxL = Math.max(ls.maxL, ls.sum + rs.maxL);
            long maxR = Math.max(rs.maxR, rs.sum + ls.maxR);
            // max(l.maxSum, r.maxSum, l.maxR + r.maxL)
            long maxSum = Math.max(Math.max(ls.maxSum, rs.maxSum), ls.maxR + rs.maxL);
            long sum = ls.sum + rs.sum;
            return new Node(maxL, maxR, maxSum, sum);
        }
    }
}

线段树维护最长重复子段

详情
  1. 由单个字符重复的最长子字符串
import java.util.Arrays;

public class Solution2213 {
    // 120ms
    public int[] longestRepeating(String s, String queryCharacters, int[] queryIndices) {
        char[] cs = s.toCharArray();
        int n = s.length();
        int k = queryCharacters.length();

        int[] res = new int[k];
        SegmentTree seg = new SegmentTree(cs);
        for (int i = 0; i < k; i++) {
            int idx = queryIndices[i];
            seg.cs[idx] = queryCharacters.charAt(i);
            seg.update(1, 1, n, idx + 1);
            res[i] = seg.tree[1].max;
        }
        return res;
    }

    private static class SegmentTree {
        Node[] tree;

        static class Node {
            int pre, suf, max;

            public Node(int pre, int suf, int max) {
                this.pre = pre;
                this.suf = suf;
                this.max = max;
            }
        }

        char[] cs;

        public SegmentTree(char[] cs) {
            int n = cs.length;
            this.cs = cs;
            tree = new Node[4 * n];
            Arrays.setAll(tree, e -> new Node(0, 0, 0));

            build(1, 1, n);
        }

        void build(int p, int l, int r) {
            if (l == r) {
                tree[p].pre = 1;
                tree[p].suf = 1;
                tree[p].max = 1;
                return;
            }
            int mid = l + (r - l) / 2;
            build(p << 1, l, mid);
            build(p << 1 | 1, mid + 1, r);
            maintain(l, r, p);
        }

        private void maintain(int l, int r, int p) {
            Node ls = tree[p << 1];
            Node rs = tree[p << 1 | 1];

            tree[p].pre = ls.pre;
            tree[p].suf = rs.suf;
            tree[p].max = Math.max(ls.max, rs.max);
            int mid = l + (r - l) / 2;
            // 中间字符相同,可以合并
            if (cs[mid - 1] == cs[mid]) {
                if (ls.suf == mid - l + 1) {
                    tree[p].pre += rs.pre;
                }
                if (rs.suf == r - mid) {
                    tree[p].suf += ls.suf;
                }
                tree[p].max = Math.max(tree[p].max, ls.suf + rs.pre);
            }
        }

        private void update(int p, int l, int r, int i) {
            if (l == r) {
                return;
            }
            int mid = l + (r - l) / 2;
            if (i <= mid) {
                update(p << 1, l, mid, i);
            } else {
                update(p << 1 | 1, mid + 1, r, i);
            }
            maintain(l, r, p);
        }
    }
}

最长递增子序列

详情
  1. 最长递增子序列 II
public class Solution2407 {
    private static final int N = (int) 1e5;

    public int lengthOfLIS(int[] nums, int k) {
        DynamicSegTree seg = new DynamicSegTree();
        for (int x : nums) {
            int max = seg.getMax(x - k, x - 1);
            seg.update(x, x, max + 1);
        }
        return seg.getMax(1, N);
    }

    private static class DynamicSegTree {
        static class Node {
            Node ls, rs;
            int max, lazy;
        }

        final Node root = new Node();

        // 区间更新 [l,r] 置为 val
        void update(int l, int r, int val) {
            this.update(root, 0, N, l, r, val);
        }

        // 区间查询 [l,r] 最大值
        int getMax(int l, int r) {
            return this.getMax(root, 0, N, l, r);
        }

        void update(Node p, int l, int r, int ql, int qr, int val) {
            if (ql <= l && r <= qr) {
                p.max = val;
                p.lazy = val;
                return;
            }
            pushDown(p);
            int mid = l + (r - l) / 2;
            if (ql <= mid) update(p.ls, l, mid, ql, qr, val);
            if (qr > mid) update(p.rs, mid + 1, r, ql, qr, val);
            pushUp(p);
        }

        int getMax(Node p, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) {
                return p.max;
            }
            pushDown(p);
            int mid = l + (r - l) / 2;
            int max = 0;
            if (ql <= mid) max = Math.max(max, getMax(p.ls, l, mid, ql, qr));
            if (qr > mid) max = Math.max(max, getMax(p.rs, mid + 1, r, ql, qr));
            return max;
        }

        void pushDown(Node p) {
            if (p.ls == null) p.ls = new Node();
            if (p.rs == null) p.rs = new Node();
            if (p.lazy > 0) {
                p.ls.max = p.lazy;
                p.rs.max = p.lazy;
                p.ls.lazy = p.lazy;
                p.rs.lazy = p.lazy;
                p.lazy = 0;
            }
        }

        void pushUp(Node p) {
            p.max = Math.max(p.ls.max, p.rs.max);
        }
    }
}

线段树维护矩阵乘

详情

魔术师【算法赛】

import java.util.Arrays;
import java.util.Scanner;
import java.util.stream.Collectors;

public class LQ231223T7 {
    static int n, m;

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        n = scanner.nextInt();
        m = scanner.nextInt();

        long[][] nums = new long[n + 1][3];
        for (int i = 1; i <= n; i++) {
            int c = scanner.nextInt();
            nums[i][c - 1] = 1;
        }
        SegmentTreeMat seg = new SegmentTreeMat(n, nums);
        String[] output = new String[m];
        for (int i = 0; i < m; i++) {
            int l = scanner.nextInt();
            int r = scanner.nextInt();
            int opt = scanner.nextInt();

            int a = scanner.nextInt();
            if (opt == 1) {
                int b = scanner.nextInt();
                if (a > b) {
                    int tmp = a;
                    a = b;
                    b = tmp;
                }
                if (a == 1 && b == 2) {
                    seg.update(l, r, swap1_2);
                }
                if (a == 1 && b == 3) {
                    seg.update(l, r, swap1_3);
                }
                if (a == 2 && b == 3) {
                    seg.update(l, r, swap2_3);
                }
            } else if (opt == 2) {
                int b = scanner.nextInt();
                if (a == 1 && b == 2) seg.update(l, r, change1_2);
                if (a == 1 && b == 3) seg.update(l, r, change1_3);
                if (a == 2 && b == 3) seg.update(l, r, change2_3);
                if (a == 2 && b == 1) seg.update(l, r, change2_1);
                if (a == 3 && b == 1) seg.update(l, r, change3_1);
                if (a == 3 && b == 2) seg.update(l, r, change3_2);
            } else {
                if (a == 1) seg.update(l, r, split1);
                if (a == 2) seg.update(l, r, split2);
                if (a == 3) seg.update(l, r, split3);
            }
            long[] sum = seg.getSum(1, n);
            output[i] = Arrays.stream(sum).mapToObj(String::valueOf).collect(Collectors.joining(" "));
        }
        System.out.println(String.join(System.lineSeparator(), output));
    }

    // 颜色互换
    static final long[][] swap1_2 = {{0, 1, 0}, {1, 0, 0}, {0, 0, 1}};
    static final long[][] swap1_3 = {{0, 0, 1}, {0, 1, 0}, {1, 0, 0}};
    static final long[][] swap2_3 = {{1, 0, 0}, {0, 0, 1}, {0, 1, 0}};
    // 染色
    static final long[][] change1_2 = {{0, 1, 0}, {0, 1, 0}, {0, 0, 1}};
    static final long[][] change1_3 = {{0, 0, 1}, {0, 1, 0}, {0, 0, 1}};
    static final long[][] change2_3 = {{1, 0, 0}, {0, 0, 1}, {0, 0, 1}};
    static final long[][] change2_1 = {{1, 0, 0}, {1, 0, 0}, {0, 0, 1}};
    static final long[][] change3_1 = {{1, 0, 0}, {0, 1, 0}, {1, 0, 0}};
    static final long[][] change3_2 = {{1, 0, 0}, {0, 1, 0}, {0, 1, 0}};
    // 分裂
    static final long[][] split1 = {{2, 0, 0}, {0, 1, 0}, {0, 0, 1}};
    static final long[][] split2 = {{1, 0, 0}, {0, 2, 0}, {0, 0, 1}};
    static final long[][] split3 = {{1, 0, 0}, {0, 1, 0}, {0, 0, 2}};

    static class SegmentTreeMat {
        static final int MOD = 998244353;
        static final long[][] ONE = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; // 单位矩阵
        int n;
        long[][] nums;

        long[][] tree;
        long[][][] lazy;

        public SegmentTreeMat(int n, long[][] nums) {
            this.n = n;
            this.tree = new long[n * 4][3];
            this.lazy = new long[n * 4][3][3];
            for (int i = 0; i < n * 4; i++) {
                lazy[i] = new long[][]{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}};
            }

            this.nums = nums;
            build(1, 1, n);
        }

        void build(int p, int l, int r) {
            if (l == r) {
                tree[p] = nums[l].clone();
                return;
            }
            int mid = l + (r - l) / 2;
            build(p << 1, l, mid);
            build(p << 1 | 1, mid + 1, r);
            pushUp(p);
        }

        void update(int ql, int qr, long[][] val) {
            update(1, 1, n, ql, qr, val);
        }

        void update(int p, int l, int r, int ql, int qr, long[][] val) {
            if (ql <= l && r <= qr) {
                tree[p] = matMul(tree[p], val);
                lazy[p] = matMul(lazy[p], val);
                return;
            }
            pushDown(p);
            int mid = l + (r - l) / 2;
            if (ql <= mid) update(p << 1, l, mid, ql, qr, val);
            if (qr > mid) update(p << 1 | 1, mid + 1, r, ql, qr, val);
            pushUp(p);
        }

        long[] getSum(int ql, int qr) {
            return getSum(1, 1, n, ql, qr);
        }

        long[] getSum(int p, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) {
                return tree[p];
            }
            pushDown(p);
            int mid = l + (r - l) / 2;
            long[] res = new long[3];
            if (ql <= mid) res = mulAdd(res, getSum(p << 1, l, mid, ql, qr));
            if (qr > mid) res = mulAdd(res, getSum(p << 1 | 1, mid + 1, r, ql, qr));
            return res;
        }

        void pushDown(int p) {
            if (!Arrays.deepEquals(ONE, lazy[p])) {
                tree[p << 1] = matMul(tree[p << 1], lazy[p]);
                tree[p << 1 | 1] = matMul(tree[p << 1 | 1], lazy[p]);
                lazy[p << 1] = matMul(lazy[p << 1], lazy[p]);
                lazy[p << 1 | 1] = matMul(lazy[p << 1 | 1], lazy[p]);
                lazy[p] = new long[][]{{1, 0, 0}, {0, 1, 0}, {0, 0, 1}};
            }
        }

        void pushUp(int p) {
            tree[p] = mulAdd(tree[p << 1], tree[p << 1 | 1]);
        }

        // 矩阵加法 res[] = a[] + b[]
        long[] mulAdd(long[] a, long[] b) {
            int n = a.length;
            long[] res = new long[n];
            for (int i = 0; i < n; i++) {
                res[i] = (a[i] + b[i]) % MOD;
            }
            return res;
        }

        // 矩阵乘法 res[] = a[] * b[][]
        long[] matMul(long[] a, long[][] b) {
            int n = a.length;
            long[] res = new long[n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    res[i] += a[j] * b[j][i] % MOD;
                    res[i] %= MOD;
                }
            }
            return res;
        }

        // 矩阵乘法 res[][] = a[][] * b[][]
        long[][] matMul(long[][] a, long[][] b) {
            int n = a.length;
            long[][] res = new long[n][n];
            for (int i = 0; i < n; i++) {
                for (int j = 0; j < n; j++) {
                    for (int k = 0; k < n; k++) {
                        res[i][j] += a[i][k] * b[k][j] % MOD;
                        res[i][j] %= MOD;
                    }
                }
            }
            return res;
        }
    }
}

线段树维护区间 top k

详情

CF1665E:维护区间内最小的 31 个数

import java.util.Arrays;
import java.util.Scanner;
import java.util.stream.Collectors;

public class CF1665E {
    static int n, q;
    static int[] a;
    static int[][] lr;

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);
        int t = scanner.nextInt();
        while (t-- > 0) {
            n = scanner.nextInt();
            a = new int[n];
            for (int i = 0; i < n; i++) {
                a[i] = scanner.nextInt();
            }
            q = scanner.nextInt();
            lr = new int[q][2];
            for (int i = 0; i < q; i++) {
                lr[i][0] = scanner.nextInt();
                lr[i][1] = scanner.nextInt();
            }
            System.out.println(solve());
        }
    }

    private static String solve() {
        SegmentTree seg = new SegmentTree(n);
        seg.build(a, 1, 1, n);

        int[] ans = new int[q];
        for (int qi = 0; qi < q; qi++) {
            int l = lr[qi][0], r = lr[qi][1];

            int[] b = seg.query(1, 1, n, l, r);
            int res = 1 << 30;
            for (int i = 0; i < b.length; i++) {
                int v = b[i];
                for (int j = 0; j < i; j++) {
                    int w = b[j];
                    res = Math.min(res, v | w);
                }
            }
            ans[qi] = res;
        }
        return Arrays.stream(ans).mapToObj(String::valueOf).collect(Collectors.joining(System.lineSeparator()));
    }

    private static class SegmentTree {
        int n;
        Node[] t;

        static class Node {
            int[] arr = new int[0];
        }

        public SegmentTree(int n) {
            this.n = n;
            this.t = new Node[4 * n];
            Arrays.setAll(t, e -> new Node());
        }

        // 合并两个有序数组,保留前 k 个数
        int[] merge(int[] a, int[] b) {
            int i = 0, n = a.length;
            int j = 0, m = b.length;
            int k = Math.min(31, n + m);
            int[] res = new int[k];
            int id = 0;
            while (id < k) {
                if (i == n) {
                    while (id < k && j < m) res[id++] = b[j++];
                    break;
                }
                if (j == m) {
                    while (id < k && i < n) res[id++] = a[i++];
                    break;
                }
                if (a[i] <= b[j]) res[id++] = a[i++];
                else res[id++] = b[j++];
            }
            return res;
        }

        void build(int[] a, int p, int l, int r) {
            if (l == r) {
                t[p].arr = Arrays.copyOfRange(a, l - 1, l);
                return;
            }
            int mid = l + (r - l) / 2;
            build(a, p << 1, l, mid);
            build(a, p << 1 | 1, mid + 1, r);
            t[p].arr = merge(t[p << 1].arr, t[p << 1 | 1].arr);
        }

        int[] query(int p, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) {
                return t[p].arr;
            }
            int mid = l + (r - l) / 2;
            // ?
            if (qr <= mid) return query(p << 1, l, mid, ql, qr);
            if (ql > mid) return query(p << 1 | 1, mid + 1, r, ql, qr);
            return merge(query(p << 1, l, mid, ql, qr), query(p << 1 | 1, mid + 1, r, ql, qr));
        }
    }
}

(全文完)