图解ConcurrentHashMap的前世今生

发布时间:2022-03-01 11:11:03 作者:yexindonglai@163.com 阅读(411)

前言

     首先呢,想要了解ConcurrentHashMap, 你得先了解HashMap,可以看我另一个帖子 : HashMap底层原理以及 LinkedHashMap、HashTable 、HashSet 四者区别

    为什么要先了解HashMap呢? 因为HashMap是线程不安全的类,只适合在单线程上使用,既然使用受限,那就意味着它的结构相对比较简单,所以呢,先学HashMap在来了解ConcurrentHashMap将会更好理解,达到锦上添花的作用,事实也是如此,因为ConcurrentHashMap也是基于HashMap发展而来的;如果你一上来就直接看ConcurrentHashMap的源码,会非常懵逼,源码的可读性都不太好,所以呀,学东西最好由浅入深,不然一下子太难了,学起来也会很费劲,甚至像打退堂鼓;

1、ConcurrentHashMap 是什么?

       ConcurrentHashMap 是JDK1.5之后新出一个在并发包里面类,包名叫 java.util.current;简称JUC,既然叫并发包,那肯定就意味着它是线程安全的,里面有一个概念:分而治之,这是ConcurrentHashMap的核心思想,并且在jdk7里面用到一个非常新颖且时髦的技术 :分段锁;由此可见,ConcurrentHashMap的出现就是为了高并发而准备的;并且使用方式和HashMap一样,用key-value方式存储数据;连方法名都一样;只不过区别是一个线程安全,一个线程不安全。

HashTable的实现

但是不对啊,线程的安全的map不是已经有HashTable了吗?为什么还要正处一个ConcurrentHashMap出来呢?这是个好问题,首先我们先来看看HashTable的实现;

HashTable 的每个修饰为 public 的方法都加上了 synchronized 的同步方法,也就是说,不管我对map的增删改查都会上锁,也正因为它的锁简单粗暴,不管你干嘛我都给你锁住,造成一个原因就是效率低下;高并发场景下,只要有一个线程对HashTable操作,其他线程都会进入阻塞状态,线程数量太多的情况下会造成响应时间缓慢,所以你会看到,现在几乎没人用HashTable来实现线程安全;

2、JDK1.7 中的ConcurrentHashMap实现原理

    在jdk1.7及其以下的版本中,结构是用Segments数组 + HashEntry数组 + 链表实现的,

final Segment<K,V>[] segments;

Segment 继承了ReentrantLock,所以它除了是一个独占锁之外,还是一种可重入锁(ReentrantLock),多个锁同时存在时会自动合并为一把锁来操作,ConcurrentHashMap 使用了分段锁技术来保证线程安全,它把数据分成一段一段的,也就是Segments [] 数组,Segments数组中每一个元素就是一个段,每个元素里面又存储了一个Enter数组,这个enter数组就相当于是一个HashMap,所以在高并发场景下,每次修改内容时只会锁住segment数组的每个元素,多个元素之间各自负责自己的锁,分段后,多个段(元素)之间的插入修改不会有任何影响,既做到了并发,又提升了效率;按照默认的并发级别 concurrentLevel 来说 ,默认是16,所以理论上支持同时16个线程并发操作,并且还互不冲突,是不是很牛掰!

2.1、构造函数

  • 当我们new一个ConcurrentHashMap后, 它会先计算出Segments数组的大小,segment数组的大小是根据 concurrentLevel 计算出来的,ssize 就是Segments数组的大小,并且一定是2的次幂;
  • 默认情况下concurrentLevel是16,则ssize为16;
  • 若concurrentLevel为14,ssize为16;
  • 若concurrentLevel为17,则ssize为32。

