跳至主要內容

最近公共祖先

大约 11 分钟

最近公共祖先

题目难度
235. 二叉搜索树的最近公共祖先open in new window简单BST
236. 二叉树的最近公共祖先open in new window中等
$1644. 二叉树的最近公共祖先 IIopen in new window 中等
$1650. 二叉树的最近公共祖先 IIIopen in new window 中等相交链表
$1676. 二叉树的最近公共祖先 IVopen in new window 中等

倍增法(Binary Lifting)

题目难度
1483. 树节点的第 K 个祖先open in new window困难
CF587Copen in new windowrating 2200pa, pamins
CF609Eopen in new windowrating 2100pa, pamax
CF733Fopen in new windowrating 2200pa, pamax, paeid
CF980Eopen in new windowrating 2200pa, only java 8/11 AC, java 17 64bit MLE
CF1702G2open in new windowrating 2000pa

定义

最近公共祖先简称 LCA(Lowest Common Ancestor)。两个节点的最近公共祖先,就是这两个点的公共祖先里面。

1483. 树节点的第 K 个祖先

public class Solution1483 {
    static class TreeAncestor {
        // pa[x][i] 表示节点 x 的 2^i 个祖先节点,若不存在,用 -1 表示
        private final int[][] pa;

        public TreeAncestor(int n, int[] parent) {
            // n 的二进制长度
            int m = 32 - Integer.numberOfLeadingZeros(n);
            pa = new int[n][m];
            for (int i = 0; i < n; i++) {
                pa[i][0] = parent[i];
            }

            for (int x = 0; x < n; x++) {
                for (int i = 0; i + 1 < m; i++) {
                    int fa = pa[x][i];
                    pa[x][i + 1] = fa < 0 ? -1 : pa[fa][i];
                }
            }
        }

        public int getKthAncestor(int node, int k) {
            // k 的二进制长度
            int m = 32 - Integer.numberOfLeadingZeros(k);
            for (int i = 0; i < m; i++) {
                if ((k >> i & 1) == 1) {
                    node = pa[node][i];
                    if (node < 0) {
                        break;
                    }
                }
            }
            return node;
        }
    }
}

CF587C

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.stream.Collectors;

public class CF587C {
    static int n, m, q;
    static int[] c;
    static int[][] uva;

    public static void main(String[] args) {
        FastReader scanner = new FastReader();
        n = scanner.nextInt();
        m = scanner.nextInt();
        q = scanner.nextInt();
        adj = new HashMap<>();
        for (int i = 0; i < n - 1; i++) {
            int x = scanner.nextInt() - 1;
            int y = scanner.nextInt() - 1;
            adj.computeIfAbsent(x, key -> new ArrayList<>()).add(y);
            adj.computeIfAbsent(y, key -> new ArrayList<>()).add(x);
        }
        c = new int[m];
        for (int i = 0; i < m; i++) {
            c[i] = scanner.nextInt();
        }
        uva = new int[q][3];
        for (int i = 0; i < q; i++) {
            uva[i][0] = scanner.nextInt() - 1;
            uva[i][1] = scanner.nextInt() - 1;
            uva[i][2] = scanner.nextInt();
        }
        System.out.println(solve());
    }

    // 2^17 = 131072 > 1e5
    private static final int mx = 17;
    private static Map<Integer, List<Integer>> adj;
    private static int[] depth;
    private static int[][] pa;
    static List<Integer>[][] pamins;
    private static int k = 10;

