311d2da3189afcfcf09141b3b52996cc1895140d
[utils] / test / enterprise / src / main / java / org / wamblee / support / persistence / DatabaseUtils.java
1 /*
2  * Copyright 2005-2010 the original author or authors.
3  * 
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  * 
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  * 
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package org.wamblee.support.persistence;
17
18 import java.sql.Connection;
19 import java.sql.PreparedStatement;
20 import java.sql.ResultSet;
21 import java.sql.SQLException;
22 import java.util.ArrayList;
23 import java.util.List;
24 import java.util.logging.Logger;
25
26 import javax.sql.DataSource;
27
28 import junit.framework.TestCase;
29
30 import org.dbunit.DataSourceDatabaseTester;
31 import org.dbunit.IDatabaseTester;
32 import org.dbunit.database.DatabaseConnection;
33 import org.dbunit.database.DatabaseSequenceFilter;
34 import org.dbunit.database.IDatabaseConnection;
35 import org.dbunit.dataset.FilteredDataSet;
36 import org.dbunit.dataset.IDataSet;
37 import org.dbunit.dataset.filter.ITableFilter;
38 import org.dbunit.dataset.filter.ITableFilterSimple;
39 import org.dbunit.operation.DatabaseOperation;
40
41 /**
42  * Database utilities is a simple support class for common tasks in working with
43  * databases.
44  */
45 public class DatabaseUtils {
46
47     public static interface TableSet {
48         boolean contains(String aTableName);
49     }
50
51     public static interface JdbcUnitOfWork<T> {
52         T execute(Connection aConnection) throws Exception;
53     }
54
55     public static interface TableSetOperation {
56         void execute(String aTable) throws Exception;
57     }
58
59     private static final Logger LOG = Logger.getLogger(DatabaseUtils.class
60         .getName());
61
62     /**
63      * Schema pattern.
64      */
65     private static final String SCHEMA_PATTERN = "%";
66     private DataSource dataSource;
67     private ITableFilterSimple tables;
68
69     public DatabaseUtils(DataSource aDataSource, ITableFilterSimple aTables) {
70         dataSource = aDataSource;
71         tables = aTables;
72     }
73
74     public IDatabaseTester createDbTester() throws Exception {
75         return createDbTester(getTableNames(tables));
76     }
77
78     public IDatabaseTester createDbTester(String[] aTables) throws Exception {
79         IDatabaseTester dbtester = new DataSourceDatabaseTester(dataSource);
80         dbtester.setDataSet(dbtester.getConnection().createDataSet(aTables));
81         return dbtester;
82     }
83
84     public void cleanDatabase() throws Exception {
85         cleanDatabase(tables);
86     }
87
88     public void executeOnTables(ITableFilterSimple aTables,
89         final TableSetOperation aOperation) throws Exception {
90         final String[] tables = getTableNames(aTables);
91         executeInTransaction(new JdbcUnitOfWork<Void>() {
92             public Void execute(Connection aConnection) throws Exception {
93                 for (int i = tables.length-1; i >= 0; i--) {
94                     aOperation.execute(tables[i]);
95                 }
96                 return null;
97             }
98         });
99     }
100
101     public void cleanDatabase(ITableFilterSimple aSelection) throws Exception {
102
103         final String[] tables = getTableNames(aSelection);
104         executeInTransaction(new JdbcUnitOfWork<Void>() {
105
106             public Void execute(Connection aConnection) throws Exception {
107                 IDatabaseConnection connection = new DatabaseConnection(
108                     aConnection);
109                 ITableFilter filter = new DatabaseSequenceFilter(connection,
110                     tables);
111                 IDataSet dataset = new FilteredDataSet(filter, connection
112                     .createDataSet(tables));
113                 DatabaseOperation.DELETE_ALL.execute(connection, dataset);
114                 return null;
115             }
116         });
117
118     }
119
120     public <T> T executeInTransaction(JdbcUnitOfWork<T> aCallback)
121         throws Exception {
122         Connection connection = dataSource.getConnection();
123         try {
124             T value = aCallback.execute(connection);
125             connection.commit();
126             return value;
127         } finally {
128             connection.close();
129         }
130     }
131
132     public String[] getTableNames() throws Exception {
133         return getTableNames(tables);
134     }
135
136     /**
137      * @throws SQLException
138      */
139     public String[] getTableNames(ITableFilterSimple aSelection)
140         throws Exception {
141
142         List<String> result = new ArrayList<String>();
143         LOG.fine("Getting database table names to clean (schema: '" +
144             SCHEMA_PATTERN + "'");
145
146         Connection connection = dataSource.getConnection();
147         try {
148             ResultSet tables = connection.getMetaData().getTables(null,
149                 SCHEMA_PATTERN, "%", new String[] { "TABLE" });
150             while (tables.next()) {
151                 String table = tables.getString("TABLE_NAME");
152                 if (aSelection.accept(table)) {
153                     result.add(table);
154                 }
155             }
156             return (String[]) result.toArray(new String[0]);
157         } finally {
158             connection.close();
159         }
160     }
161
162     public void emptyTables() throws Exception {
163         executeOnTables(tables, new TableSetOperation() {
164             public void execute(String aTable) throws Exception {
165                 emptyTable(aTable);
166             }
167         });
168     }
169
170     /**
171      * @return
172      * @throws SQLException
173      */
174     public void emptyTables(final ITableFilterSimple aSelection)
175         throws Exception {
176         executeOnTables(aSelection, new TableSetOperation() {
177             public void execute(String aTable) throws Exception {
178                 emptyTable(aTable);
179             }
180         });
181     }
182
183     /**
184      * @return
185      * @throws SQLException
186      */
187     public void emptyTable(String aTable) throws Exception {
188         executeSql("delete from " + aTable);
189     }
190     
191     public void dropTables() throws Exception { 
192         dropTables(tables);
193     }
194
195     public void dropTables(ITableFilterSimple aTables) throws Exception {
196         final String[] tables = getTableNames(aTables);
197         String[] sortedTables = executeInTransaction(new JdbcUnitOfWork<String[]>() {
198
199             public String[] execute(Connection aConnection) throws Exception {
200                 IDatabaseConnection connection = new DatabaseConnection(
201                     aConnection);
202                 ITableFilter filter = new DatabaseSequenceFilter(connection,
203                     tables);
204                 IDataSet dataset = new FilteredDataSet(filter, connection
205                     .createDataSet(tables));
206                 return dataset.getTableNames();
207             }
208         });
209         for (int i = sortedTables.length-1; i >= 0; i--) { 
210             dropTable(sortedTables[i]);
211         }
212     }
213     
214     /**
215      * @return
216      * @throws SQLException
217      */
218     public void dropTable(final String aTable) throws Exception {
219         executeInTransaction(new JdbcUnitOfWork<Void>() {
220             public Void execute(Connection aConnection) throws Exception {
221                 executeUpdate(aConnection, "drop table " + aTable);
222                 return null;
223             }
224         });
225
226     }
227
228     /**
229      * Executes an SQL statement within a transaction.
230      * 
231      * @param aSql
232      *            SQL statement.
233      * @return Return code of the corresponding JDBC call.
234      */
235     public int executeSql(final String aSql) throws Exception {
236         return executeSql(aSql, new Object[0]);
237     }
238
239     /**
240      * Executes an SQL statement within a transaction. See
241      * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
242      * supported argument types.
243      * 
244      * @param aSql
245      *            SQL statement.
246      * @param aArg
247      *            Argument of the sql statement.
248      * @return Return code of the corresponding JDBC call.
249      */
250     public int executeSql(final String aSql, final Object aArg)
251         throws Exception {
252         return executeSql(aSql, new Object[] { aArg });
253     }
254
255     /**
256      * Executes an sql statement. See
257      * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
258      * supported argument types.
259      * 
260      * @param aSql
261      *            SQL query to execute.
262      * @param aArgs
263      *            Arguments.
264      * @return Number of rows updated.
265      */
266     public int executeSql(final String aSql, final Object[] aArgs)
267         throws Exception {
268         return executeInTransaction(new JdbcUnitOfWork<Integer>() {
269             public Integer execute(Connection aConnection) throws Exception {
270                 PreparedStatement stmt = aConnection.prepareStatement(aSql);
271                 setPreparedParams(aArgs, stmt);
272                 return stmt.executeUpdate();
273             }
274         });
275     }
276
277     /**
278      * Executes an SQL query.
279      * 
280      * @param aSql
281      *            Query to execute.
282      * @return Result set.
283      */
284     public ResultSet executeQuery(Connection aConnection, String aSql) {
285         return executeQuery(aConnection, aSql, new Object[0]);
286     }
287
288     /**
289      * Executes a query with a single argument. See
290      * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
291      * supported argument types.
292      * 
293      * @param aSql
294      *            Query.
295      * @param aArg
296      *            Argument.
297      * @return Result set.
298      */
299     public ResultSet executeQuery(Connection aConnection, String aSql,
300         Object aArg) {
301         return executeQuery(aConnection, aSql, new Object[] { aArg });
302     }
303
304     /**
305      * Executes a query. See
306      * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
307      * supported argument types.
308      * 
309      * @param aSql
310      *            Sql query.
311      * @param aArgs
312      *            Arguments to the query.
313      * @return Result set.
314      */
315     public ResultSet executeQuery(Connection aConnection, final String aSql,
316         final Object[] aArgs) {
317         try {
318             PreparedStatement statement = aConnection.prepareStatement(aSql);
319             setPreparedParams(aArgs, statement);
320
321             return statement.executeQuery();
322         } catch (SQLException e) {
323             throw new RuntimeException(e);
324         }
325     }
326
327     public int executeUpdate(Connection aConnection, final String aSql,
328         final Object... aArgs) {
329         try {
330             PreparedStatement statement = aConnection.prepareStatement(aSql);
331             setPreparedParams(aArgs, statement);
332
333             return statement.executeUpdate();
334         } catch (SQLException e) {
335             throw new RuntimeException(e);
336         }
337     }
338
339     /**
340      * Sets the values of a prepared statement. See
341      * {@link #setPreparedParam(int, PreparedStatement, Object)}for details on
342      * supported argument types.
343      * 
344      * @param aArgs
345      *            Arguments to the prepared statement.
346      * @param aStatement
347      *            Prepared statement
348      * @throws SQLException
349      */
350     private void setPreparedParams(final Object[] aArgs,
351         PreparedStatement aStatement) throws SQLException {
352         for (int i = 1; i <= aArgs.length; i++) {
353             setPreparedParam(i, aStatement, aArgs[i - 1]);
354         }
355     }
356
357     /**
358      * Sets a prepared statement parameter.
359      * 
360      * @param aIndex
361      *            Index of the parameter.
362      * @param aStatement
363      *            Prepared statement.
364      * @param aObject
365      *            Value Must be of type Integer, Long, or String.
366      * @throws SQLException
367      */
368     private void setPreparedParam(int aIndex, PreparedStatement aStatement,
369         Object aObject) throws SQLException {
370         if (aObject instanceof Integer) {
371             aStatement.setInt(aIndex, ((Integer) aObject).intValue());
372         } else if (aObject instanceof Long) {
373             aStatement.setLong(aIndex, ((Long) aObject).longValue());
374         } else if (aObject instanceof String) {
375             aStatement.setString(aIndex, (String) aObject);
376         } else {
377             TestCase.fail("Unsupported object type for prepared statement: " +
378                 aObject.getClass() + " value: " + aObject + " statement: " +
379                 aStatement);
380         }
381     }
382
383     /**
384      * @return
385      * @throws SQLException
386      */
387     public int getTableSize(final String aTable) throws Exception {
388         return executeInTransaction(new JdbcUnitOfWork<Integer>() {
389             public Integer execute(Connection aConnection) throws Exception {
390                 ResultSet resultSet = executeQuery(aConnection,
391                     "select count(*) from " + aTable);
392                 resultSet.next();
393                 return resultSet.getInt(1);
394             }
395         });
396
397     }
398
399     public int countResultSet(ResultSet aResultSet) throws SQLException {
400         int count = 0;
401
402         while (aResultSet.next()) {
403             count++;
404         }
405
406         return count;
407     }
408
409 }