在前面的一篇文章中的结尾提到过,使用MyBatis-Plus的多数据源方案,当调用MyBatis-Plus提供的IService的saveBatch、updateBatchById等批量方法时,会使多数据源失效。有兴趣的可以去看一下这篇文章:

https://blog.flycat.tech/archives/spring-transaction-annotation-causes-mybatis-dynamic-datasource-failure

导致多数据源失效的原因有两个:一个是@Transactional导致多数据源失效,另一个是在执行SQL时没有设置动态数据源名称。

对于我们自己开发的代码,当然是可以避免@DSTransactional和@Transactional混用的,但是在MyBatis-Plus提供的IService中,其中的批量方法(包括saveBatch、updateBatchById等)上被标记上了@Transactional注解,而我们又无法去更改第三方包的代码,那么这个时候就只能手动复制saveBatch、updateBatchById等,然后创建一个新的方法,并且在这个新的方法中增加@DS注解,以后所有的批量调用都必须调用这个新的方法。类似以下的形式:

830F4DF28E0B4430B44CACDB4ECC7786.png

需要注意的是,不能直接重写IService中的这些批量方法,因为@Transactional是直接标注在了IService接口的方法上的,即使重写了saveBatch、updateBatchById,然后在子类的实现类中增加@DSTransactional注解,多数据源同样会失效。

有没有一种可以一劳永逸的方法呢?

对于MyBatis-Plus提供的IService,其实有两个问题需要解决:

  1. 需要去掉批量方法上面的@Transactional注解,改成@DSTransactional

  2. 批量方法(saveBatch、updateBatchById)实际最终是调用executeBatch来实现的,而在executeBatch中是直接通过sqlSession来执行的。也就是说,通过sqlSession执行的sql,没有办法获取到mapper类上@DS注解的数据源配置。

对于第一个问题很好解决,我们可以实现一个自定义的IDsService,在这个service中将@Transactional注解全部替换成@DSTransactional。然后将原来继承IService的类全部改成继承IDsService。

然而如果@DS注解是放在Mapper类上的,这样仍然不能解决问题。

/**
 * @author <a href="mailto:me@flycat.tech">Bryon Zen</a>
 * @since 2025/1/15
 */
// 动态数据源的注解放在Mapper类上,但是如果在service中是直接通过sqlSession执行的,则拿不到这个配置
@DS(DSConstant.DB_2)
public interface ProductMapper extends BaseMapper<Product> {
}

在前面的文章中我们看过MP的动态数据源的源码(不熟悉的可以去看看:MyBatis-Plus动态数据源实现原理)。

其实就是将动态数据源名称push到一个栈的结构中,如果我们可以实现在执行executeBatch方法时手动push这个数据源,是不是就可以了。

由于代码是不会变的,我们可以在IDsService的实现类DsServiceImpl中增加一个动态数据源名称的字段dsName,并且在加载类时解析Mapper类中的@DS注解的配置。代码如下:

/**
 * @author <a href="mailto:me@flycat.tech">Bryon Zen</a>
 * @since 2025/1/15
 */
public abstract class DsServiceImpl<M extends BaseMapper<T>, T> implements IDsService<T> {
    // ...
    // 省略其他不重要的代码

    @Autowired
    protected M baseMapper;
    protected Class<T> entityClass = currentModelClass();
    // Mapper的Class对象
    protected Class<T> mapperClass = currentMapperClass();

    /**
     * 动态数据源名称
     */
    private final String dsName = parseDsName();

    /**
     * 子类可继承该方法自定义批量执行使用的数据源
     * @return
     */
    protected String parseDsName() {
        // 先从service类找动态数据源名称
        DS serviceDsAnnotation = AnnotationUtils.findAnnotation(this.getClass(), DS.class);
        if (serviceDsAnnotation != null && StringUtils.isNotBlank(serviceDsAnnotation.value())) {
            return serviceDsAnnotation.value();
        }

        // 再从mapper类找动态数据源名称
        DS mapperDsAnnotation = AnnotationUtils.findAnnotation(mapperClass, DS.class);
        if (mapperDsAnnotation != null && StringUtils.isNotBlank(mapperDsAnnotation.value())) {
            return mapperDsAnnotation.value();
        }

        return null;
    }

