前言

上篇文章介绍了 HashMap 源码后,在博客平台广受好评,让本来己经不打算更新这个系列的我,仿佛被打了一顿鸡血。真的,被读者认可的感觉,就是这么奇妙。

然后,有读者希望我能出一版 ConcurrentHashMap 的解析。所以,今天的这篇文章,我准备讲述一下 ConcurrentHashMap 分别在JDK1.7和 JDK1.8 的源码。文章较长,建议小伙伴们可以先收藏再看哦~

说一下为什么我要把源码解析写的这么详细吧。一方面,可以记录下当时自己的思考过程,也方便后续自己复习翻阅;另一方面,记录下来还能够帮助看到文章的小伙伴加深对源码的理解,简直是一举两得的事情。

更正错误

另外,上一篇文章,有个错误点,却没有读者给我指正出来。o(╥﹏╥)o 。因此,我只能自己在此更正一下。见下面截图,

put 方法,在新值替换旧值那里,应该是只有一种情况的,e 不包括新值。图中的方框也标注出来了。因为,判断 e=p.next==null , 然后新的节点是赋值给 p.next 了,并没有赋值给 e,此时 e 依旧是空的。所以 e!=null,代表当前的 e 是已经存在的旧值。

文章编写过程,难免出现作者考虑不周的地方,如果有朋友发现有错误的地方,还请不吝赐教,指正出来。知错能改,善莫大焉,对于技术,我们应该怀有一颗严谨的心态~

文章目录

这篇文章,我打算从以下几个方面来讲。

1)多线程下的 HashMap 有什么问题?

2)怎样保证线程安全,为什么选用 ConcurrentHashMap?

3)ConcurrentHashMap 1.7 源码解析

  • 底层存储结构
  • 常用变量
  • 构造函数
  • put() 方法
  • ensureSegment() 方法
  • scanAndLockForPut() 方法
  • rehash() 扩容机制
  • get() 获取元素方法
  • remove() 方法
  • size() 方法是怎么统计元素个数的

4)ConcurrentHashMap 1.8 源码解析

  • put()方法详解
  • initTable()初始化表
  • addCount()方法
  • fullAddCount()方法
  • transfer()是怎样扩容和迁移元素的
  • helpTransfer()方法帮助迁移元素

多线程下 HashMap 有什么问题?

在上一篇文章中,已经讲解了 HashMap 1.7 死循环的成因,也正因为如此,我们才说 HashMap 在多线程下是不安全的。但是,在JDK1.8 的 HashMap 改为采用尾插法,已经不存在死循环的问题了,为什么也会线程不安全呢?

我们以 put 方法为例(1.8),

假如现在有两个线程都执行到了上图中的划线处。当线程一判断为空之后,CPU 时间片到了,被挂起。线程二也执行到此处判断为空,继续执行下一句,创建了一个新节点,插入到此下标位置。然后,线程一解挂,同样认为此下标的元素为空,因此也创建了一个新节点放在此下标处,因此造成了元素的覆盖。

所以,可以看到不管是 JDK1.7 还是 1.8 的 HashMap 都存在线程安全的问题。那么,在多线程环境下,应该怎样去保证线程安全呢?

怎样保证线程安全,为什么选用 ConcurrentHashMap?

首先,你可能想到,在多线程环境下用 Hashtable 来解决线程安全的问题。这样确实是可以的,但是同样的它也有缺点,我们看下最常用的 put 方法和 get 方法。

Hashtable-put

Hatable-get

可以看到,不管是往 map 里边添加元素还是获取元素,都会用 synchronized 关键字加锁。当有多个元素之前存在资源竞争时,只能有一个线程可以获取到锁,操作资源。更不能忍的是,一个简单的读取操作,互相之间又不影响,为什么也不能同时进行呢?

所以,hashtable 的缺点显而易见,它不管是 get 还是 put 操作,都是锁住了整个 table,效率低下,因此 并不适合高并发场景。

也许,你还会想起来一个集合工具类 Collections,生成一个SynchronizedMap。其实,它和 Hashtable 差不多,同样的原因,锁住整张表,效率低下。

所以,思考一下,既然锁住整张表的话,并发效率低下,那我把整张表分成 N 个部分,并使元素尽量均匀的分布到每个部分中,分别给他们加锁,互相之间并不影响,这种方式岂不是更好 。这就是在 JDK1.7 中 ConcurrentHashMap 采用的方案,被叫做锁分段技术,每个部分就是一个 Segment(段)。

但是,在JDK1.8中,完全重构了,采用的是 Synchronized + CAS ,把锁的粒度进一步降低,而放弃了 Segment 分段。(此时的 Synchronized 已经升级了,效率得到了很大提升,锁升级可以了解一下)

