字典树(0-1 Trie)
大约 6 分钟
字典树(0-1 Trie)
- OI Wiki: https://oi-wiki.org/string/trie/
题目 | 难度 | |
---|---|---|
208. 实现 Trie (前缀树) | 中等 | 模板 |
211. 添加与搜索单词 - 数据结构设计 | 中等 | |
212. 单词搜索 II | 困难 | |
421. 数组中两个数的最大异或值 | 中等 | |
$1804. 实现 Trie (前缀树) II | 中等 | 模板 |
题目 | 难度 | |
---|---|---|
421. 数组中两个数的最大异或值 | 中等 | 0-1 Trie |
1707. 与数组中元素的最大异或值 | 中等 | 0-1 Trie |
1938. 查询最大基因差 | 中等 | 0-1 Trie |
CF282E | rating 2200 | 0-1 Trie |
模板
https://codeforces.com/contest/282/submission/212508137
208. 实现 Trie (前缀树)
public class Solution208 {
static class Trie {
private final Trie[] children;
private boolean isWord;
public Trie() {
children = new Trie[26];
}
public void insert(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
node.children[idx] = new Trie();
}
node = node.children[idx];
}
node.isWord = true;
}
public boolean search(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
return false;
}
node = node.children[idx];
}
return node.isWord;
}
public boolean startsWith(String prefix) {
Trie node = this;
for (char ch : prefix.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
return false;
}
node = node.children[idx];
}
return true;
}
}
}
212. 单词搜索 II
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class Solution212 {
private static final int[][] DIRECTIONS = {{1, 0}, {0, 1}, {-1, 0}, {0, -1}};
private char[][] board;
private int m, n;
private Set<String> set;
public List<String> findWords(char[][] board, String[] words) {
this.board = board;
m = board.length;
n = board[0].length;
Trie trie = new Trie();
for (String word : words) {
trie.insert(word);
}
set = new HashSet<>();
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
dfs(trie, i, j);
}
}
return new ArrayList<>(set);
}
private void dfs(Trie trie, int x, int y) {
char ch = board[x][y];
if (trie.children[ch - 'a'] == null) {
return;
}
trie = trie.children[ch - 'a'];
if (trie.word != null) {
set.add(trie.word);
}
// ascii码序 排在 'z' 后的一个字符是 '{'
board[x][y] = '{';
for (int[] dir : DIRECTIONS) {
int nx = x + dir[0];
int ny = y + dir[1];
if (nx >= 0 && nx < m && ny >= 0 && ny < n) {
dfs(trie, nx, ny);
}
}
board[x][y] = ch;
}
private static class Trie {
Trie[] children;
String word;
public Trie() {
// 'a'~'z' + '{'
children = new Trie[27];
}
public void insert(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
node.children[idx] = new Trie();
}
node = node.children[idx];
}
node.word = word;
}
}
}
$1804. 实现 Trie (前缀树) II
public class Solution1804 {
static class Trie {
private final Trie[] children;
// 以 word 为字符串的个数
private int cntWord;
// 以 prefix 为前缀的个数
private int cntPrefix;
public Trie() {
children = new Trie[26];
cntWord = 0;
cntPrefix = 0;
}
public void insert(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
node.children[idx] = new Trie();
}
node.children[idx].cntPrefix++;
node = node.children[idx];
}
node.cntWord++;
}
public int countWordsEqualTo(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
return 0;
}
node = node.children[idx];
}
return node.cntWord;
}
public int countWordsStartingWith(String prefix) {
Trie node = this;
for (char ch : prefix.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
return 0;
}
node = node.children[idx];
}
return node.cntPrefix;
}
public void erase(String word) {
Trie node = this;
for (char ch : word.toCharArray()) {
int idx = ch - 'a';
if (node.children[idx] == null) {
return;
}
node.children[idx].cntPrefix--;
node = node.children[idx];
}
node.cntWord--;
}
}
}
0-1 Trie
421. 数组中两个数的最大异或值
public class Solution421 {
public int findMaximumXOR(int[] nums) {
TrieNode root = buildTrie(nums);
int max = 0;
for (int num : nums) {
TrieNode node = root;
int xor = 0;
for (int i = 31; i >= 0; i--) {
int bit = (num >> i) & 1;
if (node.children[1 - bit] != null) {
xor = (xor << 1) + 1;
node = node.children[1 - bit];
} else {
xor = (xor << 1);
node = node.children[bit];
}
}
max = Math.max(max, xor);
}
return max;
}
private TrieNode buildTrie(int[] nums) {
TrieNode root = new TrieNode();
for (int num : nums) {
TrieNode node = root;
for (int i = 31; i >= 0; i--) {
int bit = (num >> i) & 1;
if (node.children[bit] == null) {
node.children[bit] = new TrieNode();
}
node = node.children[bit];
}
}
return root;
}
private static class TrieNode {
TrieNode[] children;
public TrieNode() {
children = new TrieNode[2];
}
}
public int findMaximumXOR2(int[] nums) {
int n = nums.length;
Trie trie = new Trie(n, 32);
for (int x : nums) {
trie.insert(x);
}
int ans = 0;
for (int x : nums) {
ans = Math.max(ans, trie.query(x));
}
return ans;
}
// 0-1 Trie
// 2^31
private static class Trie {
int[][] dict;
int nextIdx, m;
// n:长度 m:2^m
public Trie(int n, int m) {
this.dict = new int[2][n * m + 2];
this.nextIdx = 1;
this.m = m;
}
public void insert(int x) {
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (dict[pos][idx] == 0) {
dict[pos][idx] = nextIdx++;
}
idx = dict[pos][idx];
}
}
public int query(int x) {
int res = 0;
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (dict[1 - pos][idx] != 0) {
res |= 1 << k;
idx = dict[1 - pos][idx];
} else {
idx = dict[pos][idx];
}
}
return res;
}
}
}
1707. 与数组中元素的最大异或值
import java.util.Arrays;
import java.util.Comparator;
public class Solution1707 {
public int[] maximizeXor(int[] nums, int[][] queries) {
int n = nums.length;
int q = queries.length;
Arrays.sort(nums);
Integer[] ids = new Integer[q];
for (int i = 0; i < q; i++) ids[i] = i;
Arrays.sort(ids, Comparator.comparingInt(o -> queries[o][1]));
Trie trie = new Trie(n, 32);
int[] ans = new int[q];
int loc = 0; // 记录 nums 中哪些位置之前的数已经放入 Trie
for (int i = 0; i < q; i++) {
int id = ids[i];
int x = queries[id][0], limit = queries[id][1];
// 将小于等于 limit 的数存入 Trie
while (loc < n && nums[loc] <= limit) {
trie.insert(nums[loc]);
loc++;
}
if (loc == 0) {
ans[id] = -1;
} else {
ans[id] = trie.query(x);
}
}
return ans;
}
// 0-1 Trie
// 2^31
private static class Trie {
int[][] dict;
int nextIdx, m;
// n:长度 m:2^m
public Trie(int n, int m) {
this.dict = new int[2][n * m + 2];
this.nextIdx = 1;
this.m = m;
}
public void insert(int x) {
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (dict[pos][idx] == 0) {
dict[pos][idx] = nextIdx++;
}
idx = dict[pos][idx];
}
}
public int query(int x) {
int res = 0;
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (dict[1 - pos][idx] != 0) {
res |= 1 << k;
idx = dict[1 - pos][idx];
} else {
idx = dict[pos][idx];
}
}
return res;
}
}
}
1938. 查询最大基因差
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Solution1938 {
private Map<Integer, List<Integer>> g;
private Map<Integer, List<int[]>> qs;
private int[] ans;
private Trie trie;
public int[] maxGeneticDifference(int[] parents, int[][] queries) {
int n = parents.length;
int q = queries.length;
int root = 0;
g = new HashMap<>();
for (int x = 0; x < n; x++) {
int pa = parents[x];
if (pa == -1) {
root = x;
} else {
g.computeIfAbsent(pa, key -> new ArrayList<>()).add(x);
}
}
qs = new HashMap<>();
for (int i = 0; i < q; i++) {
int node = queries[i][0], val = queries[i][1];
qs.computeIfAbsent(node, key -> new ArrayList<>()).add(new int[]{val, i});
}
ans = new int[q];
trie = new Trie(n, 32);
dfs(root);
return ans;
}
private void dfs(int v) {
trie.insert(v, 1);
for (int[] tuple : qs.getOrDefault(v, new ArrayList<>())) {
int val = tuple[0], i = tuple[1];
ans[i] = trie.query(val);
}
for (Integer w : g.getOrDefault(v, new ArrayList<>())) {
dfs(w);
}
trie.insert(v, -1);
}
// 0-1 Trie
// 2^31
private static class Trie {
int[][] dict;
int[] cnt;
int nextIdx, m;
// n:长度 m:2^m
public Trie(int n, int m) {
this.dict = new int[2][n * m + 2];
this.cnt = new int[n * m + 2];
this.nextIdx = 1;
this.m = m;
}
// op:1 插入 op:-1 删除
public void insert(int x, int op) {
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (dict[pos][idx] == 0) {
dict[pos][idx] = nextIdx++;
}
idx = dict[pos][idx];
cnt[idx] += op;
}
}
public int query(int x) {
int res = 0;
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = x >> k & 1;
if (cnt[dict[1 - pos][idx]] != 0) {
res |= 1 << k;
idx = dict[1 - pos][idx];
} else {
idx = dict[pos][idx];
}
}
return res;
}
}
}
CF282E
import java.nio.charset.StandardCharsets;
import java.util.Scanner;
public class CF282E {
static int n;
static long[] a;
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
n = scanner.nextInt();
a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = scanner.nextLong();
}
System.out.println(solve1());
}
// Time limit exceeded on test 46
private static String solve() {
long ans = 0;
long pre = 0;
TrieNode root = new TrieNode();
for (int i = 0; i < n; i++) {
// 插入前缀(除了最后一个)
TrieNode o = root;
for (int j = 39; j >= 0; j--) {
int b = (int) (pre >> j & 1);
if (o.children[b] == null) {
o.children[b] = new TrieNode();
}
o = o.children[b];
o.cnt++;
}
pre ^= a[i];
// 前缀最大值
ans = Math.max(ans, pre);
}
long suf = 0;
for (int i = n - 1; i >= 0; i--) {
suf ^= a[i];
// 「后缀异或前缀」的最大值
long res = 0;
TrieNode o = root;
for (int j = 39; j >= 0; j--) {
int b = (int) (suf >> j & 1);
if (o.children[b ^ 1] != null && o.children[b ^ 1].cnt > 0) {
res |= 1L << j;
b ^= 1;
}
o = o.children[b];
}
ans = Math.max(ans, res);
}
return String.valueOf(ans);
}
private static class TrieNode {
TrieNode[] children;
int cnt;
public TrieNode() {
children = new TrieNode[2];
cnt = 0;
}
}
// https://codeforces.com/contest/282/submission/212508137
private static String solve1() {
long ans = 0;
Trie trie = new Trie(n, 40);
long pre = 0;
for (int i = 0; i < n; i++) {
trie.insert(pre);
pre ^= a[i];
}
long suf = 0;
for (int i = n - 1; i >= 0; i--) {
ans = Math.max(ans, trie.query(suf));
suf ^= a[i];
}
ans = Math.max(ans, trie.query(suf));
return String.valueOf(ans);
}
private static class Trie {
int[][] dict;
int nextIdx, m;
// n:长度 m:2^m
public Trie(int n, int m) {
this.dict = new int[2][n * m + 2];
this.nextIdx = 1;
this.m = m;
}
public void insert(long x) {
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = (int) (x >> k & 1);
if (dict[pos][idx] == 0) {
dict[pos][idx] = nextIdx++;
}
idx = dict[pos][idx];
}
}
public long query(long x) {
long res = 0;
int idx = 0;
for (int k = m - 1; k >= 0; k--) {
int pos = (int) (x >> k & 1);
if (dict[1 - pos][idx] != 0) {
res |= 1L << k;
idx = dict[1 - pos][idx];
} else {
idx = dict[pos][idx];
}
}
return res;
}
}
}
(全文完)