需求
在使用Mybatis时,通过一个SQL注解就可以实现DAO接口。这个功能的基本实现很简单。主要是使用了动态代理来实现接口并使用Java反射来对查询结果和Java Bean进行映射。
项目地址
本文贴的代码文件DButils.java
对反射还不了解的可移步Java反射常用API解读
实现思路
- 封装JDBC操作,将SQL区分为查询和其他操作。因为非查询操作返回的是int。
- 定义注解增删改查 @Insert, @Delete, @Update, @Select
- 使用动态代理实现接口,并通过注解类型和返回结果调用步骤1封装的JDBC方法
依赖
这里以MySQL为例,使用HikariCP连接池,添加Lombok依赖方便写Java bean
<dependencies>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.16</version>
</dependency>
<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
<version>3.4.5</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>8.0.21</version>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter-api</artifactId>
<version>5.7.0</version>
<scope>test</scope>
</dependency>
</dependencies>
封装JDBC
查询单条记录SelectOne
// DBUtils.java
/**
* 查询返回单条记录
* @param <T> 返回结果泛型
* @param connection 数据库连接
* @param clazz 返回结果的实际类型
* @param sql 需要执行的sql
* @param args sql的参数
* @return
*/
public <T> T selectOne(Connection connection, Class<T> clazz, String sql, Object... args) {
PreparedStatement statement = null;
ResultSet resultSet = null;
try {
statement = connection.prepareStatement(sql);
// 初始化prepareStatement, 对占位符?配参
initStatementArgs(statement, args);
resultSet = statement.executeQuery();
// 取出所有列名
List<String> columns = getColumns(resultSet.getMetaData());
if (resultSet.next()) {
// 只取第一条记录(如果有)
// Javabean 映射
return mapTo(clazz, resultSet, columns);
}
// 结果集为空,返回null
return null;
} catch (SQLException e) {
throw new RuntimeException("SQLException: " + e.getMessage());
} finally {
release(statement, resultSet);
}
}
查询结果和Java bean的映射mapTo
供selectOne 和 selectList调用。实现思路
- 通过ResultSetMateData获取所有columns
- 遍历所有的columns
- 根据column获取其在java bean的字段(Field)和set方法(Method)。这里约定了数据库字段名使用下划线,对应类字段名使用驼峰命名
- 根据Field的类型(fieldType),调用resultSet.getObject(column, fieldType)获取数据库对应列的值并转换为Java bean对应字段的类型
- 调用setter的method.invoke()方法将数据库的值注入到java bean实例中
/**
* 单条查询结果和Java bean的映射
* @param <T> 结果泛型定义
* @param clazz 返回实际类型
* @param row 单条ResultSet结果集
* @param columns 所有列名
* @return
* @throws SQLException
*/
private <T> T mapTo(Class<T> clazz, ResultSet row, List<String> columns) throws SQLException {
Object obj = null;
try {
// 使用默认构造器实例化一个对象
obj = clazz.getConstructor().newInstance();
// 根据返回结果的列名和Java bean的字段名实现映射,这里使用下划线转驼峰命名
for (String column: columns) {
String fieldName = toCamelCase(column);
Field field = clazz.getDeclaredField(fieldName);
Class<?> fieldType = field.getType();
Method setter = clazz.getMethod(getSetter(column), field.getType());
Object value = row.getObject(column, fieldType);
setter.invoke(obj, value);
}
} catch (NoSuchMethodException e){
throw new RuntimeException("NoSuchMethodException: " + e.getMessage());
} catch (InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException
| SecurityException e) {
throw new RuntimeException(e.getMessage());
} catch (NoSuchFieldException e) {
throw new RuntimeException("NoSuchFieldException: " + e.getMessage());
}
return clazz.cast(obj);
}
查询List selectList
/**
* 查询返回List集合
* @param <T> List泛型定义
* @param connection 数据库连接
* @param clazz 泛型实际类型
* @param sql 需要执行的sql
* @param args sql参数
* @return
*/
public <T> List<T> selectList(Connection connection, Class<T> clazz, String sql, Object...args) {
PreparedStatement statement = null;
ResultSet resultSet = null;
List<T> list = new LinkedList<>(); // 使用LinkedList免去扩容消耗
try {
statement = connection.prepareStatement(sql);
initStatementArgs(statement, args);
resultSet = statement.executeQuery();
List<String> columns = getColumns(resultSet.getMetaData());
if (ReflectUtil.isPrimaryType(clazz)) {
// List的泛型类型为基本类型,直接取第一列
while (resultSet.next()) {
list.add(resultSet.getObject(1, clazz));
}
} else {
// List的泛型类型为自定义Javabean, 使用mapTo方法完成映射
while (resultSet.next()) {
T t = mapTo(clazz, resultSet, columns);
list.add(t);
}
}
} catch (SQLException e) {
throw new RuntimeException("SQLException: " + e.getMessage());
} finally {
release(statement, resultSet);
}
return list;
}
单列查询 selectValue
查询单列时使用这个方法,传入的class是基本类型,由于是查询单列直接取第一列的值。
/**
* 查询单列一条记录
* @param <T> 泛型定义
* @param connection 数据库连接
* @param clazz 返回结果的实际类型,需传入基本类型
* @param sql 需要执行的sql
* @param args sql参数
* @return
*/
public <T> T selectValue(Connection connection, Class<T> clazz, String sql, Object...args) {
PreparedStatement statement = null;
ResultSet resultSet = null;
try {
statement = connection.prepareStatement(sql);
initStatementArgs(statement, args);
resultSet = statement.executeQuery();
if (resultSet.next()) {
// 因为是查询单列操作,取第一列
return resultSet.getObject(1, clazz);
} else {
return null;
}
} catch (SQLException e) {
throw new RuntimeException("SQLException: " + e.getMessage());
} finally {
release(statement, resultSet);
}
}
非查询操作 execute
/**
* 非查询操作
* @param connection 数据库连接
* @param sql 需要执行的sql
* @param args sql参数
* @return
*/
public int execute(Connection connection, String sql, Object...args) {
PreparedStatement statement = null;
ResultSet resultSet = null;
try {
statement = connection.prepareStatement(sql);
initStatementArgs(statement, args);
return statement.executeUpdate(); // 执行sql
} catch (SQLException e) {
throw new RuntimeException("SQLException: " + e.getMessage());
} finally {
release(statement, resultSet);
}
}
测试接口
public class DBUtilTest {
private static final DBUtils dbUtils = DBUtils.getInstance();
private static Connection connection;
@BeforeEach
public void getConnection() {
connection = dbUtils.getConnection();
}
@AfterEach
public void releaseConnection() {
dbUtils.release(connection);
}
@Test
public void selectOne() {
String sql = "select * from t_student where id = ?";
Student student = dbUtils.selectOne(connection, Student.class, sql, 2);
Assertions.assertNotNull(student);
}
@Test
public void selectList() {
String sql = "select * from t_student";
List<Student> list = dbUtils.selectList(connection, Student.class, sql);
Assertions.assertTrue(list.size() > 0);
}
@Test
public void selectValue() {
String sql = "select birthday from t_student where id = ?";
LocalDate birthday = dbUtils.selectValue(connection, LocalDate.class, sql, 2);
Assertions.assertNotNull(birthday);
}
@Test
public void execute() {
String sql = "update t_student set name = ? where id = ?";
Integer res = dbUtils.execute(connection, sql, "李莫愁", 2);
Assertions.assertEquals(1, res);
}