ConcurrentHashMap 1.7 源码解析

我们看下在 JDK1.7中 ConcurrentHashMap 是怎么实现的。墙裂建议,在本文之前了解一下多线程的基本知识,如JMM内存模型,volatile关键字作用,CAS和自旋,ReentranLock重入锁。

底层存储结构

在 JDK1.7中,本质上还是采用链表+数组的形式存储键值对的。但是,为了提高并发,把原来的整个 table 划分为 n 个 Segment 。所以,从整体来看,它是一个由 Segment 组成的数组。然后,每个 Segment 里边是由 HashEntry 组成的数组,每个 HashEntry之间又可以形成链表。我们可以把每个 Segment 看成是一个小的 HashMap,其内部结构和 HashMap 是一模一样的。

当对某个 Segment 加锁时,如图中 Segment2,并不会影响到其他 Segment 的读写。每个 Segment 内部自己操作自己的数据。这样一来,我们要做的就是尽可能的让元素均匀的分布在不同的 Segment中。最理想的状态是,所有执行的线程操作的元素都是不同的 Segment,这样就可以降低锁的竞争。

废话了这么多,还是来看底层源码吧,因为所有的思想都在代码里体现。借用 Linus的一句话,“No BB . Show me the code ” (改编版,哈哈)

常用变量

先看下 1.7 中常用的变量和内部类都有哪些,这有助于我们了解 ConcurrentHashMap 的整体结构。

//默认初始化容量,这个和 HashMap中的容量是一个概念,表示的是整个 Map的容量
static final int DEFAULT_INITIAL_CAPACITY = 16;

//默认加载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;

//默认的并发级别,这个参数决定了 Segment 数组的长度
static final int DEFAULT_CONCURRENCY_LEVEL = 16;

//最大的容量
static final int MAXIMUM_CAPACITY = 1 << 30;

//每个Segment中table数组的最小长度为2,且必须是2的n次幂。
//由于每个Segment是懒加载的,用的时候才会初始化,因此为了避免使用时立即调整大小,设定了最小容量2
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;

//用于限制Segment数量的最大值,必须是2的n次幂
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative

//在size方法和containsValue方法,会优先采用乐观的方式不加锁,直到重试次数达到2,才会对所有Segment加锁
//这个值的设定,是为了避免无限次的重试。后边size方法会详讲怎么实现乐观机制的。
static final int RETRIES_BEFORE_LOCK = 2;

//segment掩码值,用于根据元素的hash值定位所在的 Segment 下标。后边会细讲
final int segmentMask;

//和 segmentMask 配合使用来定位 Segment 的数组下标,后边讲。
final int segmentShift;

// Segment 组成的数组,每一个 Segment 都可以看做是一个特殊的 HashMap
final Segment<K,V>[] segments;

//Segment 对象,继承自 ReentrantLock 可重入锁。
//其内部的属性和方法和 HashMap 神似,只是多了一些拓展功能。
static final class Segment<K,V> extends ReentrantLock implements Serializable {
	
	//这是在 scanAndLockForPut 方法中用到的一个参数,用于计算最大重试次数
	//获取当前可用的处理器的数量,若大于1,则返回64,否则返回1。
	static final int MAX_SCAN_RETRIES =
		Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;

	//用于表示每个Segment中的 table,是一个用HashEntry组成的数组。
	transient volatile HashEntry<K,V>[] table;

	//Segment中的元素个数,每个Segment单独计数(下边的几个参数同样的都是单独计数)
	transient int count;

	//每次 table 结构修改时,如put,remove等,此变量都会自增
	transient int modCount;

	//当前Segment扩容的阈值,同HashMap计算方法一样也是容量乘以加载因子
	//需要知道的是,每个Segment都是单独处理扩容的,互相之间不会产生影响
	transient int threshold;

	//加载因子
	final float loadFactor;

	//Segment构造函数
	Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
		this.loadFactor = lf;
		this.threshold = threshold;
		this.table = tab;
	}
	
	...
	// put(),remove(),rehash() 方法都在此类定义
}

// HashEntry,存在于每个Segment中,它就类似于HashMap中的Node,用于存储键值对的具体数据和维护单向链表的关系
static final class HashEntry<K,V> {
	//每个key通过哈希运算后的结果,用的是 Wang/Jenkins hash 的变种算法,此处不细讲,感兴趣的可自行查阅相关资料
	final int hash;
	final K key;
	//value和next都用 volatile 修饰,用于保证内存可见性和禁止指令重排序
	volatile V value;
	//指向下一个节点
	volatile HashEntry<K,V> next;

