(no commit message)
[utils] / test / enterprise / src / main / java / org / wamblee / support / persistence / DatabaseUtils.java
diff --git a/test/enterprise/src/main/java/org/wamblee/support/persistence/DatabaseUtils.java b/test/enterprise/src/main/java/org/wamblee/support/persistence/DatabaseUtils.java
new file mode 100644 (file)
index 0000000..312e38a
--- /dev/null
@@ -0,0 +1,389 @@
+package org.wamblee.support.persistence;
+
+import java.sql.Connection;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.logging.Logger;
+
+import javax.sql.DataSource;
+
+import junit.framework.TestCase;
+
+import org.dbunit.DataSourceDatabaseTester;
+import org.dbunit.IDatabaseTester;
+import org.dbunit.database.DatabaseConnection;
+import org.dbunit.database.DatabaseSequenceFilter;
+import org.dbunit.database.IDatabaseConnection;
+import org.dbunit.dataset.FilteredDataSet;
+import org.dbunit.dataset.IDataSet;
+import org.dbunit.dataset.filter.ITableFilter;
+import org.dbunit.dataset.filter.ITableFilterSimple;
+import org.dbunit.operation.DatabaseOperation;
+
+/**
+ * Database utilities is a simple support class for common tasks in working with
+ * databases.
+ */
+public class DatabaseUtils {
+
+       public static interface TableSet {
+               boolean contains(String aTableName);
+       }
+
+       public static interface JdbcUnitOfWork<T> {
+               T execute(Connection aConnection) throws Exception;
+       }
+
+       public static interface TableSetOperation {
+               void execute(String aTable) throws Exception;
+       }
+
+       private static final Logger LOG = Logger.getLogger(DatabaseUtils.class
+                       .getName());
+
+       /**
+        * Schema pattern.
+        */
+       private static final String SCHEMA_PATTERN = "%";
+       private DataSource dataSource;
+       private ITableFilterSimple tables;
+       
+
+       public DatabaseUtils(DataSource aDataSource, ITableFilterSimple aTables) {
+               dataSource = aDataSource;
+               tables = aTables;
+       }
+
+       public IDatabaseTester createDbTester() throws Exception {
+               return createDbTester(getTableNames(tables));
+       }
+
+       public IDatabaseTester createDbTester(String[] aTables) throws Exception {
+               IDatabaseTester dbtester = new DataSourceDatabaseTester(dataSource);
+               dbtester.setDataSet(dbtester.getConnection().createDataSet(aTables));
+               return dbtester;
+       }
+
+       public void cleanDatabase() throws Exception {
+               cleanDatabase(tables);
+       }
+
+       public void executeOnTables(ITableFilterSimple aTables,
+                       final TableSetOperation aOperation) throws Exception {
+               final String[] tables = getTableNames(aTables);
+               executeInTransaction(new JdbcUnitOfWork<Void>() {
+                       public Void execute(Connection aConnection) throws Exception {
+                               for (int i = tables.length - 1; i >= 0; i--) {
+                                       aOperation.execute(tables[i]);
+                               }
+                               return null;
+                       }
+               });
+               for (String table : tables) {
+
+               }
+       }
+
+       public void cleanDatabase(ITableFilterSimple aSelection) throws Exception {
+
+               final String[] tables = getTableNames(aSelection);
+               executeInTransaction(new JdbcUnitOfWork<Void>() {
+
+                       public Void execute(Connection aConnection) throws Exception {
+                               IDatabaseConnection connection = new DatabaseConnection(
+                                               aConnection);
+                               ITableFilter filter = new DatabaseSequenceFilter(connection,
+                                               tables);
+                               IDataSet dataset = new FilteredDataSet(filter, connection
+                                               .createDataSet(tables));
+                               DatabaseOperation.DELETE_ALL.execute(connection, dataset);
+                               return null;
+                       }
+               });
+
+       }
+
+       public <T> T executeInTransaction(JdbcUnitOfWork<T> aCallback)
+                       throws Exception {
+               Connection connection = dataSource.getConnection();
+               try {
+                       T value = aCallback.execute(connection);
+                       connection.commit();
+                       return value;
+               } finally {
+                       connection.close();
+               }
+       }
+
+       public String[] getTableNames() throws Exception {
+               return getTableNames(tables);
+       }
+
+       /**
+        * @throws SQLException
+        */
+       public String[] getTableNames(ITableFilterSimple aSelection)
+                       throws Exception {
+
+               List<String> result = new ArrayList<String>();
+               LOG.fine("Getting database table names to clean (schema: '"
+                               + SCHEMA_PATTERN + "'");
+
+               ResultSet tables = dataSource.getConnection().getMetaData().getTables(
+                               null, SCHEMA_PATTERN, "%", new String[] { "TABLE" });
+               while (tables.next()) {
+                       String table = tables.getString("TABLE_NAME");
+                       if (aSelection.accept(table)) {
+                               result.add(table);
+                       }
+               }
+               return (String[]) result.toArray(new String[0]);
+       }
+
+       public void emptyTables() throws Exception {
+               executeOnTables(tables, new TableSetOperation() {
+                       public void execute(String aTable) throws Exception {
+                               emptyTable(aTable);
+                       }
+               });
+       }
+
+       /**
+        * @return
+        * @throws SQLException
+        */
+       public void emptyTables(final ITableFilterSimple aSelection)
+                       throws Exception {
+               executeOnTables(aSelection, new TableSetOperation() {
+                       public void execute(String aTable) throws Exception {
+                               emptyTable(aTable);
+                       }
+               });
+       }
+
+       /**
+        * @return
+        * @throws SQLException
+        */
+       public void emptyTable(String aTable) throws Exception {
+               executeSql("delete from " + aTable);
+       }
+       
+       public void dropTables() throws Exception {
+               executeOnTables(tables, new TableSetOperation() {
+                       
+                       public void execute(String aTable) throws Exception {
+                               dropTable(aTable);      
+                       }
+               });
+       }
+
+       
+       public void dropTables(ITableFilterSimple aTables) throws Exception {
+               executeOnTables(aTables, new TableSetOperation() {
+                       
+                       public void execute(String aTable) throws Exception {
+                               dropTable(aTable);      
+                       }
+               });
+       }
+
+       /**
+        * @return
+        * @throws SQLException
+        */
+       public void dropTable(final String aTable) throws Exception {
+               executeInTransaction(new JdbcUnitOfWork<Void>() {
+                       public Void execute(Connection aConnection) throws Exception {
+                               executeUpdate(aConnection, "drop table " + aTable);
+                               return null;
+                       }
+               });
+
+       }
+
+       /**
+        * Executes an SQL statement within a transaction.
+        * 
+        * @param aSql
+        *            SQL statement.
+        * @return Return code of the corresponding JDBC call.
+        */
+       public int executeSql(final String aSql) throws Exception {
+               return executeSql(aSql, new Object[0]);
+       }
+
+       /**
+        * Executes an SQL statement within a transaction. See
+        * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+        * supported argument types.
+        * 
+        * @param aSql
+        *            SQL statement.
+        * @param aArg
+        *            Argument of the sql statement.
+        * @return Return code of the corresponding JDBC call.
+        */
+       public int executeSql(final String aSql, final Object aArg)
+                       throws Exception {
+               return executeSql(aSql, new Object[] { aArg });
+       }
+
+       /**
+        * Executes an sql statement. See
+        * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+        * supported argument types.
+        * 
+        * @param aSql
+        *            SQL query to execute.
+        * @param aArgs
+        *            Arguments.
+        * @return Number of rows updated.
+        */
+       public int executeSql(final String aSql, final Object[] aArgs)
+                       throws Exception {
+               return executeInTransaction(new JdbcUnitOfWork<Integer>() {
+                       public Integer execute(Connection aConnection) throws Exception {
+                               PreparedStatement stmt = aConnection.prepareStatement(aSql);
+                               setPreparedParams(aArgs, stmt);
+                               return stmt.executeUpdate();
+                       }
+               });
+       }
+
+       /**
+        * Executes an SQL query.
+        * 
+        * @param aSql
+        *            Query to execute.
+        * @return Result set.
+        */
+       public ResultSet executeQuery(Connection aConnection, String aSql) {
+               return executeQuery(aConnection, aSql, new Object[0]);
+       }
+
+       /**
+        * Executes a query with a single argument. See
+        * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+        * supported argument types.
+        * 
+        * @param aSql
+        *            Query.
+        * @param aArg
+        *            Argument.
+        * @return Result set.
+        */
+       public ResultSet executeQuery(Connection aConnection, String aSql,
+                       Object aArg) {
+               return executeQuery(aConnection, aSql, new Object[] { aArg });
+       }
+
+       /**
+        * Executes a query. See
+        * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+        * supported argument types.
+        * 
+        * @param aSql
+        *            Sql query.
+        * @param aArgs
+        *            Arguments to the query.
+        * @return Result set.
+        */
+       public ResultSet executeQuery(Connection aConnection, final String aSql,
+                       final Object[] aArgs) {
+               try {
+                       PreparedStatement statement = aConnection.prepareStatement(aSql);
+                       setPreparedParams(aArgs, statement);
+
+                       return statement.executeQuery();
+               } catch (SQLException e) {
+                       throw new RuntimeException(e);
+               }
+       }
+       
+       public int executeUpdate(Connection aConnection, final String aSql,
+                       final Object... aArgs) {
+               try {
+                       PreparedStatement statement = aConnection.prepareStatement(aSql);
+                       setPreparedParams(aArgs, statement);
+
+                       return statement.executeUpdate();
+               } catch (SQLException e) {
+                       throw new RuntimeException(e);
+               }
+       }
+
+       /**
+        * Sets the values of a prepared statement. See
+        * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
+        * supported argument types.
+        * 
+        * @param aArgs
+        *            Arguments to the prepared statement.
+        * @param aStatement
+        *            Prepared statement
+        * @throws SQLException
+        */
+       private void setPreparedParams(final Object[] aArgs,
+                       PreparedStatement aStatement) throws SQLException {
+               for (int i = 1; i <= aArgs.length; i++) {
+                       setPreparedParam(i, aStatement, aArgs[i - 1]);
+               }
+       }
+
+       /**
+        * Sets a prepared statement parameter.
+        * 
+        * @param aIndex
+        *            Index of the parameter.
+        * @param aStatement
+        *            Prepared statement.
+        * @param aObject
+        *            Value Must be of type Integer, Long, or String.
+        * @throws SQLException
+        */
+       private void setPreparedParam(int aIndex, PreparedStatement aStatement,
+                       Object aObject) throws SQLException {
+               if (aObject instanceof Integer) {
+                       aStatement.setInt(aIndex, ((Integer) aObject).intValue());
+               } else if (aObject instanceof Long) {
+                       aStatement.setLong(aIndex, ((Integer) aObject).longValue());
+               } else if (aObject instanceof String) {
+                       aStatement.setString(aIndex, (String) aObject);
+               } else {
+                       TestCase.fail("Unsupported object type for prepared statement: "
+                                       + aObject.getClass() + " value: " + aObject
+                                       + " statement: " + aStatement);
+               }
+       }
+
+       /**
+        * @return
+        * @throws SQLException
+        */
+       public int getTableSize(final String aTable) throws Exception {
+               return executeInTransaction(new JdbcUnitOfWork<Integer>() {
+                       public Integer execute(Connection aConnection) throws Exception {
+                               ResultSet resultSet = executeQuery(aConnection,
+                                               "select count(*) from " + aTable);
+                               resultSet.next();
+                               return resultSet.getInt(1);
+                       }
+               });
+
+       }
+
+       public int countResultSet(ResultSet aResultSet) throws SQLException {
+               int count = 0;
+
+               while (aResultSet.next()) {
+                       count++;
+               }
+
+               return count;
+       }
+
+}