Skip to content

SpringBoot 整合 Mockito 和 testcontainers 单元测试

目录结构

sql
pursue-project
--src
----main
----test
------java
--------pub.pursue.xxx
----------controller
------------PursueControllerTest.java
----------dao
------------PursueMapperTest.java
----------service
------------PursueServiceTest.java
----------App.java
----------BaseMapperTest.java
----------TestUtils.java
------resources
--------application.yml
--------clean.sql
--------create.sql

相关依赖

xml
 <!-- mockito 静态方法用 -->
<dependency>
    <groupId>org.mockito</groupId>
    <artifactId>mockito-inline</artifactId>
    <scope>test</scope>
</dependency>
<dependency>
    <groupId>org.testcontainers</groupId>
    <artifactId>junit-jupiter</artifactId>
    <scope>test</scope>
</dependency>
<!--testcontainers 相关依赖-->
<dependency>
    <groupId>org.testcontainers</groupId>
    <artifactId>testcontainers</artifactId>
    <scope>test</scope>
</dependency>
<dependency>
    <groupId>org.testcontainers</groupId>
    <artifactId>tidb</artifactId>
    <scope>test</scope>
</dependency>
<dependency>
    <groupId>org.mybatis.spring.boot</groupId>
    <artifactId>mybatis-spring-boot-starter-test</artifactId>
    <scope>test</scope>
</dependency>
<dependency>
    <groupId>org.jacoco</groupId>
    <artifactId>org.jacoco.agent</artifactId>
    <classifier>runtime</classifier>
    <scope>test</scope>
</dependency>

使用 Squaretest 生成相关的测试代码

使用idea 替换参数

一个参数的替换正则

java
(when\(\w+\.\w+\().{1,20}[^any()]\)\)\)

$1any()))
 
//替换前
    when(xxMapper.xx(Arrays.asList("value"))).thenReturn(Collections.emptyList());
//替换后
    when(xxMapper.xx(any())).thenReturn(Collections.emptyList());
java
(when\(\w+\.\w+\()\w+ \w+\(\)\)\)

$1any()))
 
//替换前   
    when(xxMapper.xx(new User())).thenReturn(xx);
//替换后
    when(xxMapper.xx(any())).thenReturn(xx);

两个参数的替换正则

java
(when\(\w+\.\w+\().*\,.*[^any\(\)](\)\))
        
$1any(), any()$2
 
//替换前
    when(xxMapper.delete("mecNo", "updater")).thenReturn(0);
//替换后
    when(xxMapper.delete(any(), any())).thenReturn(0);
java
import com.google.common.collect.ImmutableMap;
import lombok.extern.slf4j.Slf4j;
import org.mockito.stubbing.Answer;

import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.sql.Timestamp;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
 * @author pursue
 */
@Slf4j
public class TestUtils {
    private static final Map<Class<?>, Function<Class<?>, Object>> handlerHashMap = new HashMap<Class<?>, Function<Class<?>, Object>>() {{
        put(String.class, c -> "string");
        put(Integer.class, c -> 1);
        put(Byte.class, c -> (byte) 0);
        put(Short.class, c -> (short) 0);
        put(Long.class, c -> 0L);
        put(Float.class, c -> 0F);
        put(Double.class, c -> 0D);
        put(Boolean.class, c -> true);
        put(Timestamp.class, c -> new Timestamp(new Date().getTime()));
        put(Map.class, c -> Collections.emptyMap());
        put(Set.class, c -> Collections.emptySet());
    }};


    static BiFunction<Class<?>, Map<String, Object>, Object> defaultHandlerWithMap = (type, fieldNameValMap) -> {
        if (type.isEnum()) {
            try {
                Object values = type.getMethod("values").invoke(null);
                Object[] enumObj = (Object[]) values;
                return enumObj[0];
            } catch (Exception e) {
                return null;
            }
        }
        if (type.isArray()) {
            Class<?> componentType = type.getComponentType();
            Object arrInstance = Array.newInstance(componentType, 1);
            Object object = createObjAndFillFieldWithDefaultVal(componentType, fieldNameValMap);
            Array.set(arrInstance, 0, object);
            return arrInstance;
        }

        if (type.getName().contains("com.sensetime")) {
            return createObjAndFillFieldWithDefaultVal(type, fieldNameValMap);
        }
        return null;
    };

    static Function<Class<?>, Object> defaultHandler = type -> defaultHandlerWithMap.apply(type, null);


    public static Answer<?> AnswerByTypeDefault = invocation -> handlerHashMap.getOrDefault(invocation.getMethod().getReturnType(), defaultHandler)
            .apply(invocation.getMethod().getReturnType());


    public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz, Map<String, Object> fieldNameValMap) {
        try {
            T object = clazz.newInstance();
            Field[] fields = clazz.getDeclaredFields();
            for (Field field : fields) {
                field.setAccessible(true);
                Class<?> type = field.getType();
                if (field.get(object) != null) {
                    continue;
                }

                // 返回指定在map中的值, 否则根据类型返回默认值
                Object obj = fieldNameValMap.getOrDefault(field.getName(), handlerHashMap.getOrDefault(type, c -> defaultHandlerWithMap.apply(c, fieldNameValMap)).apply(type));
                field.set(object, obj);
            }
            return object;
        } catch (Exception e) {
            log.error("FillFieldWithDefaultVal", e);
            return null;
        }
    }

    public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz) {
        return createObjAndFillFieldWithDefaultVal(clazz, ImmutableMap.of());
    }
}

