跳至主要內容

分布式 ID 算法

大约 4 分钟

分布式 ID 算法

1 前言

《深入分析 mysql 为什么不推荐使用 uuid 或者雪花 id 作为主键》open in new window 说 id 使用 auto_increment,但在分布式场景中,自增 id 有很多问题,譬如不能提前预知(在落表之前),多 IDC 部署(如全球部署)全局唯一,高并发场景下获取性能都是瓶颈。因此出现了分布式 ID 这一方案。

2 Scala 版本

2010 年 Twitter 开源 Snowflake 分布式 ID 算法,该算法使用 Scala 编写,支持在不依赖第三方情况下生成分布式唯一 ID,因其长度比 UUID 更少而风靡一时。

/** Copyright 2010-2012 Twitter, Inc.*/
package com.twitter.service.snowflake

import com.twitter.ostrich.stats.Stats
import com.twitter.service.snowflake.gen._
import java.util.Random
import com.twitter.logging.Logger

/**
 * An object that generates IDs.
 * This is broken into a separate class in case
 * we ever want to support multiple worker threads
 * per process
 */
class IdWorker(val workerId: Long, val datacenterId: Long, private val reporter: Reporter, var sequence: Long = 0L)
extends Snowflake.Iface {
  private[this] def genCounter(agent: String) = {
    Stats.incr("ids_generated")
    Stats.incr("ids_generated_%s".format(agent))
  }
  private[this] val exceptionCounter = Stats.getCounter("exceptions")
  private[this] val log = Logger.get
  private[this] val rand = new Random

  val twepoch = 1288834974657L

  private[this] val workerIdBits = 5L
  private[this] val datacenterIdBits = 5L
  private[this] val maxWorkerId = -1L ^ (-1L << workerIdBits)
  private[this] val maxDatacenterId = -1L ^ (-1L << datacenterIdBits)
  private[this] val sequenceBits = 12L

  private[this] val workerIdShift = sequenceBits
  private[this] val datacenterIdShift = sequenceBits + workerIdBits
  private[this] val timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits
  private[this] val sequenceMask = -1L ^ (-1L << sequenceBits)

  private[this] var lastTimestamp = -1L

  // sanity check for workerId
  if (workerId > maxWorkerId || workerId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("worker Id can't be greater than %d or less than 0".format(maxWorkerId))
  }

  if (datacenterId > maxDatacenterId || datacenterId < 0) {
    exceptionCounter.incr(1)
    throw new IllegalArgumentException("datacenter Id can't be greater than %d or less than 0".format(maxDatacenterId))
  }

  log.info("worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
    timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId)

  def get_id(useragent: String): Long = {
    if (!validUseragent(useragent)) {
      exceptionCounter.incr(1)
      throw new InvalidUserAgentError
    }

    val id = nextId()
    genCounter(useragent)

    reporter.report(new AuditLogEntry(id, useragent, rand.nextLong))
    id
  }

  def get_worker_id(): Long = workerId
  def get_datacenter_id(): Long = datacenterId
  def get_timestamp() = System.currentTimeMillis

  protected[snowflake] def nextId(): Long = synchronized {
    var timestamp = timeGen()

    if (timestamp < lastTimestamp) {
      exceptionCounter.incr(1)
      log.error("clock is moving backwards.  Rejecting requests until %d.", lastTimestamp);
      throw new InvalidSystemClock("Clock moved backwards.  Refusing to generate id for %d milliseconds".format(
        lastTimestamp - timestamp))
    }

    if (lastTimestamp == timestamp) {
      sequence = (sequence + 1) & sequenceMask
      if (sequence == 0) {
        timestamp = tilNextMillis(lastTimestamp)
      }
    } else {
      sequence = 0
    }

    lastTimestamp = timestamp
    ((timestamp - twepoch) << timestampLeftShift) |
      (datacenterId << datacenterIdShift) |
      (workerId << workerIdShift) |
      sequence
  }

  protected def tilNextMillis(lastTimestamp: Long): Long = {
    var timestamp = timeGen()
    while (timestamp <= lastTimestamp) {
      timestamp = timeGen()
    }
    timestamp
  }

  protected def timeGen(): Long = System.currentTimeMillis()

  val AgentParser = """([a-zA-Z][a-zA-Z\-0-9]*)""".r

  def validUseragent(useragent: String): Boolean = useragent match {
    case AgentParser(_) => true
    case _ => false
  }
}

3 Java 版本

根据 Scala 版本代码,我们很容易写出一个 Java 版本:

package com.twitter.service.snowflake;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Locale;

public class SnowflakeIdWorker {
    private static final Logger log = LoggerFactory.getLogger(SnowflakeIdWorker.class);

    private final long twepoch = 1288834974657L;

    private final long workerIdBits = 5L;
    private final long datacenterIdBits = 5L;
    private final long maxWorkerId = -1L ^ (-1L << workerIdBits);
    private final long maxDatacenterId = -1L ^ (-1L << datacenterIdBits);
    private final long sequenceBits = 12L;

    private final long workerIdShift = sequenceBits;
    private final long datacenterIdShift = sequenceBits + workerIdBits;
    private final long timestampLeftShift = sequenceBits + workerIdBits + datacenterIdBits;
    private final long sequenceMask = -1L ^ (-1L << sequenceBits);

