likes
comments
collection
share

利用Spring 自带类写一个自动初始化数据库的工具组件

作者站长头像
站长
· 阅读数 17

写了多个组件,如日志组件、权限组件等等,每个组件都各自需要几个表,这个时候往往一个新项目建立后,都需要引用几个组件,每次都要翻来翻去找sql文件建表,为了简化这个过程,决定在通用组件内添加一个统一的自动建表功能

参考spring.factories的思路,我们将表名和建表sql的key-value键值对放在约定好的sql-init.properties文件中,文件格式如下:

# suppress inspection "UnusedProperty" for whole file
login_log=classpath:login_log.sql
operation_log=classpath:operation_log.sql

sql-init.properties文件位于src/main/resources/sql-init.properties login_log.sql文件位于 src/main/sql/login_log.sql pom文件需要确保sql文件和properties文件被打包入jar包中

<build>
    <resources>
        <resource>
            <directory>src/main/resources</directory>
        </resource>
        <resource>
            <directory>src/main/sql</directory>
            <includes>
                <include>**/*.sql</include>
            </includes>
        </resource>
    </resources>
</build>

1.读取sql-init.properties文件

我们使用PathMatchingResourcePatternResolver 来查找不同jar包中所有的sql-init.properties文件,classpath*:开头的路径代表查找所有jar包:

Resource[] resources = new PathMatchingResourcePatternResolver().getResources("classpath*:sql-init.properties");

2.解析sql-init.properties文件

文件被读取之后,是Spring 的Resouce 类,想要读取其中key-value对,我们借用Spring 的PropertiesPropertySourceLoader 来读取文件,存储时使用Map<String, List<String>>,因为可能有不同包使用了同样的表名,某些sql文件是用来初始化表的,某些sql文件是用来添加一些初始数据的

Map<String, List<String>> schemaSqlFiles = new HashMap<>();
PropertiesPropertySourceLoader loader = new PropertiesPropertySourceLoader();
for (Resource resource : new PathMatchingResourcePatternResolver().getResources("classpath*:sql-init.properties")) {
    for (PropertySource<?> propertySource : loader.load("sql-init", resource)) {
        for (Map.Entry<String, Object> entry : ((OriginTrackedMapPropertySource) propertySource).getSource().entrySet()) {
            String location = (String) ((OriginTrackedValue) entry.getValue()).getValue();
            if (StringUtils.hasText(entry.getKey()) && StringUtils.hasText(location)) {
                schemaSqlFiles.computeIfAbsent(entry.getKey(), t -> new ArrayList<>()).add(location);
            }
        }
    }
}

如果遇到了同key下同value相同的文件名时,因为可能都要执行,所以需要将文件路径从classpath:改为classpath*:因为前者找到一个文件就停止了,不会找到全部的sql文件

如果存在不同key内,value sql文件路径相同的情况,这种情况下也会出现多个文件只找到一个的情况,也可以将其改为classpath*:

但是最好不要存在任何情况下文件名重复的问题,否则并不能保证PathMatchingResourcePatternResolver找到的文件是你想要的文件

要求严格的情况下,也可以将sql文件放入软件包内,如Oauth2 Client中就将sql文件放到了java文件同级目录 classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql 这种情况下就基本保证文件路径不会重复了

3.判断表是否已经初始化

现在已经读取了所有的sql-init.properties文件,获得了所有初始化sql文件的路径了,然后我们通过JdbcTemplate执行sql判断对应的表是否已经建立,如果建立我们就跳过这个表对应的sql,使用mysql的 database()函数来获取当前的默认数据库名

// language=sql
private static final String TABLE_NAME_SQL = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME in (?) and TABLE_SCHEMA = database()";

private List<String> getExistSchemas(Collection<String> schemas, JdbcTemplate jdbcTemplate) {
    if (CollectionUtils.isEmpty(schemas)) {
        return Collections.emptyList();
    }
    StringJoiner stringJoiner = new StringJoiner(",");
    schemas.forEach(e -> {
        if (!e.startsWith("'") || !e.endsWith("'")) {
            stringJoiner.add("'" + e + "'");
        } else {
            stringJoiner.add(e);
        }
    });
    String sql = TABLE_NAME_SQL.replace("?", stringJoiner.toString());
    return jdbcTemplate.queryForList(sql, String.class);
}

(如果你是用的是idea 那么// language=sql 可以标识将下面一行的字符串注入sql语言,语言注入功能效果杠杠的)

4.执行数据库初始化

spring自带的DataSourceScriptDatabaseInitializer本身就能读取并执行sql文件,我们直接拿来用,省去了超多代码,我们这里用SqlDataSourceScriptDatabaseInitializer(springboot2.6,2.5可以用DataSourceScriptDatabaseInitializer

private void initDataSource(List<String> sqlLocations) throws Exception {
    if (sqlLocations.size() <= 0) {
        return;
    }
    SqlInitializationProperties properties = new SqlInitializationProperties();
    properties.setEncoding(StandardCharsets.UTF_8);
    properties.setMode(DatabaseInitializationMode.ALWAYS);
    properties.setSchemaLocations(sqlLocations);
    new SqlDataSourceScriptDatabaseInitializer(dataSource, properties).afterPropertiesSet();
    log.info("\n数据库初始化sql已执行:{}",Arrays.toString(sqlLocations.toArray()));
}

完整代码

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.autoconfigure.sql.init.SqlDataSourceScriptDatabaseInitializer;
import org.springframework.boot.autoconfigure.sql.init.SqlInitializationProperties;
import org.springframework.boot.env.OriginTrackedMapPropertySource;
import org.springframework.boot.env.PropertiesPropertySourceLoader;
import org.springframework.boot.origin.OriginTrackedValue;
import org.springframework.boot.sql.init.DatabaseInitializationMode;
import org.springframework.core.env.PropertySource;
import org.springframework.core.io.Resource;
import org.springframework.core.io.ResourceLoader;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.util.StringUtils;

import javax.sql.DataSource;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.sql.Types;
import java.util.*;
import java.util.stream.Collectors;

/**
 * 数据库自动初始化器自动检测默认数据库是否含有指定表,如果不含有则读取并执行初始化sql<br>
 * <H1>使用方法:在"src/main/resources/sql-init.properties"中注册👇</H1>
 * <pre>
 *     表名=文件名
 *     brief_chain_log=classpath:brief_chain_log.sql
 * </pre>
 * src/main/sql/brief_chain_log.sql sql文件保证打包到jar包classes目录下即可
 * pom要把sql文件和properties文件打包进去
 * <pre>
 * &lt;build&gt;
 *     &lt;resources&gt;
 *        &lt;resource&gt;
 *           &lt;directory&gt;src/main/sql&lt;/directory&gt;
 *        &lt;/resource&gt;
 *        &lt;resource&gt;
 *           &lt;directory&gt;src/main/resources&lt;/directory&gt;
 *        &lt;/resource&gt;
 *     &lt;/resources&gt;
 * &lt;/build&gt;
 * </pre>
 * 文件名读取采用{@link PathMatchingResourcePatternResolver} 可采用通配符等ant匹配一次指定多个文件
 * <pre>
 *     superHero=classpath:SSS-*.sql #所有SSS-开头的都会被找到,使用通配符时应注意,本地运行、本地测试类运行、以jar包在linux运行结果可能会不同
 * </pre>
 */
