深入学习ConcurrentHashMap

近期深入学习了ConcurrentHashMap,便整理成一篇博文记录一下,请注意:此博文针对的是JDK1.6,因此如果你看到的源码跟我文中的不同,则可能是由于版本不一样。

ConcurrentHashMap的锁分段技术

HashTable容器在竞争激烈的并发环境下表现出效率低下的原因,是因为所有访问HashTable的线程必须竞争同一把锁。如果容器里有多把锁,每一把锁用于锁容器的其中一部分数据,那么当多线程访问容器里不同数据段的数据时,线程间就不会存在锁竞争,从而可以有效的提高并发访问效率,这就是ConcurrentHashMap所使用的锁分段技术。首先将数据分成一段一段的存储,然后给每一段数据配一把锁,当一个线程占用锁访问其中一个段数据的时候,其他段的数据也能被其他线程访问。

ConcurrentHashMap的结构

我们通过ConcurrentHashMap的类图来分析ConcurrentHashMap的结构。

ConcurrentHashMap是由Segment数组结构和HashEntry数组结构组成。Segment是一种可重入锁ReentrantLock,在ConcurrentHashMap里扮演锁的角色,HashEntry则用于存储键值对数据。一个ConcurrentHashMap里包含一个Segment数组,Segment的结构和HashMap类似,是一种数组和链表结构,一个Segment里包含一个HashEntry数组,每个HashEntry是一个链表结构的元素,每个Segment守护着一个HashEntry数组里的元素,当对HashEntry数组的数据进行修改时,必须首先获得它对应的Segment锁。

ConcurrentHashMap方法源码解读

请注意,如果一个方法中我贴了几段代码,那么一般是:第一段代码为方法的入口,其他的为被入口方法调用过的方法。

初始化方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel)
{

if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();

if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;

// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
segmentShift = 32 - sshift;
segmentMask = ssize - 1;
this.segments = Segment.newArray(ssize);

if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize;
if (c * ssize < initialCapacity)
++c;
int cap = 1;
while (cap < c)
cap <<= 1;

for (int i = 0; i < this.segments.length; ++i)
this.segments[i] = new Segment<K,V>(cap, loadFactor);
}

代码中的第一个while循环是用来计算segments数组的大小ssize(必须为2的N次方)。segmentShift和segmentMask是用来定位当前元素在哪个segment,前者用于移位,后者用于进行位与运算。第二个while循环是用来计算每个segment中HashEntry数组的大小cap(必须为2的N次方),最后对segments数组进行初始化。

1
2
3
4
Segment(int initialCapacity, float lf) {
loadFactor = lf;
setTable(HashEntry.<K,V>newArray(initialCapacity));
}

1
2
3
4
void setTable(HashEntry<K,V>[] newTable) {
threshold = (int)(newTable.length * loadFactor);
table = newTable;
}
1
2
3
static final <K,V> HashEntry<K,V>[] newArray(int i) {
return new HashEntry[i];
}

对segments数组进行初始化的同时,也对segment类里面的HashEntry进行初始化,并给loadFactor和threshold赋值。

get方法

1
2
3
4
public V get(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).get(key, hash);
}

根据key的hashcode重新计算hash值(主要是为了减少hash冲突),通过segmentFor方法定位到具体的哪个segment,然后调用segment的get方法。

1
2
3
final Segment<K,V> segmentFor(int hash) {
return segments[(hash >>> segmentShift) & segmentMask];
}

segmentFor方法是用来定位到具体的segment的,主要是通过使用hash值的高位与掩码进行位运算,segmentShift和segmentMask是通过上文初始化方法计算而来。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
V get(Object key, int hash) {
if (count != 0) { // read-volatile
HashEntry<K,V> e = getFirst(hash);
while (e != null) {
if (e.hash == hash && key.equals(e.key)) {
V v = e.value;
if (v != null)
return v;
return readValueUnderLock(e); // recheck
}
e = e.next;
}
}
return null;
}