	HashEntry(int hash, K key, V value, HashEntry<K,V> next) {
		this.hash = hash;
		this.key = key;
		this.value = value;
		this.next = next;
	}
}

构造函数

ConcurrentHashMap 有五种构造函数,但是最终都会调用同一个构造函数,所以只需要搞明白这一个核心的构造函数就可以了。

PS: 文章注释中 (1)(2)(3) 等序号都是用来方便做标记,不是计算值

public ConcurrentHashMap(int initialCapacity,
                             float loadFactor, int concurrencyLevel) {
	//检验参数是否合法。值得说的是,并发级别一定要大于0,否则就没办法实现分段锁了。
	if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
		throw new IllegalArgumentException();
	//并发级别不能超过最大值
	if (concurrencyLevel > MAX_SEGMENTS)
		concurrencyLevel = MAX_SEGMENTS;
	// Find power-of-two sizes best matching arguments
	//偏移量,是为了对hash值做位移操作,计算元素所在的Segment下标,put方法详讲
	int sshift = 0;
	//用于设定最终Segment数组的长度,必须是2的n次幂
	int ssize = 1;
	//这里就是计算 sshift 和 ssize 值的过程  (1) 
	while (ssize < concurrencyLevel) {
		++sshift;
		ssize <<= 1;
	}
	this.segmentShift = 32 - sshift;
	//Segment的掩码
	this.segmentMask = ssize - 1;
	if (initialCapacity > MAXIMUM_CAPACITY)
		initialCapacity = MAXIMUM_CAPACITY;
	//c用于辅助计算cap的值   (2)
	int c = initialCapacity / ssize;
	if (c * ssize < initialCapacity)
		++c;
	// cap 用于确定某个Segment的容量,即Segment中HashEntry数组的长度
	int cap = MIN_SEGMENT_TABLE_CAPACITY;
	//(3)
	while (cap < c)
		cap <<= 1;
	// create segments and segments[0]
	//这里用 loadFactor做为加载因子,cap乘以加载因子作为扩容阈值,创建长度为cap的HashEntry数组,
	//三个参数,创建一个Segment对象,保存到S0对象中。后边在 ensureSegment 方法会用到S0作为原型对象去创建对应的Segment。
	Segment<K,V> s0 =
		new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
						 (HashEntry<K,V>[])new HashEntry[cap]);
	//创建出长度为 ssize 的一个 Segment数组
	Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
	//把S0存到Segment数组中去。在这里,我们就可以发现,此时只是创建了一个Segment数组,
	//但是并没有把数组中的每个Segment对象创建出来,仅仅创建了一个Segment用来作为原型对象。
	UNSAFE.putOrderedObject(ss, SBASE, s0); // ordered write of segments[0]
	this.segments = ss;
}				

上边的注释中留了 (1)(2)(3) 三个地方还没有细说。我们现在假设一组数据,把涉及到的几个变量计算出来,就能明白这些参数的含义了。

//假设调用了默认构造,都用的是默认参数,即 initialCapacity 和 concurrencyLevel 都是16
//(1)  sshift 和 ssize 值的计算过程为,每次循环,都会把 sshift 自增1,并且 ssize 左移一位,即乘以2,
//直到 ssize 的值大于等于 concurrencyLevel 的值 16。
sshfit=0,1,2,3,4
ssize=1,2,4,8,16
//可以看到,初始他们的值分别是0和1,最终结果是4和16
//sshfit是为了辅助计算segmentShift值,ssize是为了确定Segment数组长度。
//(2)  此时,计算c的值,
c = 16/16 = 1;
//判断 c * 16 < 16 是否为真,真的话 c 自增1,此处为false,因此 c的值为1不变。
//(3)  此时,由于c为1, cap为2 ,因此判断 cap < c 为false,最终cap为2。
//总结一下,以上三个步骤,最终都是为了确定以下几个关键参数的值,
//确定 segmentShift ,这个用于后边计算hash值的偏移量,此处即为 32-4=28,
//确定 ssize,必须是一个大于等于 concurrencyLevel 的一个2的n次幂值
//确定 cap,必须是一个大于等于2的一个2的n次幂值
//感兴趣的小伙伴,还可以用另外几组参数来计算上边的参数值,可以加深理解参数的含义。
//例如initialCapacity和concurrencyLevel分别传入10和5,或者传入33和16

put()方法

put 方法的总体流程是,

  1. 通过哈希算法计算出当前 key 的 hash 值
  2. 通过这个 hash 值找到它所对应的 Segment 数组的下标
  3. 再通过 hash 值计算出它在对应 Segment 的 HashEntry数组 的下标
  4. 找到合适的位置插入元素
