java实现一个简单的线程池

池是一种非常流行的资源管理模式,比如线程池、连接池等。对于线程池,在java的类库中已经提供了线程池的一些基本实现,

java实现一个简单的线程池

在平常应用中,我们使用类库中给我们提供的已经足够了。但不知大伙有没有跟我一样的疑问,线程池中的线程是如何保持空闲状态的?

这个问题其实很好找到答案,看下源码就知道了!废话不多说,直接贴上JDK中的源码

 final void runWorker(Worker w) {
        Thread wt = Thread.currentThread();
        Runnable task = w.firstTask;
        w.firstTask = null;
        w.unlock(); // allow interrupts
        boolean completedAbruptly = true;
        try {
            while (task != null || (task = getTask()) != null) {
                w.lock();
                // If pool is stopping, ensure thread is interrupted;
                // if not, ensure thread is not interrupted.  This
                // requires a recheck in second case to deal with
                // shutdownNow race while clearing interrupt
                if ((runStateAtLeast(ctl.get(), STOP) ||
                     (Thread.interrupted() &&
                      runStateAtLeast(ctl.get(), STOP))) &&
                    !wt.isInterrupted())
                    wt.interrupt();
                try {
                    beforeExecute(wt, task);
                    Throwable thrown = null;
                    try {
                        task.run();
                    } catch (RuntimeException x) {
                        thrown = x; throw x;
                    } catch (Error x) {
                        thrown = x; throw x;
                    } catch (Throwable x) {
                        thrown = x; throw new Error(x);
                    } finally {
                        afterExecute(task, thrown);
                    }
                } finally {
                    task = null;
                    w.completedTasks++;
                    w.unlock();
                }
            }
            completedAbruptly = false;
        } finally {
            processWorkerExit(w, completedAbruptly);
        }
    }

这个是java.util.concurrent.ThreadPoolExecutor中的方法,对于此方法,我们只需关注两个地方,getTask()和task.run(),getTask()是个阻塞的方法,用来从任务队列中获取用户提交的任务,当任务队列为空时,线程将一直处于阻塞状态,即线程池中空闲线程;task.run()则是真正执行用户任务代码之处。由此,我们就可以看出线程池是通过一个while循环来复用线程的,并且使用阻塞来使线程保持空闲状态。

知道了原理,下面我们就利用此原理来实现一个简单的线程池:

package com.sqxww.blog.thread;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

public class MyExecutor implements Executor{
	
	private TaskQueue taskQueue = new TaskQueue();	/*任务队列*/
	private final int maxThreads;	/*最大线程数*/
	private final int maxIdleThreads; /*最大空闲线程数*/
	private AtomicInteger idleThreads = new AtomicInteger(0);/*空闲线程数*/
	private boolean shutdown = false;	/*关闭标识*/
	private final Map<Long, Thread> threadMap = new ConcurrentHashMap<Long, Thread>();	/*线程容器*/
	
	public MyExecutor() {
		this(5, 1);
	}
	
	public MyExecutor(int maxThreads, int maxIdleThreads) {
		maxIdleThreads = maxIdleThreads < 1 ? 1 : maxIdleThreads;
		this.maxThreads = maxThreads;
		this.maxIdleThreads = maxIdleThreads;
	}
	
	@Override
	public void execute(Runnable task) {
		if(shutdown) {
			return;
		}
		//判断是否有空闲线程
		if(idleThreads.get()  <= 0 && threadMap.size() < maxThreads) {
			//创建工作线程
			createWorker();
		}
		//提交任务
		taskQueue.add(task);
	}
	
	/**
	 * 平缓关闭线程池
	 */
	public void shutdown() {
		shutdown = true;
		taskQueue.shutdown();
	}
	
	/**
	 * 立即关闭线程池
	 */
	public void shutdownNow() {
		shutdown();
		//清空任务队列
		taskQueue.clear();
		//尝试中断线程
		interruptAll();
	}
	
	private synchronized void createWorker() {
		//双重判断是否有空闲线程
		if(idleThreads.get()  > 0 || threadMap.size() >= maxThreads) {
			return;
		}
		
		Thread worker = new Thread(new Runnable() {
			
			@Override
			public void run() {
				//空闲线程数加加
				idleThreads.getAndIncrement();
				while(true) {
					//获取任务
					Runnable task = taskQueue.pop();
					//空闲线程数减减
					idleThreads.getAndDecrement();
					if(null == task)
						break;
					try {
						//执行任务
						task.run();
					} catch (Exception e) {
						e.printStackTrace();
					}
					if(shutdown) {
						break;
					}
					int temp = idleThreads.get();
					//利用CAS方法将空闲线程数加加
					while(temp < maxIdleThreads && !idleThreads.compareAndSet(temp, temp + 1)) {
						temp = idleThreads.get();
					}
					//判断线程是否可以进入空闲状态
					if(temp < maxIdleThreads && !shutdown) {
						continue;
					}
					break;
				}
				//将线程从线程池中移除
				threadMap.remove(Thread.currentThread().getId());
			}
		});
		//将工作加入到容器中
		threadMap.put(worker.getId(), worker);
		//开启工作线程
		worker.start();
	}
	
	private void interruptAll() {
		for(Entry<Long, Thread> entry : threadMap.entrySet()) {
			entry.getValue().interrupt();
		}
	}

}

class TaskQueue{
	private final List<Runnable> taskList = new ArrayList<Runnable>();
	private Lock lock = new ReentrantLock();
	private Condition notEmpty = lock.newCondition();
	private Condition empty = lock.newCondition();
	private boolean shutdown = false;
	
	public void add(Runnable task) {
		lock.lock();
		try {
			if(shutdown)
				return;
			taskList.add(task);
			notEmpty.signal();
		} finally {
			lock.unlock();
		}
	}
	
	public Runnable pop() {
		lock.lock();
		try {
			while (taskList.isEmpty()) {
				if(shutdown) {
					return null;
				}
				notEmpty.await(1, TimeUnit.SECONDS);
			}
			Runnable task = taskList.remove(0);
			if(taskList.isEmpty())
				empty.signalAll();
			return task;
		} catch (InterruptedException e) {
			return null;
		} finally {
			lock.unlock();
		}
	}
	
	public void clear() {
		lock.lock();
		try {
			taskList.clear();
		}finally {
			lock.unlock();
		}
	}
	
	public void shutdown() {
		shutdown = true;
	}
	
}