    /**
     * 注入动态数据源并执行相应的代码
     */
    protected <D> D pushDsNameAndDo(Supplier<D> supplier) {
        if (dsName == null || dsName.isEmpty()) {
            return supplier.get();
        }

        DynamicDataSourceContextHolder.push(dsName);
        try {
            return supplier.get();
        } finally {
            DynamicDataSourceContextHolder.poll();
        }
    }

    /**
     * 批量插入
     *
     * @param entityList ignore
     * @param batchSize  ignore
     * @return ignore
     */
    @DSTransactional
    @Override
    public boolean saveBatch(Collection<T> entityList, int batchSize) {
        String sqlStatement = getSqlStatement(SqlMethod.INSERT_ONE);
        return executeBatch(entityList, batchSize, (sqlSession, entity) -> sqlSession.insert(sqlStatement, entity));
    }

    @DSTransactional
    @Override
    public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
        return pushDsNameAndDo(() -> {
            TableInfo tableInfo = TableInfoHelper.getTableInfo(entityClass);
            Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
            String keyProperty = tableInfo.getKeyProperty();
            Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
            return SqlHelper.saveOrUpdateBatch(this.entityClass, this.mapperClass, this.log, entityList, batchSize, (sqlSession, entity) -> {
                Object idVal = ReflectionKit.getFieldValue(entity, keyProperty);
                return StringUtils.checkValNull(idVal)
                        || CollectionUtils.isEmpty(sqlSession.selectList(getSqlStatement(SqlMethod.SELECT_BY_ID), entity));
            }, (sqlSession, entity) -> {
                MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
                param.put(Constants.ENTITY, entity);
                sqlSession.update(getSqlStatement(SqlMethod.UPDATE_BY_ID), param);
            });
        });
    }

    @DSTransactional
    @Override
    public boolean updateBatchById(Collection<T> entityList, int batchSize) {
        String sqlStatement = getSqlStatement(SqlMethod.UPDATE_BY_ID);
        return executeBatch(entityList, batchSize, (sqlSession, entity) -> {
            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
            param.put(Constants.ENTITY, entity);
            sqlSession.update(sqlStatement, param);
        });
    }

    /**
     * 执行批量操作
     *
     * @param list      数据集合
     * @param batchSize 批量大小
     * @param consumer  执行方法
     * @param <E>       泛型
     * @return 操作结果
     * @since 3.3.1
     */
    protected <E> boolean executeBatch(Collection<E> list, int batchSize, BiConsumer<SqlSession, E> consumer) {
        // 包了一层pushDsNameAndDo,在这个里面会设置动态数据源名称
        return pushDsNameAndDo(() -> SqlHelper.executeBatch(this.entityClass, this.log, list, batchSize, consumer));
    }

}

来解释一下上面的代码。

首先增加了一个final类型的dsName字段,当service类创建时会调用parseDsName方法解析动态数据源名称。

在parseDsName中,首先从当前类(Service类)查询有没有@DS注解,如果有的话就读取这个注解的动态数据源配置;如果当前类没有,则再从Mapper类找动态数据源名称,如果有的话就返回。

在执行saveBatch方法时,实际调用的是executeBatch方法,而在executeBatch中执行真正的SqlHelper.executeBatch方法之前,会先执行pushDsNameAndDo方法中的代码。

可以看到pushDsNameAndDo中其实就是将前面解析到的动态数据源名称dsName放到DynamicDataSourceContextHolder中,那么当执行SQL时就能拿到动态数据源配置了。

这样的话,就只需要将新的IDsService和IDsServiceImpl复制到项目中,然后替换原来的IService和IServiceImpl就行了。

完整代码和测试用例参见:https://github.com/bryonzen/mybatis-plus-ds-solution