跳至主要內容

并查集

大约 4 分钟

并查集

题目难度
200. 岛屿数量open in new window中等推荐
$261. 以图判树open in new window 中等
$305. 岛屿数量 IIopen in new window 困难推荐
$323. 无向图中连通分量的数目open in new window 中等
547. 省份数量open in new window中等
684. 冗余连接open in new window中等
685. 冗余连接 IIopen in new window困难推荐
721. 账户合并open in new window中等离散化
$737. 句子相似性 IIopen in new window 中等离散化
947. 移除最多的同行或同列石头open in new window中等HashMap 版本并查集
2334. 元素值大于变化阈值的子数组open in new window困难连通分量合并到较小的节点

并查集是一种树形的数据结构,顾名思义,它用于处理一些不交集的 合并 及 查询 问题。 它支持两种操作:

  • 查找(Find):确定某个元素处于哪个子集;
  • 合并(Union):将两个子集合并成一个集合。

初始化、查找、合并

public class DSU {
    // 父节点数组/祖先数组
    int[] fa;

    // 初始化
    public DSU(int n) {
        fa = new int[n];
        for (int i = 0; i < n; i++) {
            fa[i] = i;
        }
    }

    // 查找
    int find(int x) {
        // 路径压缩
        if (x != fa[x]) {
            fa[x] = find(fa[x]);
        }
        return fa[x];
    }

    // 合并
    void union(int p, int q) {
        int rootP = find(p);
        int rootQ = find(q);
        if (rootP == rootQ) {
            return;
        }
        fa[rootQ] = rootP;
    }
}

时间复杂度

优化平均时间复杂度最坏时间复杂度
无优化O(logn)O(n)
路径压缩O(α(n))O(logn)
按秩合并O(logn)O(logn)
路径压缩 + 按秩合并O(α(n))O(α(n))

这里 α 表示阿克曼函数的反函数,在宇宙可观测的 n 内(例如宇宙中包含的粒子总数),α(n) 不会超过 5。

200. 岛屿数量

连通分量 逐渐合并减少。

public class Solution200 {
    public int numIslands(char[][] grid) {
        int M = grid.length;
        int N = grid[0].length;

        // 岛屿数目
        int cnt = 0;
        for (char[] chars : grid) {
            for (char aChar : chars) {
                if (aChar == '1') {
                    cnt++;
                }
            }
        }

        DSU dsu = new DSU(M * N);
        dsu.sz = cnt;
        for (int i = 0; i < M; i++) {
            for (int j = 0; j < N; j++) {
                if (grid[i][j] == '1') {
                    int p = i * N + j;
                    // up
                    if (i - 1 >= 0 && grid[i - 1][j] == '1') {
                        dsu.union(p, p - N);
                    }
                    // down
                    if (i + 1 < M && grid[i + 1][j] == '1') {
                        dsu.union(p, p + N);
                    }
                    // left
                    if (j - 1 >= 0 && grid[i][j - 1] == '1') {
                        dsu.union(p, p - 1);
                    }
                    // right
                    if (j + 1 < N && grid[i][j + 1] == '1') {
                        dsu.union(p, p + 1);
                    }
                }
            }
        }
        return dsu.sz;
    }

    private static class DSU {
        int[] fa;
        int sz;

        public DSU(int n) {
            fa = new int[n];
            for (int i = 0; i < n; i++) {
                fa[i] = i;
            }
        }

        int find(int x) {
            if (x != fa[x]) {
                fa[x] = find(fa[x]);
            }
            return fa[x];
        }

        void union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return;
            }
            fa[rootQ] = rootP;
            sz--;
        }
    }
}

$305. 岛屿数量 II

连通分量 要从 0 开始增加。

import java.util.ArrayList;
import java.util.List;

