/* * Copyright 2005-2010 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.wamblee.support.persistence; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.ArrayList; import java.util.List; import java.util.logging.Logger; import javax.sql.DataSource; import junit.framework.TestCase; import org.dbunit.DataSourceDatabaseTester; import org.dbunit.IDatabaseTester; import org.dbunit.database.DatabaseConnection; import org.dbunit.database.DatabaseSequenceFilter; import org.dbunit.database.IDatabaseConnection; import org.dbunit.dataset.FilteredDataSet; import org.dbunit.dataset.IDataSet; import org.dbunit.dataset.filter.ITableFilter; import org.dbunit.dataset.filter.ITableFilterSimple; import org.dbunit.operation.DatabaseOperation; /** * Database utilities is a simple support class for common tasks in working with * databases. */ public class DatabaseUtils { public static interface TableSet { boolean contains(String aTableName); } public static interface JdbcUnitOfWork { T execute(Connection aConnection) throws Exception; } public static interface TableSetOperation { void execute(String aTable) throws Exception; } private static final Logger LOG = Logger.getLogger(DatabaseUtils.class .getName()); /** * Schema pattern. */ private static final String SCHEMA_PATTERN = "%"; private DataSource dataSource; private ITableFilterSimple tables; public DatabaseUtils(DataSource aDataSource, ITableFilterSimple aTables) { dataSource = aDataSource; tables = aTables; } public IDatabaseTester createDbTester() throws Exception { return createDbTester(getTableNames(tables)); } public IDatabaseTester createDbTester(String[] aTables) throws Exception { IDatabaseTester dbtester = new DataSourceDatabaseTester(dataSource); dbtester.setDataSet(dbtester.getConnection().createDataSet(aTables)); return dbtester; } public void cleanDatabase() throws Exception { cleanDatabase(tables); } public void executeOnTables(ITableFilterSimple aTables, final TableSetOperation aOperation) throws Exception { final String[] tables = getTableNames(aTables); executeInTransaction(new JdbcUnitOfWork() { public Void execute(Connection aConnection) throws Exception { for (int i = tables.length-1; i >= 0; i--) { aOperation.execute(tables[i]); } return null; } }); } public void cleanDatabase(ITableFilterSimple aSelection) throws Exception { final String[] tables = getTableNames(aSelection); executeInTransaction(new JdbcUnitOfWork() { public Void execute(Connection aConnection) throws Exception { IDatabaseConnection connection = new DatabaseConnection( aConnection); ITableFilter filter = new DatabaseSequenceFilter(connection, tables); IDataSet dataset = new FilteredDataSet(filter, connection .createDataSet(tables)); DatabaseOperation.DELETE_ALL.execute(connection, dataset); return null; } }); } public T executeInTransaction(JdbcUnitOfWork aCallback) throws Exception { Connection connection = dataSource.getConnection(); try { T value = aCallback.execute(connection); connection.commit(); return value; } finally { connection.close(); } } public String[] getTableNames() throws Exception { return getTableNames(tables); } /** * @throws SQLException */ public String[] getTableNames(ITableFilterSimple aSelection) throws Exception { List result = new ArrayList(); LOG.fine("Getting database table names to clean (schema: '" + SCHEMA_PATTERN + "'"); Connection connection = dataSource.getConnection(); try { ResultSet tables = connection.getMetaData().getTables(null, SCHEMA_PATTERN, "%", new String[] { "TABLE" }); while (tables.next()) { String table = tables.getString("TABLE_NAME"); if (aSelection.accept(table)) { result.add(table); } } return (String[]) result.toArray(new String[0]); } finally { connection.close(); } } public void emptyTables() throws Exception { executeOnTables(tables, new TableSetOperation() { public void execute(String aTable) throws Exception { emptyTable(aTable); } }); } /** * @return * @throws SQLException */ public void emptyTables(final ITableFilterSimple aSelection) throws Exception { executeOnTables(aSelection, new TableSetOperation() { public void execute(String aTable) throws Exception { emptyTable(aTable); } }); } /** * @return * @throws SQLException */ public void emptyTable(String aTable) throws Exception { executeSql("delete from " + aTable); } public void dropTables() throws Exception { dropTables(tables); } public void dropTables(ITableFilterSimple aTables) throws Exception { final String[] tables = getTableNames(aTables); String[] sortedTables = executeInTransaction(new JdbcUnitOfWork() { public String[] execute(Connection aConnection) throws Exception { IDatabaseConnection connection = new DatabaseConnection( aConnection); ITableFilter filter = new DatabaseSequenceFilter(connection, tables); IDataSet dataset = new FilteredDataSet(filter, connection .createDataSet(tables)); return dataset.getTableNames(); } }); for (int i = sortedTables.length-1; i >= 0; i--) { dropTable(sortedTables[i]); } } /** * @return * @throws SQLException */ public void dropTable(final String aTable) throws Exception { executeInTransaction(new JdbcUnitOfWork() { public Void execute(Connection aConnection) throws Exception { executeUpdate(aConnection, "drop table " + aTable); return null; } }); } /** * Executes an SQL statement within a transaction. * * @param aSql * SQL statement. * @return Return code of the corresponding JDBC call. */ public int executeSql(final String aSql) throws Exception { return executeSql(aSql, new Object[0]); } /** * Executes an SQL statement within a transaction. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * SQL statement. * @param aArg * Argument of the sql statement. * @return Return code of the corresponding JDBC call. */ public int executeSql(final String aSql, final Object aArg) throws Exception { return executeSql(aSql, new Object[] { aArg }); } /** * Executes an sql statement. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * SQL query to execute. * @param aArgs * Arguments. * @return Number of rows updated. */ public int executeSql(final String aSql, final Object[] aArgs) throws Exception { return executeInTransaction(new JdbcUnitOfWork() { public Integer execute(Connection aConnection) throws Exception { PreparedStatement stmt = aConnection.prepareStatement(aSql); setPreparedParams(aArgs, stmt); return stmt.executeUpdate(); } }); } /** * Executes an SQL query. * * @param aSql * Query to execute. * @return Result set. */ public ResultSet executeQuery(Connection aConnection, String aSql) { return executeQuery(aConnection, aSql, new Object[0]); } /** * Executes a query with a single argument. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * Query. * @param aArg * Argument. * @return Result set. */ public ResultSet executeQuery(Connection aConnection, String aSql, Object aArg) { return executeQuery(aConnection, aSql, new Object[] { aArg }); } /** * Executes a query. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aSql * Sql query. * @param aArgs * Arguments to the query. * @return Result set. */ public ResultSet executeQuery(Connection aConnection, final String aSql, final Object[] aArgs) { try { PreparedStatement statement = aConnection.prepareStatement(aSql); setPreparedParams(aArgs, statement); return statement.executeQuery(); } catch (SQLException e) { throw new RuntimeException(e); } } public int executeUpdate(Connection aConnection, final String aSql, final Object... aArgs) { try { PreparedStatement statement = aConnection.prepareStatement(aSql); setPreparedParams(aArgs, statement); return statement.executeUpdate(); } catch (SQLException e) { throw new RuntimeException(e); } } /** * Sets the values of a prepared statement. See * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on * supported argument types. * * @param aArgs * Arguments to the prepared statement. * @param aStatement * Prepared statement * @throws SQLException */ private void setPreparedParams(final Object[] aArgs, PreparedStatement aStatement) throws SQLException { for (int i = 1; i <= aArgs.length; i++) { setPreparedParam(i, aStatement, aArgs[i - 1]); } } /** * Sets a prepared statement parameter. * * @param aIndex * Index of the parameter. * @param aStatement * Prepared statement. * @param aObject * Value Must be of type Integer, Long, or String. * @throws SQLException */ private void setPreparedParam(int aIndex, PreparedStatement aStatement, Object aObject) throws SQLException { if (aObject instanceof Integer) { aStatement.setInt(aIndex, ((Integer) aObject).intValue()); } else if (aObject instanceof Long) { aStatement.setLong(aIndex, ((Long) aObject).longValue()); } else if (aObject instanceof String) { aStatement.setString(aIndex, (String) aObject); } else { TestCase.fail("Unsupported object type for prepared statement: " + aObject.getClass() + " value: " + aObject + " statement: " + aStatement); } } /** * @return * @throws SQLException */ public int getTableSize(final String aTable) throws Exception { return executeInTransaction(new JdbcUnitOfWork() { public Integer execute(Connection aConnection) throws Exception { ResultSet resultSet = executeQuery(aConnection, "select count(*) from " + aTable); resultSet.next(); return resultSet.getInt(1); } }); } public int countResultSet(ResultSet aResultSet) throws SQLException { int count = 0; while (aResultSet.next()) { count++; } return count; } }