    private long lastTimestamp = -1L;

    private final long workerId;
    private final long datacenterId;
    private long sequence = 0L;

    public SnowflakeIdWorker(long workerId, long datacenterId) {
        if (workerId > maxWorkerId || workerId < 0) {
            throw new RuntimeException(String.format(Locale.ENGLISH, "worker Id can't be greater than %d or less than 0", maxWorkerId));
        }
        if (datacenterId > maxDatacenterId || datacenterId < 0) {
            throw new RuntimeException(String.format(Locale.ENGLISH, "datacenter Id can't be greater than %d or less than 0", maxDatacenterId));
        }
        log.info(String.format(Locale.ENGLISH, "worker starting. timestamp left shift %d, datacenter id bits %d, worker id bits %d, sequence bits %d, workerid %d",
                timestampLeftShift, datacenterIdBits, workerIdBits, sequenceBits, workerId));
        this.workerId = workerId;
        this.datacenterId = datacenterId;
    }

    public synchronized long nextId() {
        long timestamp = timeGen();

        if (timestamp < lastTimestamp) {
            log.error(String.format(Locale.ENGLISH, "clock is moving backwards.  Rejecting requests until %d.", lastTimestamp));
            throw new RuntimeException(String.format(Locale.ENGLISH, "Clock moved backwards.  Refusing to generate id for %d milliseconds",
                    lastTimestamp - timestamp));
        }

        if (lastTimestamp == timestamp) {
            sequence = (sequence + 1) & sequenceMask;
            if (sequence == 0) {
                timestamp = tilNextMillis(lastTimestamp);
            }
        } else {
            sequence = 0;
        }

        lastTimestamp = timestamp;
        return ((timestamp - twepoch) << timestampLeftShift) |
                (datacenterId << datacenterIdShift) |
                (workerId << workerIdShift) |
                sequence;
    }

    protected long tilNextMillis(long lastTimestamp) {
        long timestamp = timeGen();
        while (timestamp <= lastTimestamp) {
            timestamp = timeGen();
        }
        return timestamp;
    }

    protected long timeGen() {
        return System.currentTimeMillis();
    }
}

4 基准测试

在 MyBatis-Plus (3.3.0+) 中,默认集成了雪花算法 + UUID (不含中划线),通过坐标引入相关代码

<!-- mybatis-plus -->
<dependency>
    <groupId>com.baomidou</groupId>
    <artifactId>mybatis-plus</artifactId>
    <version>3.5.0</version>
</dependency>

<!-- jmh -->
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-core</artifactId>
    <version>1.34</version>
</dependency>
<dependency>
    <groupId>org.openjdk.jmh</groupId>
    <artifactId>jmh-generator-annprocess</artifactId>
    <version>1.34</version>
</dependency>

4.1 32bit UUID

package com.twitter.service.snowflake;

import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;

public class UUIDIdWorker {
    public static String getId() {
        ThreadLocalRandom random = ThreadLocalRandom.current();
        return new UUID(random.nextLong(), random.nextLong()).toString().replace("-", "");
    }
}

4.2 JMH 代码

package com.devyy;

import com.baomidou.mybatisplus.core.toolkit.IdWorker;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.infra.Blackhole;
import org.openjdk.jmh.results.format.ResultFormatType;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;

public class DistributedIdBenchmark {
    @Benchmark
    public void snowflake(Blackhole blackhole) {
        String id = IdWorker.getIdStr();
        blackhole.consume(id);
    }

    @Benchmark
    public void uuid(Blackhole blackhole) {
        String id = IdWorker.get32UUID();
        blackhole.consume(id);
    }

    public static void main(String[] args) throws RunnerException {
        Options opt = new OptionsBuilder()
                .include(DistributedIdBenchmark.class.getSimpleName())
                .forks(1)
                .result("result.json")
                .resultFormat(ResultFormatType.JSON)
                .build();
        new Runner(opt).run();
    }
}

测试结果:UUID 在 jdk8 / jdk17 表现差异极大!

# JMH version: 1.34
# VM version: JDK 1.8.0_202, Java HotSpot(TM) 64-Bit Server VM, 25.202-b08
# VM invoker: C:\Program Files\Java\jdk1.8.0_202\jre\bin\java.exe
...
Benchmark                          Mode  Cnt        Score      Error  Units
DistributedIdBenchmark.snowflake  thrpt    5   260907.704 ±  175.112  ops/s
DistributedIdBenchmark.uuid       thrpt    5  1750558.071 ± 3696.515  ops/s
# JMH version: 1.34
# VM version: JDK 17.0.1, OpenJDK 64-Bit Server VM, 17.0.1+12-39
# VM invoker: C:\Program Files\Java\jdk-17.0.1\bin\java.exe
...
Benchmark                          Mode  Cnt         Score       Error  Units
DistributedIdBenchmark.snowflake  thrpt    5    259901.958 ±  2868.592  ops/s
DistributedIdBenchmark.uuid       thrpt    5  14384998.019 ± 45547.378  ops/s

5 存在的问题

  1. System.currentTimeMillis() 缓慢问题: http://pzemtsov.github.io/2017/07/23/the-slow-currenttimemillis.htmlopen in new window
  2. MyBatis-Plus: https://baomidou.com/pages/568eb2/#spring-bootopen in new window

(全文完)