跳至主要內容

拒绝采样

大约 2 分钟

拒绝采样

题目难度
470. 用 Rand7() 实现 Rand10()open in new window中等TODO
478. 在圆内随机生成点open in new window中等
710. 黑名单中的随机数open in new window困难用于辨析,此题不能用 拒绝采样

定义

在数值分析和计算统计中,拒绝采样是用于从分布生成观测值的基本技术。它通常也被称为接受拒绝方法或「接受拒绝算法」,是一种精确的模拟方法。该方法适用于具有密度的任何分布。

拒绝采样基于以下观察:在一维中对随机变量进行采样,可以对二维笛卡尔图执行均匀随机采样,并将样本保持在其密度函数图下的区域中。

470. 用 Rand7() 实现 Rand10()

class Solution extends SolBase {
    public int rand10() {
        int a, b, idx;
        while (true) {
            a = rand7();
            b = rand7();
            idx = b + (a - 1) * 7;
            if (idx <= 40) {
                return 1 + (idx - 1) % 10;
            }
            a = idx - 40;
            b = rand7();
            // get uniform dist from 1 - 63
            idx = b + (a - 1) * 7;
            if (idx <= 60) {
                return 1 + (idx - 1) % 10;
            }
            a = idx - 60;
            b = rand7();
            // get uniform dist from 1 - 21
            idx = b + (a - 1) * 7;
            if (idx <= 20) {
                return 1 + (idx - 1) % 10;
            }
        }
    }
}

478. 在圆内随机生成点

import java.util.Random;

public class Solution478 {
    static class Solution {
        private final Random random;
        private final double radius;
        private final double xCenter;
        private final double yCenter;
        private final double size;

        public Solution(double radius, double x_center, double y_center) {
            this.radius = radius;
            this.xCenter = x_center;
            this.yCenter = y_center;
            size = Math.PI * radius * radius;
            random = new Random();
        }

        public double[] randPoint() {
            while (true) {
                double x = random.nextDouble() * (2 * radius) - radius;
                double y = random.nextDouble() * (2 * radius) - radius;
                if (x * x + y * y <= radius * radius) {
                    return new double[]{xCenter + x, yCenter + y};
                }
            }
        }

        public double[] randPoint2() {
            double theta = random.nextDouble() * 2 * Math.PI, r = Math.sqrt(random.nextDouble() * size / Math.PI);
            return new double[]{xCenter + Math.cos(theta) * r, yCenter + Math.sin(theta) * r};
        }
    }
}

710. 黑名单中的随机数

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;

public class Solution710 {
    static class Solution {
        private final Map<Integer, Integer> b2w;
        private final Random random;
        private final int bound;

        public Solution(int n, int[] blacklist) {
            b2w = new HashMap<>();
            random = new Random();
            int m = blacklist.length;
            bound = n - m;
            Set<Integer> black = new HashSet<>();
            for (int b : blacklist) {
                if (b >= bound) {
                    black.add(b);
                }
            }

            int w = bound;
            for (int b : blacklist) {
                if (b < bound) {
                    while (black.contains(w)) {
                        ++w;
                    }
                    b2w.put(b, w);
                    ++w;
                }
            }
        }

        public int pick() {
            int x = random.nextInt(bound);
            return b2w.getOrDefault(x, x);
        }
    }
}

(全文完)