跳至主要內容

字典树(0-1 Trie)

大约 6 分钟

字典树(0-1 Trie)

题目难度
208. 实现 Trie (前缀树)open in new window中等模板
211. 添加与搜索单词 - 数据结构设计open in new window中等
212. 单词搜索 IIopen in new window困难
421. 数组中两个数的最大异或值open in new window中等
$1804. 实现 Trie (前缀树) IIopen in new window 中等模板
题目难度
421. 数组中两个数的最大异或值open in new window中等0-1 Trie
1707. 与数组中元素的最大异或值open in new window中等0-1 Trie
1938. 查询最大基因差open in new window中等0-1 Trie
CF282Eopen in new windowrating 22000-1 Trie

模板

https://codeforces.com/contest/282/submission/212508137open in new window

208. 实现 Trie (前缀树)

public class Solution208 {
    static class Trie {
        private final Trie[] children;
        private boolean isWord;

        public Trie() {
            children = new Trie[26];
        }

        public void insert(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    node.children[idx] = new Trie();
                }
                node = node.children[idx];
            }
            node.isWord = true;
        }

        public boolean search(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    return false;
                }
                node = node.children[idx];
            }
            return node.isWord;
        }

        public boolean startsWith(String prefix) {
            Trie node = this;
            for (char ch : prefix.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    return false;
                }
                node = node.children[idx];
            }
            return true;
        }
    }
}

212. 单词搜索 II

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class Solution212 {
    private static final int[][] DIRECTIONS = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
    private char[][] board;
    private int m, n;
    private Set<String> set;

    public List<String> findWords(char[][] board, String[] words) {
        this.board = board;
        m = board.length;
        n = board[0].length;
        Trie trie = new Trie();
        for (String word : words) {
            trie.insert(word);
        }
        set = new HashSet<>();
        for (int i = 0; i < m; i++) {
            for (int j = 0; j < n; j++) {
                dfs(trie, i, j);
            }
        }
        return new ArrayList<>(set);
    }

    private void dfs(Trie trie, int x, int y) {
        char ch = board[x][y];
        if (trie.children[ch - 'a'] == null) {
            return;
        }
        trie = trie.children[ch - 'a'];
        if (trie.word != null) {
            set.add(trie.word);
        }
        // ascii码序 排在 'z' 后的一个字符是 '{'
        board[x][y] = '{';
        for (int[] dir : DIRECTIONS) {
            int nx = x + dir[0];
            int ny = y + dir[1];
            if (nx >= 0 && nx < m && ny >= 0 && ny < n) {
                dfs(trie, nx, ny);
            }
        }
        board[x][y] = ch;
    }

    private static class Trie {
        Trie[] children;
        String word;

        public Trie() {
            // 'a'~'z' + '{'
            children = new Trie[27];
        }

        public void insert(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    node.children[idx] = new Trie();
                }
                node = node.children[idx];
            }
            node.word = word;
        }
    }
}

$1804. 实现 Trie (前缀树) II

public class Solution1804 {
    static class Trie {
        private final Trie[] children;
        // 以 word 为字符串的个数
        private int cntWord;
        // 以 prefix 为前缀的个数
        private int cntPrefix;

        public Trie() {
            children = new Trie[26];
            cntWord = 0;
            cntPrefix = 0;
        }

        public void insert(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    node.children[idx] = new Trie();
                }
                node.children[idx].cntPrefix++;
                node = node.children[idx];
            }
            node.cntWord++;
        }

        public int countWordsEqualTo(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    return 0;
                }
                node = node.children[idx];
            }
            return node.cntWord;
        }

        public int countWordsStartingWith(String prefix) {
            Trie node = this;
            for (char ch : prefix.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    return 0;
                }
                node = node.children[idx];
            }
            return node.cntPrefix;
        }

        public void erase(String word) {
            Trie node = this;
            for (char ch : word.toCharArray()) {
                int idx = ch - 'a';
                if (node.children[idx] == null) {
                    return;
                }
                node.children[idx].cntPrefix--;
                node = node.children[idx];
            }
            node.cntWord--;
        }
    }
}

0-1 Trie

421. 数组中两个数的最大异或值

public class Solution421 {
    public int findMaximumXOR(int[] nums) {
        TrieNode root = buildTrie(nums);
        int max = 0;
        for (int num : nums) {
            TrieNode node = root;
            int xor = 0;
            for (int i = 31; i >= 0; i--) {
                int bit = (num >> i) & 1;
                if (node.children[1 - bit] != null) {
                    xor = (xor << 1) + 1;
                    node = node.children[1 - bit];
                } else {
                    xor = (xor << 1);
                    node = node.children[bit];
                }
            }
            max = Math.max(max, xor);
        }
        return max;
    }

