2 * Copyright 2005 the original author or authors.
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 package org.wamblee.test;
19 import java.sql.Connection;
20 import java.sql.PreparedStatement;
21 import java.sql.ResultSet;
22 import java.sql.SQLException;
23 import java.util.ArrayList;
24 import java.util.HashMap;
25 import java.util.Iterator;
26 import java.util.List;
28 import java.util.TreeMap;
30 import javax.sql.DataSource;
32 import junit.framework.TestCase;
34 import org.apache.commons.logging.Log;
35 import org.apache.commons.logging.LogFactory;
36 import org.dbunit.DatabaseUnitException;
37 import org.dbunit.database.DatabaseConnection;
38 import org.dbunit.database.DatabaseSequenceFilter;
39 import org.dbunit.database.IDatabaseConnection;
40 import org.dbunit.dataset.FilteredDataSet;
41 import org.dbunit.dataset.IDataSet;
42 import org.dbunit.dataset.filter.ITableFilter;
43 import org.dbunit.operation.DatabaseOperation;
44 import org.hibernate.SessionFactory;
45 import org.jmock.cglib.MockObjectTestCase;
46 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
47 import org.springframework.beans.factory.config.BeanDefinition;
48 import org.springframework.beans.factory.support.RootBeanDefinition;
49 import org.springframework.context.ApplicationContext;
50 import org.springframework.context.support.ClassPathXmlApplicationContext;
51 import org.springframework.context.support.GenericApplicationContext;
52 import org.springframework.jdbc.core.JdbcTemplate;
53 import org.springframework.jdbc.datasource.DataSourceUtils;
54 import org.springframework.jdbc.datasource.DriverManagerDataSource;
55 import org.springframework.orm.hibernate3.HibernateTemplate;
56 import org.springframework.transaction.PlatformTransactionManager;
57 import org.springframework.transaction.TransactionDefinition;
58 import org.springframework.transaction.TransactionStatus;
59 import org.springframework.transaction.support.DefaultTransactionDefinition;
60 import org.springframework.transaction.support.TransactionCallback;
61 import org.springframework.transaction.support.TransactionCallbackWithoutResult;
62 import org.springframework.transaction.support.TransactionTemplate;
63 import org.wamblee.general.BeanKernel;
64 import org.wamblee.persistence.hibernate.HibernateMappingFiles;
67 * Test case support class for spring tests.
69 public class SpringTestCase extends MockObjectTestCase {
71 private static final Log LOG = LogFactory.getLog(SpringTestCase.class);
74 * Session factory bean name.
76 private static final String SESSION_FACTORY = "sessionFactory";
79 * Data source bean name.
81 private static final String DATA_SOURCE = "dataSource";
84 * Transaction manager bean name.
86 private static final String TRANSACTION_MANAGER = "transactionManager";
89 * Name of the ConfigFileList bean that describes the Hibernate mapping
92 private static final String HIBERNATE_CONFIG_FILES = "hibernateMappingFiles";
97 private static final String SCHEMA_PATTERN = "%";
100 * List of (String) configuration file locations for spring.
102 private String[] _configLocations;
105 * Application context for storing bean definitions that vary on a test by
106 * test basis and cannot be hardcoded in the spring configuration files.
108 private GenericApplicationContext _parentContext;
111 * Cached spring application context.
113 private ApplicationContext _context;
115 public SpringTestCase(Class<? extends SpringConfigFiles> aSpringFiles,
116 Class<? extends HibernateMappingFiles> aMappingFiles) {
118 SpringConfigFiles springFiles = aSpringFiles.newInstance();
119 _configLocations = springFiles.toArray(new String[0]);
120 } catch (Exception e) {
121 fail("Could not construct spring config files class '"
122 + aSpringFiles.getName() + "'");
125 // Register the Hibernate mapping files as a bean.
126 _parentContext = new GenericApplicationContext();
127 BeanDefinition lDefinition = new RootBeanDefinition(aMappingFiles);
128 _parentContext.registerBeanDefinition(HIBERNATE_CONFIG_FILES,
130 _parentContext.refresh();
135 * Gets the spring context.
137 * @return Spring context.
139 protected synchronized ApplicationContext getSpringContext() {
140 if (_context == null) {
141 _context = new ClassPathXmlApplicationContext(
142 (String[]) _configLocations, _parentContext);
143 assertNotNull(_context);
149 * @return Hibernate session factory.
151 protected SessionFactory getSessionFactory() {
152 SessionFactory factory = (SessionFactory) getSpringContext().getBean(
154 assertNotNull(factory);
158 protected void setUp() throws Exception {
159 LOG.info("Performing setUp()");
163 _context = null; // make sure we get a new application context for
167 BeanKernel.overrideBeanFactory(new TestSpringBeanFactory(
168 getSpringContext()));
176 * @see junit.framework.TestCase#tearDown()
179 protected void tearDown() throws Exception {
183 LOG.info("tearDown() complete");
188 * @return Transaction manager
190 protected PlatformTransactionManager getTransactionManager() {
191 PlatformTransactionManager manager = (PlatformTransactionManager) getSpringContext()
192 .getBean(TRANSACTION_MANAGER);
193 assertNotNull(manager);
198 * @return Starts a new transaction.
200 protected TransactionStatus getTransaction() {
201 DefaultTransactionDefinition def = new DefaultTransactionDefinition();
202 def.setPropagationBehavior(TransactionDefinition.PROPAGATION_REQUIRED);
204 return getTransactionManager().getTransaction(def);
208 * Returns the hibernate template for executing hibernate-specific
211 * @return Hibernate template.
213 protected HibernateTemplate getTemplate() {
214 HibernateTemplate template = (HibernateTemplate) getSpringContext()
215 .getBean(HibernateTemplate.class.getName());
216 assertNotNull(template);
221 * Flushes the session. Should be called after some Hibernate work and
222 * before JDBC is used to check results.
225 protected void flush() {
226 getTemplate().flush();
230 * Flushes the session first and then removes all objects from the Session
231 * cache. Should be called after some Hibernate work and before JDBC is used
235 protected void clear() {
237 getTemplate().clear();
241 * Evicts the object from the session. This is essential for the
242 * implementation of unit tests where first an object is saved and is
243 * retrieved later. By removing the object from the session, Hibernate must
244 * retrieve the object again from the database.
248 protected void evict(Object aObject) {
249 getTemplate().evict(aObject);
253 * Gets the connection.
255 * @return Connection.
257 public Connection getConnection() {
258 return DataSourceUtils.getConnection(getDataSource());
261 public void cleanDatabase() throws SQLException {
263 if (!isDatabaseConfigured()) {
267 String[] tables = getTableNames();
270 IDatabaseConnection connection = new DatabaseConnection(
272 ITableFilter filter = new DatabaseSequenceFilter(connection, tables);
273 IDataSet dataset = new FilteredDataSet(filter, connection
274 .createDataSet(tables));
276 DatabaseOperation.DELETE_ALL.execute(connection, dataset);
277 } catch (DatabaseUnitException e) {
278 SQLException exc = new SQLException(e.getMessage());
285 * @throws SQLException
287 public String[] getTableNames() throws SQLException {
289 List<String> result = new ArrayList<String>();
290 LOG.debug("Getting database table names to clean (schema: '"
291 + SCHEMA_PATTERN + "'");
293 ResultSet tables = getConnection().getMetaData().getTables(null,
294 SCHEMA_PATTERN, "%", new String[] { "TABLE" });
295 while (tables.next()) {
296 String table = tables.getString("TABLE_NAME");
297 // Make sure we do not touch hibernate's specific
298 // infrastructure tables.
299 if (!table.toLowerCase().startsWith("hibernate")) {
301 LOG.debug("Adding " + table
302 + " to list of tables to be cleaned.");
305 return (String[]) result.toArray(new String[0]);
310 * @throws SQLException
312 public void emptyTables(List aTableList) throws SQLException {
313 Iterator liTable = aTableList.iterator();
314 while (liTable.hasNext()) {
315 emptyTable((String) liTable.next());
321 * @throws SQLException
323 public void emptyTable(String aTable) throws SQLException {
324 executeSql("delete from " + aTable);
329 * @throws SQLException
331 public void dropTable(String aTable) throws SQLException {
332 executeQuery("drop table " + aTable);
336 * Executes an SQL statement within a transaction.
340 * @return Return code of the corresponding JDBC call.
342 public int executeSql(final String aSql) {
343 return executeSql(aSql, new Object[0]);
347 * Executes an SQL statement within a transaction. See
348 * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
349 * supported argument types.
354 * Argument of the sql statement.
355 * @return Return code of the corresponding JDBC call.
357 public int executeSql(final String aSql, final Object aArg) {
358 return executeSql(aSql, new Object[] { aArg });
362 * Executes an sql statement. See
363 * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
364 * supported argument types.
367 * SQL query to execute.
370 * @return Number of rows updated.
372 public int executeSql(final String aSql, final Object[] aArgs) {
373 Map results = executeTransaction(new TestTransactionCallback() {
374 public Map execute() throws Exception {
375 JdbcTemplate template = new JdbcTemplate(getDataSource());
376 int result = template.update(aSql, aArgs);
378 Map<String, Integer> map = new TreeMap<String, Integer>();
379 map.put("result", new Integer(result));
385 return ((Integer) results.get("result")).intValue();
389 * Executes a transaction with a result.
392 * Callback to do your transactional work.
395 public Object executeTransaction(TransactionCallback aCallback) {
396 TransactionTemplate lTemplate = new TransactionTemplate(
397 getTransactionManager());
398 return lTemplate.execute(aCallback);
402 * Executes a transaction without a result.
405 * Callback to do your transactional work. .
407 protected void executeTransaction(TransactionCallbackWithoutResult aCallback) {
408 TransactionTemplate template = new TransactionTemplate(
409 getTransactionManager());
410 template.execute(aCallback);
414 * Executes a transaction with a result, causing the testcase to fail if any
415 * type of exception is thrown.
418 * Code to be executed within the transaction.
421 public Map executeTransaction(final TestTransactionCallback aCallback) {
422 return (Map) executeTransaction(new TransactionCallback() {
423 public Object doInTransaction(TransactionStatus aArg) {
425 return aCallback.execute();
426 } catch (Exception e) {
427 // test case must fail.
429 throw new RuntimeException(e);
436 * Executes a transaction with a result, causing the testcase to fail if any
437 * type of exception is thrown.
440 * Code to be executed within the transaction.
442 public void executeTransaction(
443 final TestTransactionCallbackWithoutResult aCallback) {
444 executeTransaction(new TransactionCallbackWithoutResult() {
445 public void doInTransactionWithoutResult(TransactionStatus aArg) {
448 } catch (Exception e) {
449 // test case must fail.
450 throw new RuntimeException(e.getMessage(), e);
457 * Executes an SQL query.
461 * @return Result set.
463 public ResultSet executeQuery(String aSql) {
464 return executeQuery(aSql, new Object[0]);
468 * Executes a query with a single argument. See
469 * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
470 * supported argument types.
476 * @return Result set.
478 public ResultSet executeQuery(String aSql, Object aArg) {
479 return executeQuery(aSql, new Object[] { aArg });
483 * Executes a query. See
484 * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
485 * supported argument types.
490 * Arguments to the query.
491 * @return Result set.
493 public ResultSet executeQuery(final String aSql, final Object[] aArgs) {
495 Connection connection = getConnection();
497 PreparedStatement statement = connection.prepareStatement(aSql);
498 setPreparedParams(aArgs, statement);
500 return statement.executeQuery();
501 } catch (SQLException e) {
502 throw new RuntimeException(e);
507 * Sets the values of a prepared statement. See
508 * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
509 * supported argument types.
512 * Arguments to the prepared statement.
515 * @throws SQLException
517 private void setPreparedParams(final Object[] aArgs,
518 PreparedStatement aStatement) throws SQLException {
519 for (int i = 1; i <= aArgs.length; i++) {
520 setPreparedParam(i, aStatement, aArgs[i - 1]);
525 * Sets a prepared statement parameter.
528 * Index of the parameter.
530 * Prepared statement.
532 * Value Must be of type Integer, Long, or String. TODO extend
533 * with more types of values.
534 * @throws SQLException
536 private void setPreparedParam(int aIndex, PreparedStatement aStatement,
537 Object aObject) throws SQLException {
538 if (aObject instanceof Integer) {
539 aStatement.setInt(aIndex, ((Integer) aObject).intValue());
540 } else if (aObject instanceof Long) {
541 aStatement.setLong(aIndex, ((Integer) aObject).longValue());
542 } else if (aObject instanceof String) {
543 aStatement.setString(aIndex, (String) aObject);
545 TestCase.fail("Unsupported object type for prepared statement: "
546 + aObject.getClass() + " value: " + aObject
547 + " statement: " + aStatement);
551 private boolean isDatabaseConfigured() {
554 } catch (NoSuchBeanDefinitionException e) {
561 * @return Returns the dataSource.
563 public DataSource getDataSource() {
564 DataSource ds = (DriverManagerDataSource) getSpringContext().getBean(
572 * @throws SQLException
574 protected int getTableSize(final String aTable) throws SQLException {
576 ResultSet resultSet = executeQuery("select * from " + aTable);
579 while (resultSet.next()) {
585 protected int countResultSet(ResultSet aResultSet) throws SQLException {
588 while (aResultSet.next()) {