跳至主要內容

树形 DP

大约 4 分钟

树形 DP

题目难度
543. 二叉树的直径open in new window简单灵神 b 站
124. 二叉树中的最大路径和open in new window困难灵神 b 站
2246. 相邻字符不同的最长路径open in new window困难T4、灵神 b 站
$1245. 树的直径open in new window 中等
2538. 最大价值和与最小价值和的差值open in new window困难T4
310. 最小高度树open in new window中等
834. 树中距离之和open in new window困难
1617. 统计子树中城市之间最大距离open in new window困难TODO

定义

543. 二叉树的直径

public class Solution543 {
    private int max;

    public int diameterOfBinaryTree(TreeNode root) {
        max = 0;
        dfs(root);
        return max;
    }

    private int dfs(TreeNode node) {
        if (node == null) {
            return -1;
        }
        int left = dfs(node.left);
        int right = dfs(node.right);
        max = Math.max(max, left + right + 2);
        return Math.max(left, right) + 1;
    }
}

124. 二叉树中的最大路径和

public class Solution124 {
    private int max;

    public int maxPathSum(TreeNode root) {
        max = Integer.MIN_VALUE;
        dfs(root);
        return max;
    }

    private int dfs(TreeNode node) {
        if (node == null) {
            return 0;
        }
        int left = Math.max(0, dfs(node.left));
        int right = Math.max(0, dfs(node.right));
        max = Math.max(max, left + right + node.val);
        return Math.max(left, right) + node.val;
    }
}

2246. 相邻字符不同的最长路径

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution2246 {
    private String s;
    private Map<Integer, List<Integer>> adj;
    private int max;

    public int longestPath(int[] parent, String s) {
        this.s = s;
        adj = new HashMap<>();
        for (int i = 1; i < parent.length; i++) {
            adj.computeIfAbsent(parent[i], key -> new ArrayList<>()).add(i);
        }
        max = 0;

        dfs(0, -1);
        return max + 1;
    }

    private int dfs(int x, int fa) {
        int xLen = 0;
        for (int y : adj.getOrDefault(x, new ArrayList<>())) {
            if (y == fa) continue;
            int yLen = dfs(y, x) + 1;
            if (s.charAt(x) != s.charAt(y)) {
                max = Math.max(max, xLen + yLen);
                xLen = Math.max(xLen, yLen);
            }
        }
        return xLen;
    }
}

$1245. 树的直径

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution1245 {
    private Map<Integer, List<Integer>> adj;
    private int ans;

    public int treeDiameter(int[][] edges) {
        adj = new HashMap<>();
        for (int[] edge : edges) {
            adj.computeIfAbsent(edge[0], key -> new ArrayList<>()).add(edge[1]);
            adj.computeIfAbsent(edge[1], key -> new ArrayList<>()).add(edge[0]);
        }
        ans = 0;

        dfs(0, -1);
        return ans;
    }

    private int dfs(int x, int fa) {
        int maxLen = 0;
        for (int y : adj.getOrDefault(x, new ArrayList<>())) {
            if (y == fa) {
                continue;
            }
            int len = dfs(y, x);
            ans = Math.max(ans, maxLen + len);
            maxLen = Math.max(maxLen, len);
        }
        return maxLen + 1;
    }
}

2538. 最大价值和与最小价值和的差值

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution6294 {
    private int[] price;
    private Map<Integer, List<Integer>> adj;
    private long ans;

    public long maxOutput(int n, int[][] edges, int[] price) {
        this.price = price;
        adj = new HashMap<>();
        for (int[] edge : edges) {
            adj.computeIfAbsent(edge[0], key -> new ArrayList<>()).add(edge[1]);
            adj.computeIfAbsent(edge[1], key -> new ArrayList<>()).add(edge[0]);
        }

        ans = 0;
        dfs(0, -1);
        return ans;
    }

    private long[] dfs(int x, int fa) {
        // sum1 带上端点的最大路径和,sum2 不带上端点的最大路径和
        long sum1 = price[x], sum2 = 0;
        for (int y : adj.getOrDefault(x, new ArrayList<>())) {
            if (y == fa) {
                continue;
            }
            long[] tuple = dfs(y, x);
            ans = Math.max(ans, Math.max(tuple[0] + sum2, tuple[1] + sum1));
            sum1 = Math.max(sum1, tuple[0] + price[x]);
            sum2 = Math.max(sum2, tuple[1] + price[x]);
        }
        return new long[]{sum1, sum2};
    }
}