    private TrieNode buildTrie(int[] nums) {
        TrieNode root = new TrieNode();
        for (int num : nums) {
            TrieNode node = root;
            for (int i = 31; i >= 0; i--) {
                int bit = (num >> i) & 1;
                if (node.children[bit] == null) {
                    node.children[bit] = new TrieNode();
                }
                node = node.children[bit];
            }
        }
        return root;
    }

    private static class TrieNode {
        TrieNode[] children;

        public TrieNode() {
            children = new TrieNode[2];
        }
    }

    public int findMaximumXOR2(int[] nums) {
        int n = nums.length;
        Trie trie = new Trie(n, 32);
        for (int x : nums) {
            trie.insert(x);
        }

        int ans = 0;
        for (int x : nums) {
            ans = Math.max(ans, trie.query(x));
        }
        return ans;
    }

    // 0-1 Trie
    // 2^31
    private static class Trie {
        int[][] dict;
        int nextIdx, m;

        // n:长度 m:2^m
        public Trie(int n, int m) {
            this.dict = new int[2][n * m + 2];
            this.nextIdx = 1;
            this.m = m;
        }

        public void insert(int x) {
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (dict[pos][idx] == 0) {
                    dict[pos][idx] = nextIdx++;
                }
                idx = dict[pos][idx];
            }
        }

        public int query(int x) {
            int res = 0;
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (dict[1 - pos][idx] != 0) {
                    res |= 1 << k;
                    idx = dict[1 - pos][idx];
                } else {
                    idx = dict[pos][idx];
                }
            }
            return res;
        }
    }
}

1707. 与数组中元素的最大异或值

import java.util.Arrays;
import java.util.Comparator;

public class Solution1707 {
    public int[] maximizeXor(int[] nums, int[][] queries) {
        int n = nums.length;
        int q = queries.length;

        Arrays.sort(nums);

        Integer[] ids = new Integer[q];
        for (int i = 0; i < q; i++) ids[i] = i;
        Arrays.sort(ids, Comparator.comparingInt(o -> queries[o][1]));

        Trie trie = new Trie(n, 32);

        int[] ans = new int[q];
        int loc = 0; // 记录 nums 中哪些位置之前的数已经放入 Trie
        for (int i = 0; i < q; i++) {
            int id = ids[i];
            int x = queries[id][0], limit = queries[id][1];
            // 将小于等于 limit 的数存入 Trie
            while (loc < n && nums[loc] <= limit) {
                trie.insert(nums[loc]);
                loc++;
            }
            if (loc == 0) {
                ans[id] = -1;
            } else {
                ans[id] = trie.query(x);
            }
        }
        return ans;
    }

    // 0-1 Trie
    // 2^31
    private static class Trie {
        int[][] dict;
        int nextIdx, m;

        // n:长度 m:2^m
        public Trie(int n, int m) {
            this.dict = new int[2][n * m + 2];
            this.nextIdx = 1;
            this.m = m;
        }

        public void insert(int x) {
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (dict[pos][idx] == 0) {
                    dict[pos][idx] = nextIdx++;
                }
                idx = dict[pos][idx];
            }
        }

        public int query(int x) {
            int res = 0;
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (dict[1 - pos][idx] != 0) {
                    res |= 1 << k;
                    idx = dict[1 - pos][idx];
                } else {
                    idx = dict[pos][idx];
                }
            }
            return res;
        }
    }
}

