likes
comments
collection
share

Java ThreadLocal 源码解析

作者站长头像
站长
· 阅读数 16

ThreadLocal 线程本地变量,用于存储线程相关的数据,值存储在每个线程的持有的数据结构(ThreadLocalMap)当中,当在一个线程里面设置了ThreadLocal 变量的值,其他线程进行获取ThreadLocal变量值时是访问不到

构造函数

非常简单,就提供了一个公有的构造函数

public class ThreadLocal<T> {
    public ThreadLocal() {
    }
}

可用于继承 ThreadLocal 覆写 initialValue 提供初始值

    protected T initialValue() {
        return null;
    } 

方法

  • ThreadLocal get() 方法,获取线程本地变量存储的值
    public T get() {
        // 获取当前线程
        Thread t = Thread.currentThread();
        // 获取线程存储的数据结构 ThreadLocalMap 
        ThreadLocalMap map = getMap(t);
        // 若 ThreadLocalMap 不为空
        if (map != null) {
           // 可从当前线程的 ThreadLocalMap 中获取之前set 过的值
           // 以 ThreadLocal<T> 为 key 
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        // 若当前线程的 ThreadLocalMap为空或未存储数据当前 ThreadLocal 的数据 
        return setInitialValue();
    }
      // ThreadLocalMap 属于线程的成员变量 
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

  
   private T setInitialValue() {
        // 调用 initialValue 获取初始值,默认为 null
        T value = initialValue();
        Thread t = Thread.currentThread();
        // 获取当前线程的 ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        // 如果当前线程已经创建过 ThreadLocalMap 直接赋值
        // 以当前线程本地变量 this 为 Key ,value 就是 ThreadLocal 存储的变量值
        if (map != null) {
            map.set(this, value);
        } else {
        // 若当前线程没有创建过 ThreadLocalMap 进行创建
            createMap(t, value);
        }
        if (this instanceof TerminatingThreadLocal) {
            TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
        }
        return value;
    }
   
   // 1. 创建 ThreadLocalMap ,给它赋值给当前线程
   // 2. 创建 ThreadLocalMap 并把第一个 key value ,key 为 线程本地变量 this 放入Map 
   void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
  • ThreadLocal .set(T value) 方法 设置线程本地变量 ThreadLocal 的值
public void set(T value) {
    Thread t = Thread.currentThread();
    // 获取当前线程存储的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    // 若不为空,
    if (map != null) {
        // 以 this ThreadLocal<T> 为 key ,为当前 ThreadLocal变量赋值,存储在 ThreadLocalMap 中
        map.set(this, value); 
    } else {
    // 若为空,创建 ThreadLocalMap ,并把为当前 ThreadLocal 变量的值,存储在 ThreadLocalMap 中
        createMap(t, value);
    }
}
  • ThreadLocal .remove() 方法 ,从 ThreadLocalMap 移除数据
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null) {
        m.remove(this);
    }
}

使用示例

public class ThreadLocalTest {
    public static void main(String[] args) {
        ThreadLocalTest test = new ThreadLocalTest();
        test.done();
    }
    
    private ThreadLocal<Integer> mThreadLocal = new ThreadLocal<>();

    private void done() {
        Integer value = mThreadLocal.get();
        // 当前主线程 mThreadLocal 未进行赋值,所以为空
        System.out.println(Thread.currentThread().getName() + " : " + value);
        // 当前主线程 mThreadLocal 赋值 为10010
        mThreadLocal.set(10010);
         // 获取到当前主线程 mThreadLocal 的值 为 10010
        value = mThreadLocal.get();
        System.out.println(Thread.currentThread().getName() + " : " + value);
        
        // 开启一个新线程
        final Thread thread = new Thread(() -> {
            // 子线程中进行获取 mThreadLocal ,
            // 虽然在主线程中mThreadLocal 已经赋值为 10010 ,但是子线程中是没有的
            Integer data = mThreadLocal.get();
            System.out.println(Thread.currentThread().getName() + " : " + data);
            // 给 mThreadLocal 变量赋值为 10086 
            mThreadLocal.set(10086);
            // 获取子线程中 mThreadLocal 变量的值
            data = mThreadLocal.get();
            System.out.println(Thread.currentThread().getName() + " : " + data);

        });
        // 启动子线程
        thread.start();
        try {
            // 等待子线程执行完成
            thread.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        // 子线程中虽然更改了 mThreadLocal 变量的值,但是 主线程中的 mThreadLocal 值依然是 10086 
        value = mThreadLocal.get();
        System.out.println(Thread.currentThread().getName() + " : " + value);
    }
}

输出

main : null
main : 10010
Thread-0 : null
Thread-0 : 10086
main : 10010

ThreadLocalMap 介绍

public class ThreadLocal<T> {