//这是Map的put方法
public V put(K key, V value) {
	Segment<K,V> s;
	//不支持value为空
	if (value == null)
		throw new NullPointerException();
	//通过 Wang/Jenkins 算法的一个变种算法,计算出当前key对应的hash值
	int hash = hash(key);
	//上边我们计算出的 segmentShift为28,因此hash值右移28位,说明此时用的是hash的高4位,
	//然后把它和掩码15进行与运算,得到的值一定是一个 0000 ~ 1111 范围内的值,即 0~15 。
	int j = (hash >>> segmentShift) & segmentMask;
	//这里是用Unsafe类的原子操作找到Segment数组中j下标的 Segment 对象
	if ((s = (Segment<K,V>)UNSAFE.getObject          // nonvolatile; recheck
		 (segments, (j << SSHIFT) + SBASE)) == null) //  in ensureSegment
		//初始化j下标的Segment
		s = ensureSegment(j);
	//在此Segment中添加元素
	return s.put(key, hash, value, false);
}

上边有一个这样的方法, UNSAFE.getObject (segments, (j << SSHIFT) + SBASE。它是为了通过Unsafe这个类,找到 j 最新的实际值。这个计算 (j << SSHIFT) + SBASE ,在后边非常常见,我们只需要知道它代表的是 j 的一个偏移量,通过偏移量,就可以得到 j 的实际值。可以类比,AQS 中的 CAS 操作。 Unsafe中的操作,都需要一个偏移量,看下图,

(j << SSHIFT) + SBASE 就相当于图中的 stateOffset偏移量。只不过图中是 CAS 设置新值,而我们这里是取 j 的最新值。 后边很多这样的计算方式,就不赘述了。接着看 s.put 方法,这才是最终确定元素位置的方法。

//Segment中的 put 方法
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
	//这里通过tryLock尝试加锁,如果加锁成功,返回null,否则执行 scanAndLockForPut方法
	//这里说明一下,tryLock 和 lock 是 ReentrantLock 中的方法,
	//区别是 tryLock 不会阻塞,抢锁成功就返回true,失败就立马返回false,
	//而 lock 方法是,抢锁成功则返回,失败则会进入同步队列,阻塞等待获取锁。
	HashEntry<K,V> node = tryLock() ? null :
		scanAndLockForPut(key, hash, value);
	V oldValue;
	try {
		//当前Segment的table数组
		HashEntry<K,V>[] tab = table;
		//这里就是通过hash值,与tab数组长度取模,找到其所在HashEntry数组的下标
		int index = (tab.length - 1) & hash;
		//当前下标位置的第一个HashEntry节点
		HashEntry<K,V> first = entryAt(tab, index);
		for (HashEntry<K,V> e = first;;) {
			//如果第一个节点不为空
			if (e != null) {
				K k;
				//并且第一个节点,就是要插入的节点,则替换value值,否则继续向后查找
				if ((k = e.key) == key ||
					(e.hash == hash && key.equals(k))) {
					//替换旧值
					oldValue = e.value;
					if (!onlyIfAbsent) {
						e.value = value;
						++modCount;
					}
					break;
				}
				e = e.next;
			}
			//说明当前index位置不存在任何节点,此时first为null,
			//或者当前index存在一条链表,并且已经遍历完了还没找到相等的key,此时first就是链表第一个元素
			else {
				//如果node不为空,则直接头插
				if (node != null)
					node.setNext(first);
				//否则,创建一个新的node,并头插
				else
					node = new HashEntry<K,V>(hash, key, value, first);
				int c = count + 1;
				//如果当前Segment中的元素大于阈值,并且tab长度没有超过容量最大值,则扩容
				if (c > threshold && tab.length < MAXIMUM_CAPACITY)
					rehash(node);
				//否则,就把当前node设置为index下标位置新的头结点
				else
					setEntryAt(tab, index, node);
				++modCount;
				//更新count值
				count = c;
				//这种情况说明旧值肯定为空
				oldValue = null;
				break;
			}
		}
	} finally {
		//需要注意ReentrantLock必须手动解锁
		unlock();
	}
	//返回旧值
	return oldValue;
}

这里说明一下计算 Segment 数组下标和计算 HashEntry 数组下标的不同点:

//下边的hash值是通过哈希运算后的hash值,不是hashCode
//计算 Segment 下标
 (hash >>> segmentShift) & segmentMask 
 //计算 HashEntry 数组下标
 (tab.length - 1) & hash

思考一下,为什么它们的算法不一样呢? 计算 Segment 数组下标是用的 hash值高几位(这里以高 4 位为例)和掩码做与运算,而计算 HashEntry 数组下标是直接用的 hash 值和数组长度减1做与运算。