    private static String solve() {
        depth = new int[n];
        pa = new int[n][mx];
        pamins = new ArrayList[n][mx];
        for (int i = 0; i < n; i++) {
            pamins[i][0] = new ArrayList<>();
        }
        for (int i = 0; i < m; i++) {
            int v = c[i] - 1;
            if (pamins[v][0].size() < 10) {
                pamins[v][0].add(i + 1);
            }
        }
        dfs(0, -1, 0);
        for (int i = 0; i + 1 < mx; i++) {
            for (int v = 0; v < n; v++) {
                int p = pa[v][i];
                if (p != -1) {
                    pa[v][i + 1] = pa[p][i];
                    pamins[v][i + 1] = merge(pamins[v][i], pamins[p][i]);
                } else {
                    pa[v][i + 1] = -1;
                    pamins[v][i + 1] = new ArrayList<>();
                }
            }
        }

        List<String> resList = new ArrayList<>();
        for (int[] tuple : uva) {
            int v = tuple[0], w = tuple[1];
            k = tuple[2];

            if (depth[v] > depth[w]) {
                int tmp = v;
                v = w;
                w = tmp;
            }
            List<Integer> mins = new ArrayList<>();
            for (int i = 0; i < mx; i++) {
                if (((depth[w] - depth[v]) >> i & 1) == 1) {
                    mins = merge(mins, pamins[w][i]);
                    w = pa[w][i];
                }
            }
            if (w != v) {
                for (int i = mx - 1; i >= 0; i--) {
                    if (pa[v][i] != pa[w][i]) {
                        mins = merge(mins, merge(pamins[v][i], pamins[w][i]));
                        v = pa[v][i];
                        w = pa[w][i];
                    }
                }
                mins = merge(mins, merge(pamins[v][0], pamins[w][0]));
                v = pa[v][0];
            }
            mins = merge(mins, pamins[v][0]);

            String res = mins.size() + " " + mins.stream().map(String::valueOf).collect(Collectors.joining(" "));
            resList.add(res.trim());
        }
        return String.join(System.lineSeparator(), resList);
    }

    private static void dfs(int x, int fa, int d) {
        pa[x][0] = fa;
        depth[x] = d;
        for (Integer y : adj.getOrDefault(x, new ArrayList<>())) {
            if (y == fa) continue;
            dfs(y, x, d + 1);
        }
    }

    private static List<Integer> merge(List<Integer> a, List<Integer> b) {
        int i = 0, n = a.size();
        int j = 0, m = b.size();
        List<Integer> c = new ArrayList<>(Math.min(n + m, k));
        while (c.size() < k) {
            if (i == n) {
                if (c.size() + m - j > k) {
                    c.addAll(b.subList(j, j + k - c.size()));
                } else {
                    c.addAll(b.subList(j, b.size()));
                }
                break;
            }
            if (j == m) {
                if (c.size() + n - i > k) {
                    c.addAll(a.subList(i, i + k - c.size()));
                } else {
                    c.addAll(a.subList(i, a.size()));
                }
                break;
            }
            if (a.get(i) < b.get(j)) {
                c.add(a.get(i));
                i++;
            } else {
                c.add(b.get(j));
                j++;
            }
        }
        return c;
    }

    private static class FastReader {
        private final BufferedReader bufferedReader;
        private StringTokenizer stringTokenizer;

        public FastReader() {
            bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        }

