initial version of support project with build support.
[utils] / test / org / wamblee / test / SpringTestCase.java
diff --git a/test/org/wamblee/test/SpringTestCase.java b/test/org/wamblee/test/SpringTestCase.java
new file mode 100644 (file)
index 0000000..8b43621
--- /dev/null
@@ -0,0 +1,595 @@
+/*
+ * Copyright 2005 the original author or authors.
+ * 
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ * 
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.wamblee.test;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import javax.sql.DataSource;
+
+import junit.framework.TestCase;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.dbunit.DatabaseUnitException;
+import org.dbunit.database.DatabaseConnection;
+import org.dbunit.database.DatabaseSequenceFilter;
+import org.dbunit.database.IDatabaseConnection;
+import org.dbunit.dataset.FilteredDataSet;
+import org.dbunit.dataset.IDataSet;
+import org.dbunit.dataset.filter.ITableFilter;
+import org.dbunit.operation.DatabaseOperation;
+import org.hibernate.SessionFactory;
+import org.jmock.cglib.MockObjectTestCase;
+import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.support.RootBeanDefinition;
+import org.springframework.context.ApplicationContext;
+import org.springframework.context.support.ClassPathXmlApplicationContext;
+import org.springframework.context.support.GenericApplicationContext;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.datasource.DataSourceUtils;
+import org.springframework.jdbc.datasource.DriverManagerDataSource;
+import org.springframework.orm.hibernate3.HibernateTemplate;
+import org.springframework.transaction.PlatformTransactionManager;
+import org.springframework.transaction.TransactionDefinition;
+import org.springframework.transaction.TransactionStatus;
+import org.springframework.transaction.support.DefaultTransactionDefinition;
+import org.springframework.transaction.support.TransactionCallback;
+import org.springframework.transaction.support.TransactionCallbackWithoutResult;
+import org.springframework.transaction.support.TransactionTemplate;
+import org.wamblee.general.BeanKernel;
+import org.wamblee.persistence.hibernate.HibernateMappingFiles;
+
+/**
+ * Test case support class for spring tests. 
+ */
+public class SpringTestCase extends MockObjectTestCase {
+
+    private Log LOG = LogFactory.getLog(SpringTestCase.class);
+
+    /**
+     * Session factory bean name.
+     */
+    private static final String SESSION_FACTORY = "sessionFactory";
+
+    /**
+     * Data source bean name.
+     */
+    private static final String DATA_SOURCE = "dataSource";
+
+    /**
+     * Transaction manager bean name.
+     */
+    private static final String TRANSACTION_MANAGER = "transactionManager";
+
+    /**
+     * Name of the ConfigFileList bean that describes the Hibernate mapping
+     * files to use.
+     */
+    private static final String HIBERNATE_CONFIG_FILES = "hibernateMappingFiles";
+
+    /**
+     * Schema pattern.
+     */
+    private static final String SCHEMA_PATTERN = "%";
+   
+    /**
+     * List of (String) configuration file locations for spring.
+     */
+    private String[] _configLocations;
+    
+    /**
+     * Application context for storing bean definitions that vary on a test by
+     * test basis and cannot be hardcoded in the spring configuration files.
+     */
+    private GenericApplicationContext _parentContext;
+    
+    /**
+     * Cached spring application context. 
+     */
+    private ApplicationContext _context; 
+
+    public SpringTestCase(Class<? extends SpringConfigFiles> aSpringFiles,
+            Class<? extends HibernateMappingFiles> aMappingFiles) {
+        try {
+            SpringConfigFiles springFiles = aSpringFiles.newInstance();
+            _configLocations = springFiles.toArray(new String[0]);
+        } catch (Exception e) {
+            fail("Could not construct spring config files class '" + aSpringFiles.getName() + "'"); 
+        }
+   
+        // Register the Hibernate mapping files as a bean.
+        _parentContext = new GenericApplicationContext();
+        BeanDefinition lDefinition = new RootBeanDefinition(aMappingFiles);
+        _parentContext.registerBeanDefinition(HIBERNATE_CONFIG_FILES,
+                lDefinition);
+        _parentContext.refresh();
+
+    }
+
+    /**
+     * Gets the spring context.
+     * 
+     * @return Spring context.
+     */
+    protected synchronized ApplicationContext getSpringContext() {
+        if ( _context == null ) { 
+            _context = new ClassPathXmlApplicationContext(
+                (String[]) _configLocations, _parentContext);
+            assertNotNull(_context);
+        }
+        return _context; 
+    }
+
+    /**
+     * @return Hibernate session factory.
+     */
+    protected SessionFactory getSessionFactory() {
+        SessionFactory factory = (SessionFactory) getSpringContext().getBean(SESSION_FACTORY);
+        assertNotNull(factory); 
+        return factory; 
+    }
+
+    protected void setUp() throws Exception {
+        LOG.info("Performing setUp()");
+
+        super.setUp();
+        
+        _context = null;  // make sure we get a new application context for every
+                          // new test. 
+
+        BeanKernel
+                .overrideBeanFactory(new TestSpringBeanFactory(getSpringContext()));
+  
+        cleanDatabase(); 
+    }
+
+    /*
+     * (non-Javadoc)
+     * 
+     * @see junit.framework.TestCase#tearDown()
+     */
+    @Override
+    protected void tearDown() throws Exception {
+        try {
+            super.tearDown();
+        } finally {
+            LOG.info("tearDown() complete");
+        }
+    }
+
+    /**
+     * @return Transaction manager
+     */
+    protected PlatformTransactionManager getTransactionManager() {
+        PlatformTransactionManager manager = (PlatformTransactionManager) getSpringContext()
+        .getBean(TRANSACTION_MANAGER);
+        assertNotNull(manager); 
+        return manager; 
+    }
+
+    /**
+     * @return Starts a new transaction.
+     */
+    protected TransactionStatus getTransaction() {
+        DefaultTransactionDefinition def = new DefaultTransactionDefinition();
+        def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
+
+        return getTransactionManager().getTransaction(def);
+    }
+
+    /**
+     * Returns the hibernate template for executing hibernate-specific
+     * functionality.
+     * 
+     * @return Hibernate template.
+     */
+    protected HibernateTemplate getTemplate() {
+        HibernateTemplate template = (HibernateTemplate) getSpringContext().getBean(HibernateTemplate.class.getName());
+        assertNotNull(template); 
+        return template; 
+    }
+
+    /**
+     * Flushes the session. Should be called after some Hibernate work and
+     * before JDBC is used to check results.
+     * 
+     */
+    protected void flush() {
+        getTemplate().flush();
+    }
+
+    /**
+     * Flushes the session first and then removes all objects from the Session
+     * cache. Should be called after some Hibernate work and before JDBC is used
+     * to check results.
+     * 
+     */
+    protected void clear() {
+        flush();
+        getTemplate().clear();
+    }
+
+    /**
+     * Evicts the object from the session. This is essential for the
+     * implementation of unit tests where first an object is saved and is
+     * retrieved later. By removing the object from the session, Hibernate must
+     * retrieve the object again from the database.
+     * 
+     * @param aObject
+     */
+    protected void evict(Object aObject) {
+        getTemplate().evict(aObject);
+    }
+
+    /**
+     * Gets the connection.
+     * 
+     * @return Connection.
+     */
+    public Connection getConnection() {
+        return DataSourceUtils.getConnection(getDataSource());
+    }
+
+    public void cleanDatabase() throws SQLException {
+        
+        if (! isDatabaseConfigured() ) {
+            return; 
+        }
+        
+        String[] tables = getTableNames();
+
+        try {
+            IDatabaseConnection connection = new DatabaseConnection(
+                    getConnection());
+            ITableFilter filter = new DatabaseSequenceFilter(connection, tables);
+            IDataSet dataset = new FilteredDataSet(filter, connection
+                    .createDataSet(tables));
+
+            DatabaseOperation.DELETE_ALL.execute(connection, dataset);
+        } catch (DatabaseUnitException e) {
+            SQLException exc = new SQLException(e.getMessage());
+            exc.initCause(e);
+            throw exc;
+        } 
+    }
+
+    /**
+     * @throws SQLException
+     */
+    public String[] getTableNames() throws SQLException {
+
+        List result = new ArrayList();
+        LOG.debug("Getting database table names to clean (schema: '"
+                + SCHEMA_PATTERN + "'");
+
+        ResultSet tables = getConnection().getMetaData().getTables(null,
+                SCHEMA_PATTERN, "%", new String[] { "TABLE" });
+        while (tables.next()) {
+            String table = tables.getString("TABLE_NAME");
+            // Make sure we do not touch hibernate's specific
+            // infrastructure tables.
+            if (!table.toLowerCase().startsWith("hibernate")) {
+                result.add(table);
+                LOG.debug("Adding " + table
+                        + " to list of tables to be cleaned.");
+            }
+        }
+        return (String[]) result.toArray(new String[0]);
+    }
+
+    /**
+     * @return
+     * @throws SQLException
+     */
+    public void emptyTables(List aTableList) throws SQLException {
+        Iterator liTable = aTableList.iterator();
+        while (liTable.hasNext()) {
+            emptyTable((String) liTable.next());
+        }
+    }
+
+    /**
+     * @return
+     * @throws SQLException
+     */
+    public void emptyTable(String aTable) throws SQLException {
+        executeSql("delete from " + aTable);
+    }
+
+    /**
+     * @return
+     * @throws SQLException
+     */
+    public void dropTable(String aTable) throws SQLException {
+        executeQuery("drop table " + aTable);
+    }
+
+    /**
+     * Executes an SQL statement within a transaction.
+     * 
+     * @param aSql
+     *            SQL statement.
+     * @return Return code of the corresponding JDBC call.
+     */
+    public int executeSql(final String aSql) {
+        return executeSql(aSql, new Object[0]);
+    }
+
+    /**
+     * Executes an SQL statement within a transaction. See
+     * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+     * supported argument types.
+     * 
+     * @param aSql
+     *            SQL statement.
+     * @param aArg
+     *            Argument of the sql statement.
+     * @return Return code of the corresponding JDBC call.
+     */
+    public int executeSql(final String aSql, final Object aArg) {
+        return executeSql(aSql, new Object[] { aArg });
+    }
+
+    /**
+     * Executes an sql statement. See
+     * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+     * supported argument types.
+     * 
+     * @param aSql
+     *            SQL query to execute.
+     * @param aArgs
+     *            Arguments.
+     * @return Number of rows updated.
+     */
+    public int executeSql(final String aSql, final Object[] aArgs) {
+        Map results = executeTransaction(new TestTransactionCallback() {
+            public Map execute() throws Exception {
+                JdbcTemplate template = new JdbcTemplate(getDataSource());
+                int result = template.update(aSql, aArgs);
+
+                Map map = new TreeMap();
+                map.put("result", new Integer(result));
+
+                return map;
+            }
+        });
+
+        return ((Integer) results.get("result")).intValue();
+    }
+
+    /**
+     * Executes a transaction with a result.
+     * 
+     * @param aCallback
+     *            Callback to do your transactional work.
+     * @return Result.
+     */
+    public Object executeTransaction(TransactionCallback aCallback) {
+        TransactionTemplate lTemplate = new TransactionTemplate(
+                getTransactionManager());
+        return lTemplate.execute(aCallback);
+    }
+
+    /**
+     * Executes a transaction without a result.
+     * 
+     * @param aCallback
+     *            Callback to do your transactional work. .
+     */
+    protected void executeTransaction(TransactionCallbackWithoutResult aCallback) {
+        TransactionTemplate template = new TransactionTemplate(
+                getTransactionManager());
+        template.execute(aCallback);
+    }
+
+    /**
+     * Executes a transaction with a result, causing the testcase to fail if any
+     * type of exception is thrown.
+     * 
+     * @param aCallback
+     *            Code to be executed within the transaction.
+     * @return Result.
+     */
+    public Map executeTransaction(final TestTransactionCallback aCallback) {
+        return (Map) executeTransaction(new TransactionCallback() {
+            public Object doInTransaction(TransactionStatus aArg) {
+                try {
+                    return aCallback.execute();
+                } catch (Exception e) {
+                    // test case must fail.
+                    e.printStackTrace();
+                    throw new RuntimeException(e);
+                }
+            }
+        });
+    }
+
+    /**
+     * Executes a transaction with a result, causing the testcase to fail if any
+     * type of exception is thrown.
+     * 
+     * @param aCallback
+     *            Code to be executed within the transaction.
+     */
+    public void executeTransaction(
+            final TestTransactionCallbackWithoutResult aCallback) {
+        executeTransaction(new TransactionCallbackWithoutResult() {
+            public void doInTransactionWithoutResult(TransactionStatus aArg) {
+                try {
+                    aCallback.execute();
+                } catch (Exception e) {
+                    // test case must fail.
+                    throw new RuntimeException(e.getMessage(), e);
+                }
+            }
+        });
+    }
+
+    /**
+     * Executes an SQL query within a transaction.
+     * 
+     * @param aSql
+     *            Query to execute.
+     * @return Result set.
+     */
+    public ResultSet executeQuery(String aSql) {
+        return executeQuery(aSql, new Object[0]);
+    }
+
+    /**
+     * Executes a query with a single argument. See
+     * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+     * supported argument types.
+     * 
+     * @param aSql
+     *            Query.
+     * @param aArg
+     *            Argument.
+     * @return Result set.
+     */
+    public ResultSet executeQuery(String aSql, Object aArg) {
+        return executeQuery(aSql, new Object[] { aArg });
+    }
+
+    /**
+     * Executes a query within a transaction. See
+     * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+     * supported argument types.
+     * 
+     * @param aSql
+     *            Sql query.
+     * @param aArgs
+     *            Arguments to the query.
+     * @return Result set.
+     */
+    public ResultSet executeQuery(final String aSql, final Object[] aArgs) {
+        Map results = executeTransaction(new TestTransactionCallback() {
+            public Map execute() throws Exception {
+                Connection connection = getConnection();
+
+                PreparedStatement statement = connection.prepareStatement(aSql);
+                setPreparedParams(aArgs, statement);
+
+                ResultSet resultSet = statement.executeQuery();
+                TreeMap results = new TreeMap();
+                results.put("resultSet", resultSet);
+
+                return results;
+            }
+        });
+
+        return (ResultSet) results.get("resultSet");
+    }
+
+    /**
+     * Sets the values of a prepared statement. See
+     * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+     * supported argument types.
+     * 
+     * @param aArgs
+     *            Arguments to the prepared statement.
+     * @param aStatement
+     *            Prepared statement
+     * @throws SQLException
+     */
+    private void setPreparedParams(final Object[] aArgs,
+            PreparedStatement aStatement) throws SQLException {
+        for (int i = 1; i <= aArgs.length; i++) {
+            setPreparedParam(i, aStatement, aArgs[i - 1]);
+        }
+    }
+
+    /**
+     * Sets a prepared statement parameter.
+     * 
+     * @param aIndex
+     *            Index of the parameter.
+     * @param aStatement
+     *            Prepared statement.
+     * @param aObject
+     *            Value Must be of type Integer, Long, or String. TODO extend
+     *            with more types of values.
+     * @throws SQLException
+     */
+    private void setPreparedParam(int aIndex, PreparedStatement aStatement,
+            Object aObject) throws SQLException {
+        if (aObject instanceof Integer) {
+            aStatement.setInt(aIndex, ((Integer) aObject).intValue());
+        } else if (aObject instanceof Long) {
+            aStatement.setLong(aIndex, ((Integer) aObject).longValue());
+        } else if (aObject instanceof String) {
+            aStatement.setString(aIndex, (String) aObject);
+        } else {
+            TestCase.fail("Unsupported object type for prepared statement: "
+                    + aObject.getClass() + " value: " + aObject
+                    + " statement: " + aStatement);
+        }
+    }
+    
+    private boolean isDatabaseConfigured() { 
+        try { 
+            getDataSource(); 
+        } catch (NoSuchBeanDefinitionException e ) { 
+            return false; 
+        }
+        return true; 
+    }
+
+    /**
+     * @return Returns the dataSource.
+     */
+    public DataSource getDataSource() {
+        DataSource ds = (DriverManagerDataSource) getSpringContext().getBean(DATA_SOURCE);
+        assertNotNull(ds); 
+        return ds; 
+    }
+
+    /**
+     * @return
+     * @throws SQLException
+     */
+    protected int getTableSize(String aTable) throws SQLException {
+        ResultSet resultSet = executeQuery("select * from " + aTable);
+        int count = 0;
+
+        while (resultSet.next()) {
+            count++;
+        }
+
+        return count;
+    }
+    
+    protected int countResultSet(ResultSet aResultSet) throws SQLException { 
+        int count = 0;
+
+        while (aResultSet.next()) {
+            count++;
+        }
+
+        return count;
+    }
+
+}