后缀数组(SA)
2025年3月30日大约 5 分钟
后缀数组(SA)
- OI Wiki: https://oi-wiki.org/string/sa/
题目 | 难度 | |
---|---|---|
1698. 字符串的不同子字符串个数 | 困难 | |
3213. 最小代价构造字符串 | 困难 |
定义
后缀数组(Suffix Array)主要关系到两个数组:sa
和 rk
。
其中,sa[i]
表示将所有后缀排序后第 i
小的后缀的编号,也是所说的后缀数组,后文也称编号数组 sa
;
rk[i]
表示后缀 i
的排名,是重要的辅助数组,后文也称排名数组 rk
。
这两个数组满足性质:sa[rk[i]]=rk[sa[i]]=i
。
3213. 最小代价构造字符串
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Queue;
import java.util.function.Function;
public class Solution3213 {
private static final int INF = (int) 1e9;
public int minimumCost(String target, String[] words, int[] costs) {
int m = words.length;
Map<String, Integer> minCost = new HashMap<>();
for (int i = 0; i < m; i++) {
minCost.merge(words[i], costs[i], Integer::min);
}
int n = target.length();
idx = 0;
Arrays.fill(he, -1);
SuffixArray suffixArray = new SuffixArray(target);
int[] sa = suffixArray.sa0;
for (Map.Entry<String, Integer> entry : minCost.entrySet()) {
String word = entry.getKey();
int cost = entry.getValue();
int[] lr = suffixArray.lookupAll(word);
for (int i = lr[0]; i < lr[1]; i++) {
int l = sa[i];
int r = l + word.length();
add(r, l, cost);
}
}
int[] f = new int[n + 1];
for (int i = 1; i <= n; i++) {
f[i] = INF;
for (int j = he[i]; j != -1; j = ne[j]) {
int p_l = ed[j], p_cost = we[j];
f[i] = Math.min(f[i], f[p_l] + p_cost);
}
}
if (f[n] == INF) return -1;
return f[n];
}
// 链式前向星
// 最坏情况:target 有 50000 个 a,words=[a,aa,aaa,...],有 315 项, 总长 L 为 (1+315)*315/2=49770,
// 匹配次数为 (50000+(50000-314))*315/2 = 15700545
static int N = (int) (5e4 + 5), M = 15700545 + 5;
static int[] he = new int[N], ne = new int[M], ed = new int[M], we = new int[M];
static int idx = 0;
static void add(int u, int v, int w) {
ed[idx] = v;
ne[idx] = he[u];
he[u] = idx;
we[idx] = w;
idx++;
}
// O(nlogn) 后缀数组 https://oi-wiki.org/string/sa/
static class SuffixArray {
String data;
int[] rk, old_rk, sa, id, cnt;
int[] sa0; // 下标从 0 开始
public SuffixArray(String S) {
data = S;
int n = S.length(), m = 128, p;
rk = new int[n * 2 + 1];
old_rk = new int[n * 2 + 1];
id = new int[n + 1];
cnt = new int[Math.max(n + 1, m + 1)];
sa = new int[n + 1];
S = " " + S;
char[] s = S.toCharArray();
for (int i = 1; i <= n; i++) cnt[rk[i] = s[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i;
for (int w = 1; ; w <<= 1, m = p) { // m = p 即为值域优化
int cur = 0;
for (int i = n - w + 1; i <= n; i++) id[++cur] = i;
for (int i = 1; i <= n; i++) {
if (sa[i] > w) id[++cur] = sa[i] - w;
}
Arrays.fill(cnt, 0);
for (int i = 1; i <= n; i++) cnt[rk[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[rk[id[i]]]--] = id[i];
p = 0;
System.arraycopy(rk, 0, old_rk, 0, old_rk.length);
for (int i = 1; i <= n; i++) {
if (old_rk[sa[i]] == old_rk[sa[i - 1]] && old_rk[sa[i] + w] == old_rk[sa[i - 1] + w])
rk[sa[i]] = p;
else rk[sa[i]] = ++p;
}
if (p == n) break; // p = n 时无需再排序
}
sa0 = new int[n];
for (int i = 0; i < n; i++) sa0[i] = sa[i + 1] - 1;
}
// src/index/suffixarray/suffixarray.go:242
// func (x *Index) lookupAll(s []byte) ints { ... }
// O(log(N)*len(s)) 优化:1、避免 String#substring 带来的额外开销
// 2、避免 Arrays.copyOfRange(sa0, i, j) 带来的额外开销,仅返回 i 和 j
int[] lookupAll(String s) {
// int i = sortSearch(sa0.length, m -> at(m).compareTo(s) >= 0);
// int j = i + sortSearch(sa0.length - i, m -> !at(m + i).startsWith(s));
int i = sortSearch(sa0.length, m -> compareTo(data, sa0[m], s) >= 0);
int j = i + sortSearch(sa0.length - i, m -> !data.startsWith(s, sa0[m + i]));
return new int[]{i, j};
}
// x.data[x.sa.get(i):]
private String at(int i) {
return data.substring(sa0[i]);
}
// func Search(n int, f func(int) bool) int { ... }
private int sortSearch(int n, Function<Integer, Boolean> f) {
int l = 0, r = n;
while (l < r) {
int mid = l + (r - l) / 2;
if (f.apply(mid)) r = mid;
else l = mid + 1;
}
return l;
}
// 等价于 s.substring(beginIndex).compareTo(t)
private int compareTo(String s, int beginIndex, String t) {
int len1 = s.length() - beginIndex, len2 = t.length();
int lim = Math.min(len1, len2);
for (int k = 0; k < lim; k++) {
char c1 = s.charAt(k + beginIndex), c2 = t.charAt(k);
if (c1 != c2) return c1 - c2;
}
return len1 - len2;
}
}
}
$1698. 字符串的不同子字符串个数
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class SolutionP1698 {
// 后缀数组
static class V1 {
public int countDistinct(String s) {
int n = s.length();
SuffixArray suffixArray = new SuffixArray(s);
int[] height = suffixArray.height;
int ans = n * (n + 1) / 2;
return ans - Arrays.stream(height).sum();
}
// https://oi-wiki.org/string/sa/
static class SuffixArray {
String data;
int[] rk, old_rk, sa, id, cnt;
int[] sa0; // 下标从 0 开始
int[] height;
public SuffixArray(String S) {
data = S;
int n = S.length(), m = 128, p;
rk = new int[n * 2 + 1];
old_rk = new int[n * 2 + 1];
id = new int[n + 1];
cnt = new int[Math.max(n + 1, m + 1)];
sa = new int[n + 1];
S = " " + S;
char[] s = S.toCharArray();
for (int i = 1; i <= n; i++) cnt[rk[i] = s[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[rk[i]]--] = i;
for (int w = 1; ; w <<= 1, m = p) { // m = p 即为值域优化
int cur = 0;
for (int i = n - w + 1; i <= n; i++) id[++cur] = i;
for (int i = 1; i <= n; i++) {
if (sa[i] > w) id[++cur] = sa[i] - w;
}
Arrays.fill(cnt, 0);
for (int i = 1; i <= n; i++) cnt[rk[i]]++;
for (int i = 1; i <= m; i++) cnt[i] += cnt[i - 1];
for (int i = n; i >= 1; i--) sa[cnt[rk[id[i]]]--] = id[i];
p = 0;
System.arraycopy(rk, 0, old_rk, 0, old_rk.length);
for (int i = 1; i <= n; i++) {
if (old_rk[sa[i]] == old_rk[sa[i - 1]] && old_rk[sa[i] + w] == old_rk[sa[i - 1] + w])
rk[sa[i]] = p;
else rk[sa[i]] = ++p;
}
if (p == n) break; // p = n 时无需再排序
}
sa0 = new int[n];
for (int i = 0; i < n; i++) sa0[i] = sa[i + 1] - 1;
// height
height = new int[n + 1];
int k = 0;
for (int i = 1; i <= n; i++) {
if (rk[i] == 1) continue;
if (k > 0) --k;
int j = sa[rk[i] - 1];
while (i + k <= n && j + k <= n && s[i + k] == s[j + k]) k++;
height[rk[i]] = k;
}
}
}
}
// 后缀自动机
static class V2 {
public int countDistinct(String s) {
Sam sam = new Sam();
sam.buildSam(s);
int ans = 0;
for (int i = 1; i < sam.nodes.size(); i++) {
Node o = sam.nodes.get(i);
ans += o.len - o.fa.len;
}
return ans;
}
static class Node {
Node fa;
Node[] ch = new Node[26];
int len;
public Node(Node fa, int len) {
this.fa = fa;
this.len = len;
}
}
static class Sam {
List<Node> nodes = new ArrayList<>();
Node last;
public Sam() {
last = new Node(null, 0);
nodes.add(last);
}
private Node newNode(Node fa, int len) {
Node newNode = new Node(fa, len);
nodes.add(newNode);
return newNode;
}
public void append(char c) {
int index = c - 'a';
Node last = newNode(nodes.get(0), this.last.len + 1);
for (Node o = this.last; o != null; o = o.fa) {
Node p = o.ch[index];
if (p == null) {
o.ch[index] = last;
continue;
}
if (o.len + 1 == p.len) {
last.fa = p;
} else {
Node np = newNode(p.fa, o.len + 1);
np.ch = p.ch.clone(); // Assuming shallow copy is enough for this problem
p.fa = np;
for (; o != null && o.ch[index] == p; o = o.fa) {
o.ch[index] = np;
}
last.fa = np;
}
break;
}
this.last = last;
}
public void buildSam(String s) {
for (char c : s.toCharArray()) {
append(c);
}
}
}
}
}
(全文完)