利用Spring 自带类写一个自动初始化数据库的工具组件
写了多个组件,如日志组件、权限组件等等,每个组件都各自需要几个表,这个时候往往一个新项目建立后,都需要引用几个组件,每次都要翻来翻去找sql文件建表,为了简化这个过程,决定在通用组件内添加一个统一的自动建表功能
参考spring.factories
的思路,我们将表名和建表sql的key-value键值对放在约定好的sql-init.properties
文件中,文件格式如下:
# suppress inspection "UnusedProperty" for whole file
login_log=classpath:login_log.sql
operation_log=classpath:operation_log.sql
sql-init.properties
文件位于src/main/resources/sql-init.properties
login_log.sql
文件位于 src/main/sql/login_log.sql
pom文件需要确保sql文件和properties文件被打包入jar包中
<build>
<resources>
<resource>
<directory>src/main/resources</directory>
</resource>
<resource>
<directory>src/main/sql</directory>
<includes>
<include>**/*.sql</include>
</includes>
</resource>
</resources>
</build>
1.读取sql-init.properties
文件
我们使用PathMatchingResourcePatternResolver
来查找不同jar包中所有的sql-init.properties
文件,classpath*:
开头的路径代表查找所有jar包:
Resource[] resources = new PathMatchingResourcePatternResolver().getResources("classpath*:sql-init.properties");
2.解析sql-init.properties
文件
文件被读取之后,是Spring 的Resouce 类,想要读取其中key-value对,我们借用Spring 的PropertiesPropertySourceLoader
来读取文件,存储时使用Map<String, List<String>>
,因为可能有不同包使用了同样的表名,某些sql文件是用来初始化表的,某些sql文件是用来添加一些初始数据的
Map<String, List<String>> schemaSqlFiles = new HashMap<>();
PropertiesPropertySourceLoader loader = new PropertiesPropertySourceLoader();
for (Resource resource : new PathMatchingResourcePatternResolver().getResources("classpath*:sql-init.properties")) {
for (PropertySource<?> propertySource : loader.load("sql-init", resource)) {
for (Map.Entry<String, Object> entry : ((OriginTrackedMapPropertySource) propertySource).getSource().entrySet()) {
String location = (String) ((OriginTrackedValue) entry.getValue()).getValue();
if (StringUtils.hasText(entry.getKey()) && StringUtils.hasText(location)) {
schemaSqlFiles.computeIfAbsent(entry.getKey(), t -> new ArrayList<>()).add(location);
}
}
}
}
如果遇到了同key下同value相同的文件名时,因为可能都要执行,所以需要将文件路径从classpath:
改为classpath*:
因为前者找到一个文件就停止了,不会找到全部的sql文件
如果存在不同key内,value sql文件路径相同的情况,这种情况下也会出现多个文件只找到一个的情况,也可以将其改为classpath*:
但是最好不要存在任何情况下文件名重复的问题,否则并不能保证PathMatchingResourcePatternResolver
找到的文件是你想要的文件
要求严格的情况下,也可以将sql文件放入软件包内,如Oauth2 Client中就将sql文件放到了java文件同级目录 classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql
这种情况下就基本保证文件路径不会重复了
3.判断表是否已经初始化
现在已经读取了所有的sql-init.properties
文件,获得了所有初始化sql文件的路径了,然后我们通过JdbcTemplate
执行sql判断对应的表是否已经建立,如果建立我们就跳过这个表对应的sql,使用mysql的 database()
函数来获取当前的默认数据库名
// language=sql
private static final String TABLE_NAME_SQL = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME in (?) and TABLE_SCHEMA = database()";
private List<String> getExistSchemas(Collection<String> schemas, JdbcTemplate jdbcTemplate) {
if (CollectionUtils.isEmpty(schemas)) {
return Collections.emptyList();
}
StringJoiner stringJoiner = new StringJoiner(",");
schemas.forEach(e -> {
if (!e.startsWith("'") || !e.endsWith("'")) {
stringJoiner.add("'" + e + "'");
} else {
stringJoiner.add(e);
}
});
String sql = TABLE_NAME_SQL.replace("?", stringJoiner.toString());
return jdbcTemplate.queryForList(sql, String.class);
}
(如果你是用的是idea 那么// language=sql
可以标识将下面一行的字符串注入sql语言,语言注入功能效果杠杠的)
4.执行数据库初始化
spring自带的DataSourceScriptDatabaseInitializer
本身就能读取并执行sql文件,我们直接拿来用,省去了超多代码,我们这里用SqlDataSourceScriptDatabaseInitializer
(springboot2.6,2.5可以用DataSourceScriptDatabaseInitializer
)
private void initDataSource(List<String> sqlLocations) throws Exception {
if (sqlLocations.size() <= 0) {
return;
}
SqlInitializationProperties properties = new SqlInitializationProperties();
properties.setEncoding(StandardCharsets.UTF_8);
properties.setMode(DatabaseInitializationMode.ALWAYS);
properties.setSchemaLocations(sqlLocations);
new SqlDataSourceScriptDatabaseInitializer(dataSource, properties).afterPropertiesSet();
log.info("\n数据库初始化sql已执行:{}",Arrays.toString(sqlLocations.toArray()));
}
完整代码
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.autoconfigure.sql.init.SqlDataSourceScriptDatabaseInitializer;
import org.springframework.boot.autoconfigure.sql.init.SqlInitializationProperties;
import org.springframework.boot.env.OriginTrackedMapPropertySource;
import org.springframework.boot.env.PropertiesPropertySourceLoader;
import org.springframework.boot.origin.OriginTrackedValue;
import org.springframework.boot.sql.init.DatabaseInitializationMode;
import org.springframework.core.env.PropertySource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.util.StringUtils;
import javax.sql.DataSource;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.Types;
import java.util.*;
import java.util.stream.Collectors;
/**
* 数据库自动初始化器自动检测默认数据库是否含有指定表,如果不含有则读取并执行初始化sql<br>
* <H1>使用方法:在"src/main/resources/sql-init.properties"中注册👇</H1>
* <pre>
* 表名=文件名
* brief_chain_log=classpath:brief_chain_log.sql
* </pre>
* src/main/sql/brief_chain_log.sql sql文件保证打包到jar包classes目录下即可
* pom要把sql文件和properties文件打包进去
* <pre>
* <build>
* <resources>
* <resource>
* <directory>src/main/sql</directory>
* </resource>
* <resource>
* <directory>src/main/resources</directory>
* </resource>
* </resources>
* </build>
* </pre>
* 文件名读取采用{@link PathMatchingResourcePatternResolver} 可采用通配符等ant匹配一次指定多个文件
* <pre>
* superHero=classpath:SSS-*.sql #所有SSS-开头的都会被找到,使用通配符时应注意,本地运行、本地测试类运行、以jar包在linux运行结果可能会不同
* </pre>
*/
@Slf4j
@RequiredArgsConstructor
public class SpringDatabaseInitializer implements InitializingBean {
public static final String BEAN_NAME = "springDatabaseInitializer";
private final JdbcTemplate jdbcTemplate;
private final DataSource dataSource;
private final PathMatchingResourcePatternResolver resourceResolver;
private final Map<String, List<String>> schemaSqlFiles = new HashMap<>();
// language=sql
private static final String TABLE_NAME_SQL = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME in (?) and TABLE_SCHEMA = database()";
@Override
public void afterPropertiesSet() throws Exception {
initDatabase();
destroySelf();
}
public synchronized void initDatabase() throws Exception {
initSchemaSqlFiles();
try {
initDataSource();
} catch (Exception e) {
if (!isEmbeddedDatabase()) {
throw e;
}
log.warn("内嵌数据库不支持某些SQL,数据库自动初始化失败!", e);
}
}
private synchronized void initSchemaSqlFiles() throws IOException {
PropertiesPropertySourceLoader loader = new PropertiesPropertySourceLoader();
for (Resource resource : resourceResolver.getResources("classpath*:sql-init.properties")) {
for (PropertySource<?> propertySource : loader.load("sql-init", resource)) {
for (Map.Entry<String, Object> entry : ((OriginTrackedMapPropertySource) propertySource).getSource().entrySet()) {
String location = (String) ((OriginTrackedValue) entry.getValue()).getValue();
if (StringUtils.hasText(entry.getKey()) && StringUtils.hasText(location)) {
List<String> locations = schemaSqlFiles.computeIfAbsent(entry.getKey(), t -> new ArrayList<>());
//如果有重复的,则改为 classpath*: 搜索全部同名文件
if (locations.contains(location)) {
location = convertToAllClassPath(location);
}
locations.add(location);
}
}
}
}
if (CollectionUtils.isEmpty(schemaSqlFiles)) {
return;
}
getExistSchemas(schemaSqlFiles.keySet(), jdbcTemplate).forEach(schemaSqlFiles::remove);
}
private synchronized void initDataSource() throws Exception {
if (CollectionUtils.isEmpty(schemaSqlFiles)) {
return;
}
List<String> sqlLocations = schemaSqlFiles.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
if (CollectionUtils.isEmpty(sqlLocations)) {
return;
}
if (sqlLocations.stream().distinct().count() != sqlLocations.size()) {
log.warn("A file with the same name appears!");
if (log.isDebugEnabled()) {
System.err.println(Arrays.toString(sqlLocations.toArray()));
}
}
SqlInitializationProperties properties = new SqlInitializationProperties();
properties.setEncoding(StandardCharsets.UTF_8);
properties.setMode(DatabaseInitializationMode.ALWAYS);
properties.setSchemaLocations(sqlLocations);
new SqlDataSourceScriptDatabaseInitializer(dataSource, properties).afterPropertiesSet();
log.info("\n数据库初始化sql已执行:{}", Arrays.toString(sqlLocations.toArray()));
}
private void destroySelf() {
SpringBeanRegistryUtil.unregisterBean(BEAN_NAME);
}
private List<String> getExistSchemas(Collection<String> schemas, JdbcTemplate jdbcTemplate) {
if (CollectionUtils.isEmpty(schemas)) {
return Collections.emptyList();
}
StringJoiner stringJoiner = new StringJoiner(",");
schemas.forEach(e -> {
if (!e.startsWith("'") || !e.endsWith("'")) {
stringJoiner.add("'" + e + "'");
} else {
stringJoiner.add(e);
}
});
String sql = TABLE_NAME_SQL.replace("?", stringJoiner.toString());
return jdbcTemplate.queryForList(sql, String.class);
}
private String convertToAllClassPath(String location) {
if (location.startsWith(ResourceLoader.CLASSPATH_URL_PREFIX)) {
return ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + location.substring(ResourceLoader.CLASSPATH_URL_PREFIX.length());
}
return location;
}
private boolean isEmbeddedDatabase() {
try {
return EmbeddedDatabaseConnection.isEmbedded(this.dataSource);
} catch (Exception ex) {
log.debug("Could not determine if datasource is embedded", ex);
return false;
}
}
}
使用方法
@Bean(SpringDatabaseInitializer.BEAN_NAME)
public static SpringDatabaseInitializer springDatabaseInitializer(JdbcTemplate jdbcTemplate, DataSource dataSource) {
return new SpringDatabaseInitializer(jdbcTemplate, dataSource, new PathMatchingResourcePatternResolver());
}
转载自:https://juejin.cn/post/7068188517012078606