1
2
3
4
HashEntry<K,V> getFirst(int hash) {
HashEntry<K,V>[] tab = table;
return tab[hash & (tab.length - 1)];
}

根据hash值跟数组长度-1进行位与运算,定位到具体的HashEntry(getFirst方法),遍历该HashEntry链表,找到链表中某个元素的hash值与传入的hash值相同并且使用equals方法比较key相同的元素,如果该元素的value不为空,返回value值;如果为空,则尝试在加锁的情况下再读一次。
get操作的高效之处在于整个get过程不需要加锁,除非读到的值是空的才会加锁重读,我们知道HashTable容器的get方法是需要加锁的,那么ConcurrentHashMap的get操作是如何做到不加锁的呢?原因是它的get方法里将要使用的共享变量都定义成volatile,如用于统计当前Segement大小的count字段和用于存储值的HashEntry的value,定义成volatile的变量,能够在线程之间保持可见性,能够被多线程同时读,并且保证不会读到过期的值,但是只能被单线程写(有一种情况可以被多线程写,就是写入的值不依赖于原值),在get操作里只需要读不需要写共享变量count和value,所以可以不用加锁。之所以不会读到过期的值,是根据Java内存模型的happen before原则,对volatile字段的写入操作先于读操作,即使两个线程同时修改和获取volatile变量,get操作也能拿到最新的值,这是用volatile替换锁的经典应用场景。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/**
* Reads value field of an entry under lock. Called if value
* field ever appears to be null. This is possible only if a
* compiler happens to reorder a HashEntry initialization with
* its table assignment, which is legal under memory model
* but is not known to ever occur.
*/

V readValueUnderLock(HashEntry<K,V> e) {
lock();
try {
return e.value;
} finally {
unlock();
}
}

readValueUnderLock:在有锁的状态下再读一次。这似乎有些费解,理论上结点的值不可能为空,这是因为put的时候就进行了判断,如果为空就要抛NullPointerException。空值的唯一源头就是HashEntry中的默认值,因为HashEntry中的value不是final的,非同步读取有可能读取到空值。仔细看下put操作的语句:tab[index] = new HashEntry(key, hash, first, value),在这条语句中,HashEntry构造函数中对value的赋值以及对tab[index]的赋值可能被重新排序(方法上面的一大段注释有提到,英语好的可以直接读注释),这就可能导致结点的值为空。这里当value为空时,可能是一个线程正在改变节点,而之前的get操作都未进行锁定,根据bernstein条件,读后写或写后读都会引起数据的不一致,所以这里要对这个e重新上锁再读一遍,以保证得到的是正确值。

put方法

1
2
3
4
5
6
public V put(K key, V value) {
if (value == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).put(key, hash, value, false);
}

根据key的hashcode重新计算hash值(跟get方法一样),通过segmentFor方法定位到具体的哪个segment,然后调用segment的put方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
V put(K key, int hash, V value, boolean onlyIfAbsent) {
lock();
try {
int c = count;
if (c++ > threshold) // ensure capacity
rehash();
HashEntry<K,V>[] tab = table;
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];
HashEntry<K,V> e = first;
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;

V oldValue;
if (e != null) {
oldValue = e.value;
if (!onlyIfAbsent)
e.value = value;
}
else {
oldValue = null;
++modCount;
tab[index] = new HashEntry<K,V>(key, hash, first, value);
count = c; // write-volatile
}
return oldValue;
} finally {
unlock();
}
}