我的理解是,这是为了尽量避免当前 hash 值计算出来的 Segment 数组下标和计算出来的 HashEntry 数组下标趋于相同。简单说,就是为了避免分配到同一个 Segment 中的元素扎堆现象,即避免它们都被分配到同一条链表上,导致链表过长。同时,也是为了减少并发。下面做一个运算,帮助理解一下(假设不用高 4 位运算,而是正常情况都用低位做运算)。

//我们以并发级别16,HashEntry数组容量 4 为例,则它们参与运算的掩码分别为 15 和 3
//hash值
0110 1101 0110 1111 0110 1110 0010 0010
//segmentMask = 15   ,标记为 (1)
0000 0000 0000 0000 0000 0000 0000 1111
//tab.length - 1 = 3     ,标记为 (2)
0000 0000 0000 0000 0000 0000 0000 0011
//用 hash 分别和 15 ,3 做与运算,会发现得到的结果是一样,都是十进制 2.
//这表明,当前 hash值被分配到下标为 2 的 Segment 中,同时,被分配到下标为 2 的 HashEntry 数组中
//现在若有另外一个 hash 值 h2,和第一个hash值,高位不同,但是低4位相同,
1010 1101 0110 1111 0110 1110 0010 0010
//我们会发现,最后它也会被分配到下标为 2 的 Segment 和 HashEntry 数组,就会和第一个元素形成链表。
//所以,为了避免这种扎堆现象,让元素尽量均匀分配,就让 hash 的高 4 位和 (1)处做与 运算,而用低位和 (2)处做与运算
//这样计算后,它们所在的Segment下标分别为 6(0110), 10(1010),即使它们在HashEntry数组中的下标都为 2(0010),也无所谓
//因为它们并不在一个 Segment 中,也就不会在同一个 HashEntry 数组中,更不会形成链表。
//更重要的是,它们不会有并发,因为在各自不同的 Segment 自己操作自己的加锁解锁,互不影响

可能有的小伙伴就会打岔了,那如果两个 hash 值,低位和高位都相同,怎么办呢。如果是这样,我只能说,这个 hash 算法也太烂了吧。(这里的 hash 算法也会尽量避免这种情况,当然只是减少几率,并不能杜绝)

我有个大胆的想法,这里的高低位不同的计算方式,是不是后边 1.8 HashMap 让 hash 高低位做异或运算的引子呢?不得而知。。

put 方法比较简单,只要能看懂 HashMap 中的 put 方法,这里也没问题。主要是它调用的子方法比较复杂,下边一个一个讲解。

ensureSegment()方法

回到 Map的 put 方法,判断 j 下标的 Segment为空后,则需要调用此方法,初始化一个 Segment 对象,以确保拿到的对象一定是不为空的,否则无法执行s.put了。

//k为 (hash >>> segmentShift) & segmentMask 算法计算出来的值
private Segment<K,V> ensureSegment(int k) {
	final Segment<K,V>[] ss = this.segments;
	//u代表 k 的偏移量,用于通过 UNSAFE 获取主内存最新的实际 K 值
	long u = (k << SSHIFT) + SBASE; // raw offset
	Segment<K,V> seg;
	//从内存中取到最新的下标位置的 Segment 对象,判断是否为空,(1)
	if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
		//之前构造函数说了,s0是作为一个原型对象,用于创建新的 Segment 对象
		Segment<K,V> proto = ss[0]; // use segment 0 as prototype
		//容量
		int cap = proto.table.length;
		//加载因子
		float lf = proto.loadFactor;
		//扩容阈值
		int threshold = (int)(cap * lf);
		//把 Segment 对应的 HashEntry 数组先创建出来
		HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
		//再次检查 K 下标位置的 Segment 是否为空, (2)
		if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
			== null) { // recheck
			//此处把 Segment 对象创建出来,并赋值给 s,
			Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
			//循环检查 K 下标位置的 Segment 是否为空, (3)
			//若不为空,则说明有其它线程抢先创建成功,并且已经成功同步到主内存中了,
			//则把它取出来,并返回
			while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
				   == null) {
				//CAS,若当前下标的Segment对象为空,就把它替换为最新创建出来的 s 对象。
				//若成功,就跳出循环,否则,就一直自旋直到成功,或者 seg 不为空(其他线程成功导致)。
				if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
					break;
			}
		}
	}
	return seg;
}

可以发现,我标注了上边 (1)(2)(3) 个地方,每次都判断最新的Segment是否为空。可能有的小伙伴就会迷惑,为什么做这么多次判断,我直接去自旋不就好了,反正最后都要自旋的。

