拒绝采样
大约 2 分钟
拒绝采样
题目 | 难度 | |
---|---|---|
470. 用 Rand7() 实现 Rand10() | 中等 | TODO |
478. 在圆内随机生成点 | 中等 | |
710. 黑名单中的随机数 | 困难 | 用于辨析,此题不能用 拒绝采样 |
定义
在数值分析和计算统计中,拒绝采样是用于从分布生成观测值的基本技术。它通常也被称为接受拒绝方法或「接受拒绝算法」,是一种精确的模拟方法。该方法适用于具有密度的任何分布。
拒绝采样基于以下观察:在一维中对随机变量进行采样,可以对二维笛卡尔图执行均匀随机采样,并将样本保持在其密度函数图下的区域中。
470. 用 Rand7() 实现 Rand10()
class Solution extends SolBase {
public int rand10() {
int a, b, idx;
while (true) {
a = rand7();
b = rand7();
idx = b + (a - 1) * 7;
if (idx <= 40) {
return 1 + (idx - 1) % 10;
}
a = idx - 40;
b = rand7();
// get uniform dist from 1 - 63
idx = b + (a - 1) * 7;
if (idx <= 60) {
return 1 + (idx - 1) % 10;
}
a = idx - 60;
b = rand7();
// get uniform dist from 1 - 21
idx = b + (a - 1) * 7;
if (idx <= 20) {
return 1 + (idx - 1) % 10;
}
}
}
}
478. 在圆内随机生成点
import java.util.Random;
public class Solution478 {
static class Solution {
private final Random random;
private final double radius;
private final double xCenter;
private final double yCenter;
private final double size;
public Solution(double radius, double x_center, double y_center) {
this.radius = radius;
this.xCenter = x_center;
this.yCenter = y_center;
size = Math.PI * radius * radius;
random = new Random();
}
public double[] randPoint() {
while (true) {
double x = random.nextDouble() * (2 * radius) - radius;
double y = random.nextDouble() * (2 * radius) - radius;
if (x * x + y * y <= radius * radius) {
return new double[]{xCenter + x, yCenter + y};
}
}
}
public double[] randPoint2() {
double theta = random.nextDouble() * 2 * Math.PI, r = Math.sqrt(random.nextDouble() * size / Math.PI);
return new double[]{xCenter + Math.cos(theta) * r, yCenter + Math.sin(theta) * r};
}
}
}
710. 黑名单中的随机数
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
public class Solution710 {
static class Solution {
private final Map<Integer, Integer> b2w;
private final Random random;
private final int bound;
public Solution(int n, int[] blacklist) {
b2w = new HashMap<>();
random = new Random();
int m = blacklist.length;
bound = n - m;
Set<Integer> black = new HashSet<>();
for (int b : blacklist) {
if (b >= bound) {
black.add(b);
}
}
int w = bound;
for (int b : blacklist) {
if (b < bound) {
while (black.contains(w)) {
++w;
}
b2w.put(b, w);
++w;
}
}
}
public int pick() {
int x = random.nextInt(bound);
return b2w.getOrDefault(x, x);
}
}
}
(全文完)