为什么Segment的数组大小一定是2的次幂?其实主要是便于通过按位与的散列算法来定位Segment的index

  1. public ConcurrentHashMap(int initialCapacity,
  2. float loadFactor, int concurrencyLevel) {
  3. if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
  4. throw new IllegalArgumentException();
  5. //MAX_SEGMENTS 为1<<16=65536,也就是最大并发数为65536
  6. if (concurrencyLevel > MAX_SEGMENTS)
  7. concurrencyLevel = MAX_SEGMENTS;
  8. //2的sshif次方等于ssize,例:ssize=16,sshift=4;ssize=32,sshif=5
  9. int sshift = 0;
  10. //ssize 为segments数组长度,根据concurrentLevel计算得出
  11. int ssize = 1;
  12. while (ssize < concurrencyLevel) {
  13. ++sshift;
  14. ssize <<= 1;
  15. }
  16. //segmentShift和segmentMask这两个变量在定位segment时会用到,后面会详细讲
  17. this.segmentShift = 32 - sshift;
  18. this.segmentMask = ssize - 1;
  19. if (initialCapacity > MAXIMUM_CAPACITY)
  20. initialCapacity = MAXIMUM_CAPACITY;
  21. //计算cap的大小,即Segment中HashEntry的数组长度,cap也一定为2的n次方.
  22. int c = initialCapacity / ssize;
  23. if (c * ssize < initialCapacity)
  24. ++c;
  25. int cap = MIN_SEGMENT_TABLE_CAPACITY;
  26. while (cap < c)
  27. cap <<= 1;
  28. //创建segments数组并初始化第一个Segment,其余的Segment延迟初始化
  29. Segment<K,V> s0 =
  30. new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
  31. (HashEntry<K,V>[])new HashEntry[cap]);
  32. Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
  33. UNSAFE.putOrderedObject(ss, SBASE, s0);
  34. this.segments = ss;
  35. }

2.2、put方法