我的理解是,在多线程环境下,因为不确定是什么时候会有其它线程 CAS 成功,有可能发生在以上的任意时刻。所以,只要发现一旦内存中的对象已经存在了,则说明已经有其它线程把Segment对象创建好,并CAS成功同步到主内存了。此时,就可以直接返回,而不需要往下执行了。这样做,是为了代码执行效率考虑。

scanAndLockForPut()方法

put 方法第一步抢锁失败之后,就会执行此方法,

private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
	//根据hash值定位到它对应的HashEntry数组的下标位置,并找到链表的第一个节点
	//注意,这个操作会从主内存中获取到最新的状态,以确保获取到的first是最新值
	HashEntry<K,V> first = entryForHash(this, hash);
	HashEntry<K,V> e = first;
	HashEntry<K,V> node = null;
	//重试次数,初始化为 -1
	int retries = -1; // negative while locating node
	//若抢锁失败,就一直循环,直到成功获取到锁。有三种情况
	while (!tryLock()) {
		HashEntry<K,V> f; // to recheck first below
		//1.若 retries 小于0,
		if (retries < 0) {
			if (e == null) {
				//若 e 节点和 node 都为空,则创建一个 node 节点。这里只是预测性的创建一个node节点
				if (node == null) // speculatively create node
					node = new HashEntry<K,V>(hash, key, value, null);
				retries = 0;
			}
			//如当前遍历到的 e 节点不为空,则判断它的key是否等于传进来的key,若是则把 retries 设为0
			else if (key.equals(e.key))
				retries = 0;
			//否则,继续向后遍历节点
			else
				e = e.next;
		}
		//2.若是重试次数超过了最大尝试次数,则调用lock方法加锁。表明不再重试,我下定决心了一定要获取到锁。
		//要么当前线程可以获取到锁,要么获取不到就去排队等待获取锁。获取成功后,再 break。
		else if (++retries > MAX_SCAN_RETRIES) {
			lock();
			break;
		}
		//3.若 retries 的值为偶数,并且从内存中再次获取到最新的头节点,判断若不等于first
		//则说明有其他线程修改了当前下标位置的头结点,于是需要更新头结点信息。
		else if ((retries & 1) == 0 &&
				 (f = entryForHash(this, hash)) != first) {
			//更新头结点信息,并把重试次数重置为 -1,继续下一次循环,从最新的头结点遍历当前链表。
			e = first = f; // re-traverse if entry changed
			retries = -1;
		}
	}
	return node;
}

这个方法逻辑比较复杂,会一直循环尝试获取锁,若获取成功,则返回。否则的话,每次循环时,都会同时遍历当前链表。若遍历完了一次,还没找到和key相等的节点,就会预先创建一个节点。注意,这里只是预测性的创建一个新节点,也有可能在这之前,就已经获取锁成功了。

同时,当重试次每偶数次时,就会检查一次当前最新的头结点是否被改变。因为若有变化的话,还需要从最新的头结点开始遍历链表。

还有一种情况,就是循环次数达到了最大限制,则停止循环,用阻塞的方式去获取锁。这时,也就停止了遍历链表的动作,当前线程也不会再做其他预热(warm up)的事情。

关于为什么预测性的创建新节点,源码中原话是这样的:

Since traversal speed doesn’t matter, we might as well help warm up the associated code and accesses as well.

解释一下就是,因为遍历速度无所谓,所以,我们可以预先(warm up)做一些相关联代码的准备工作。这里相关联代码,指的就是循环中,在获取锁成功或者调用 lock 方法之前做的这些事情,当然也包括创建新节点。

在put 方法中可以看到,有一句是判断 node 是否为空,若创建了,就直接头插。否则的话,它也会自己创建这个新节点。

scanAndLockForPut 这个方法可以确保返回时,当前线程一定是获取到锁的状态。

rehash()方法

当 put 方法时,发现元素个数超过了阈值,则会扩容。需要注意的是,每个Segment只管它自己的扩容,互相之间并不影响。换句话说,可以出现这个 Segment的长度为2,另一个Segment的长度为4的情况(只要是2的n次幂)。


