跳至主要內容

状压 DP

大约 4 分钟

状压 DP

题目难度
$465. 最优账单平衡open in new window 困难
473. 火柴拼正方形open in new window中等
698. 划分为 k 个相等的子集open in new window中等类似 473
1723. 完成所有工作的最短时间open in new window困难
1986. 完成任务的最少工作时间段open in new window中等
2305. 公平分发饼干open in new window中等同 1723

Bitmasking

  • 我们使用 0/1 来表示某物的状态。大多数情况下,1 代表有效状态,0 代表无效状态;
  • 我们可以使用 ((state >> k) & 1) == 1 来获得状态 state 的第 k 位;又或者 (state & (1 << k)) != 0
  • 我们可以使用 (x & y) == x 来检查 x 是否是 y 的子集;
  • 我们可以使用 (x & (x >> 1)) == 0 来检查 x 中是否没有相邻的有效状态;
  • Java 中可以使用 Integer.bitCount(state) 获取状态 state 1 的个数;
  • 或运算 | 可以将一个状态加入到集合中;
  • 异或运算 ^ 可以将一个状态从集合中剔除;

$465. 最优账单平衡

public class Solution465 {
    public int minTransfers(int[][] distributions) {
        int n = 12;

        int[] cnt = new int[n];
        for (int[] distribution : distributions) {
            cnt[distribution[0]] -= distribution[2];
            cnt[distribution[1]] += distribution[2];
        }

        int[] dp = new int[1 << n];
        for (int state = 1; state < (1 << n); state++) {
            int sum = 0;
            for (int k = 0; k < n; k++) {
                if (((state >> k) & 1) == 1) {
                    sum += cnt[k];
                }

                if (sum > 0) {
                    dp[state] = Integer.MAX_VALUE / 2;
                } else {
                    dp[state] = Integer.bitCount(state) - 1;
                    for (int subState = state; subState > 0; subState = (subState - 1) & state) {
                        dp[state] = Math.min(dp[state], dp[state ^ subState] + dp[subState]);
                    }
                }
            }
        }
        return dp[(1 << n) - 1];
    }
}
  • 时间复杂度:O(3^n)
  • 空间复杂度:O(2^n)

473. 火柴拼正方形

import java.util.Arrays;

public class Solution473 {
    public boolean makesquare(int[] matchsticks) {
        int totalLen = Arrays.stream(matchsticks).sum();
        if (totalLen % 4 != 0) {
            return false;
        }
        int edgeLen = totalLen / 4;
        int n = matchsticks.length;

        int[] dp = new int[1 << n];
        Arrays.fill(dp, -1);
        dp[0] = 0;
        for (int state = 0; state < (1 << n); state++) {
            for (int k = 0; k < n; k++) {
                // 第 k 位被选中
                if (((state >> k) & 1) == 1) {

                    // 去掉 state 的第 k 根火柴得到状态 s1
                    int s1 = state & ~(1 << k);
                    if (dp[s1] >= 0 && dp[s1] + matchsticks[k] <= edgeLen) {
                        dp[state] = (dp[s1] + matchsticks[k]) % edgeLen;
                        break;
                    }
                }
            }
        }

        return dp[(1 << n) - 1] == 0;
    }
}

698. 划分为 k 个相等的子集

import java.util.Arrays;

public class Solution698 {
    public boolean canPartitionKSubsets(int[] nums, int k) {
        int n = nums.length;
        int sum = Arrays.stream(nums).sum();
        if (sum % k != 0) {
            return false;
        }
        int partitionSum = sum / k;

        // 状态定义 dp[mask] 为选取状态 mask 时,未填满的组的和
        int[] dp = new int[1 << n];
        Arrays.fill(dp, -1);
        dp[0] = 0;
        for (int state = 0; state < (1 << n); state++) {
            for (int i = 0; i < n; i++) {
                // 第 k 位被选中
                if (((state >> i) & 1) == 1) {
                    // 去掉 state 的第 k 根火柴得到状态 s1
                    int s1 = state & ~(1 << i);
                    if (dp[s1] >= 0 && dp[s1] + nums[i] <= partitionSum) {
                        dp[state] = (dp[s1] + nums[i]) % partitionSum;
                        break;
                    }
                }
            }
        }
        return dp[(1 << n) - 1] == 0;
    }
}

1723. 完成所有工作的最短时间

public class Solution1723 {
    public int minimumTimeRequired(int[] jobs, int k) {
        int n = jobs.length;
        int[] sum = new int[1 << n];
        for (int state = 1; state < (1 << n); state++) {
            // 尾随 0
            int x = Integer.numberOfTrailingZeros(state);
            int y = state - (1 << x);
            sum[state] = sum[y] + jobs[x];
        }

        // dp[i][j] 表示给前 i 个人分配工作,工作分配情况为 j 时,完成所有工作量的最短时间。
        int[][] dp = new int[k][1 << n];

        // 初始状态 dp[0][state] = sum[state]
        System.arraycopy(sum, 0, dp[0], 0, (1 << n));

        // 状态转移 dp[i][j] = min(max(dp[i-1][j的子集的补集],sum[j的子集]))
        for (int i = 1; i < k; i++) {
            for (int state = 0; state < (1 << n); state++) {
                int min = Integer.MAX_VALUE;
                for (int subState = state; subState > 0; subState = (subState - 1) & state) {
                    min = Math.min(min, Math.max(dp[i - 1][state ^ subState], sum[subState]));
                }
                dp[i][state] = min;
            }
        }
        return dp[k - 1][(1 << n) - 1];
    }
}

1986. 完成任务的最少工作时间段

import java.util.Arrays;

public class Solution1986 {
    public int minSessions(int[] tasks, int sessionTime) {
        // 1 <= n <= 14
        int n = tasks.length;

        // 状态压缩 2^14 = 16384
        boolean[] valid = new boolean[1 << n];
        for (int state = 0; state < (1 << n); state++) {
            int totalTime = 0;
            for (int k = 0; k < n; k++) {
                // 第 k 位被选中
                if (((state >> k) & 1) == 1) {
                    totalTime += tasks[k];
                }
            }
            valid[state] = totalTime <= sessionTime;
        }

        // 动态规划
        int[] dp = new int[1 << n];
        Arrays.fill(dp, Integer.MAX_VALUE / 2);
        // 初始状态
        dp[0] = 0;
        for (int state = 1; state < (1 << n); state++) {
            // 使用按位与运算在 O(1) 的时间快速得到下一个(即更小的)mask 的子集
            for (int subState = state; subState > 0; subState = (subState - 1) & state) {
                if (valid[subState]) {
                    // 补集
                    dp[state] = Math.min(dp[state], dp[state ^ subState] + 1);
                }
            }
        }
        return dp[(1 << n) - 1];
    }
}

参考链接

(全文完)