Tuesday, July 7, 2009

Java thread pooling and connection pooling: How to handle proper connection cleanup on thread exit.

I wanted to have a fixed pool of threads for executing tasks that needed MySQL database access. Also, I couldn't afford to create and destroy a connection for each task; I needed to have up to one persistent connection per thread, so that the many tasks that will successively execute on that thread, will re-use the same connection.
Another problem that I had was how to clean up the connection if the ThreadExecutor decides to close my thread. It's all solved in the following code.

The following codes runs OK for me in Java 1.6 SE. Please do let me know if you find flaws with it.


import java.sql.Connection;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;

public class MyManager {
    private ExecutorService pool;

    public static void main(String[] args) throws Exception {
        (new MyManager()).run();
    }

    private void run() throws Exception {
        pool = Executors.newFixedThreadPool(10 /* thread pool size */,
                new MyThreadFactory());

        WorkerCounter counter = new WorkerCounter();

        while (true) {
            for(int i = 1; i < 1000; ++i)
                pool.execute(new AngolanTask(counter));

            /*
             * This is a requirement I had; you may not need this: I had to execute 1000
             * tasks (for loop above), then way for every task to finish, and loop again (while loop).
             * That's why I need the WorkerCounter class I created.
             */
            counter.blockTillEveryonesDone();
        }

        // You may care to:
        // pool.shutdown();
    }
}

/**
 * The actual task run code.
 */
class AngolanTask implements Runnable {
    private WorkerCounter counter;

    public AngolanTask(WorkerCounter counter) {
        this.counter = counter;
    }

    /**
     * Called by the thread pool's thread
     */
    public void run() {
        counter.workerStarting();
        try {
            /*
             * Put your actual task code here. In particular, I need my thread-local connection.
             */
            Connection conn = DbDispenser.getDB();
        } finally {
            counter.workerDone();
        }
    }

}

/**
 * Allow the master to know when all threads have finished. Inspired on
 * http://www.ibm.com/developerworks/library/j-thread.html
 */
class WorkerCounter {
    private int activeWorkers = 0;

    public synchronized void blockTillEveryonesDone() {
        while (activeWorkers > 0) {
            try {
                wait();
            } catch (InterruptedException e) {
                // keep trying
            }
        }
    }

    public synchronized void workerStarting() {
        activeWorkers++;
    }

    public synchronized void workerDone() {
        activeWorkers--;
        if (activeWorkers == 0)
            notify(); // alert a thread that's blocking on this semaphore
    }
}

/**
 * Creates threads that clean up connection if finalized.
 *
 * Copied from Executors.defaultThreadFactory()
 */
class MyThreadFactory implements ThreadFactory {
    static final AtomicInteger poolNumber = new AtomicInteger(1);
    final ThreadGroup group;
    final AtomicInteger threadNumber = new AtomicInteger(1);
    final String namePrefix;

    MyThreadFactory() {
        SecurityManager s = System.getSecurityManager();
        group = (s != null) ? s.getThreadGroup() : Thread.currentThread()
                .getThreadGroup();
        namePrefix = "pool-" + poolNumber.getAndIncrement() + "-thread-";
    }

    public Thread newThread(Runnable r) {
        Thread t = new AngolanThread(group, r, namePrefix
                + threadNumber.getAndIncrement(), 0);
        if (t.isDaemon())
            t.setDaemon(false);

        // Set thread priority here, if needed
       
        return t;
    }

    class AngolanThread extends Thread {
        public AngolanThread(ThreadGroup group, Runnable target, String name,
                long stackSize) {
            super(group, target, name, stackSize);
        }

        @Override
        public void run() {
            try {
                super.run();
            } finally {
            }
        }
    }
}

// Connection's are not thread-safe.
// http://www.ibm.com/developerworks/java/library/j-threads3.html
class DbDispenser {
    private static class ThreadLocalConnection extends ThreadLocal {
        public Object initialValue() {
            try {
                return null; // Create your connection here.
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
    }

    private static ThreadLocalConnection conn = new ThreadLocalConnection();

    public static Connection getDB() {
        return (Connection) conn.get();
    }
}