    private final int threadLocalHashCode = nextHashCode();

    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    /**
     * The difference between successively generated hash codes - turns
     * implicit sequential thread-local IDs into near-optimally spread
     * multiplicative hash values for power-of-two-sized tables.
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * Returns the next hash code.
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }
 }

注意 threadLocalHashCode属于成员变量 ,nextHashCode 属于静态变量,每次创建 ThreadLocal 变量时构造函数初始化时会给 threadLocalHashCode 进行赋值为nextHashCode 加上 HASH_INCREMENT

HASH_INCREMENT 为

0x61c88647 = 1640531527 ≈ 2 ^ 32 * (1 - 1 / φ) , φ = (√5 + 1) ÷ 2 为黄金分割比

使用 0x61c88647 能让哈希码能均匀的分布在 2 的 N 次方的数组里。

构造函数

static class ThreadLocalMap {

    /**
       Entry 是弱引用,保存的 ThreadLocal<?> key, 
       当ThreadLocal<?> 没有被强引用引用时,会被回收,避免内存泄漏
     */
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }
       /**
           ThreadLocalMap 构造函数,在 ThreadLocal createMap 中调用,创建过程中就存入了一个 ThreadLocal变量。
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
           // 创建一个 默认大小为 16 ,2 的 4次方 Entry数组
            table = new Entry[INITIAL_CAPACITY]; 
            // 通过ThreadLocal 的hash值获取到 Entry 数组的索引值
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); 
            table[i] = new Entry(firstKey, firstValue); // 把key value 存放在数组中
            size = 1; // ThreadLocalMap 存储 key,value 的大小
            // 设置阈值 threshold = len * 2 / 3; 
            setThreshold(INITIAL_CAPACITY);
        }
    
}

方法

  • 给ThreadLocalMap 取值的过程

当通过 hash 计算出来的数组索引值中存储的 Key 不相同时,会找它下一个索引值的key进行判断是否相等,若下一个索引值 Entry 为空,则说明ThreadLocalMap 未存储该 Key,返回 null


// ThreadLocal 取值时会把自身当成 key 从 ThreadLocalMap 取出
private Entry getEntry(ThreadLocal<?> key) {
    // 通过 ThreadLocal的hash值拿到索引值
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 通过索引值拿到 Entry ,当 e不为空,e存的key与当前取值的key相同时,返回之前设置的 value
    if (e != null && e.get() == key)
        return e;
    else
        // 1. 当 e 为空,之前没存过,
       // 2. 或者 发生了hash 碰撞,可能之前存储过 ThreadLocal,但是被回收,新生成的 ThreadLocal 和之前回收的 ThreadLocal 是同一个索引值,但是不相等
        return getEntryAfterMiss(key, i, e);
}

 private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
     Entry[] tab = table;
     int len = tab.length;
    // 1. 当 e 为空,之前没存过,跳槽循环之前返回 null 
    // 2. 遍历之 e 后面的索引存储的 ThreadLocal,若已被回收,进行清理,若和取值的key相同则返回Entry
     while (e != null) {
         // 若获取到key 相同,直接返回 Entry 
         ThreadLocal<?> k = e.get();
         if (k == key)
             return e;
         // 若 key 为空,说明之前的ThreadLocal已被回收,但是 ThreadLocalMap 还未进行清理释放
         if (k == null)
             // 进行清理
             expungeStaleEntry(i);
         else
        // 获取Entry数组的下一个下一个索引 
             i = nextIndex(i, len); 
        // i 的下一个索引的 Entry 值赋值给 e,进行检查,若e的k值与key相同则返回,若为空则进行清理, 
        // Entry 数组没有存满,阈值时 size的三分之二 当碰到 Entry 为 null 时会跳出循环
         e = tab[i];
     }
     return null;
 }

// 清理当前staleSlot 索引的 Entry 
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    
    // 清理当前staleSlot索引的 Entry ,size 减 1
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    // 遍历staleSlot下一个索引,当下一个索引对应的值为null时循环中止
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // 当对应的索引的 Entry 的key 值为 null时,该 Entry 设置为空, size 减 1
        if (k == null) { 
            e.value = null;
            tab[i] = null;
            size--;
        } else {
        // 通过 k 的 hash 获取索引值,一般情况下 h 应该等于 i 
        // rehash 时,之前的索引有值时并不会覆盖原来的值,只会存储在下一个空 Entry 索引值里面
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                // 把 i 索引值设置为 null
                tab[i] = null; 
               // 若索引为 h 的 Entry 为空,正好 和 hash 索引值匹配上
               // 若不为空找,下一个为空的 Entry
                while (tab[h] != null) 
                    h = nextIndex(h, len);
                tab[h] = e; // 把索引为i的 Entry 值赋值给 tab[h] 
            }
        }
    }
    return i;
}
  • 给ThreadLocalMap 设置值的过程

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    // 通过 ThreadLocal 的hash值获取到索引
    int i = key.threadLocalHashCode & (len-1);
    // 遍历数组 e  = null 跳出循环
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value; // 当key相等时,进行赋值,返回
            return;
        }
        // 当 k 为 null 时,替换 Entry 
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 若当前索引位置为空,新建一个 Entry ,赋值给 tab[i] 
    tab[i] = new Entry(key, value);
    // 当前 size 加 1 
    int sz = ++size;
    // 若未清理出槽位并且,大小大于等于阈值进行 rehash
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}
    

 private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                int staleSlot) {
     Entry[] tab = table;
     int len = tab.length;
     Entry e;

    // slotToExpunge 代表预期需要清除的索引值,初始化为staleSlot,遍历 staleSlot 上一个索引的 Entry,
    // 若上一个索引的 Entry 不为 null 并且该 Entry 的key被回收,把该索引值赋值给 slotToExpunge
    // 若上一个索引 Entry 为null 跳出循环
// 为了保留 staleSlot 索引值 不用被清理,找之前需要被清理的索引值
     int slotToExpunge = staleSlot;
     for (int i = prevIndex(staleSlot, len);
          (e = tab[i]) != null;
          i = prevIndex(i, len))
         if (e.get() == null)
             slotToExpunge = i;

     // 遍历 staleSlot 下一个索引的 Entry,
    // 若下一个索引 Entry 为null 跳出循环
     for (int i = nextIndex(staleSlot, len);
          (e = tab[i]) != null;
          i = nextIndex(i, len)) {
         ThreadLocal<?> k = e.get();

         // 若下一个索引的 Entry 的 key 和设置值的 key 相等 
         if (k == key) {
             e.value = value; // 进行赋值
             // 交换 i 和 staleSlot 的数据
             tab[i] = tab[staleSlot]; 
             tab[staleSlot] = e;
    
            // 如果需要清除的索引值 slotToExpunge 等于 staleSlot 则slotToExpunge 赋值为 i,
    // 因为数据 i 和 staleSlot数据交换
             if (slotToExpunge == staleSlot)
                 slotToExpunge = i;
             // 清除 key 已被回收的数据
             cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
             return;
         }

       // 若未找到相等的key,并且 k 为null,并且 slotToExpunge 等于 staleSlot
       // 需要清理的索引 slotToExpunge 赋值为 i
      // 为了保留 staleSlot 索引不被清理, 因为key的哈希值计算出的索引值正好为 staleSlot
         if (k == null && slotToExpunge == staleSlot)
             slotToExpunge = i;
     }

     // 若未找到相等的 key ,则创建新的 Entry 进行赋值
     tab[staleSlot].value = null;
     tab[staleSlot] = new Entry(key, value);

     // 如果需要清除的索引值不等于 staleSlot 则进行清除 key 已被回收的数据
    // 若等于则不需要清除,因为已被赋值新 Entry
     if (slotToExpunge != staleSlot)
         cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
 }
 // 清理一下 Key 被回收的槽位   
 private boolean cleanSomeSlots(int i, int n) {
     boolean removed = false;
     Entry[] tab = table;
     int len = tab.length;
     do {
         i = nextIndex(i, len);
         Entry e = tab[i];
         if (e != null && e.get() == null) {
          // 当发现有需要清除的 Entry 数据时,再次把 n 赋值为 数组长度
             n = len; 
             removed = true;
             i = expungeStaleEntry(i); //清除索引为 i 的数据
         }
    // n 向右移动,时间复杂度为 logN ,为了加快时间,不进行全量清理
     } while ( (n >>>= 1) != 0);
     return removed;
 }    

 private void rehash() {
     // 在扩大数组大小之前进行全量清理
     expungeStaleEntries();
       
     // 使用较低的阈值,避免迟滞
     if (size >= threshold - threshold / 4)
         resize();
 }
    
 
  //  遍历数组清除所有被回收的 key 的 Entry
 private void expungeStaleEntries() {
     Entry[] tab = table;
     int len = tab.length;
     for (int j = 0; j < len; j++) {
         Entry e = tab[j];
         if (e != null && e.get() == null)
             expungeStaleEntry(j);
     }
 }
                             
// 数组容量翻倍                             
 private void resize() {
     Entry[] oldTab = table;
     int oldLen = oldTab.length;
     int newLen = oldLen * 2;
     // 进建一个之前大小2倍的数组                        
     Entry[] newTab = new Entry[newLen];
     int count = 0;
      // 对原数组进行拷贝
     for (int j = 0; j < oldLen; ++j) {
         Entry e = oldTab[j];
         if (e != null) {
             ThreadLocal<?> k = e.get();
             // 如果之前的key为空,说明已被回收,不进行拷贝直接略过
             if (k == null) {
                 e.value = null; // Help the GC
             } else {
                // 重新计算 hash 值 
                 int h = k.threadLocalHashCode & (newLen - 1);
                 // 最好这个计算出的hash索引值为空,如果已被占用找到一个不为空的槽位
                 while (newTab[h] != null)
                     h = nextIndex(h, newLen);
                // 把旧的值赋值到新数组
                 newTab[h] = e;
                 count++;
             }
         }
     }
      // 设置新的阈值,更新size 大小,更新数组
     setThreshold(newLen);
     size = count;
     table = newTab;
 }                             
                             
  • ThreadLocalMap 移除数据非常简单,先用 hash 值找到索引,若和 key 相等直接移除,若不相等查找下一个不为null的索引,判断是否等于key进行移除.
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}