笔者最近在做一个项目,项目中为了提升吞吐量,使用了消息队列,中间实现了生产消费模式,在生产消费者模式中需要有一个集合,来存储生产者所生产的物品,笔者使用了最常见的List<T>集合类型。
由于生产者线程有很多个,消费者线程也有很多个,所以不可避免的就产生了线程同步的问题。开始笔者是使用lock关键字,进行线程同步,但是性能并不是特别理想,然后有网友说可以使用SynchronizedList<T>来代替使用List<T>达到线程安全的目的。于是笔者就替换成了SynchronizedList<T>,但是发现性能依旧糟糕,于是查看了SynchronizedList<T>的源代码,发现它就是简单的在List<T>提供的API的基础上加了lock,所以性能基本与笔者实现方式相差无几。
最后笔者找到了解决的方案,使用ConcurrentBag<T>类来实现,性能有很大的改观,于是笔者查看了ConcurrentBag<T>的源代码,实现非常精妙,特此在这记录一下。
二、ConcurrentBag类ConcurrentBag<T>实现了IProducerConsumerCollection<T>接口,该接口主要用于生产者消费者模式下,可见该类基本就是为生产消费者模式定制的。然后还实现了常规的IReadOnlyCollection<T>类,实现了该类就需要实现IEnumerable<T>、IEnumerable、 ICollection类。
ConcurrentBag<T>对外提供的方法没有List<T>那么多,但是同样有Enumerable实现的扩展方法。类本身提供的方法如下所示。
名称 说明Add 将对象添加到 ConcurrentBag 中。
CopyTo 从指定数组索引开始,将 ConcurrentBag 元素复制到现有的一维 Array 中。
Equals(Object) 确定指定的 Object 是否等于当前的 Object。 (继承自 Object。)
Finalize 允许对象在“垃圾回收”回收之前尝试释放资源并执行其他清理操作。 (继承自 Object。)
GetEnumerator 返回循环访问 ConcurrentBag 的枚举器。
GetHashCode 用作特定类型的哈希函数。 (继承自 Object。)
GetType 获取当前实例的 Type。 (继承自 Object。)
MemberwiseClone 创建当前 Object 的浅表副本。 (继承自 Object。)
ToArray 将 ConcurrentBag 元素复制到新数组。
ToString 返回表示当前对象的字符串。 (继承自 Object。)
TryPeek 尝试从 ConcurrentBag 返回一个对象但不移除该对象。
TryTake 尝试从 ConcurrentBag 中移除并返回对象。
三、 ConcurrentBag线程安全实现原理 1. ConcurrentBag的私有字段
ConcurrentBag线程安全实现主要是通过它的数据存储的结构和细颗粒度的锁。
public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T> { // ThreadLocalList对象包含每个线程的数据 ThreadLocal<ThreadLocalList> m_locals; // 这个头指针和尾指针指向中的第一个和最后一个本地列表,这些本地列表分散在不同线程中 // 允许在线程局部对象上枚举 volatile ThreadLocalList m_headList, m_tailList; // 这个标志是告知操作线程必须同步操作 // 在GlobalListsLock 锁中 设置 bool m_needSync; }首选我们来看它声明的私有字段,其中需要注意的是集合的数据是存放在ThreadLocal线程本地存储中的。也就是说访问它的每个线程会维护一个自己的集合数据列表,一个集合中的数据可能会存放在不同线程的本地存储空间中,所以如果线程访问自己本地存储的对象,那么是没有问题的,这就是实现线程安全的第一层,使用线程本地存储数据。
然后可以看到ThreadLocalList m_headList, m_tailList;这个是存放着本地列表对象的头指针和尾指针,通过这两个指针,我们就可以通过遍历的方式来访问所有本地列表。它使用volatile修饰,所以它是线程安全的。
最后又定义了一个标志,这个标志告知操作线程必须进行同步操作,这是实现了一个细颗粒度的锁,因为只有在几个条件满足的情况下才需要进行线程同步。
2. 用于数据存储的TrehadLocalList类接下来我们来看一下ThreadLocalList类的构造,该类就是实际存储了数据的位置。实际上它是使用双向链表这种结构进行数据存储。
[Serializable] // 构造了双向链表的节点 internal class Node { public Node(T value) { m_value = value; } public readonly T m_value; public Node m_next; public Node m_prev; } /// <summary> /// 集合操作类型 /// </summary> internal enum ListOperation { None, Add, Take }; /// <summary> /// 线程锁定的类 /// </summary> internal class ThreadLocalList { // 双向链表的头结点 如果为null那么表示链表为空 internal volatile Node m_head; // 双向链表的尾节点 private volatile Node m_tail; // 定义当前对List进行操作的种类 // 与前面的 ListOperation 相对应 internal volatile int m_currentOp; // 这个列表元素的计数 private int m_count; // The stealing count // 这个不是特别理解 好像是在本地列表中 删除某个Node 以后的计数 internal int m_stealCount; // 下一个列表 可能会在其它线程中 internal volatile ThreadLocalList m_nextList; // 设定锁定是否已进行 internal bool m_lockTaken; // The owner thread for this list internal Thread m_ownerThread; // 列表的版本,只有当列表从空变为非空统计是底层 internal volatile int m_version; /// <summary> /// ThreadLocalList 构造器 /// </summary> /// <param>拥有这个集合的线程</param> internal ThreadLocalList(Thread ownerThread) { m_ownerThread = ownerThread; } /// <summary> /// 添加一个新的item到链表首部 /// </summary> /// <param>The item to add.</param> /// <param>是否更新计数.</param> internal void Add(T item, bool updateCount) { checked { m_count++; } Node node = new Node(item); if (m_head == null) { Debug.Assert(m_tail == null); m_head = node; m_tail = node; m_version++; // 因为进行初始化了,所以将空状态改为非空状态 } else { // 使用头插法 将新的元素插入链表 node.m_next = m_head; m_head.m_prev = node; m_head = node; } if (updateCount) // 更新计数以避免此添加同步时溢出 { m_count = m_count - m_stealCount; m_stealCount = 0; } } /// <summary> /// 从列表的头部删除一个item /// </summary> /// <param>The removed item</param> internal void Remove(out T result) { // 双向链表删除头结点数据的流程 Debug.Assert(m_head != null); Node head = m_head; m_head = m_head.m_next; if (m_head != null) { m_head.m_prev = null; } else { m_tail = null; } m_count--; result = head.m_value; } /// <summary> /// 返回列表头部的元素 /// </summary> /// <param>the peeked item</param> /// <returns>True if succeeded, false otherwise</returns> internal bool Peek(out T result) { Node head = m_head; if (head != null) { result = head.m_value; return true; } result = default(T); return false; } /// <summary> /// 从列表的尾部获取一个item /// </summary> /// <param>the removed item</param> /// <param>remove or peek flag</param> internal void Steal(out T result, bool remove) { Node tail = m_tail; Debug.Assert(tail != null); if (remove) // Take operation { m_tail = m_tail.m_prev; if (m_tail != null) { m_tail.m_next = null; } else { m_head = null; } // Increment the steal count m_stealCount++; } result = tail.m_value; } /// <summary> /// 获取总计列表计数, 它不是线程安全的, 如果同时调用它, 则可能提供不正确的计数 /// </summary> internal int Count { get { return m_count - m_stealCount; } } }