+package org.wamblee.support.persistence;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+import java.util.logging.Logger;
+
+import javax.sql.DataSource;
+
+import junit.framework.TestCase;
+
+import org.dbunit.DataSourceDatabaseTester;
+import org.dbunit.DatabaseTestCase;
+import org.dbunit.DatabaseUnitException;
+import org.dbunit.IDatabaseTester;
+import org.dbunit.database.DatabaseConnection;
+import org.dbunit.database.DatabaseSequenceFilter;
+import org.dbunit.database.IDatabaseConnection;
+import org.dbunit.dataset.FilteredDataSet;
+import org.dbunit.dataset.IDataSet;
+import org.dbunit.dataset.filter.ITableFilter;
+import org.dbunit.dataset.filter.ITableFilterSimple;
+import org.dbunit.operation.DatabaseOperation;
+
+/**
+ * Database utilities is a simple support class for common tasks in working with
+ * databases.
+ */
+public class DatabaseUtils {
+
+ public static interface TableSet {
+ boolean contains(String aTableName);
+ }
+
+ public static interface JdbcUnitOfWork<T> {
+ T execute(Connection aConnection) throws Exception;
+ }
+
+ public static interface TableSetOperation {
+ void execute(String aTable) throws Exception;
+ }
+
+ private static final Logger LOG = Logger.getLogger(DatabaseUtils.class
+ .getName());
+
+ /**
+ * Schema pattern.
+ */
+ private static final String SCHEMA_PATTERN = "%";
+ private DataSource dataSource;
+ private ITableFilterSimple tables;
+
+
+ public DatabaseUtils(DataSource aDataSource, ITableFilterSimple aTables) {
+ dataSource = aDataSource;
+ tables = aTables;
+ }
+
+ public IDatabaseTester createDbTester() throws Exception {
+ return createDbTester(getTableNames(tables));
+ }
+
+ public IDatabaseTester createDbTester(String[] aTables) throws Exception {
+ IDatabaseTester dbtester = new DataSourceDatabaseTester(dataSource);
+ dbtester.setDataSet(dbtester.getConnection().createDataSet(aTables));
+ return dbtester;
+ }
+
+ public void cleanDatabase() throws Exception {
+ cleanDatabase(tables);
+ }
+
+ public void executeOnTables(ITableFilterSimple aTables,
+ final TableSetOperation aOperation) throws Exception {
+ final String[] tables = getTableNames(aTables);
+ executeInTransaction(new JdbcUnitOfWork<Void>() {
+ public Void execute(Connection aConnection) throws Exception {
+ for (int i = tables.length - 1; i >= 0; i--) {
+ aOperation.execute(tables[i]);
+ }
+ return null;
+ }
+ });
+ for (String table : tables) {
+
+ }
+ }
+
+ public void cleanDatabase(ITableFilterSimple aSelection) throws Exception {
+
+ final String[] tables = getTableNames(aSelection);
+ executeInTransaction(new JdbcUnitOfWork<Void>() {
+
+ public Void execute(Connection aConnection) throws Exception {
+ IDatabaseConnection connection = new DatabaseConnection(
+ aConnection);
+ ITableFilter filter = new DatabaseSequenceFilter(connection,
+ tables);
+ IDataSet dataset = new FilteredDataSet(filter, connection
+ .createDataSet(tables));
+ DatabaseOperation.DELETE_ALL.execute(connection, dataset);
+ return null;
+ }
+ });
+
+ }
+
+ public <T> T executeInTransaction(JdbcUnitOfWork<T> aCallback)
+ throws Exception {
+ Connection connection = dataSource.getConnection();
+ try {
+ T value = aCallback.execute(connection);
+ connection.commit();
+ return value;
+ } finally {
+ connection.close();
+ }
+ }
+
+ public String[] getTableNames() throws Exception {
+ return getTableNames(tables);
+ }
+
+ /**
+ * @throws SQLException
+ */
+ public String[] getTableNames(ITableFilterSimple aSelection)
+ throws Exception {
+
+ List<String> result = new ArrayList<String>();
+ LOG.fine("Getting database table names to clean (schema: '"
+ + SCHEMA_PATTERN + "'");
+
+ ResultSet tables = dataSource.getConnection().getMetaData().getTables(
+ null, SCHEMA_PATTERN, "%", new String[] { "TABLE" });
+ while (tables.next()) {
+ String table = tables.getString("TABLE_NAME");
+ if (aSelection.accept(table)) {
+ result.add(table);
+ }
+ }
+ return (String[]) result.toArray(new String[0]);
+ }
+
+ public void emptyTables() throws Exception {
+ executeOnTables(tables, new TableSetOperation() {
+ public void execute(String aTable) throws Exception {
+ emptyTable(aTable);
+ }
+ });
+ }
+
+ /**
+ * @return
+ * @throws SQLException
+ */
+ public void emptyTables(final ITableFilterSimple aSelection)
+ throws Exception {
+ executeOnTables(aSelection, new TableSetOperation() {
+ public void execute(String aTable) throws Exception {
+ emptyTable(aTable);
+ }
+ });
+ }
+
+ /**
+ * @return
+ * @throws SQLException
+ */
+ public void emptyTable(String aTable) throws Exception {
+ executeSql("delete from " + aTable);
+ }
+
+ public void dropTables() throws Exception {
+ executeOnTables(tables, new TableSetOperation() {
+
+ public void execute(String aTable) throws Exception {
+ dropTable(aTable);
+ }
+ });
+ }
+
+
+ public void dropTables(ITableFilterSimple aTables) throws Exception {
+ executeOnTables(aTables, new TableSetOperation() {
+
+ public void execute(String aTable) throws Exception {
+ dropTable(aTable);
+ }
+ });
+ }
+
+ /**
+ * @return
+ * @throws SQLException
+ */
+ public void dropTable(final String aTable) throws Exception {
+ executeInTransaction(new JdbcUnitOfWork<Void>() {
+ public Void execute(Connection aConnection) throws Exception {
+ executeUpdate(aConnection, "drop table " + aTable);
+ return null;
+ }
+ });
+
+ }
+
+ /**
+ * Executes an SQL statement within a transaction.
+ *
+ * @param aSql
+ * SQL statement.
+ * @return Return code of the corresponding JDBC call.
+ */
+ public int executeSql(final String aSql) throws Exception {
+ return executeSql(aSql, new Object[0]);
+ }
+
+ /**
+ * Executes an SQL statement within a transaction. See
+ * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+ * supported argument types.
+ *
+ * @param aSql
+ * SQL statement.
+ * @param aArg
+ * Argument of the sql statement.
+ * @return Return code of the corresponding JDBC call.
+ */
+ public int executeSql(final String aSql, final Object aArg)
+ throws Exception {
+ return executeSql(aSql, new Object[] { aArg });
+ }
+
+ /**
+ * Executes an sql statement. See
+ * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+ * supported argument types.
+ *
+ * @param aSql
+ * SQL query to execute.
+ * @param aArgs
+ * Arguments.
+ * @return Number of rows updated.
+ */
+ public int executeSql(final String aSql, final Object[] aArgs)
+ throws Exception {
+ return executeInTransaction(new JdbcUnitOfWork<Integer>() {
+ public Integer execute(Connection aConnection) throws Exception {
+ PreparedStatement stmt = aConnection.prepareStatement(aSql);
+ setPreparedParams(aArgs, stmt);
+ return stmt.executeUpdate();
+ }
+ });
+ }
+
+ /**
+ * Executes an SQL query.
+ *
+ * @param aSql
+ * Query to execute.
+ * @return Result set.
+ */
+ public ResultSet executeQuery(Connection aConnection, String aSql) {
+ return executeQuery(aConnection, aSql, new Object[0]);
+ }
+
+ /**
+ * Executes a query with a single argument. See
+ * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+ * supported argument types.
+ *
+ * @param aSql
+ * Query.
+ * @param aArg
+ * Argument.
+ * @return Result set.
+ */
+ public ResultSet executeQuery(Connection aConnection, String aSql,
+ Object aArg) {
+ return executeQuery(aConnection, aSql, new Object[] { aArg });
+ }
+
+ /**
+ * Executes a query. See
+ * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+ * supported argument types.
+ *
+ * @param aSql
+ * Sql query.
+ * @param aArgs
+ * Arguments to the query.
+ * @return Result set.
+ */
+ public ResultSet executeQuery(Connection aConnection, final String aSql,
+ final Object[] aArgs) {
+ try {
+ PreparedStatement statement = aConnection.prepareStatement(aSql);
+ setPreparedParams(aArgs, statement);
+
+ return statement.executeQuery();
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public int executeUpdate(Connection aConnection, final String aSql,
+ final Object... aArgs) {
+ try {
+ PreparedStatement statement = aConnection.prepareStatement(aSql);
+ setPreparedParams(aArgs, statement);
+
+ return statement.executeUpdate();
+ } catch (SQLException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ /**
+ * Sets the values of a prepared statement. See
+ * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+ * supported argument types.
+ *
+ * @param aArgs
+ * Arguments to the prepared statement.
+ * @param aStatement
+ * Prepared statement
+ * @throws SQLException
+ */
+ private void setPreparedParams(final Object[] aArgs,
+ PreparedStatement aStatement) throws SQLException {
+ for (int i = 1; i <= aArgs.length; i++) {
+ setPreparedParam(i, aStatement, aArgs[i - 1]);
+ }
+ }
+
+ /**
+ * Sets a prepared statement parameter.
+ *
+ * @param aIndex
+ * Index of the parameter.
+ * @param aStatement
+ * Prepared statement.
+ * @param aObject
+ * Value Must be of type Integer, Long, or String.
+ * @throws SQLException
+ */
+ private void setPreparedParam(int aIndex, PreparedStatement aStatement,
+ Object aObject) throws SQLException {
+ if (aObject instanceof Integer) {
+ aStatement.setInt(aIndex, ((Integer) aObject).intValue());
+ } else if (aObject instanceof Long) {
+ aStatement.setLong(aIndex, ((Integer) aObject).longValue());
+ } else if (aObject instanceof String) {
+ aStatement.setString(aIndex, (String) aObject);
+ } else {
+ TestCase.fail("Unsupported object type for prepared statement: "
+ + aObject.getClass() + " value: " + aObject
+ + " statement: " + aStatement);
+ }
+ }
+
+ /**
+ * @return
+ * @throws SQLException
+ */
+ public int getTableSize(final String aTable) throws Exception {
+ return executeInTransaction(new JdbcUnitOfWork<Integer>() {
+ public Integer execute(Connection aConnection) throws Exception {
+ ResultSet resultSet = executeQuery(aConnection,
+ "select count(*) from " + aTable);
+ resultSet.next();
+ return resultSet.getInt(1);
+ }
+ });
+
+ }
+
+ public int countResultSet(ResultSet aResultSet) throws SQLException {
+ int count = 0;
+
+ while (aResultSet.next()) {
+ count++;
+ }
+
+ return count;
+ }
+
+}