分布式 ID 算法
大约 4 分钟
分布式 ID 算法
1 前言
有 《深入分析 mysql 为什么不推荐使用 uuid 或者雪花 id 作为主键》 说 id 使用 auto_increment,但在分布式场景中,自增 id 有很多问题,譬如不能提前预知(在落表之前),多 IDC 部署(如全球部署)全局唯一,高并发场景下获取性能都是瓶颈。因此出现了分布式 ID 这一方案。
2 Scala 版本
2010 年 Twitter 开源 Snowflake 分布式 ID 算法,该算法使用 Scala 编写,支持在不依赖第三方情况下生成分布式唯一 ID,因其长度比 UUID 更少而风靡一时。
- Github: https://github.com/twitter-archive/snowflake
- 源码: https://github.com/twitter-archive/snowflake/releases/tag/snowflake-2010
/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake
import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger
/**
* An object that generates IDs.
* This is broken into a separate class in case
* we ever want to support multiple worker threads
* per process
*/
class IdWorker(val workerId: Long, val datacenterId: Long, private val reporter: Reporter, var sequence: Long = 0L)
extends Snowflake.Iface {
private[this] def genCounter(agent: String) = {
Stats.incr("ids_generated")
Stats.incr("ids_generated_%s".format(agent))
}
private[this] val exceptionCounter = Stats.getCounter("exceptions")
private[this] val log = Logger.get
private[this] val rand = new Random
val twepoch = 1288834974657L
private[this] val workerIdBits = 5L
private[this] val datacenterIdBits = 5L
private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
private[this] val sequenceBits = 12L
private[this] val workerIdShift = sequenceBits
private[this] val datacenterIdShift = sequenceBits + workerIdBits
private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)
private[this] var lastTimestamp = -1L
// sanity check for workerId
if (workerId > maxWorkerId || workerId < 0) {
exceptionCounter.incr(1)
throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
}
if (datacenterId > maxDatacenterId || datacenterId < 0) {
exceptionCounter.incr(1)
throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
}
log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)
def get_id(useragent: String): Long = {
if (!validUseragent(useragent)) {
exceptionCounter.incr(1)
throw new InvalidUserAgentError
}
val id = nextId()
genCounter(useragent)
reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
id
}
def get_worker_id(): Long = workerId
def get_datacenter_id(): Long = datacenterId
def get_timestamp() = System.currentTimeMillis
protected[snowflake] def nextId(): Long = synchronized {
var timestamp = timeGen()
if (timestamp < lastTimestamp) {
exceptionCounter.incr(1)
log.error("clock is moving backwards. Rejecting requests until %d.", lastTimestamp);
throw new InvalidSystemClock("Clock moved backwards. Refusing to generate id for %d milliseconds".format(
lastTimestamp - timestamp))
}
if (lastTimestamp == timestamp) {
sequence = (sequence + 1) & sequenceMask
if (sequence == 0) {
timestamp = tilNextMillis(lastTimestamp)
}
} else {
sequence = 0
}
lastTimestamp = timestamp
((timestamp - twepoch) << timestampLeftShift) |
(datacenterId << datacenterIdShift) |
(workerId << workerIdShift) |
sequence
}
protected def tilNextMillis(lastTimestamp: Long): Long = {
var timestamp = timeGen()
while (timestamp <= lastTimestamp) {
timestamp = timeGen()
}
timestamp
}
protected def timeGen(): Long = System.currentTimeMillis()
val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r
def validUseragent(useragent: String): Boolean = useragent match {
case AgentParser(_) => true
case _ => false
}
}
3 Java 版本
根据 Scala 版本代码,我们很容易写出一个 Java 版本:
package com.twitter.service.snowflake;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Locale;
public class SnowflakeIdWorker {
private static final Logger log = LoggerFactory.getLogger(SnowflakeIdWorker.class);
private final long twepoch = 1288834974657L;
private final long workerIdBits = 5L;
private final long datacenterIdBits = 5L;
private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
private final long sequenceBits = 12L;
private final long workerIdShift = sequenceBits;
private final long datacenterIdShift = sequenceBits + workerIdBits;
private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
private final long sequenceMask = -1L ^ (-1L << sequenceBits);
private long lastTimestamp = -1L;
private final long workerId;
private final long datacenterId;
private long sequence = 0L;
public SnowflakeIdWorker(long workerId, long datacenterId) {
if (workerId > maxWorkerId || workerId < 0) {
throw new RuntimeException(String.format(Locale.ENGLISH, "worker Id can't be greater than %d or less than 0", maxWorkerId));
}
if (datacenterId > maxDatacenterId || datacenterId < 0) {
throw new RuntimeException(String.format(Locale.ENGLISH, "datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
}
log.info(String.format(Locale.ENGLISH, "worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId));
this.workerId = workerId;
this.datacenterId = datacenterId;
}
public synchronized long nextId() {
long timestamp = timeGen();
if (timestamp < lastTimestamp) {
log.error(String.format(Locale.ENGLISH, "clock is moving backwards. Rejecting requests until %d.", lastTimestamp));
throw new RuntimeException(String.format(Locale.ENGLISH, "Clock moved backwards. Refusing to generate id for %d milliseconds",
lastTimestamp - timestamp));
}
if (lastTimestamp == timestamp) {
sequence = (sequence + 1) & sequenceMask;
if (sequence == 0) {
timestamp = tilNextMillis(lastTimestamp);
}
} else {
sequence = 0;
}
lastTimestamp = timestamp;
return ((timestamp - twepoch) << timestampLeftShift) |
(datacenterId << datacenterIdShift) |
(workerId << workerIdShift) |
sequence;
}
protected long tilNextMillis(long lastTimestamp) {
long timestamp = timeGen();
while (timestamp <= lastTimestamp) {
timestamp = timeGen();
}
return timestamp;
}
protected long timeGen() {
return System.currentTimeMillis();
}
}
4 基准测试
在 MyBatis-Plus (3.3.0+) 中,默认集成了雪花算法 + UUID (不含中划线),通过坐标引入相关代码
<!-- mybatis-plus -->
<dependency>
<groupId>com.baomidou</groupId>
<artifactId>mybatis-plus</artifactId>
<version>3.5.0</version>
</dependency>
<!-- jmh -->
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-core</artifactId>
<version>1.34</version>
</dependency>
<dependency>
<groupId>org.openjdk.jmh</groupId>
<artifactId>jmh-generator-annprocess</artifactId>
<version>1.34</version>
</dependency>
4.1 32bit UUID
package com.twitter.service.snowflake;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
public class UUIDIdWorker {
public static String getId() {
ThreadLocalRandom random = ThreadLocalRandom.current();
return new UUID(random.nextLong(), random.nextLong()).toString().replace("-", "");
}
}
4.2 JMH 代码
package com.devyy;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.results.format.ResultFormatType;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
public class DistributedIdBenchmark {
@Benchmark
public void snowflake(Blackhole blackhole) {
String id = IdWorker.getIdStr();
blackhole.consume(id);
}
@Benchmark
public void uuid(Blackhole blackhole) {
String id = IdWorker.get32UUID();
blackhole.consume(id);
}
public static void main(String[] args) throws RunnerException {
Options opt = new OptionsBuilder()
.include(DistributedIdBenchmark.class.getSimpleName())
.forks(1)
.result("result.json")
.resultFormat(ResultFormatType.JSON)
.build();
new Runner(opt).run();
}
}
测试结果:UUID 在 jdk8 / jdk17 表现差异极大!
# JMH version: 1.34
# VM version: JDK 1.8.0_202, Java HotSpot(TM) 64-Bit Server VM, 25.202-b08
# VM invoker: C:\Program Files\Java\jdk1.8.0_202\jre\bin\java.exe
...
Benchmark Mode Cnt Score Error Units
DistributedIdBenchmark.snowflake thrpt 5 260907.704 ± 175.112 ops/s
DistributedIdBenchmark.uuid thrpt 5 1750558.071 ± 3696.515 ops/s
# JMH version: 1.34
# VM version: JDK 17.0.1, OpenJDK 64-Bit Server VM, 17.0.1+12-39
# VM invoker: C:\Program Files\Java\jdk-17.0.1\bin\java.exe
...
Benchmark Mode Cnt Score Error Units
DistributedIdBenchmark.snowflake thrpt 5 259901.958 ± 2868.592 ops/s
DistributedIdBenchmark.uuid thrpt 5 14384998.019 ± 45547.378 ops/s
5 存在的问题
System.currentTimeMillis()
缓慢问题: http://pzemtsov.github.io/2017/07/23/the-slow-currenttimemillis.html- MyBatis-Plus: https://baomidou.com/pages/568eb2/#spring-boot
(全文完)