加锁进行以下操作:判断是否需要扩容,如果需要则调用rehash方法(下面有介绍)。根据hash值跟数组长度-1进行位与运算,定位到具体的HashEntry,遍历该HashEntry,根据传入的的key,使用equals方法找到需要的元素。如果能找到,则将该元素的value值覆盖为传入的value,否则将传入的key、value、hash值作为一个新元素放在该HashEntry的头部,最后进行解锁。入参中的onlyIfAbsent为true时,表示如果该key已经存在value值,则不会覆盖原value值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
void rehash() {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
if (oldCapacity >= MAXIMUM_CAPACITY)
return;

/*
* Reclassify nodes in each list to new Map. Because we are
* using power-of-two expansion, the elements from each bin
* must either stay at same index, or move with a power of two
* offset. We eliminate unnecessary node creation by catching
* cases where old nodes can be reused because their next
* fields won't change. Statistically, at the default
* threshold, only about one-sixth of them need cloning when
* a table doubles. The nodes they replace will be garbage
* collectable as soon as they are no longer referenced by any
* reader thread that may be in the midst of traversing table
* right now.
*/


HashEntry<K,V>[] newTable = HashEntry.newArray(oldCapacity<<1); //新表扩容为原来大小的2倍
threshold = (int)(newTable.length * loadFactor); //重新计算阀值
int sizeMask = newTable.length - 1; //新表的掩码值还是为表长度-1
for (int i = 0; i < oldCapacity ; i++) {
// We need to guarantee that any existing reads of old Map can
// proceed. So we cannot yet null out each bin.
HashEntry<K,V> e = oldTable[i];

if (e != null) {
HashEntry<K,V> next = e.next; //元素e的下一个元素
int idx = e.hash & sizeMask; //计算元素e在新表中的索引位位置

// Single node on list
if (next == null) //如果当前位置只有一个元素,则直接移动到新表的对应位置
newTable[idx] = e;

else {
// Reuse trailing consecutive sequence at same slot
HashEntry<K,V> lastRun = e; //lastRun:最后一个需要处理的元素,初始值为元素e
int lastIdx = idx; //lastIdx:最后一个需要处理的元素的索引位置,初始值为元素e在新表中的索引值
for (HashEntry<K,V> last = next; //遍历该链表,找到最后一个需要处理的元素
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) { //如果当前元素的索引位置跟lastIdx不一致,则将lastIdx和lastRun替换成当前元素的相应值
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun; //将最后一个需要处理的元素放到新表中

// Clone all remaining nodes
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {//遍历处理lastRun之前的所有元素
int k = p.hash & sizeMask; //计算当前遍历元素p在新表的索引k
HashEntry<K,V> n = newTable[k]; //取到新表中索引位置k的链表头元素赋值给n
newTable[k] = new HashEntry<K,V>(p.key, p.hash,
n, p.value); //将当前遍历元素p复制到新表的索引位置k的链表头部,next属性指向新表该索引位置原来的链表头n
}
}
}
}
table = newTable; //将新表赋值给table
}

lastRun:最后一个需要处理的元素的意思就是该元素之后的所有元素都跟该元素有相同的索引值(对于新表),所以只需要将该元素放到新表的对应位置,该元素之后的所有元素也就跟着到了新表的对应位置。相当于直接将该链表的最后一截(可能包含若干个元素)直接一次性移到了新表的某个位置。
如果整个循环结束,if (k != lastIdx) 语句没有成立过,就代表当前位置(oldTable[i])的整个HashEntry在新表中的索引位置是一致的,只需要移动一次即可将整个链表移到新表上。根据rehash方法中的那一大段注释提到的“ Statistically, at the default threshold, only about one-sixth of them need cloning when a table doubles”(据统计,在默认阈值下,当表扩大为原来的两倍时,只有约六分之一的元素需要克隆),可以想象,这个if语句没有成立过的可能性应该是挺大的。

remove方法

1
2
3
4
public V remove(Object key) {
int hash = hash(key.hashCode());
return segmentFor(hash).remove(key, hash, null);
}

根据key的hashcode重新计算hash值(跟get方法一样),通过segmentFor方法定位到具体的哪个segment,然后调用segment的remove方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
V remove(Object key, int hash, Object value) {
lock();
try {
int c = count - 1;
HashEntry<K,V>[] tab = table;
int index = hash & (tab.length - 1);
HashEntry<K,V> first = tab[index];
HashEntry<K,V> e = first;
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;

V oldValue = null;
if (e != null) {
V v = e.value;
if (value == null || value.equals(v)) {
oldValue = v;
// All entries following removed node can stay
// in list, but all preceding ones need to be
// cloned.
++modCount;
HashEntry<K,V> newFirst = e.next;
for (HashEntry<K,V> p = first; p != e; p = p.next)
newFirst = new HashEntry<K,V>(p.key, p.hash,
newFirst, p.value);
tab[index] = newFirst;
count = c; // write-volatile
}
}
return oldValue;
} finally {
unlock();
}
}

加锁进行以下操作:根据hash值跟数组长度-1进行位运算,定位到具体的HashEntry,遍历该HashEntry,根据传入的的key,使用equals方法找到需要的元素,进行以下操作。

1
2
3
4
5
HashEntry<K,V> newFirst = e.next;
for (HashEntry<K,V> p = first; p != e; p = p.next)
newFirst = new HashEntry<K,V>(p.key, p.hash,
newFirst, p.value);
tab[index] = newFirst;

该段代码是remove方法中的片段,过程比较特殊,拿出来单独讨论。因为HashEntry使用final修饰,这意味着在第一次设置了next域之后便不能再改变它,因此,此处的remove操作是新建一个HashEntry并将它之前的节点全都克隆一次。至于HashEntry为什么要设置为不变性,这跟不变性的访问不需要同步从而节省时间有关。
用实际例子看上面这段代码更容易懂:
假设1:此时HashEntry为:1 2 3 4 5 6,其中1为链表头,并且1.next = 2,2.next = 3以此类推。
假设2:此时e = 4,即根据key匹配到的元素4是即将remove掉的。
则上面这段代码有以下流程:
HashEntry newFirst = 4.next = 5
for( p = 1; p != 4; p++)
newFirst = new HashEntry(p.key, p.hash, newFirst, p.value);
此循环如下:
p = 1:newFirst = new HashEntry(1.key, 1.hash, 5, 1.value)
p = 2:newFirst = new HashEntry(2.key, 2.hash, 1, 2.value)
p = 3:newFirst = new HashEntry(3.key, 3.hash, 2, 3.value)
p = 4:结束循环
tab[index] = 3;
index为当前链表在HashEntry中的索引位置,所以此时HashEntry为:3 2 1 5 6,被remove的元素之前的元素顺序颠倒了。

remove方法中还有以下这句代码,这句代码在代码中出现非常多次,主要是起什么作用?

1
HashEntry<K,V>[] tab = table;

这句代码是将table赋给一个局部变量tab,这是因为table是 volatile变量,读写volatile变量的开销很大,编译器也不能对volatile变量的读写做任何优化,直接多次访问非volatile实例变量没有多大影响,编译器会做相应优化。

replace方法

1
2
3
4
5
6
public boolean replace(K key, V oldValue, V newValue) {
if (oldValue == null || newValue == null)
throw new NullPointerException();
int hash = hash(key.hashCode());
return segmentFor(hash).replace(key, hash, oldValue, newValue);
}

根据key的hashcode重新计算hash值(跟get方法一样),通过segmentFor方法定位到具体的哪个segment,然后调用segment的replace方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
boolean replace(K key, int hash, V oldValue, V newValue) {
lock();
try {
HashEntry<K,V> e = getFirst(hash);
while (e != null && (e.hash != hash || !key.equals(e.key)))
e = e.next;

boolean replaced = false;
if (e != null && oldValue.equals(e.value)) {
replaced = true;
e.value = newValue;
}
return replaced;
} finally {
unlock();
}
}

加锁进行以下操作:根据hash值跟数组长度-1进行位运算,定位到具体的HashEntry(getFirst方法),遍历该HashEntry,使用equals方法比较传入的key和链表中元素中的key,找到所需元素。如果能找到并且该元素的value跟传入的oldValue相等,则将该元素的value替换成newValue。

clear方法

1
2
3
4
public void clear() {
for (int i = 0; i < segments.length; ++i)
segments[i].clear();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
void clear() {
if (count != 0) {
lock();
try {
HashEntry<K,V>[] tab = table;
for (int i = 0; i < tab.length ; i++)
tab[i] = null;
++modCount;
count = 0; // write-volatile
} finally {
unlock();
}
}
}

遍历segments,对每一个segment进行清空操作:加锁进行以下操作,遍历HashEntry数组,将每个HashEntry设置为null,并将count设置为0。

size方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public int size() {
final Segment<K,V>[] segments = this.segments;
long sum = 0;
long check = 0;
int[] mc = new int[segments.length];
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) {
check = 0;
sum = 0;
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {//第一次统计
sum += segments[i].count;
mcsum += mc[i] = segments[i].modCount;
}
if (mcsum != 0) {
for (int i = 0; i < segments.length; ++i) {//第二次统计
check += segments[i].count;
if (mc[i] != segments[i].modCount) {//modCount发生该变则结束当次尝试
check = -1; // force retry
break;
}
}
}
if (check == sum)
break;
}
if (check != sum) { // Resort to locking all segments
sum = 0;
for (int i = 0; i < segments.length; ++i)
segments[i].lock();
for (int i = 0; i < segments.length; ++i)
sum += segments[i].count;
for (int i = 0; i < segments.length; ++i)
segments[i].unlock();
}
if (sum > Integer.MAX_VALUE)
return Integer.MAX_VALUE;
else
return (int)sum;

}

先在不加锁的情况下尝试进行统计,如果两次统计结果相同,并且两次统计之间没有任何对segment的修改操作(即每个segment的modCount没有改变),则返回统计结果。否则,对每个segment进行加锁,然后统计出结果,返回结果。

containsValue方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
public boolean containsValue(Object value) {
if (value == null)
throw new NullPointerException();

// See explanation of modCount use above

final Segment<K,V>[] segments = this.segments;
int[] mc = new int[segments.length];

// Try a few times without locking
for (int k = 0; k < RETRIES_BEFORE_LOCK; ++k) {
int sum = 0;
int mcsum = 0;
for (int i = 0; i < segments.length; ++i) {
int c = segments[i].count;
mcsum += mc[i] = segments[i].modCount;
if (segments[i].containsValue(value))//遍历该segment里面的所有HashEntry的所有元素
return true;
}
boolean cleanSweep = true;
if (mcsum != 0) {
for (int i = 0; i < segments.length; ++i) {
int c = segments[i].count;
if (mc[i] != segments[i].modCount) {//如果modCount发生改变则结束尝试,进行加锁操作
cleanSweep = false;
break;
}
}
}
if (cleanSweep) //cleanSweep为true表示所有segment的modCount没有发生过改变
return false;
}
// Resort to locking all segments
for (int i = 0; i < segments.length; ++i)
segments[i].lock(); //对所有segment进行加锁
boolean found = false;
try {
for (int i = 0; i < segments.length; ++i) {
if (segments[i].containsValue(value)) {//遍历该segment里面的所有HashEntry的所有元素
found = true;
break;
}
}
} finally {
for (int i = 0; i < segments.length; ++i)
segments[i].unlock();
}
return found;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
boolean containsValue(Object value) {
if (count != 0) { // read-volatile
HashEntry<K,V>[] tab = table;
int len = tab.length;
for (int i = 0 ; i < len; i++) { //遍历所有HashEntry
for (HashEntry<K,V> e = tab[i]; e != null; e = e.next) { //遍历每个HashEntry的所有元素
V v = e.value;
if (v == null) // recheck
v = readValueUnderLock(e);
if (value.equals(v))
return true;
}
}
}
return false;
}

先在不加锁的情况下尝试进行查找,遍历所有segment的所有HashEntry的所有元素,如果找到则返回true,如果找不到且在遍历期间没有任何对segment的修改操作(即每个segment的modCount没有改变)则返回false。如果在遍历期间segment进行过修改操作,则结束不加锁的尝试。循环对每个segment进行加锁,然后进行遍历查找是否存在。

参考:

JDK1.6源码
Java集合—-ConcurrentHashMap原理分析