public class Solution305 {
    public List<Integer> numIslands2(int m, int n, int[][] positions) {
        int[][] grid = new int[m][n];

        List<Integer> resList = new ArrayList<>();
        DSU dsu = new DSU(m * n);
        for (int[] position : positions) {
            int i = position[0];
            int j = position[1];

            if (grid[i][j] == 0) {
                grid[i][j] = 1;
                // 增加一个岛屿
                dsu.sz++;
                int p = position[0] * n + position[1];

                // up
                if (i - 1 >= 0 && grid[i - 1][j] == 1) {
                    dsu.union(p, p - n);
                }
                // down
                if (i + 1 < m && grid[i + 1][j] == 1) {
                    dsu.union(p, p + n);
                }
                // left
                if (j - 1 >= 0 && grid[i][j - 1] == 1) {
                    dsu.union(p, p - 1);
                }
                // right
                if (j + 1 < n && grid[i][j + 1] == 1) {
                    dsu.union(p, p + 1);
                }
            }
            resList.add(dsu.sz);
        }
        return resList;
    }

    private static class DSU {
        int[] fa;
        int sz;

        public DSU(int n) {
            fa = new int[n];
            for (int i = 0; i < n; i++) {
                fa[i] = i;
            }
        }

        int find(int x) {
            if (x != fa[x]) {
                fa[x] = find(fa[x]);
            }
            return fa[x];
        }

        void union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return;
            }
            fa[rootQ] = rootP;
            sz--;
        }
    }
}

947. 移除最多的同行或同列石头

import java.util.HashMap;
import java.util.Map;

public class Solution947 {
    public int removeStones(int[][] stones) {
        DSU dsu = new DSU();
        for (int[] stone : stones) {
            // 并查集里如何区分横纵坐标 下面这三种写法任选其一
//            dsu.union(~stone[0], stone[1]);
//            dsu.union(stone[0] - 10001, stone[1]);
            dsu.union(stone[0] + 10001, stone[1]);
        }
        return stones.length - dsu.sz;
    }

    private static class DSU {
        Map<Integer, Integer> faMap;
        int sz;

        public DSU() {
            faMap = new HashMap<>();
        }

        int find(int x) {
            if (!faMap.containsKey(x)) {
                faMap.put(x, x);
                sz++;
            }
            if (x != faMap.get(x)) {
                faMap.put(x, find(faMap.get(x)));
            }
            return faMap.get(x);
        }

        void union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return;
            }
            faMap.put(rootQ, rootP);
            sz--;
        }
    }
}

2334. 元素值大于变化阈值的子数组

import java.util.Arrays;
import java.util.stream.IntStream;

public class Solution2334 {
    public int validSubarraySize(int[] nums, int threshold) {
        int n = nums.length;

        Integer[] ids = IntStream.range(0, n).boxed().toArray(Integer[]::new);
        Arrays.sort(ids, (o1, o2) -> Integer.compare(nums[o2], nums[o1]));

        DSU dsu = new DSU(n);
        for (int i : ids) {
            int faI = dsu.find(i);
            int faJ = dsu.find(i + 1);
            // 贪心,优先 union 较大的数
            dsu.union(faI, faJ);

            int k = dsu.sz[faI] - 1;
            if (nums[i] > threshold / k) {
                return k;
            }
        }
        return -1;
    }

    private static class DSU {
        int[] fa;
        int[] sz;

        public DSU(int n) {
            int N = n + 1;
            fa = new int[N];
            sz = new int[N];
            for (int i = 0; i < N; i++) {
                fa[i] = i;
                sz[i] = 1;
            }
        }

        int find(int x) {
            if (x != fa[x]) {
                fa[x] = find(fa[x]);
            }
            return fa[x];
        }

        void union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return;
            }
            // 合并到较小的节点
            if (rootP < rootQ) {
                fa[rootQ] = rootP;
                sz[rootP] += sz[rootQ];
            } else {
                fa[rootP] = rootQ;
                sz[rootQ] += sz[rootP];
            }
        }
    }
}

参考链接

(全文完)