@Slf4j
@RequiredArgsConstructor
public class SpringDatabaseInitializer implements InitializingBean {
    public static final String BEAN_NAME = "springDatabaseInitializer";
    private final JdbcTemplate jdbcTemplate;
    private final DataSource dataSource;
    private final PathMatchingResourcePatternResolver resourceResolver;
    private final Map<String, List<String>> schemaSqlFiles = new HashMap<>();
    // language=sql
    private static final String TABLE_NAME_SQL = "SELECT TABLE_NAME FROM information_schema.TABLES WHERE TABLE_NAME in (?) and TABLE_SCHEMA = database()";

    @Override
    public void afterPropertiesSet() throws Exception {
        initDatabase();
        destroySelf();
    }

    public synchronized void initDatabase() throws Exception {
        initSchemaSqlFiles();
        try {
            initDataSource();
        } catch (Exception e) {
            if (!isEmbeddedDatabase()) {
                throw e;
            }
            log.warn("内嵌数据库不支持某些SQL,数据库自动初始化失败!", e);
        }
    }

    private synchronized void initSchemaSqlFiles() throws IOException {
        PropertiesPropertySourceLoader loader = new PropertiesPropertySourceLoader();
        for (Resource resource : resourceResolver.getResources("classpath*:sql-init.properties")) {
            for (PropertySource<?> propertySource : loader.load("sql-init", resource)) {
                for (Map.Entry<String, Object> entry : ((OriginTrackedMapPropertySource) propertySource).getSource().entrySet()) {
                    String location = (String) ((OriginTrackedValue) entry.getValue()).getValue();
                    if (StringUtils.hasText(entry.getKey()) && StringUtils.hasText(location)) {
                        List<String> locations = schemaSqlFiles.computeIfAbsent(entry.getKey(), t -> new ArrayList<>());
                        //如果有重复的,则改为 classpath*: 搜索全部同名文件
                        if (locations.contains(location)) {
                            location = convertToAllClassPath(location);
                        }
                        locations.add(location);
                    }
                }
            }
        }
        if (CollectionUtils.isEmpty(schemaSqlFiles)) {
            return;
        }
        getExistSchemas(schemaSqlFiles.keySet(), jdbcTemplate).forEach(schemaSqlFiles::remove);
    }