310. 最小高度树

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution310 {
    private Map<Integer, List<Integer>> adj;
    private int[] f1, f2, g, p;

    // 树形 DP
    // https://leetcode.cn/problems/minimum-height-trees/solution/by-ac_oier-7xio/
    public List<Integer> findMinHeightTrees(int n, int[][] edges) {
        adj = new HashMap<>();
        for (int[] edge : edges) {
            adj.computeIfAbsent(edge[0], key -> new ArrayList<>()).add(edge[1]);
            adj.computeIfAbsent(edge[1], key -> new ArrayList<>()).add(edge[0]);
        }
        // f[u] 代表在以 0 号点为根节点的树中,以 u 节点为子树根节点时,往下的最大高度
        // g[u] 代表在以 0 号点为根节点的树中,以 u 节点为子节点时,往上的最大高度
        // f1 最大值,f2 次大值
        f1 = new int[n];
        f2 = new int[n];
        g = new int[n];
        // p 数组记录下取得 f1[u] 时 u 的子节点 j 为何值。
        p = new int[n];

        dfs1(0, -1);
        dfs2(0, -1);
        List<Integer> ans = new ArrayList<>();
        int min = n;
        for (int i = 0; i < n; i++) {
            int cur = Math.max(f1[i], g[i]);
            if (cur < min) {
                min = cur;
                ans.clear();
                ans.add(i);
            } else if (cur == min) {
                ans.add(i);
            }
        }
        return ans;
    }

    private int dfs1(int u, int fa) {
        for (int v : adj.getOrDefault(u, new ArrayList<>())) {
            if (v == fa) {
                continue;
            }
            int sub = dfs1(v, u) + 1;
            if (sub > f1[u]) {
                f2[u] = f1[u];
                f1[u] = sub;
                p[u] = v;
            } else if (sub > f2[u]) {
                f2[u] = sub;
            }
        }
        return f1[u];
    }

    private void dfs2(int u, int fa) {
        for (int v : adj.getOrDefault(u, new ArrayList<>())) {
            if (v == fa) {
                continue;
            }
            if (p[u] != v) {
                g[v] = Math.max(g[v], f1[u] + 1);
            } else {
                g[v] = Math.max(g[v], f2[u] + 1);
            }
            g[v] = Math.max(g[v], g[u] + 1);
            dfs2(v, u);
        }
    }
}

834. 树中距离之和

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution834 {
    private Map<Integer, List<Integer>> adj;
    private int[] sz, dp, ans;

    // https://leetcode.cn/problems/sum-of-distances-in-tree/solution/shu-zhong-ju-chi-zhi-he-by-leetcode-solution/
    public int[] sumOfDistancesInTree(int n, int[][] edges) {
        adj = new HashMap<>();
        for (int[] edge : edges) {
            adj.computeIfAbsent(edge[0], key -> new ArrayList<>()).add(edge[1]);
            adj.computeIfAbsent(edge[1], key -> new ArrayList<>()).add(edge[0]);
        }
        // dp[u] 表示以 u 为根的子树,它的所有子节点到它的距离之和
        // sz[u] 表示以 u 为根的子树的节点数量
        sz = new int[n];
        dp = new int[n];
        ans = new int[n];

        dfs1(0, -1);
        dfs2(0, -1);
        return ans;
    }

    private void dfs1(int u, int fa) {
        sz[u] = 1;
        dp[u] = 0;
        for (int v : adj.getOrDefault(u, new ArrayList<>())) {
            if (v == fa) {
                continue;
            }
            dfs1(v, u);
            dp[u] += dp[v] + sz[v];
            sz[u] += sz[v];
        }
    }

    private void dfs2(int u, int fa) {
        ans[u] = dp[u];
        for (int v : adj.getOrDefault(u, new ArrayList<>())) {
            if (v == fa) {
                continue;
            }
            // 让 v 换到根的位置,u 变为其孩子节点,同时维护原有的 dp 信息
            int pu = dp[u], pv = dp[v];
            int su = sz[u], sv = sz[v];

            dp[u] -= dp[v] + sz[v];
            sz[u] -= sz[v];
            dp[v] += dp[u] + sz[u];
            sz[v] += sz[u];

            dfs2(v, u);

            dp[u] = pu;
            dp[v] = pv;
            sz[u] = su;
            sz[v] = sv;
        }
    }
}

(全文完)