记忆化搜索
大约 4 分钟
记忆化搜索
- OI Wiki: https://oi-wiki.org/dp/memo/
题目 | 难度 | |
---|---|---|
1269. 停在原地的方案数 | 困难 | 组合数学 |
2400. 恰好移动 k 步到达某一位置的方法数目 | 困难 | 组合数学 |
DD-2019001. 排列小球 | 简单 | 组合数学 |
DD-2019006. 宅男的生活 | 简单 | 组合数学 |
注:数位 DP 同样是记忆化搜索,详见对应章节,此处不再赘述。
定义
记忆化搜索是一种通过记录已经遍历过的状态的信息,从而避免对同一状态重复遍历的搜索实现方式。
因为记忆化搜索确保了每个状态只访问一次,它也是一种常见的动态规划实现方式。
1269. 停在原地的方案数
import java.util.Arrays;
public class Solution1269 {
private static final long MOD = (long) (1e9 + 7);
private long[][] memo;
public int numWays(int steps, int arrLen) {
arrLen = Math.min(arrLen, steps);
memo = new long[steps + 1][arrLen];
for (int i = 0; i < steps + 1; i++) {
Arrays.fill(memo[i], -1);
}
return (int) dfs(steps, arrLen, 0);
}
// steps:剩余步数(0~steps) pos:位置(0~arrLen-1)
private long dfs(int steps, int arrLen, int pos) {
if (steps == 0) {
return pos == 0 ? 1 : 0;
}
if (memo[steps][pos] != -1) {
return memo[steps][pos];
}
// 原地不动
long res = dfs(steps - 1, arrLen, pos) % MOD;
// 向右
if (pos + 1 < arrLen) {
res += dfs(steps - 1, arrLen, pos + 1);
res %= MOD;
}
// 向左
if (pos - 1 >= 0) {
res += dfs(steps - 1, arrLen, pos - 1);
res %= MOD;
}
memo[steps][pos] = res;
return res;
}
}
2400. 恰好移动 k 步到达某一位置的方法数目
import java.util.Arrays;
public class Solution2400 {
private static final long MOD = (long) (1e9 + 7);
private long[][] memo;
// 记忆化搜索 时间复杂度 O(k^2)
public int numberOfWays(int startPos, int endPos, int k) {
int diff = Math.abs(startPos - endPos);
if (k - diff >= 0 && (k - diff) % 2 == 0) {
// 重叠部分
int overlap = (k - diff) / 2;
// 向一个方向走 diff+overlap 步, 向另一个方向走 overlap 步
int a = diff + overlap;
int b = overlap;
memo = new long[a + 1][b + 1];
for (int i = 0; i < a + 1; i++) {
Arrays.fill(memo[i], -1);
}
return (int) dfs(a, b);
}
return 0;
}
private long dfs(int a, int b) {
if (a + b == 0) {
return 1;
}
if (memo[a][b] != -1) {
return memo[a][b];
}
long res = 0;
if (a > 0) {
res += dfs(a - 1, b);
res %= MOD;
}
if (b > 0) {
res += dfs(a, b - 1);
res %= MOD;
}
memo[a][b] = res;
return res;
}
}
DD-2019001. 排列小球
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Scanner;
public class DD2019001 {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
int np = scanner.nextInt();
int nq = scanner.nextInt();
int nr = scanner.nextInt();
System.out.println(solve(np, nq, nr));
}
private static long[][][][] memo;
private static String solve(int np, int nq, int nr) {
memo = new long[3][np + 1][nq + 1][nr + 1];
for (int i = 0; i < 3; i++) {
for (int j = 0; j < np + 1; j++) {
for (int k = 0; k < nq + 1; k++) {
Arrays.fill(memo[i][j][k], -1);
}
}
}
long res = 0;
if (np > 0) {
res += dfs(0, np - 1, nq, nr);
}
if (nq > 0) {
res += dfs(1, np, nq - 1, nr);
}
if (nr > 0) {
res += dfs(2, np, nq, nr - 1);
}
return String.valueOf(res);
}
private static long dfs(int pre, int np, int nq, int nr) {
if (np + nq + nr == 0) {
return 1;
}
if (memo[pre][np][nq][nr] != -1) {
return memo[pre][np][nq][nr];
}
long res = 0;
if (pre != 0 && np > 0) {
res += dfs(0, np - 1, nq, nr);
}
if (pre != 1 && nq > 0) {
res += dfs(1, np, nq - 1, nr);
}
if (pre != 2 && nr > 0) {
res += dfs(2, np, nq, nr - 1);
}
memo[pre][np][nq][nr] = res;
return res;
}
}
DD-2019006. 宅男的生活
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Scanner;
public class DD2019006 {
public static void main(String[] args) {
Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
int t = scanner.nextInt();
for (int i = 0; i < t; i++) {
int n = scanner.nextInt();
int d = scanner.nextInt();
System.out.println(solve(n, d));
}
}
// 记忆化搜索
private static long[][][][][] memo;
private static String solve(int n, int d) {
// momo[first][pos][posVal][subStrLen][cnt0]
memo = new long[2][n + 1][2][n / 2 + 1][n / 2 + 1];
for (int first = 0; first < 2; first++) {
for (int pos = 0; pos < n + 1; pos++) {
for (int posVal = 0; posVal < 2; posVal++) {
for (int subStrLen = 0; subStrLen < n / 2 + 1; subStrLen++) {
Arrays.fill(memo[first][pos][posVal][subStrLen], -1);
}
}
}
}
long res = 0;
// 第一位为 0
res += dfs(n, d, 0, 1, 0, 1, 1);
// 第一位为 1
res += dfs(n, d, 1, 1, 1, 1, 0);
return String.valueOf(res);
}
// 总状态 2 * 64 * 2 * 32 * 32 = 262,144
// first:第一位数字(0/1) pos:当前下标(1~n) posVal:当前状态(0/1) subStrLen:连续子串长度(1~n/2) cnt0:0的数量
private static long dfs(int n, int d, int first, int pos, int posVal, int subStrLen, int cnt0) {
if (pos == n) {
return first == posVal ? 0 : 1;
}
// 1 的数量
int cost1 = pos - cnt0;
// 必定以 1 结尾
if (cnt0 == n / 2) {
return first == 1 ? 0 : 1;
}
// 必定以 0 结尾
if (cost1 == n / 2) {
return first == 0 ? 0 : 1;
}
if (memo[first][pos][posVal][subStrLen][cnt0] != -1) {
return memo[first][pos][posVal][subStrLen][cnt0];
}
long res = 0;
// 可以继续做同一件事
if (subStrLen < d) {
res += dfs(n, d, first, pos + 1, posVal, subStrLen + 1, posVal == 0 ? cnt0 + 1 : cnt0);
}
// 做不同的事
res += dfs(n, d, first, pos + 1, 1 - posVal, 1, posVal == 0 ? cnt0 : cnt0 + 1);
memo[first][pos][posVal][subStrLen][cnt0] = res;
return res;
}
}
(全文完)