//node为创建的新节点
private void rehash(HashEntry<K,V> node) {
	//当前Segment中的旧表
	HashEntry<K,V>[] oldTable = table;
	//旧的容量
	int oldCapacity = oldTable.length;
	//新容量为旧容量的2倍
	int newCapacity = oldCapacity << 1;
	//更新新的阈值
	threshold = (int)(newCapacity * loadFactor);
	//用新的容量创建一个新的 HashEntry 数组
	HashEntry<K,V>[] newTable =
		(HashEntry<K,V>[]) new HashEntry[newCapacity];
	//当前的掩码,用于计算节点在新数组中的下标
	int sizeMask = newCapacity - 1;
	//遍历旧表
	for (int i = 0; i < oldCapacity ; i++) {
		HashEntry<K,V> e = oldTable[i];
		//如果e不为空,说明当前链表不为空
		if (e != null) {
			HashEntry<K,V> next = e.next;
			//计算hash值再新数组中的下标位置
			int idx = e.hash & sizeMask;
			//如果e不为空,且它的下一个节点为空,则说明这条链表只有一个节点,
			//直接把这个节点放到新数组的对应下标位置即可
			if (next == null)   //  Single node on list
				newTable[idx] = e;
			//否则,处理当前链表的节点迁移操作
			else { // Reuse consecutive sequence at same slot
				//记录上一次遍历到的节点
				HashEntry<K,V> lastRun = e;
				//对应上一次遍历到的节点在新数组中的新下标
				int lastIdx = idx;
				for (HashEntry<K,V> last = next;
					 last != null;
					 last = last.next) {
					//计算当前遍历到的节点的新下标
					int k = last.hash & sizeMask;
					//若 k 不等于 lastIdx,则说明此次遍历到的节点和上次遍历到的节点不在同一个下标位置
					//需要把 lastRun 和 lastIdx 更新为当前遍历到的节点和下标值。
					//若相同,则不处理,继续下一次 for 循环。
					if (k != lastIdx) {
						lastIdx = k;
						lastRun = last;
					}
				}
				//把和 lastRun 节点的下标位置相同的链表最末尾的几个连续的节点放到新数组的对应下标位置
				newTable[lastIdx] = lastRun;
				//再把剩余的节点,复制到新数组
				//从旧数组的头结点开始遍历,直到 lastRun 节点,因为 lastRun节点后边的节点都已经迁移完成了。
				for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
					V v = p.value;
					int h = p.hash;
					int k = h & sizeMask;
					HashEntry<K,V> n = newTable[k];
					//用的是复制节点信息的方式,并不是把原来的节点直接迁移,区别于lastRun处理方式
					newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
				}
			}
		}
	}
	//所有节点都迁移完成之后,再处理传进来的新的node节点,把它头插到对应的下标位置
	int nodeIndex = node.hash & sizeMask; // add the new node
	//头插node节点
	node.setNext(newTable[nodeIndex]);
	newTable[nodeIndex] = node;
	//更新当前Segment的table信息
	table = newTable;
}

上边的迁移过程和 lastRun 和 lastIdx 变量可能不太好理解,我画个图就明白了。以其中一条链表处理方式为例。

从头结点开始向后遍历,找到当前链表的最后几个下标相同的连续的节点。如上图,虽然开头出现了有两个节点的下标都是 k2, 但是中间出现一个不同的下标 k1,打断了下标连续相同,因此从下一个k2,又重新开始算。好在后边三个连续的节点下标都是相同的,因此倒数第三个节点被标记为 lastRun,且变量无变化。

从lastRun节点到尾结点的这部分就可以整体迁移到新数组的对应下标位置了,因为它们的下标都是相同的,可以这样统一处理。

另外从头结点到 lastRun 之前的节点,无法统一处理,只能一个一个去复制了。且注意,这里不是直接迁移,而是复制节点到新的数组,旧的节点会在不久的将来,因为没有引用指向,被 JVM 垃圾回收处理掉。

(不知道为啥这个方法名起为 rehash,其实扩容时 hash 值并没有重新计算,变化的只是它们所在的下标而已。我猜测,可能是,借用了 1.7 HashMap 中的说法吧。。。)

get()

put 方法搞明白了之后,其实 get 方法就很好理解了。也是先定位到 Segment,然后再定位到 HashEntry 。

public V get(Object key) {
	Segment<K,V> s; // manually integrate access methods to reduce overhead
	HashEntry<K,V>[] tab;
	//计算hash值
	int h = hash(key);
	//同样的先定位到 key 所在的Segment ,然后从主内存中取出最新的节点
	long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
	if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
		(tab = s.table) != null) {
		//若Segment不为空,且链表也不为空,则遍历查找节点
		for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
				 (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
			 e != null; e = e.next) {
			K k;
			//找到则返回它的 value 值,否则返回 null
			if ((k = e.key) == key || (e.hash == h && key.equals(k)))
				return e.value;
		}
	}
	return null;
}

remove()

remove 方法和 put 方法类似,也不用做过多特殊的介绍,

