SpringBoot 整合 Mockito 和 testcontainers 单元测试
目录结构
sql
pursue-project
--src
----main
----test
------java
--------pub.pursue.xxx
----------controller
------------PursueControllerTest.java
----------dao
------------PursueMapperTest.java
----------service
------------PursueServiceTest.java
----------App.java
----------BaseMapperTest.java
----------TestUtils.java
------resources
--------application.yml
--------clean.sql
--------create.sql
相关依赖
xml
<!-- mockito 静态方法用 -->
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-inline</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
<!--testcontainers 相关依赖-->
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>tidb</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mybatis.spring.boot</groupId>
<artifactId>mybatis-spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jacoco</groupId>
<artifactId>org.jacoco.agent</artifactId>
<classifier>runtime</classifier>
<scope>test</scope>
</dependency>
使用 Squaretest 生成相关的测试代码
使用idea 替换参数
一个参数的替换正则
java
(when\(\w+\.\w+\().{1,20}[^any()]\)\)\)
$1any()))
//替换前
when(xxMapper.xx(Arrays.asList("value"))).thenReturn(Collections.emptyList());
//替换后
when(xxMapper.xx(any())).thenReturn(Collections.emptyList());
java
(when\(\w+\.\w+\()\w+ \w+\(\)\)\)
$1any()))
//替换前
when(xxMapper.xx(new User())).thenReturn(xx);
//替换后
when(xxMapper.xx(any())).thenReturn(xx);
两个参数的替换正则
java
(when\(\w+\.\w+\().*\,.*[^any\(\)](\)\))
$1any(), any()$2
//替换前
when(xxMapper.delete("mecNo", "updater")).thenReturn(0);
//替换后
when(xxMapper.delete(any(), any())).thenReturn(0);
java
import com.google.common.collect.ImmutableMap;
import lombok.extern.slf4j.Slf4j;
import org.mockito.stubbing.Answer;
import java.lang.reflect.Array;
import java.lang.reflect.Field;
import java.sql.Timestamp;
import java.util.*;
import java.util.function.BiFunction;
import java.util.function.Function;
/**
* @author pursue
*/
@Slf4j
public class TestUtils {
private static final Map<Class<?>, Function<Class<?>, Object>> handlerHashMap = new HashMap<Class<?>, Function<Class<?>, Object>>() {{
put(String.class, c -> "string");
put(Integer.class, c -> 1);
put(Byte.class, c -> (byte) 0);
put(Short.class, c -> (short) 0);
put(Long.class, c -> 0L);
put(Float.class, c -> 0F);
put(Double.class, c -> 0D);
put(Boolean.class, c -> true);
put(Timestamp.class, c -> new Timestamp(new Date().getTime()));
put(Map.class, c -> Collections.emptyMap());
put(Set.class, c -> Collections.emptySet());
}};
static BiFunction<Class<?>, Map<String, Object>, Object> defaultHandlerWithMap = (type, fieldNameValMap) -> {
if (type.isEnum()) {
try {
Object values = type.getMethod("values").invoke(null);
Object[] enumObj = (Object[]) values;
return enumObj[0];
} catch (Exception e) {
return null;
}
}
if (type.isArray()) {
Class<?> componentType = type.getComponentType();
Object arrInstance = Array.newInstance(componentType, 1);
Object object = createObjAndFillFieldWithDefaultVal(componentType, fieldNameValMap);
Array.set(arrInstance, 0, object);
return arrInstance;
}
if (type.getName().contains("com.sensetime")) {
return createObjAndFillFieldWithDefaultVal(type, fieldNameValMap);
}
return null;
};
static Function<Class<?>, Object> defaultHandler = type -> defaultHandlerWithMap.apply(type, null);
public static Answer<?> AnswerByTypeDefault = invocation -> handlerHashMap.getOrDefault(invocation.getMethod().getReturnType(), defaultHandler)
.apply(invocation.getMethod().getReturnType());
public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz, Map<String, Object> fieldNameValMap) {
try {
T object = clazz.newInstance();
Field[] fields = clazz.getDeclaredFields();
for (Field field : fields) {
field.setAccessible(true);
Class<?> type = field.getType();
if (field.get(object) != null) {
continue;
}
// 返回指定在map中的值, 否则根据类型返回默认值
Object obj = fieldNameValMap.getOrDefault(field.getName(), handlerHashMap.getOrDefault(type, c -> defaultHandlerWithMap.apply(c, fieldNameValMap)).apply(type));
field.set(object, obj);
}
return object;
} catch (Exception e) {
log.error("FillFieldWithDefaultVal", e);
return null;
}
}
public static <T> T createObjAndFillFieldWithDefaultVal(Class<T> clazz) {
return createObjAndFillFieldWithDefaultVal(clazz, ImmutableMap.of());
}
}
相关代码
java
import com.google.common.collect.Maps;
import lombok.extern.slf4j.Slf4j;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mybatis.spring.annotation.MapperScan;
import org.mybatis.spring.boot.test.autoconfigure.MybatisTest;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.jdbc.AutoConfigureTestDatabase;
import org.springframework.boot.test.util.TestPropertyValues;
import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.jdbc.Sql;
import org.springframework.test.context.junit.jupiter.SpringExtension;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.tidb.TiDBContainer;
import org.testcontainers.utility.DockerImageName;
import java.lang.reflect.*;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Supplier;
/**
* @author pursue
*/
@ExtendWith(SpringExtension.class)
@MybatisTest
@MapperScan(basePackages = {"xxx.**.dao"})
@Sql(scripts = "/create.sql")
@Sql(scripts = "/clean.sql", executionPhase = Sql.ExecutionPhase.AFTER_TEST_METHOD)
@AutoConfigureTestDatabase(replace = AutoConfigureTestDatabase.Replace.NONE)
@Slf4j
@Testcontainers
@ContextConfiguration(initializers = BaseMapperTest.MyTiDBContainer.class)
public abstract class BaseMapperTest<M, T> {
@Container
protected TiDBContainer tidb = initTiDBContainer();
@Autowired
protected M mapper;
private TiDBContainer initTiDBContainer() {
log.info("initTiDBContainer...");
try (TiDBContainer tiDBContainer = new MyTiDBContainer()) {
return tiDBContainer
.withReuse(false)
.withUrlParam("allowMultiQueries", "true"); // 解决 client has multi-statement capability disabled
}
}
protected Supplier<Map<String, Object>> givenObjDefaultVal() {
return Maps::newHashMap;
}
protected static final String defaultString = "defaultString";
public static class MyTiDBContainer extends TiDBContainer implements ApplicationContextInitializer<ConfigurableApplicationContext> {
protected static final String IMAGE_NAME = "registry.xxx.com/xxx/pingcap/tidb:6.5.0";
/** 修改 {@link TiDBContainer} DEFAULT_IMAGE_NAME 的值 */
static {
try {
Field nameField = TiDBContainer.class.getDeclaredField("DEFAULT_IMAGE_NAME");
nameField.setAccessible(true);
Field modifiers = nameField.getClass().getDeclaredField("modifiers");
modifiers.setAccessible(true);
modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);
nameField.set(DockerImageName.parse("pingcap/tidb"), DockerImageName.parse(IMAGE_NAME));
modifiers.setInt(nameField, nameField.getModifiers() & ~Modifier.FINAL);
} catch (Exception e) {
e.printStackTrace();
}
}
public MyTiDBContainer() {
super(IMAGE_NAME);
}
public String getDriverClassName() {
try {
Class.forName("org.mariadb.jdbc.Driver");
return "org.mariadb.jdbc.Driver";
} catch (ClassNotFoundException var2) {
return "org.mariadb.jdbc.Driver";
}
}
private static final Integer TIDB_PORT = 4000;
public String getJdbcUrl() {
String additionalUrlParams = this.constructUrlParameters("?", "&");
return "jdbc:mariadb://" + this.getHost() + ":" + super.getMappedPort(TIDB_PORT) + "/" + super.getDatabaseName() + additionalUrlParams;
}
@Override
public void initialize(ConfigurableApplicationContext context) {
this.start();
TestPropertyValues.of(
"spring.datasource.username=" + this.getUsername(),
"spring.datasource.password=" + this.getPassword(),
"spring.datasource.url=" + this.getJdbcUrl()
).applyTo(context.getEnvironment());
}
}
/**
* 基础的 mapper insert 检测
*/
protected void insert() {
T tObj = getMapperEntity();
Object invoke = invokeMapperMethod("insert", tObj);
Assertions.assertEquals(invoke, 1);
}
/**
* 基础的 mapper update 检测
*/
protected <ID> void update() {
T tObj = getMapperEntity();
Object invoke = invokeMapperMethod("update", tObj);
Assertions.assertEquals(invoke, 1);
}
protected <ID> void deleteById(ID id) {
Object invoke = invokeMapperMethod("delete", id);
Assertions.assertEquals(invoke, 1);
}
protected <ID> void selectById(ID id) {
Object invoke = invokeMapperMethod("select", id);
Assertions.assertNotNull(invoke);
}
protected T getMapperEntity() {
try {
Object obj = TestUtils.createObjAndFillFieldWithDefaultVal(getGenericSuperclass(), givenObjDefaultVal().get());
return (T) obj;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
protected Object invokeMapperMethod(String methodName, Object... args) {
Class<?> mapperClass = mapper.getClass();
Method[] methods = mapperClass.getMethods();
Method method = Arrays.stream(methods)
.filter(m -> m.getName().equalsIgnoreCase(methodName))
.findAny()
.orElseThrow(() -> new RuntimeException("method is not exist"));
try {
return method.invoke(mapper, args);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
protected Class<?> getGenericSuperclass() {
try {
Type type = this.getClass().getGenericSuperclass();
if (type instanceof ParameterizedType) {
ParameterizedType parameterizedType = (ParameterizedType) type;
Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
return Arrays.stream(actualTypeArguments)
.map(c -> (Class<?>) c)
.filter(c -> !c.getName().contains("Mapper"))
.findAny()
.orElseThrow(RuntimeException::new);
}
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}