1938. 查询最大基因差

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution1938 {
    private Map<Integer, List<Integer>> g;
    private Map<Integer, List<int[]>> qs;
    private int[] ans;
    private Trie trie;

    public int[] maxGeneticDifference(int[] parents, int[][] queries) {
        int n = parents.length;
        int q = queries.length;

        int root = 0;
        g = new HashMap<>();
        for (int x = 0; x < n; x++) {
            int pa = parents[x];
            if (pa == -1) {
                root = x;
            } else {
                g.computeIfAbsent(pa, key -> new ArrayList<>()).add(x);
            }
        }
        qs = new HashMap<>();
        for (int i = 0; i < q; i++) {
            int node = queries[i][0], val = queries[i][1];
            qs.computeIfAbsent(node, key -> new ArrayList<>()).add(new int[]{val, i});
        }

        ans = new int[q];
        trie = new Trie(n, 32);

        dfs(root);
        return ans;
    }

    private void dfs(int v) {
        trie.insert(v, 1);
        for (int[] tuple : qs.getOrDefault(v, new ArrayList<>())) {
            int val = tuple[0], i = tuple[1];
            ans[i] = trie.query(val);
        }
        for (Integer w : g.getOrDefault(v, new ArrayList<>())) {
            dfs(w);
        }
        trie.insert(v, -1);
    }

    // 0-1 Trie
    // 2^31
    private static class Trie {
        int[][] dict;
        int[] cnt;
        int nextIdx, m;

        // n:长度 m:2^m
        public Trie(int n, int m) {
            this.dict = new int[2][n * m + 2];
            this.cnt = new int[n * m + 2];
            this.nextIdx = 1;
            this.m = m;
        }

        // op:1 插入 op:-1 删除
        public void insert(int x, int op) {
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (dict[pos][idx] == 0) {
                    dict[pos][idx] = nextIdx++;
                }
                idx = dict[pos][idx];
                cnt[idx] += op;
            }
        }

        public int query(int x) {
            int res = 0;
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = x >> k & 1;
                if (cnt[dict[1 - pos][idx]] != 0) {
                    res |= 1 << k;
                    idx = dict[1 - pos][idx];
                } else {
                    idx = dict[pos][idx];
                }
            }
            return res;
        }
    }
}

CF282E

import java.nio.charset.StandardCharsets;
import java.util.Scanner;

public class CF282E {
    static int n;
    static long[] a;

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
        n = scanner.nextInt();
        a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = scanner.nextLong();
        }
        System.out.println(solve1());
    }

    // Time limit exceeded on test 46
    private static String solve() {
        long ans = 0;

        long pre = 0;
        TrieNode root = new TrieNode();
        for (int i = 0; i < n; i++) {
            // 插入前缀(除了最后一个)
            TrieNode o = root;
            for (int j = 39; j >= 0; j--) {
                int b = (int) (pre >> j & 1);
                if (o.children[b] == null) {
                    o.children[b] = new TrieNode();
                }
                o = o.children[b];
                o.cnt++;
            }
            pre ^= a[i];
            // 前缀最大值
            ans = Math.max(ans, pre);
        }

        long suf = 0;
        for (int i = n - 1; i >= 0; i--) {
            suf ^= a[i];
            // 「后缀异或前缀」的最大值
            long res = 0;
            TrieNode o = root;
            for (int j = 39; j >= 0; j--) {
                int b = (int) (suf >> j & 1);
                if (o.children[b ^ 1] != null && o.children[b ^ 1].cnt > 0) {
                    res |= 1L << j;
                    b ^= 1;
                }
                o = o.children[b];
            }
            ans = Math.max(ans, res);
        }
        return String.valueOf(ans);
    }

    private static class TrieNode {
        TrieNode[] children;
        int cnt;

        public TrieNode() {
            children = new TrieNode[2];
            cnt = 0;
        }
    }

    // https://codeforces.com/contest/282/submission/212508137
    private static String solve1() {
        long ans = 0;

        Trie trie = new Trie(n, 40);

        long pre = 0;
        for (int i = 0; i < n; i++) {
            trie.insert(pre);
            pre ^= a[i];
        }

        long suf = 0;
        for (int i = n - 1; i >= 0; i--) {
            ans = Math.max(ans, trie.query(suf));
            suf ^= a[i];
        }
        ans = Math.max(ans, trie.query(suf));
        return String.valueOf(ans);
    }

    private static class Trie {
        int[][] dict;
        int nextIdx, m;

        // n:长度 m:2^m
        public Trie(int n, int m) {
            this.dict = new int[2][n * m + 2];
            this.nextIdx = 1;
            this.m = m;
        }

        public void insert(long x) {
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = (int) (x >> k & 1);
                if (dict[pos][idx] == 0) {
                    dict[pos][idx] = nextIdx++;
                }
                idx = dict[pos][idx];
            }
        }

        public long query(long x) {
            long res = 0;
            int idx = 0;
            for (int k = m - 1; k >= 0; k--) {
                int pos = (int) (x >> k & 1);
                if (dict[1 - pos][idx] != 0) {
                    res |= 1L << k;
                    idx = dict[1 - pos][idx];
                } else {
                    idx = dict[pos][idx];
                }
            }
            return res;
        }
    }
}

(全文完)