/* * Copyright 2005 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.wamblee.test; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.TreeMap; import javax.sql.DataSource; import junit.framework.TestCase; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.dbunit.DatabaseUnitException; import org.dbunit.database.DatabaseConnection; import org.dbunit.database.DatabaseSequenceFilter; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.FilteredDataSet; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.filter.ITableFilter; import org.dbunit.operation.DatabaseOperation; import org.hibernate.SessionFactory; import org.jmock.cglib.MockObjectTestCase; import org.springframework.beans.factory.NoSuchBeanDefinitionException; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.support.RootBeanDefinition; import org.springframework.context.ApplicationContext; import org.springframework.context.support.ClassPathXmlApplicationContext; import org.springframework.context.support.GenericApplicationContext; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.DataSourceUtils; import org.springframework.jdbc.datasource.DriverManagerDataSource; import org.springframework.orm.hibernate3.HibernateTemplate; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.TransactionDefinition; import org.springframework.transaction.TransactionStatus; import org.springframework.transaction.support.DefaultTransactionDefinition; import org.springframework.transaction.support.TransactionCallback; import org.springframework.transaction.support.TransactionCallbackWithoutResult; import org.springframework.transaction.support.TransactionTemplate; import org.wamblee.general.BeanKernel; import org.wamblee.persistence.hibernate.HibernateMappingFiles; /** * Test case support class for spring tests. */ public class SpringTestCase extends MockObjectTestCase { private Log LOG = LogFactory.getLog(SpringTestCase.class); /** * Session factory bean name. */ private static final String SESSION_FACTORY = "sessionFactory"; /** * Data source bean name. */ private static final String DATA_SOURCE = "dataSource"; /** * Transaction manager bean name. */ private static final String TRANSACTION_MANAGER = "transactionManager"; /** * Name of the ConfigFileList bean that describes the Hibernate mapping * files to use. */ private static final String HIBERNATE_CONFIG_FILES = "hibernateMappingFiles"; /** * Schema pattern. */ private static final String SCHEMA_PATTERN = "%"; /** * List of (String) configuration file locations for spring. */ private String[] _configLocations; /** * Application context for storing bean definitions that vary on a test by * test basis and cannot be hardcoded in the spring configuration files. */ private GenericApplicationContext _parentContext; /** * Cached spring application context. */ private ApplicationContext _context; public SpringTestCase(Class aSpringFiles, Class aMappingFiles) { try { SpringConfigFiles springFiles = aSpringFiles.newInstance(); _configLocations = springFiles.toArray(new String[0]); } catch (Exception e) { fail("Could not construct spring config files class '" + aSpringFiles.getName() + "'"); } // Register the Hibernate mapping files as a bean. _parentContext = new GenericApplicationContext(); BeanDefinition lDefinition = new RootBeanDefinition(aMappingFiles); _parentContext.registerBeanDefinition(HIBERNATE_CONFIG_FILES, lDefinition); _parentContext.refresh(); } /** * Gets the spring context. * * @return Spring context. */ protected synchronized ApplicationContext getSpringContext() { if ( _context == null ) { _context = new ClassPathXmlApplicationContext( (String[]) _configLocations, _parentContext); assertNotNull(_context); } return _context; } /** * @return Hibernate session factory. */ protected SessionFactory getSessionFactory() { SessionFactory factory = (SessionFactory) getSpringContext().getBean(SESSION_FACTORY); assertNotNull(factory); return factory; } protected void setUp() throws Exception { LOG.info("Performing setUp()"); super.setUp(); _context = null; // make sure we get a new application context for every // new test. BeanKernel .overrideBeanFactory(new TestSpringBeanFactory(getSpringContext())); cleanDatabase(); } /* * (non-Javadoc) * * @see junit.framework.TestCase#tearDown() */ @Override protected void tearDown() throws Exception { try { super.tearDown(); } finally { LOG.info("tearDown() complete"); } } /** * @return Transaction manager */ protected PlatformTransactionManager getTransactionManager() { PlatformTransactionManager manager = (PlatformTransactionManager) getSpringContext() .getBean(TRANSACTION_MANAGER); assertNotNull(manager); return manager; } /** * @return Starts a new transaction. */ protected TransactionStatus getTransaction() { DefaultTransactionDefinition def = new DefaultTransactionDefinition(); def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED); return getTransactionManager().getTransaction(def); } /** * Returns the hibernate template for executing hibernate-specific * functionality. * * @return Hibernate template. */ protected HibernateTemplate getTemplate() { HibernateTemplate template = (HibernateTemplate) getSpringContext().getBean(HibernateTemplate.class.getName()); assertNotNull(template); return template; } /** * Flushes the session. Should be called after some Hibernate work and * before JDBC is used to check results. * */ protected void flush() { getTemplate().flush(); } /** * Flushes the session first and then removes all objects from the Session * cache. Should be called after some Hibernate work and before JDBC is used * to check results. * */ protected void clear() { flush(); getTemplate().clear(); } /** * Evicts the object from the session. This is essential for the * implementation of unit tests where first an object is saved and is * retrieved later. By removing the object from the session, Hibernate must * retrieve the object again from the database. * * @param aObject */ protected void evict(Object aObject) { getTemplate().evict(aObject); } /** * Gets the connection. * * @return Connection. */ public Connection getConnection() { return DataSourceUtils.getConnection(getDataSource()); } public void cleanDatabase() throws SQLException { if (! isDatabaseConfigured() ) { return; } String[] tables = getTableNames(); try { IDatabaseConnection connection = new DatabaseConnection( getConnection()); ITableFilter filter = new DatabaseSequenceFilter(connection, tables); IDataSet dataset = new FilteredDataSet(filter, connection .createDataSet(tables)); DatabaseOperation.DELETE_ALL.execute(connection, dataset); } catch (DatabaseUnitException e) { SQLException exc = new SQLException(e.getMessage()); exc.initCause(e); throw exc; } } /** * @throws SQLException */ public String[] getTableNames() throws SQLException { List result = new ArrayList(); LOG.debug("Getting database table names to clean (schema: '" + SCHEMA_PATTERN + "'"); ResultSet tables = getConnection().getMetaData().getTables(null, SCHEMA_PATTERN, "%", new String[] { "TABLE" }); while (tables.next()) { String table = tables.getString("TABLE_NAME"); // Make sure we do not touch hibernate's specific // infrastructure tables. if (!table.toLowerCase().startsWith("hibernate")) { result.add(table); LOG.debug("Adding " + table + " to list of tables to be cleaned."); } } return (String[]) result.toArray(new String[0]); } /** * @return * @throws SQLException */ public void emptyTables(List aTableList) throws SQLException { Iterator liTable = aTableList.iterator(); while (liTable.hasNext()) { emptyTable((String) liTable.next()); } } /** * @return * @throws SQLException */ public void emptyTable(String aTable) throws SQLException { executeSql("delete from " + aTable); } /** * @return * @throws SQLException */ public void dropTable(String aTable) throws SQLException { executeQuery("drop table " + aTable); } /** * Executes an SQL statement within a transaction. * * @param aSql * SQL statement. * @return Return code of the corresponding JDBC call. */ public int executeSql(final String aSql) { return executeSql(aSql, new Object[0]); } /** * Executes an SQL statement within a transaction. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * SQL statement. * @param aArg * Argument of the sql statement. * @return Return code of the corresponding JDBC call. */ public int executeSql(final String aSql, final Object aArg) { return executeSql(aSql, new Object[] { aArg }); } /** * Executes an sql statement. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * SQL query to execute. * @param aArgs * Arguments. * @return Number of rows updated. */ public int executeSql(final String aSql, final Object[] aArgs) { Map results = executeTransaction(new TestTransactionCallback() { public Map execute() throws Exception { JdbcTemplate template = new JdbcTemplate(getDataSource()); int result = template.update(aSql, aArgs); Map map = new TreeMap(); map.put("result", new Integer(result)); return map; } }); return ((Integer) results.get("result")).intValue(); } /** * Executes a transaction with a result. * * @param aCallback * Callback to do your transactional work. * @return Result. */ public Object executeTransaction(TransactionCallback aCallback) { TransactionTemplate lTemplate = new TransactionTemplate( getTransactionManager()); return lTemplate.execute(aCallback); } /** * Executes a transaction without a result. * * @param aCallback * Callback to do your transactional work. . */ protected void executeTransaction(TransactionCallbackWithoutResult aCallback) { TransactionTemplate template = new TransactionTemplate( getTransactionManager()); template.execute(aCallback); } /** * Executes a transaction with a result, causing the testcase to fail if any * type of exception is thrown. * * @param aCallback * Code to be executed within the transaction. * @return Result. */ public Map executeTransaction(final TestTransactionCallback aCallback) { return (Map) executeTransaction(new TransactionCallback() { public Object doInTransaction(TransactionStatus aArg) { try { return aCallback.execute(); } catch (Exception e) { // test case must fail. e.printStackTrace(); throw new RuntimeException(e); } } }); } /** * Executes a transaction with a result, causing the testcase to fail if any * type of exception is thrown. * * @param aCallback * Code to be executed within the transaction. */ public void executeTransaction( final TestTransactionCallbackWithoutResult aCallback) { executeTransaction(new TransactionCallbackWithoutResult() { public void doInTransactionWithoutResult(TransactionStatus aArg) { try { aCallback.execute(); } catch (Exception e) { // test case must fail. throw new RuntimeException(e.getMessage(), e); } } }); } /** * Executes an SQL query within a transaction. * * @param aSql * Query to execute. * @return Result set. */ public ResultSet executeQuery(String aSql) { return executeQuery(aSql, new Object[0]); } /** * Executes a query with a single argument. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * Query. * @param aArg * Argument. * @return Result set. */ public ResultSet executeQuery(String aSql, Object aArg) { return executeQuery(aSql, new Object[] { aArg }); } /** * Executes a query within a transaction. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * Sql query. * @param aArgs * Arguments to the query. * @return Result set. */ public ResultSet executeQuery(final String aSql, final Object[] aArgs) { Map results = executeTransaction(new TestTransactionCallback() { public Map execute() throws Exception { Connection connection = getConnection(); PreparedStatement statement = connection.prepareStatement(aSql); setPreparedParams(aArgs, statement); ResultSet resultSet = statement.executeQuery(); TreeMap results = new TreeMap(); results.put("resultSet", resultSet); return results; } }); return (ResultSet) results.get("resultSet"); } /** * Sets the values of a prepared statement. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aArgs * Arguments to the prepared statement. * @param aStatement * Prepared statement * @throws SQLException */ private void setPreparedParams(final Object[] aArgs, PreparedStatement aStatement) throws SQLException { for (int i = 1; i <= aArgs.length; i++) { setPreparedParam(i, aStatement, aArgs[i - 1]); } } /** * Sets a prepared statement parameter. * * @param aIndex * Index of the parameter. * @param aStatement * Prepared statement. * @param aObject * Value Must be of type Integer, Long, or String. TODO extend * with more types of values. * @throws SQLException */ private void setPreparedParam(int aIndex, PreparedStatement aStatement, Object aObject) throws SQLException { if (aObject instanceof Integer) { aStatement.setInt(aIndex, ((Integer) aObject).intValue()); } else if (aObject instanceof Long) { aStatement.setLong(aIndex, ((Integer) aObject).longValue()); } else if (aObject instanceof String) { aStatement.setString(aIndex, (String) aObject); } else { TestCase.fail("Unsupported object type for prepared statement: " + aObject.getClass() + " value: " + aObject + " statement: " + aStatement); } } private boolean isDatabaseConfigured() { try { getDataSource(); } catch (NoSuchBeanDefinitionException e ) { return false; } return true; } /** * @return Returns the dataSource. */ public DataSource getDataSource() { DataSource ds = (DriverManagerDataSource) getSpringContext().getBean(DATA_SOURCE); assertNotNull(ds); return ds; } /** * @return * @throws SQLException */ protected int getTableSize(String aTable) throws SQLException { ResultSet resultSet = executeQuery("select * from " + aTable); int count = 0; while (resultSet.next()) { count++; } return count; } protected int countResultSet(ResultSet aResultSet) throws SQLException { int count = 0; while (aResultSet.next()) { count++; } return count; } }