通过源码可以看到,在进入put方法后,会先判断key和value是否为空,ConcurrentHashMap是不允许key/value为空的;下一步就是定位segment数组的下标位置,通过hash按位与算法得出,并确保segments下标位置的HashEnter数组已初始化;

  1. public V put(K key, V value) {
  2. Segment<K,V> s;
  3. //concurrentHashMap不允许key/value为空
  4. if (value == null)
  5. throw new NullPointerException();
  6. //hash函数对key的hashCode重新散列,避免差劲的不合理的hashcode,保证散列均匀
  7. int hash = hash(key);
  8. //返回的hash值无符号右移segmentShift位与段掩码进行位运算,定位segment
  9. int j = (hash >>> segmentShift) & segmentMask;
  10. if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
  11. (segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
  12. s = ensureSegment(j);
  13. return s.put(key, hash, value, false);
  14. }

关于segmentShift和segmentMask

  segmentShift和segmentMask这两个全局变量的主要作用是用来定位Segment,int j =(hash >>> segmentShift) & segmentMask。

  segmentMask:段掩码,假如segments数组长度为16,则段掩码为16-1=15;segments长度为32,段掩码为32-1=31。这样得到的所有bit位都为1,可以更好地保证散列的均匀性

  segmentShift:2的sshift次方等于ssize,segmentShift=32-sshift。若segments长度为16,segmentShift=32-4=28;若segments长度为32,segmentShift=32-5=27。而计算得出的hash值最大为32位,无符号右移segmentShift,则意味着只保留高几位(其余位是没用的),然后与段掩码segmentMask位运算来定位Segment。

因为Segments是个数组,里面每个元素里面又套了一个数组,所以每次put一个元素,需要计算2次hash;第一次计算Segment[] 数组的下标,第二次计算HashEnter数组的下标,

  1. static final class HashEntry<K,V> {
  2. final int hash;
  3. final K key;
  4. volatile V value;
  5. volatile HashEntry<K,V> next;
  6. //其他省略
  7. }

2.3、get方法

get方法无需加锁,由于其中涉及到的共享变量都使用volatile修饰,volatile可以保证内存可见性,且防止了指令重排序,所以不会读取到过期数据。

2.4、size方法

size操作是先统计2次,如果2次的结果不一样,就代表着有线程正在修改数据,然后对put、remove、clean进行加锁后,在统计一次,之后unlock。所以最多会统计3次,

3、JDK1.8中的ConcurrentHashMap

    在JDK1.8中,对ConcurrentHashMap的结构做了一些改进,其中最大的区别就是jdk1.8抛弃了Segments数组,摒弃了分段锁的方案,而是改用了和HashMap一样的结构操作,也就是数组 + 链表 + 红黑树结构,比jdk1.7中的ConcurrentHashMap提高了效率,在并发方面,使用了cas + synchronized的方式保证数据的一致性;因为去掉了分段锁,所以在高并发时锁住的就是数组的节点了,使得结构更加简单了;

3.1 链表转红黑树条件

    要知道,链表在遍历的时候一定是从头遍历到尾的,如果很不巧,get方法中我们要找的元素恰好在尾部,那每次获取元素的时候都得遍历一次链表,所以为了避免链表过长的情况发生,在jdk1.8中,在map的结构达到一定条件之后,将会把链表自动转为红黑树的结构,这2个条件分别是:

  1. 数组中任意个链表的长度超过8个
  2. 数组长度大于64个时

 转为红黑树后的结构如下图

3.2、put方法

 jdk8中的ConcurrentHashMap 的结构看起来虽然简单了,但是源代码却不那么容易读懂,我们先来看看put方法都做了哪些逻辑

通过上图可以看到,首次添加元素和二次添加元素所做的事情也不相同,第一次添加元素时会先初始化node数组,然后再初始化链表的头部,如果该map正在扩容,则会协助其他线程扩容,最后才是将元素插入链表中,需要注意的是,在统计元素数量时会判断是否需要扩容,是下一节就是我们要讲到的并发扩容了;

3.3、并发扩容

jdk8中的ConcurrentHashMap最复杂的就是扩容机制了,因为它不是一个个地扩容,它可以并发扩容,也就是同时进行多个节点的扩容,在默认情况下,每个cpu可以负责16个元素的长度进行扩容,比如node数组的长度为32,那么线程A负责0-16下标的数组扩容, 线程B负责17-31下标的扩容,并发扩容在transfer方法中进行,这样,2个线程分别负责高16位和低16位的扩容,不管怎样都不会产生冲突,提升了效率;

3.4、计算数组索引公式

     当我们put一个元素之后,把这个元素放到数组的哪个元素下是需要通过计算得出的,在此之前,先通过 spread() 方法计算出hash值,

计算Hash值公式: key的hashcode的低十六位 异或 高十六位,然后与Hash_bits相与(与hashmap唯一不同)

  1. // 用16进制表示,转为数字后数值为:2147483647
  2. static final int HASH_BITS = 0x7fffffff; // usable bits of normal node hash
  3. //计算hash值
  4. static final int spread(int h) {
  5. return (h ^ (h >>> 16)) & HASH_BITS;
  6. }

计算hash值后,在和数组的长度n进行与运算,因为下标是从0开始的,所以需要 -  1,

计算数组索引公式:hash & (n-1)    // n表示数组的长度

接下来我们手动计算一下,比如我调用了put方法

  1. ConcurrentHashMap<Object, Object> map = new ConcurrentHashMap<>();
  2. map.put("1",1);

其中 key为字符串"1",得出hashCode 为 49,得出总公式代码如下

  1. // key 值的hashCode ,通过key.hashCode() 方法获取
  2. int h = 49;
  3. // 数组长度,默认为16
  4. int n = 16;
  5. // hash字节,在源码中用0x7fffffff表示,在这里转为数字,方便调试
  6. int HASH_BITS = 2147483647;
  7. // 计算hash值
  8. int hash = (h ^ (h >>> 16)) & HASH_BITS;
  9. // 计算数组下标
  10. int arrayIndex = hash & (n - 1);
  11. System.out.println(arrayIndex);

运行后,结果为1,通过以上结果,数组的下标值就计算出来了;

     了解代码的实现原理有什么用呢?其实我们在实现业务场景的时候,有很多的算法和机制是可以复制的,我们在编码的时候大部分情况下都要去看别人写的代码,如果有注释还好,没有注释的话就只能去猜作者为什么要这样写,然后小心验证,所以熟悉源码的作用也是如此,只有看多了别人写的代码,才有经验去读懂别人写的代码,让我们自己实现起来的话也可以起到一个借鉴的作用,就像设计模式那样,通过复用的方式减少了冗余的代码,也提高了可读性;

关键字Java