package org.wamblee.concurrency; import static junit.framework.Assert.*; import static org.mockito.Matchers.*; import static org.mockito.Mockito.*; import java.util.ArrayList; import java.util.List; import java.util.Timer; import java.util.TimerTask; import java.util.concurrent.atomic.AtomicInteger; import org.junit.Before; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; public class ReadWriteLockProxyFactoryTest { /** * This test works with time intervals of this size. */ private static final int INTERVAL_MILLIS = 10; public static interface X { void doX(int aValue); void doNoLocking(); } public static interface Y { void doY(String aValue); } public static interface Z extends X, Y { } private final Runnable doX = new Runnable() { public void run() { proxyX.doX(10); } }; private final Runnable doNoLock = new Runnable() { public void run() { proxyX.doNoLocking(); } }; private final Runnable doY = new Runnable() { public void run() { proxyY.doY("hello"); } }; private ReadWriteLockProxyFactory factory; private Z service; private X proxyX; private Y proxyY; private boolean threadFailed; private Timer timer; private List threads; private long tstart; @Before public void setUp() { factory = new ReadWriteLockProxyFactory(); threadFailed = false; timer = new Timer(); threads = new ArrayList(); } private void startTiming() { tstart = System.currentTimeMillis(); } private float endTiming() { return (float) (System.currentTimeMillis() - tstart) / INTERVAL_MILLIS; } private void sleep(int aUnits) { try { Thread.sleep(aUnits * INTERVAL_MILLIS); } catch (InterruptedException e) { threadFailed = true; } } private void schedule(final AtomicInteger aUnstartedThreads, final Thread aThread, int aUnits) { aUnstartedThreads.incrementAndGet(); timer.schedule(new TimerTask() { @Override public void run() { aThread.start(); aUnstartedThreads.decrementAndGet(); } }, aUnits * INTERVAL_MILLIS); } private void schedule(AtomicInteger aUnstartedThreads, Runnable aTask, int aUnits) { Thread t = new Thread(aTask); schedule(aUnstartedThreads, t, aUnits); threads.add(t); } private void join(AtomicInteger aUnstartedThreads, List aThreads) { while (aUnstartedThreads.get() > 0) { sleep(1); } for (Thread t : aThreads) { try { t.join(); } catch (InterruptedException e) { e.printStackTrace(); threadFailed = true; } } assertFalse(threadFailed); } private void join(AtomicInteger aUnstartedThreads) { join(aUnstartedThreads, threads); } private void stubDelays() { service = new Z() { @ReadLock public void doX(int aValue) { sleep(10); } @WriteLock public void doY(String aValue) { sleep(10); } @Override public void doNoLocking() { sleep(10); } }; createProxy(); Answer sleep = new Answer() { @Override public Object answer(InvocationOnMock aInvocation) throws Throwable { sleep(10); return null; } }; } @Test public void testProxyDelegates() { service = mock(Z.class); createProxy(); assertTrue(proxyX instanceof X); assertTrue(proxyY instanceof Y); proxyX.doX(10); verify(service).doX(10); reset(service); proxyY.doY("hello"); verify(service).doY("hello"); reset(service); } private void createProxy() { proxyX = factory.getProxy(service, X.class, Y.class); proxyY = (Y) proxyX; } @Test public void testConcurrentReadCalls() { stubDelays(); startTiming(); final int n = 4; AtomicInteger unstarted = new AtomicInteger(); for (int i = 0; i < n; i++) { schedule(unstarted, doX, 0); } join(unstarted); float duration = endTiming(); assertTrue(duration < 15); } @Test public void testNoConcurrentWrites() { stubDelays(); startTiming(); final int n = 2; AtomicInteger unstarted = new AtomicInteger(); for (int i = 0; i < n; i++) { schedule(unstarted, doY, 0); } join(unstarted); float duration = endTiming(); System.out.println("no concurrent writes: duration " + duration); assertTrue(duration >= n * 10); } @Test public void testConcurrentWriteAndNoLock() { stubDelays(); startTiming(); AtomicInteger unstarted = new AtomicInteger(); schedule(unstarted, doY, 0); for (int i = 0; i < 10; i++) { schedule(unstarted, doNoLock, 0); } join(unstarted); float duration = endTiming(); System.out.println("concurrent write and no lock: duration: " + duration); assertTrue(duration < 15); } @Test public void testNoConcurrentReadAndWrite() { stubDelays(); startTiming(); AtomicInteger unstartedReaders = new AtomicInteger(); for (int i = 0; i < 4; i++) { schedule(unstartedReaders, doX, 0); } List readers = threads; threads = new ArrayList(); // start the write some time later. AtomicInteger unstartedWriters = new AtomicInteger(); schedule(unstartedWriters, doY, 5); join(unstartedReaders, readers); float duration = endTiming(); System.out.println("no concurrent read and write: readers duration: " + duration); assertTrue(duration < 15); join(unstartedWriters, threads); duration = endTiming(); System.out.println("no concurrent read and write: writer duration: " + duration); assertTrue(duration >= 20 && duration <= 25); } }