跳至主要內容

记忆化搜索

大约 4 分钟

记忆化搜索

题目难度
1269. 停在原地的方案数open in new window困难组合数学
2400. 恰好移动 k 步到达某一位置的方法数目open in new window困难组合数学
DD-2019001. 排列小球open in new window简单组合数学
DD-2019006. 宅男的生活open in new window简单组合数学

注:数位 DP 同样是记忆化搜索,详见对应章节,此处不再赘述。

定义

记忆化搜索是一种通过记录已经遍历过的状态的信息,从而避免对同一状态重复遍历的搜索实现方式。

因为记忆化搜索确保了每个状态只访问一次,它也是一种常见的动态规划实现方式。

1269. 停在原地的方案数

import java.util.Arrays;

public class Solution1269 {
    private static final long MOD = (long) (1e9 + 7);
    private long[][] memo;

    public int numWays(int steps, int arrLen) {
        arrLen = Math.min(arrLen, steps);

        memo = new long[steps + 1][arrLen];
        for (int i = 0; i < steps + 1; i++) {
            Arrays.fill(memo[i], -1);
        }
        return (int) dfs(steps, arrLen, 0);
    }

    // steps:剩余步数(0~steps) pos:位置(0~arrLen-1)
    private long dfs(int steps, int arrLen, int pos) {
        if (steps == 0) {
            return pos == 0 ? 1 : 0;
        }
        if (memo[steps][pos] != -1) {
            return memo[steps][pos];
        }

        // 原地不动
        long res = dfs(steps - 1, arrLen, pos) % MOD;
        // 向右
        if (pos + 1 < arrLen) {
            res += dfs(steps - 1, arrLen, pos + 1);
            res %= MOD;
        }
        // 向左
        if (pos - 1 >= 0) {
            res += dfs(steps - 1, arrLen, pos - 1);
            res %= MOD;
        }
        memo[steps][pos] = res;
        return res;
    }
}

2400. 恰好移动 k 步到达某一位置的方法数目

import java.util.Arrays;

public class Solution2400 {
    private static final long MOD = (long) (1e9 + 7);
    private long[][] memo;

    // 记忆化搜索 时间复杂度 O(k^2)
    public int numberOfWays(int startPos, int endPos, int k) {
        int diff = Math.abs(startPos - endPos);
        if (k - diff >= 0 && (k - diff) % 2 == 0) {
            // 重叠部分
            int overlap = (k - diff) / 2;
            // 向一个方向走 diff+overlap 步, 向另一个方向走 overlap 步
            int a = diff + overlap;
            int b = overlap;

            memo = new long[a + 1][b + 1];
            for (int i = 0; i < a + 1; i++) {
                Arrays.fill(memo[i], -1);
            }
            return (int) dfs(a, b);
        }
        return 0;
    }

    private long dfs(int a, int b) {
        if (a + b == 0) {
            return 1;
        }
        if (memo[a][b] != -1) {
            return memo[a][b];
        }

        long res = 0;
        if (a > 0) {
            res += dfs(a - 1, b);
            res %= MOD;
        }
        if (b > 0) {
            res += dfs(a, b - 1);
            res %= MOD;
        }
        memo[a][b] = res;
        return res;
    }
}

DD-2019001. 排列小球

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Scanner;

public class DD2019001 {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
        int np = scanner.nextInt();
        int nq = scanner.nextInt();
        int nr = scanner.nextInt();
        System.out.println(solve(np, nq, nr));
    }

    private static long[][][][] memo;

    private static String solve(int np, int nq, int nr) {
        memo = new long[3][np + 1][nq + 1][nr + 1];
        for (int i = 0; i < 3; i++) {
            for (int j = 0; j < np + 1; j++) {
                for (int k = 0; k < nq + 1; k++) {
                    Arrays.fill(memo[i][j][k], -1);
                }
            }
        }

        long res = 0;
        if (np > 0) {
            res += dfs(0, np - 1, nq, nr);
        }
        if (nq > 0) {
            res += dfs(1, np, nq - 1, nr);
        }
        if (nr > 0) {
            res += dfs(2, np, nq, nr - 1);
        }
        return String.valueOf(res);
    }

    private static long dfs(int pre, int np, int nq, int nr) {
        if (np + nq + nr == 0) {
            return 1;
        }
        if (memo[pre][np][nq][nr] != -1) {
            return memo[pre][np][nq][nr];
        }

        long res = 0;
        if (pre != 0 && np > 0) {
            res += dfs(0, np - 1, nq, nr);
        }
        if (pre != 1 && nq > 0) {
            res += dfs(1, np, nq - 1, nr);
        }
        if (pre != 2 && nr > 0) {
            res += dfs(2, np, nq, nr - 1);
        }
        memo[pre][np][nq][nr] = res;
        return res;
    }
}

DD-2019006. 宅男的生活

import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Scanner;

public class DD2019006 {
    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in, StandardCharsets.UTF_8);
        int t = scanner.nextInt();
        for (int i = 0; i < t; i++) {
            int n = scanner.nextInt();
            int d = scanner.nextInt();
            System.out.println(solve(n, d));
        }
    }

    // 记忆化搜索
    private static long[][][][][] memo;

    private static String solve(int n, int d) {
        // momo[first][pos][posVal][subStrLen][cnt0]
        memo = new long[2][n + 1][2][n / 2 + 1][n / 2 + 1];
        for (int first = 0; first < 2; first++) {
            for (int pos = 0; pos < n + 1; pos++) {
                for (int posVal = 0; posVal < 2; posVal++) {
                    for (int subStrLen = 0; subStrLen < n / 2 + 1; subStrLen++) {
                        Arrays.fill(memo[first][pos][posVal][subStrLen], -1);
                    }
                }
            }
        }

        long res = 0;
        // 第一位为 0
        res += dfs(n, d, 0, 1, 0, 1, 1);
        // 第一位为 1
        res += dfs(n, d, 1, 1, 1, 1, 0);
        return String.valueOf(res);
    }

    // 总状态 2 * 64 * 2 * 32 * 32 = 262,144
    // first:第一位数字(0/1) pos:当前下标(1~n) posVal:当前状态(0/1) subStrLen:连续子串长度(1~n/2) cnt0:0的数量
    private static long dfs(int n, int d, int first, int pos, int posVal, int subStrLen, int cnt0) {
        if (pos == n) {
            return first == posVal ? 0 : 1;
        }
        // 1 的数量
        int cost1 = pos - cnt0;
        // 必定以 1 结尾
        if (cnt0 == n / 2) {
            return first == 1 ? 0 : 1;
        }
        // 必定以 0 结尾
        if (cost1 == n / 2) {
            return first == 0 ? 0 : 1;
        }
        if (memo[first][pos][posVal][subStrLen][cnt0] != -1) {
            return memo[first][pos][posVal][subStrLen][cnt0];
        }

        long res = 0;
        // 可以继续做同一件事
        if (subStrLen < d) {
            res += dfs(n, d, first, pos + 1, posVal, subStrLen + 1, posVal == 0 ? cnt0 + 1 : cnt0);
        }
        // 做不同的事
        res += dfs(n, d, first, pos + 1, 1 - posVal, 1, posVal == 0 ? cnt0 : cnt0 + 1);
        memo[first][pos][posVal][subStrLen][cnt0] = res;
        return res;
    }
}

(全文完)