--- /dev/null
+package org.wamblee.concurrency;
+
+import static junit.framework.Assert.*;
+import static org.mockito.Matchers.*;
+import static org.mockito.Mockito.*;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Timer;
+import java.util.TimerTask;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+public class ReadWriteLockProxyFactoryTest {
+
+ /**
+ * This test works with time intervals of this size.
+ */
+ private static final int INTERVAL_MILLIS = 10;
+
+ public static interface X {
+ void doX(int aValue);
+
+ void doNoLocking();
+ }
+
+ public static interface Y {
+ void doY(String aValue);
+ }
+
+ public static interface Z extends X, Y {
+
+ }
+
+ private final Runnable doX = new Runnable() {
+ public void run() {
+ proxyX.doX(10);
+ }
+ };
+
+ private final Runnable doNoLock = new Runnable() {
+ public void run() {
+ proxyX.doNoLocking();
+ }
+ };
+
+ private final Runnable doY = new Runnable() {
+ public void run() {
+ proxyY.doY("hello");
+ }
+ };
+
+ private ReadWriteLockProxyFactory<Z> factory;
+ private Z service;
+ private X proxyX;
+ private Y proxyY;
+ private boolean threadFailed;
+ private Timer timer;
+ private List<Thread> threads;
+ private long tstart;
+
+ @Before
+ public void setUp() {
+ factory = new ReadWriteLockProxyFactory<ReadWriteLockProxyFactoryTest.Z>();
+
+ threadFailed = false;
+ timer = new Timer();
+ threads = new ArrayList<Thread>();
+ }
+
+ private void startTiming() {
+ tstart = System.currentTimeMillis();
+ }
+
+ private float endTiming() {
+ return (float) (System.currentTimeMillis() - tstart) / INTERVAL_MILLIS;
+ }
+
+ private void sleep(int aUnits) {
+ try {
+ Thread.sleep(aUnits * INTERVAL_MILLIS);
+ } catch (InterruptedException e) {
+ threadFailed = true;
+ }
+ }
+
+ private void schedule(final AtomicInteger aUnstartedThreads,
+ final Thread aThread, int aUnits) {
+ aUnstartedThreads.incrementAndGet();
+ timer.schedule(new TimerTask() {
+ @Override
+ public void run() {
+ aThread.start();
+ aUnstartedThreads.decrementAndGet();
+ }
+ }, aUnits * INTERVAL_MILLIS);
+ }
+
+ private void schedule(AtomicInteger aUnstartedThreads, Runnable aTask,
+ int aUnits) {
+ Thread t = new Thread(aTask);
+ schedule(aUnstartedThreads, t, aUnits);
+ threads.add(t);
+ }
+
+ private void join(AtomicInteger aUnstartedThreads, List<Thread> aThreads) {
+ while (aUnstartedThreads.get() > 0) {
+ sleep(1);
+ }
+ for (Thread t : aThreads) {
+ try {
+ t.join();
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ threadFailed = true;
+ }
+ }
+ assertFalse(threadFailed);
+ }
+
+ private void join(AtomicInteger aUnstartedThreads) {
+ join(aUnstartedThreads, threads);
+ }
+
+ private void stubDelays() {
+ service = new Z() {
+ @ReadLock
+ public void doX(int aValue) {
+ sleep(10);
+ }
+
+ @WriteLock
+ public void doY(String aValue) {
+ sleep(10);
+ }
+
+ @Override
+ public void doNoLocking() {
+ sleep(10);
+ }
+ };
+
+ createProxy();
+ Answer sleep = new Answer() {
+ @Override
+ public Object answer(InvocationOnMock aInvocation) throws Throwable {
+ sleep(10);
+ return null;
+ }
+ };
+ }
+
+ @Test
+ public void testProxyDelegates() {
+ service = mock(Z.class);
+ createProxy();
+ assertTrue(proxyX instanceof X);
+ assertTrue(proxyY instanceof Y);
+
+ proxyX.doX(10);
+ verify(service).doX(10);
+ reset(service);
+
+ proxyY.doY("hello");
+ verify(service).doY("hello");
+ reset(service);
+ }
+
+ private void createProxy() {
+ proxyX = factory.getProxy(service, X.class, Y.class);
+ proxyY = (Y) proxyX;
+ }
+
+ @Test
+ public void testConcurrentReadCalls() {
+ stubDelays();
+ startTiming();
+ final int n = 4;
+ AtomicInteger unstarted = new AtomicInteger();
+ for (int i = 0; i < n; i++) {
+ schedule(unstarted, doX, 0);
+ }
+ join(unstarted);
+ float duration = endTiming();
+ assertTrue(duration < 15);
+ }
+
+ @Test
+ public void testNoConcurrentWrites() {
+ stubDelays();
+ startTiming();
+ final int n = 2;
+ AtomicInteger unstarted = new AtomicInteger();
+ for (int i = 0; i < n; i++) {
+ schedule(unstarted, doY, 0);
+ }
+ join(unstarted);
+ float duration = endTiming();
+ System.out.println("no concurrent writes: duration " + duration);
+ assertTrue(duration >= n * 10);
+ }
+
+ @Test
+ public void testConcurrentWriteAndNoLock() {
+ stubDelays();
+ startTiming();
+ AtomicInteger unstarted = new AtomicInteger();
+
+ schedule(unstarted, doY, 0);
+ for (int i = 0; i < 10; i++) {
+ schedule(unstarted, doNoLock, 0);
+ }
+ join(unstarted);
+ float duration = endTiming();
+ System.out.println("concurrent write and no lock: duration: " +
+ duration);
+ assertTrue(duration < 15);
+ }
+
+ @Test
+ public void testNoConcurrentReadAndWrite() {
+ stubDelays();
+ startTiming();
+
+ AtomicInteger unstartedReaders = new AtomicInteger();
+ for (int i = 0; i < 4; i++) {
+ schedule(unstartedReaders, doX, 0);
+ }
+ List<Thread> readers = threads;
+ threads = new ArrayList<Thread>();
+ // start the write some time later.
+ AtomicInteger unstartedWriters = new AtomicInteger();
+ schedule(unstartedWriters, doY, 5);
+
+ join(unstartedReaders, readers);
+ float duration = endTiming();
+ System.out.println("no concurrent read and write: readers duration: " +
+ duration);
+ assertTrue(duration < 15);
+ join(unstartedWriters, threads);
+ duration = endTiming();
+ System.out.println("no concurrent read and write: writer duration: " +
+ duration);
+ assertTrue(duration >= 20 && duration <= 25);
+ }
+}