相关代码

java

import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mybatis.spring.annotation.MapperScan;
import org.mybatis.spring.boot.test.autoconfigure.MybatisTest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase;
import org.springframework.boot.test.util.TestPropertyValues;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.tidb.TiDBContainer;
import org.testcontainers.utility.DockerImageName;

import java.lang.reflect.*;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Supplier;

/**
 * @author pursue
 */
@ExtendWith(SpringExtension.class)
@MybatisTest
@MapperScan(basePackages = {"xxx.**.dao"})
@Sql(scripts = "/create.sql")
@Sql(scripts = "/clean.sql", executionPhase = Sql.ExecutionPhase.AFTER_TEST_METHOD)
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Slf4j
@Testcontainers
@ContextConfiguration(initializers = BaseMapperTest.MyTiDBContainer.class)
public abstract class BaseMapperTest<M, T> {
    @Container
    protected TiDBContainer tidb = initTiDBContainer();
    @Autowired
    protected M mapper;

    private TiDBContainer initTiDBContainer() {
        log.info("initTiDBContainer...");
        try (TiDBContainer tiDBContainer = new MyTiDBContainer()) {
            return tiDBContainer
                    .withReuse(false)
                    .withUrlParam("allowMultiQueries", "true"); // 解决 client has multi-statement capability disabled
        }
    }

    protected Supplier<Map<String, Object>> givenObjDefaultVal() {
        return Maps::newHashMap;
    }

    protected static final String defaultString = "defaultString";

    public static class MyTiDBContainer extends TiDBContainer implements ApplicationContextInitializer<ConfigurableApplicationContext> {
        protected static final String IMAGE_NAME = "registry.xxx.com/xxx/pingcap/tidb:6.5.0";

        /** 修改 {@link TiDBContainer} DEFAULT_IMAGE_NAME 的值 */
        static {
            try {
                Field nameField = TiDBContainer.class.getDeclaredField("DEFAULT_IMAGE_NAME");
                nameField.setAccessible(true);
                Field modifiers = nameField.getClass().getDeclaredField("modifiers");
                modifiers.setAccessible(true);
                modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);

                nameField.set(DockerImageName.parse("pingcap/tidb"), DockerImageName.parse(IMAGE_NAME));
                modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);
            } catch (Exception e) {
                e.printStackTrace();
            }
        }

        public MyTiDBContainer() {
            super(IMAGE_NAME);
        }

        public String getDriverClassName() {
            try {
                Class.forName("org.mariadb.jdbc.Driver");
                return "org.mariadb.jdbc.Driver";
            } catch (ClassNotFoundException var2) {
                return "org.mariadb.jdbc.Driver";
            }
        }

        private static final Integer TIDB_PORT = 4000;

        public String getJdbcUrl() {
            String additionalUrlParams = this.constructUrlParameters("?", "&");
            return "jdbc:mariadb://" + this.getHost() + ":" + super.getMappedPort(TIDB_PORT) + "/" + super.getDatabaseName() + additionalUrlParams;
        }

        @Override
        public void initialize(ConfigurableApplicationContext context) {
            this.start();
            TestPropertyValues.of(
                    "spring.datasource.username=" + this.getUsername(),
                    "spring.datasource.password=" + this.getPassword(),
                    "spring.datasource.url=" + this.getJdbcUrl()
            ).applyTo(context.getEnvironment());
        }
    }

    /**
     * 基础的 mapper insert 检测
     */
    protected void insert() {
        T tObj = getMapperEntity();
        Object invoke = invokeMapperMethod("insert", tObj);
        Assertions.assertEquals(invoke, 1);
    }

    /**
     * 基础的 mapper update 检测
     */
    protected <ID> void update() {
        T tObj = getMapperEntity();
        Object invoke = invokeMapperMethod("update", tObj);
        Assertions.assertEquals(invoke, 1);
    }

    protected <ID> void deleteById(ID id) {
        Object invoke = invokeMapperMethod("delete", id);
        Assertions.assertEquals(invoke, 1);
    }

    protected <ID> void selectById(ID id) {
        Object invoke = invokeMapperMethod("select", id);
        Assertions.assertNotNull(invoke);
    }

    protected T getMapperEntity() {
        try {
            Object obj = TestUtils.createObjAndFillFieldWithDefaultVal(getGenericSuperclass(), givenObjDefaultVal().get());
            return (T) obj;
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected Object invokeMapperMethod(String methodName, Object... args) {
        Class<?> mapperClass = mapper.getClass();

        Method[] methods = mapperClass.getMethods();
        Method method = Arrays.stream(methods)
                .filter(m -> m.getName().equalsIgnoreCase(methodName))
                .findAny()
                .orElseThrow(() -> new RuntimeException("method is not exist"));

        try {
            return method.invoke(mapper, args);
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new RuntimeException(e);
        }
    }

    protected Class<?> getGenericSuperclass() {
        try {
            Type type = this.getClass().getGenericSuperclass();
            if (type instanceof ParameterizedType) {
                ParameterizedType parameterizedType = (ParameterizedType) type;
                Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
                return Arrays.stream(actualTypeArguments)
                        .map(c -> (Class<?>) c)
                        .filter(c -> !c.getName().contains("Mapper"))
                        .findAny()
                        .orElseThrow(RuntimeException::new);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}

文章来源于自己总结和网络转载,内容如有任何问题,请大佬斧正!联系我