分布式 ID 算法
2023年12月17日大约 4 分钟
分布式 ID 算法
1 前言
有 《深入分析 mysql 为什么不推荐使用 uuid 或者雪花 id 作为主键》 说 id 使用 auto_increment,但在分布式场景中,自增 id 有很多问题,譬如不能提前预知(在落表之前),多 IDC 部署(如全球部署)全局唯一,高并发场景下获取性能都是瓶颈。因此出现了分布式 ID 这一方案。
2 Scala 版本
2010 年 Twitter 开源 Snowflake 分布式 ID 算法,该算法使用 Scala 编写,支持在不依赖第三方情况下生成分布式唯一 ID,因其长度比 UUID 更少而风靡一时。
- Github: https://github.com/twitter-archive/snowflake
- 源码: https://github.com/twitter-archive/snowflake/releases/tag/snowflake-2010
/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake
import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger
/**
 * An object that generates IDs.
 * This is broken into a separate class in case
 * we ever want to support multiple worker threads
 * per process
 */
class IdWorker(val workerId: Long, val datacenterId: Long, private val reporter: Reporter, var sequence: Long = 0L)
extends Snowflake.Iface {
  private[this] def genCounter(agent: String) = {
    Stats.incr("ids_generated")
    Stats.incr("ids_generated_%s".format(agent))
  }
  private[this] val exceptionCounter = Stats.getCounter("exceptions")
  private[this] val log = Logger.get
  private[this] val rand = new Random
  val twepoch = 1288834974657L
  private[this] val workerIdBits = 5L
  private[this] val datacenterIdBits = 5L
  private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
  private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
  private[this] val sequenceBits = 12L
  private[this] val workerIdShift = sequenceBits
  private[this] val datacenterIdShift = sequenceBits + workerIdBits
  private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
  private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)
  private[this] var lastTimestamp = -1L
  // sanity check for workerId
  if (workerId > maxWorkerId || workerId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
  }
  if (datacenterId > maxDatacenterId || datacenterId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
  }
  log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
    timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)
  def get_id(useragent: String): Long = {
    if (!validUseragent(useragent)) {
      exceptionCounter.incr(1)
      throw new InvalidUserAgentError
    }
    val id = nextId()
    genCounter(useragent)
    reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
    id
  }
  def get_worker_id(): Long = workerId
  def get_datacenter_id(): Long = datacenterId
  def get_timestamp() = System.currentTimeMillis
  protected[snowflake] def nextId(): Long = synchronized {
    var timestamp = timeGen()
    if (timestamp < lastTimestamp) {
      exceptionCounter.incr(1)
      log.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
      throw new InvalidSystemClock("Clock moved backwards.  Refusing to generate id for %d milliseconds".format(
        lastTimestamp - timestamp))
    }
    if (lastTimestamp == timestamp) {
      sequence = (sequence + 1) & sequenceMask
      if (sequence == 0) {
        timestamp = tilNextMillis(lastTimestamp)
      }
    } else {
      sequence = 0
    }
    lastTimestamp = timestamp
    ((timestamp - twepoch) << timestampLeftShift) |
      (datacenterId << datacenterIdShift) |
      (workerId << workerIdShift) |
      sequence
  }
  protected def tilNextMillis(lastTimestamp: Long): Long = {
    var timestamp = timeGen()
    while (timestamp <= lastTimestamp) {
      timestamp = timeGen()
    }
    timestamp
  }
  protected def timeGen(): Long = System.currentTimeMillis()
  val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r
  def validUseragent(useragent: String): Boolean = useragent match {
    case AgentParser(_) => true
    case _ => false
  }
}3 Java 版本
根据 Scala 版本代码,我们很容易写出一个 Java 版本:
package com.twitter.service.snowflake;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Locale;
public class SnowflakeIdWorker {
    private static final Logger log = LoggerFactory.getLogger(SnowflakeIdWorker.class);
    private final long twepoch = 1288834974657L;
    private final long workerIdBits = 5L;
    private final long datacenterIdBits = 5L;
    private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
    private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
    private final long sequenceBits = 12L;
    private final long workerIdShift = sequenceBits;
    private final long datacenterIdShift = sequenceBits + workerIdBits;
    private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
    private final long sequenceMask = -1L ^ (-1L << sequenceBits);
    private long lastTimestamp = -1L;
    private final long workerId;
    private final long datacenterId;
    private long sequence = 0L;
    public SnowflakeIdWorker(long workerId, long datacenterId) {
        if (workerId > maxWorkerId || workerId < 0) {
            throw new RuntimeException(String.format(Locale.ENGLISH, "worker Id can't be greater than %d or less than 0", maxWorkerId));
        }
        if (datacenterId > maxDatacenterId || datacenterId < 0) {
            throw new RuntimeException(String.format(Locale.ENGLISH, "datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
        }
        log.info(String.format(Locale.ENGLISH, "worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
                timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId));
        this.workerId = workerId;
        this.datacenterId = datacenterId;
    }
    public synchronized long nextId() {
        long timestamp = timeGen();
        if (timestamp < lastTimestamp) {
            log.error(String.format(Locale.ENGLISH, "clock is moving backwards.  Rejecting requests until %d.", lastTimestamp));
            throw new RuntimeException(String.format(Locale.ENGLISH, "Clock moved backwards.  Refusing to generate id for %d milliseconds",
                    lastTimestamp - timestamp));
        }
        if (lastTimestamp == timestamp) {
            sequence = (sequence + 1) & sequenceMask;
            if (sequence == 0) {
                timestamp = tilNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0;
        }
        lastTimestamp = timestamp;
        return ((timestamp - twepoch) << timestampLeftShift) |
                (datacenterId << datacenterIdShift) |
                (workerId << workerIdShift) |
                sequence;
    }
    protected long tilNextMillis(long lastTimestamp) {
        long timestamp = timeGen();
        while (timestamp <= lastTimestamp) {
            timestamp = timeGen();
        }
        return timestamp;
    }
    protected long timeGen() {
        return System.currentTimeMillis();
    }
}4 基准测试
在 MyBatis-Plus (3.3.0+) 中,默认集成了雪花算法 + UUID (不含中划线),通过坐标引入相关代码
<!-- mybatis-plus -->
<dependency>
    <groupId>com.baomidou</groupId>
    <artifactId>mybatis-plus</artifactId>
    <version>3.5.0</version>
</dependency>
<!-- jmh -->
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-core</artifactId>
    <version>1.34</version>
</dependency>
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-generator-annprocess</artifactId>
    <version>1.34</version>
</dependency>4.1 32bit UUID
package com.twitter.service.snowflake;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
public class UUIDIdWorker {
    public static String getId() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        return new UUID(random.nextLong(), random.nextLong()).toString().replace("-", "");
    }
}4.2 JMH 代码
package com.devyy;
import com.baomidou.mybatisplus.core.toolkit.IdWorker;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.results.format.ResultFormatType;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
public class DistributedIdBenchmark {
    @Benchmark
    public void snowflake(Blackhole blackhole) {
        String id = IdWorker.getIdStr();
        blackhole.consume(id);
    }
    @Benchmark
    public void uuid(Blackhole blackhole) {
        String id = IdWorker.get32UUID();
        blackhole.consume(id);
    }
    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder()
                .include(DistributedIdBenchmark.class.getSimpleName())
                .forks(1)
                .result("result.json")
                .resultFormat(ResultFormatType.JSON)
                .build();
        new Runner(opt).run();
    }
}测试结果:UUID 在 jdk8 / jdk17 表现差异极大!
# JMH version: 1.34
# VM version: JDK 1.8.0_202, Java HotSpot(TM) 64-Bit Server VM, 25.202-b08
# VM invoker: C:\Program Files\Java\jdk1.8.0_202\jre\bin\java.exe
...
Benchmark                          Mode  Cnt        Score      Error  Units
DistributedIdBenchmark.snowflake  thrpt    5   260907.704 ±  175.112  ops/s
DistributedIdBenchmark.uuid       thrpt    5  1750558.071 ± 3696.515  ops/s# JMH version: 1.34
# VM version: JDK 17.0.1, OpenJDK 64-Bit Server VM, 17.0.1+12-39
# VM invoker: C:\Program Files\Java\jdk-17.0.1\bin\java.exe
...
Benchmark                          Mode  Cnt         Score       Error  Units
DistributedIdBenchmark.snowflake  thrpt    5    259901.958 ±  2868.592  ops/s
DistributedIdBenchmark.uuid       thrpt    5  14384998.019 ± 45547.378  ops/s5 存在的问题
- System.currentTimeMillis()缓慢问题: http://pzemtsov.github.io/2017/07/23/the-slow-currenttimemillis.html
- MyBatis-Plus: https://baomidou.com/pages/568eb2/#spring-boot
(全文完)