    private synchronized void initDataSource() throws Exception {
        if (CollectionUtils.isEmpty(schemaSqlFiles)) {
            return;
        }
        List<String> sqlLocations = schemaSqlFiles.values().stream().flatMap(Collection::stream).collect(Collectors.toList());
        if (CollectionUtils.isEmpty(sqlLocations)) {
            return;
        }
        if (sqlLocations.stream().distinct().count() != sqlLocations.size()) {
            log.warn("A file with the same name appears!");
            if (log.isDebugEnabled()) {
                System.err.println(Arrays.toString(sqlLocations.toArray()));
            }
        }
        SqlInitializationProperties properties = new SqlInitializationProperties();
        properties.setEncoding(StandardCharsets.UTF_8);
        properties.setMode(DatabaseInitializationMode.ALWAYS);
        properties.setSchemaLocations(sqlLocations);
        new SqlDataSourceScriptDatabaseInitializer(dataSource, properties).afterPropertiesSet();
        log.info("\n数据库初始化sql已执行:{}", Arrays.toString(sqlLocations.toArray()));
    }

    private void destroySelf() {
        SpringBeanRegistryUtil.unregisterBean(BEAN_NAME);
    }

    private List<String> getExistSchemas(Collection<String> schemas, JdbcTemplate jdbcTemplate) {
        if (CollectionUtils.isEmpty(schemas)) {
            return Collections.emptyList();
        }
        StringJoiner stringJoiner = new StringJoiner(",");
        schemas.forEach(e -> {
            if (!e.startsWith("'") || !e.endsWith("'")) {
                stringJoiner.add("'" + e + "'");
            } else {
                stringJoiner.add(e);
            }
        });
        String sql = TABLE_NAME_SQL.replace("?", stringJoiner.toString());
        return jdbcTemplate.queryForList(sql, String.class);
    }

    private String convertToAllClassPath(String location) {
        if (location.startsWith(ResourceLoader.CLASSPATH_URL_PREFIX)) {
            return ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + location.substring(ResourceLoader.CLASSPATH_URL_PREFIX.length());
        }
        return location;
    }

    private boolean isEmbeddedDatabase() {
        try {
            return EmbeddedDatabaseConnection.isEmbedded(this.dataSource);
        } catch (Exception ex) {
            log.debug("Could not determine if datasource is embedded", ex);
            return false;
        }
    }
}

使用方法

@Bean(SpringDatabaseInitializer.BEAN_NAME)
public static SpringDatabaseInitializer springDatabaseInitializer(JdbcTemplate jdbcTemplate, DataSource dataSource) {
    return new SpringDatabaseInitializer(jdbcTemplate, dataSource, new PathMatchingResourcePatternResolver());
}
转载自:https://juejin.cn/post/7068188517012078606
评论
请登录