public V remove(Object key) {
	int hash = hash(key);
	//定位到Segment
	Segment<K,V> s = segmentForHash(hash);
	//若 s为空,则返回 null,否则执行 remove
	return s == null ? null : s.remove(key, hash, null);
}

public boolean remove(Object key, Object value) {
	int hash = hash(key);
	Segment<K,V> s;
	return value != null && (s = segmentForHash(hash)) != null &&
		s.remove(key, hash, value) != null;
}

final V remove(Object key, int hash, Object value) {
	//尝试加锁,若失败,则执行 scanAndLock ,此方法和 scanAndLockForPut 方法类似
	if (!tryLock())
		scanAndLock(key, hash);
	V oldValue = null;
	try {
		HashEntry<K,V>[] tab = table;
		int index = (tab.length - 1) & hash;
		//从主内存中获取对应 table 的最新的头结点
		HashEntry<K,V> e = entryAt(tab, index);
		HashEntry<K,V> pred = null;
		while (e != null) {
			K k;
			HashEntry<K,V> next = e.next;
			//匹配到 key
			if ((k = e.key) == key ||
				(e.hash == hash && key.equals(k))) {
				V v = e.value;
				// value 为空,或者 value 也匹配成功
				if (value == null || value == v || value.equals(v)) {
					if (pred == null)
						setEntryAt(tab, index, next);
					else
						pred.setNext(next);
					++modCount;
					--count;
					oldValue = v;
				}
				break;
			}
			pred = e;
			e = next;
		}
	} finally {
		unlock();
	}
	return oldValue;
}

size()

size 方法需要重点说明一下。爱思考的小伙伴可能就会想到,并发情况下,有可能在统计期间,数组元素个数不停的变化,而且,整个表还被分成了 N个 Segment,怎样统计才能保证结果的准确性呢? 我们一起来看下吧。

public int size() {
	// Try a few times to get accurate count. On failure due to
	// continuous async changes in table, resort to locking.
	//segment数组
	final Segment<K,V>[] segments = this.segments;
	//统计所有Segment中元素的总个数
	int size;
	//如果size大小超过32位,则标记为溢出为true
	boolean overflow; 
	//统计每个Segment中的 modcount 之和
	long sum;         
	//上次记录的 sum 值
	long last = 0L;   
	//重试次数,初始化为 -1
	int retries = -1; 
	try {
		for (;;) {
			//如果超过重试次数,则不再重试,而是把所有Segment都加锁,再统计 size
			if (retries++ == RETRIES_BEFORE_LOCK) {
				for (int j = 0; j < segments.length; ++j)
					//强制加锁
					ensureSegment(j).lock(); // force creation
			}
			sum = 0L;
			size = 0;
			overflow = false;
			//遍历所有Segment
			for (int j = 0; j < segments.length; ++j) {
				Segment<K,V> seg = segmentAt(segments, j);
				//若当前遍历到的Segment不为空,则统计它的 modCount 和 count 元素个数
				if (seg != null) {
					//累加当前Segment的结构修改次数,如put,remove等操作都会影响modCount
					sum += seg.modCount;
					int c = seg.count;
					//若当前Segment的元素个数 c 小于0 或者 size 加上 c 的结果小于0,则认为溢出
					//因为若超过了 int 最大值,就会返回负数
					if (c < 0 || (size += c) < 0)
						overflow = true;
				}
			}
			//当此次尝试,统计的 sum 值和上次统计的值相同,则说明这段时间内,
			//并没有任何一个 Segment 的结构发生改变,就可以返回最后的统计结果
			if (sum == last)
				break;
			//不相等,则说明有 Segment 结构发生了改变,则记录最新的结构变化次数之和 sum,
			//并赋值给 last,用于下次重试的比较。
			last = sum;
		}
	} finally {
		//如果超过了指定重试次数,则说明表中的所有Segment都被加锁了,因此需要把它们都解锁
		if (retries > RETRIES_BEFORE_LOCK) {
			for (int j = 0; j < segments.length; ++j)
				segmentAt(segments, j).unlock();
		}
	}
	//若结果溢出,则返回 int 最大值,否则正常返回 size 值 
	return overflow ? Integer.MAX_VALUE : size;
}

其实源码中前两行的注释也说的非常清楚了。我们先采用乐观的方式,认为在统计 size 的过程中,并没有发生 put, remove 等会改变 Segment 结构的操作。 但是,如果发生了,就需要重试。如果重试2次都不成功(执行三次,第一次不能叫做重试),就只能强制把所有 Segment 都加锁之后,再统计了,以此来得到准确的结果。

版权声明:本文为starry-skys原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://www.cnblogs.com/starry-skys/p/12742500.html