        public String next() {
            while (stringTokenizer == null || !stringTokenizer.hasMoreElements()) {
                try {
                    stringTokenizer = new StringTokenizer(bufferedReader.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return stringTokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(next());
        }

        public long nextLong() {
            return Long.parseLong(next());
        }

        public double nextDouble() {
            return Double.parseDouble(next());
        }

        public String nextLine() {
            String str = "";
            try {
                if (stringTokenizer.hasMoreTokens()) {
                    str = stringTokenizer.nextToken("\n");
                } else {
                    str = bufferedReader.readLine();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
}

CF609E

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.stream.Collectors;

public class CF609E {
    static int n, m;
    static Edge[] es;

    public static void main(String[] args) {
        FastReader scanner = new FastReader();
        n = scanner.nextInt();
        m = scanner.nextInt();
        es = new Edge[m];
        for (int i = 0; i < m; i++) {
            int u = scanner.nextInt() - 1;
            int v = scanner.nextInt() - 1;
            int w = scanner.nextInt();
            es[i] = new Edge(u, v, w, i);
        }
        System.out.println(solve());
    }

    // 2^18 = 262144 > 2e5
    private static final int mx = 18;
    private static Map<Integer, List<int[]>> adj;
    private static int[][] pa;
    private static int[][] pamax;
    private static int[] depth;

    private static String solve() {
        Arrays.sort(es, Comparator.comparingInt(o -> o.wt));

        adj = new HashMap<>(n);
        long s = 0;
        DSU dsu = new DSU(n);
        for (Edge e : es) {
            int x = e.x, y = e.y, w = e.wt;
            if (!dsu.union(x, y)) {
                s += w;
                adj.computeIfAbsent(x, key -> new ArrayList<>()).add(new int[]{y, w});
                adj.computeIfAbsent(y, key -> new ArrayList<>()).add(new int[]{x, w});
                e.wt = 0;
            }
        }

        depth = new int[n];
        pa = new int[n][mx];
        pamax = new int[n][mx];
        dfs(0, -1, 0);

        for (int i = 0; i + 1 < mx; i++) {
            for (int v = 0; v < n; v++) {
                int p = pa[v][i];
                if (p != -1) {
                    pa[v][i + 1] = pa[p][i];
                    pamax[v][i + 1] = Math.max(pamax[v][i], pamax[p][i]);
                } else {
                    pa[v][i + 1] = -1;
                    pamax[v][i + 1] = 0;
                }
            }
        }

        long[] ans = new long[m];
        Arrays.fill(ans, s);
        for (Edge e : es) {
            if (e.wt > 0) {
                int maxWt = maxWt(e.x, e.y);
                ans[e.id] += e.wt - maxWt;
            }
        }
        return Arrays.stream(ans).mapToObj(String::valueOf).collect(Collectors.joining(System.lineSeparator()));
    }

    private static int maxWt(int v, int w) {
        int mxWt = 0;
        if (depth[v] > depth[w]) {
            int tmp = v;
            v = w;
            w = tmp;
        }
        for (int i = 0; i < mx; i++) {
            if (((depth[w] - depth[v]) >> i & 1) == 1) {
                mxWt = Math.max(mxWt, pamax[w][i]);
                w = pa[w][i];
            }
        }
        if (v == w) {
            return mxWt;
        }
        for (int i = mx - 1; i >= 0; i--) {
            if (pa[v][i] != pa[w][i]) {
                mxWt = Math.max(mxWt, Math.max(pamax[v][i], pamax[w][i]));
                v = pa[v][i];
                w = pa[w][i];
            }
        }
        return Math.max(mxWt, Math.max(pamax[v][0], pamax[w][0]));
    }

    private static void dfs(int x, int fa, int d) {
        pa[x][0] = fa;
        depth[x] = d;
        for (int[] tuple : adj.getOrDefault(x, new ArrayList<>())) {
            int y = tuple[0], wt = tuple[1];
            if (y == fa) continue;
            pamax[y][0] = wt;
            dfs(y, x, d + 1);
        }
    }

    private static class DSU {
        int[] fa;

        public DSU(int n) {
            fa = new int[n];
            for (int i = 0; i < n; i++) {
                fa[i] = i;
            }
        }

        int find(int x) {
            if (x != fa[x]) {
                fa[x] = find(fa[x]);
            }
            return fa[x];
        }

        boolean union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return true;
            }
            fa[rootQ] = rootP;
            return false;
        }
    }

    private static class Edge {
        int x, y, wt, id;

        public Edge(int x, int y, int wt, int id) {
            this.x = x;
            this.y = y;
            this.wt = wt;
            this.id = id;
        }
    }

    private static class FastReader {
        private final BufferedReader bufferedReader;
        private StringTokenizer stringTokenizer;

        public FastReader() {
            bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        }

        public String next() {
            while (stringTokenizer == null || !stringTokenizer.hasMoreElements()) {
                try {
                    stringTokenizer = new StringTokenizer(bufferedReader.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return stringTokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(next());
        }

        public long nextLong() {
            return Long.parseLong(next());
        }

        public double nextDouble() {
            return Double.parseDouble(next());
        }

        public String nextLine() {
            String str = "";
            try {
                if (stringTokenizer.hasMoreTokens()) {
                    str = stringTokenizer.nextToken("\n");
                } else {
                    str = bufferedReader.readLine();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
}

CF733F

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;

public class CF733F {
    static int n, m;
    static Edge[] es;
    static int money;

    public static void main(String[] args) {
        FastReader scanner = new FastReader();
        n = scanner.nextInt();
        m = scanner.nextInt();
        es = new Edge[m];
        for (int i = 0; i < m; i++) {
            es[i] = new Edge();
            es[i].wt = scanner.nextInt();
            es[i].id = i;
        }
        for (int i = 0; i < m; i++) {
            es[i].c = scanner.nextInt();
        }
        for (int i = 0; i < m; i++) {
            es[i].x = scanner.nextInt() - 1;
            es[i].y = scanner.nextInt() - 1;
        }
        money = scanner.nextInt();
        System.out.println(solve());
    }

    // 2^18 = 262144 > 2e5
    private static final int mx = 18;
    private static Map<Integer, List<int[]>> adj;
    private static int[][] pa;
    private static int[][] pamax;
    private static int[][] paeid;
    private static int[] depth;

    private static String solve() {
        Arrays.sort(es, Comparator.comparingInt(o -> o.wt));

        long s = 0;
        adj = new HashMap<>();
        DSU dsu = new DSU(n);
        for (Edge e : es) {
            int x = e.x, y = e.y, wt = e.wt, eid = e.id;
            if (!dsu.union(x, y)) {
                s += wt;
                adj.computeIfAbsent(x, key -> new ArrayList<>()).add(new int[]{y, wt, eid});
                adj.computeIfAbsent(y, key -> new ArrayList<>()).add(new int[]{x, wt, eid});
                e.wt = -e.wt;
            }
        }

        pa = new int[n][mx];
        pamax = new int[n][mx];
        paeid = new int[n][mx];
        depth = new int[n];
        dfs(0, -1, 0);

        for (int i = 0; i + 1 < mx; i++) {
            for (int v = 0; v < n; v++) {
                int p = pa[v][i];
                if (p != -1) {
                    pa[v][i + 1] = pa[p][i];
                    if (pamax[v][i] > pamax[p][i]) {
                        pamax[v][i + 1] = pamax[v][i];
                        paeid[v][i + 1] = paeid[v][i];
                    } else {
                        pamax[v][i + 1] = pamax[p][i];
                        paeid[v][i + 1] = paeid[p][i];
                    }
                } else {
                    pa[v][i + 1] = -1;
                }
            }
        }

        int mxDec = -1, ori = -1, cur = 0;
        for (Edge e : es) {
            int dec = money / e.c;
            if (e.wt > 0) {
                int[] tuple = maxWt(e.x, e.y);
                int mxWt = tuple[0], eid = tuple[1];
                dec = mxWt - (e.wt - dec);
                if (mxDec < dec) {
                    mxDec = dec;
                    ori = eid;
                    cur = e.id;
                }
            } else {
                if (mxDec < dec) {
                    mxDec = dec;
                    ori = -1;
                    cur = e.id;
                }
            }
        }

        StringBuilder ans = new StringBuilder();
        ans.append(s - mxDec).append(System.lineSeparator());
        for (Edge e : es) {
            if (e.id == ori || e.id != cur && e.wt > 0) {
                continue;
            }
            int wt = e.wt;
            if (wt < 0) {
                wt = -wt;
            }
            if (e.id == cur) {
                wt -= money / e.c;
            }
            ans.append(e.id + 1).append(" ").append(wt).append(System.lineSeparator());
        }
        return ans.toString();
    }

    private static int[] maxWt(int v, int w) {
        int mxWt = 0, eid = 0;
        if (depth[v] > depth[w]) {
            int tmp = v;
            v = w;
            w = tmp;
        }
        for (int i = 0; i < mx; i++) {
            if (((depth[w] - depth[v]) >> i & 1) == 1) {
                if (mxWt < pamax[w][i]) {
                    mxWt = pamax[w][i];
                    eid = paeid[w][i];
                }
                w = pa[w][i];
            }
        }
        if (v == w) {
            return new int[]{mxWt, eid};
        }
        for (int i = mx - 1; i >= 0; i--) {
            if (pa[v][i] != pa[w][i]) {
                if (mxWt < pamax[v][i]) {
                    mxWt = pamax[v][i];
                    eid = paeid[v][i];
                }
                if (mxWt < pamax[w][i]) {
                    mxWt = pamax[w][i];
                    eid = paeid[w][i];
                }
                v = pa[v][i];
                w = pa[w][i];
            }
        }
        if (mxWt < pamax[v][0]) {
            mxWt = pamax[v][0];
            eid = paeid[v][0];
        }
        if (mxWt < pamax[w][0]) {
            mxWt = pamax[w][0];
            eid = paeid[w][0];
        }
        return new int[]{mxWt, eid};
    }

    private static void dfs(int x, int fa, int d) {
        pa[x][0] = fa;
        depth[x] = d;
        for (int[] tuple : adj.getOrDefault(x, new ArrayList<>())) {
            int y = tuple[0], wt = tuple[1], eid = tuple[2];
            if (y == fa) continue;
            pamax[y][0] = wt;
            paeid[y][0] = eid;
            dfs(y, x, d + 1);
        }
    }

    private static class DSU {
        int[] fa;

        public DSU(int n) {
            fa = new int[n];
            for (int i = 0; i < n; i++) {
                fa[i] = i;
            }
        }

        int find(int x) {
            if (x != fa[x]) {
                fa[x] = find(fa[x]);
            }
            return fa[x];
        }

        boolean union(int p, int q) {
            int rootP = find(p);
            int rootQ = find(q);
            if (rootP == rootQ) {
                return true;
            }
            fa[rootQ] = rootP;
            return false;
        }
    }

    private static class Edge {
        int x, y, wt, c, id;
    }

    private static class FastReader {
        private final BufferedReader bufferedReader;
        private StringTokenizer stringTokenizer;

        public FastReader() {
            bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        }

        public String next() {
            while (stringTokenizer == null || !stringTokenizer.hasMoreElements()) {
                try {
                    stringTokenizer = new StringTokenizer(bufferedReader.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return stringTokenizer.nextToken();
        }

        public int nextInt() {
            return Integer.parseInt(next());
        }

        public long nextLong() {
            return Long.parseLong(next());
        }

        public double nextDouble() {
            return Double.parseDouble(next());
        }

        public String nextLine() {
            String str = "";
            try {
                if (stringTokenizer.hasMoreTokens()) {
                    str = stringTokenizer.nextToken("\n");
                } else {
                    str = bufferedReader.readLine();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }
}

CF980E

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Arrays;

public class CF980E {
    static int n, k;

    public static void main(String[] args) throws IOException {
        // only java 8/11 AC, java 17 64bit MLE
        Reader scanner = new Reader();
        n = scanner.nextInt();
        k = scanner.nextInt();
        Arrays.fill(he, -1);
        for (int i = 0; i < n - 1; i++) {
            int x = scanner.nextInt() - 1;
            int y = scanner.nextInt() - 1;
            add(x, y);
            add(y, x);
        }
        solve();
    }

    static int N = (int) (1e6), M = N * 2;
    // 链式前向星
    static int[] he = new int[N], ne = new int[M], ed = new int[M];
    //    static int[] we = new int[M];
    static int idx = 0;

    static void add(int u, int v) {
        ed[idx] = v;
        ne[idx] = he[u];
        he[u] = idx;
//        we[idx] = w;
        idx++;
    }

    // 2^20 = 1048576 > 1e6
    private static final int mx = 19;
    private static int[][] pa;
    private static int[] depth;

    private static void solve() {
        pa = new int[n][mx];
        depth = new int[n];
        dfs(n - 1, -1, 0);

        for (int i = 0; i + 1 < mx; i++) {
            for (int v = 0; v < n; v++) {
                int p = pa[v][i];
                if (p != -1) {
                    pa[v][i + 1] = pa[p][i];
                } else {
                    pa[v][i + 1] = -1;
                }
            }
        }

        boolean[] save = new boolean[n];
        save[n - 1] = true;
        for (int i = n - 2, left = n - 1 - k; i >= 0; i--) {
            if (save[i]) {
                continue;
            }
            int v = i;
            for (int j = mx - 1; j >= 0; j--) {
                int p = pa[v][j];
                if (p != -1 && !save[p]) {
                    v = p;
                }
            }
            int d = depth[i] - depth[v] + 1;
            if (d <= left) {
                left -= d;
                for (v = i; !save[v]; v = pa[v][0]) {
                    save[v] = true;
                }
            }
        }

        StringBuilder ans = new StringBuilder();
        for (int i = 0; i < n; i++) {
            if (!save[i]) {
                ans.append(i + 1).append(" ");
            }
        }
        System.out.println(ans);
    }

    private static void dfs(int x, int fa, int d) {
        pa[x][0] = fa;
        depth[x] = d;
        for (int i = he[x]; i != -1; i = ne[i]) {
            int y = ed[i];
            if (y == fa) continue;
            dfs(y, x, d + 1);
        }
    }

    private static class Reader {
        private final int BUFFER_SIZE = 1 << 16;
        private final DataInputStream dataInputStream;
        private final byte[] buffer;
        private int bufferPointer, bytesRead;

        public Reader() {
            dataInputStream = new DataInputStream(System.in);
            buffer = new byte[BUFFER_SIZE];
            bufferPointer = bytesRead = 0;
        }

        public Reader(String file_name) throws IOException {
            dataInputStream = new DataInputStream(new FileInputStream(file_name));
            buffer = new byte[BUFFER_SIZE];
            bufferPointer = bytesRead = 0;
        }

        public String readLine() throws IOException {
            byte[] buf = new byte[64]; // line length
            int cnt = 0, c;
            while ((c = read()) != -1) {
                if (c == '\n') {
                    if (cnt != 0) {
                        break;
                    } else {
                        continue;
                    }
                }
                buf[cnt++] = (byte) c;
            }
            return new String(buf, 0, cnt);
        }

        public int nextInt() throws IOException {
            int ret = 0;
            byte c = read();
            while (c <= ' ') {
                c = read();
            }
            boolean neg = (c == '-');
            if (neg) {
                c = read();
            }
            do {
                ret = ret * 10 + c - '0';
            } while ((c = read()) >= '0' && c <= '9');
            if (neg) {
                return -ret;
            }
            return ret;
        }

        public long nextLong() throws IOException {
            long ret = 0;
            byte c = read();
            while (c <= ' ') {
                c = read();
            }
            boolean neg = (c == '-');
            if (neg) {
                c = read();
            }
            do {
                ret = ret * 10 + c - '0';
            } while ((c = read()) >= '0' && c <= '9');
            if (neg) {
                return -ret;
            }
            return ret;
        }

        public double nextDouble() throws IOException {
            double ret = 0, div = 1;
            byte c = read();
            while (c <= ' ') {
                c = read();
            }
            boolean neg = (c == '-');
            if (neg) {
                c = read();
            }
            do {
                ret = ret * 10 + c - '0';
            } while ((c = read()) >= '0' && c <= '9');
            if (c == '.') {
                while ((c = read()) >= '0' && c <= '9') {
                    ret += (c - '0') / (div *= 10);
                }
            }
            if (neg) {
                return -ret;
            }
            return ret;
        }

        private void fillBuffer() throws IOException {
            bytesRead = dataInputStream.read(buffer, bufferPointer = 0, BUFFER_SIZE);
            if (bytesRead == -1) {
                buffer[0] = -1;
            }
        }

        private byte read() throws IOException {
            if (bufferPointer == bytesRead) {
                fillBuffer();
            }
            return buffer[bufferPointer++];
        }

        public void close() throws IOException {
            dataInputStream.close();
        }
    }
}

CF1702G2

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class CF1702G {
    public static void main(String[] args) throws IOException {
        BufferedReader reader = new BufferedReader(new InputStreamReader(System.in, StandardCharsets.UTF_8));
        BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(System.out, StandardCharsets.UTF_8));
        int n = Integer.parseInt(reader.readLine());
        String[] lineNs = new String[n - 1];
        for (int i = 0; i < n - 1; i++) {
            lineNs[i] = reader.readLine();
        }
        int q = Integer.parseInt(reader.readLine());
        String[] lineQs = new String[q];
        for (int i = 0; i < q; i++) {
            reader.readLine();
            lineQs[i] = reader.readLine();
        }
        List<String> res = solution(n, lineNs, q, lineQs);
        for (String re : res) {
            writer.write(re.concat(System.lineSeparator()));
        }

        writer.close();
        reader.close();
    }

    static int[] d;
    static int sz;
    static int[][] up;
    static Map<Integer, List<Integer>> adj;

    private static List<String> solution(int n, String[] lineNs, int q, String[] lineQs) {
        sz = 0;
        while ((1 << sz) < n) {
            sz++;
        }
        d = new int[n];
        Arrays.fill(d, -1);
        up = new int[n][sz + 1];
        // 存图
        adj = new HashMap<>();

        for (String lineN : lineNs) {
            String[] line = lineN.split(" ");
            int u = Integer.parseInt(line[0]) - 1;
            int v = Integer.parseInt(line[1]) - 1;

            adj.computeIfAbsent(u, key -> new ArrayList<>()).add(v);
            adj.computeIfAbsent(v, key -> new ArrayList<>()).add(u);
        }

        // 预处理
        precalc(0, 0);

        List<String> resList = new ArrayList<>();
        for (String lineQ : lineQs) {
            String[] line = lineQ.split(" ");
            int k = line.length;
            Integer[] p = new Integer[k];
            for (int j = 0; j < k; j++) {
                p[j] = Integer.parseInt(line[j]) - 1;
            }

            Arrays.sort(p, (o1, o2) -> Integer.compare(d[o2], d[o1]));

            boolean[] used = new boolean[k];
            for (int i = 0; i < k; i++) {
                if (lca(p[0], p[i]) == p[i]) {
                    used[i] = true;
                }
            }
            int f = 0;
            while (f < k && used[f]) {
                f++;
            }
            if (f == k) {
                resList.add("YES");
            } else {
                boolean ans = true;
                for (int i = f; i < k; i++) {
                    if (lca(p[f], p[i]) == p[i]) {
                        used[i] = true;
                    }
                }
                for (boolean e : used) {
                    ans &= e;
                }
                ans &= d[lca(p[0], p[f])] <= d[p[k - 1]];
                resList.add(ans ? "YES" : "NO");
            }
        }
        return resList;
    }

    static void precalc(int v, int p) {
        d[v] = d[p] + 1;
        up[v][0] = p;
        for (int i = 1; i <= sz; ++i) {
            up[v][i] = up[up[v][i - 1]][i - 1];
        }
        for (int u : adj.getOrDefault(v, new ArrayList<>())) {
            if (u == p) {
                continue;
            }
            precalc(u, v);
        }
    }

    // 倍增法求 lca
    static int lca(int u, int v) {
        if (d[u] < d[v]) {
            int tmp = u;
            u = v;
            v = tmp;
        }
        for (int cur = sz; cur >= 0; --cur) {
            if (d[u] - (1 << cur) >= d[v]) {
                u = up[u][cur];
            }
        }
        for (int cur = sz; cur >= 0; --cur) {
            if (up[u][cur] != up[v][cur]) {
                u = up[u][cur];
                v = up[v][cur];
            }
        }
        return u == v ? u : up[u][0];
    }
}

(全文完)