Spring Cloud Gateway 自定义 ReadBodyPredicateFactory 实现动态路由

在互网企业当中网关的重要性我就不再赘述了,相信大家都比较清楚。我们公司网关采用的是 Spring Cloud Gateway。并且是通过自定义 RouteLocator 来实现动态路由的。路由规则是请求参数里面的 bizType ,比如接收 JSON 格式体的请求对象并且业务方请求的是创建支付订单接口,下面就是业务方需要传递的参数:

{
	"bizType" : "createOrder",
	.... 其它业务参数 
}

下面就是读取 requestBody 里面的主动参数,然后解析请求对象里面的 bizType ,来决定它的路由地址:


由于历史原因,网关不仅需要 application/json 这种 Json 格式 MediaType 的请求对象,还需要支持 MediaType 为 application/x-www-form-urlencoded 这种请求。而网关之前的处理方式比较粗暴,当有请求来临的时候因为有可能是 application/x-www-form-urlencoded 所以直接 URLDecoder :

将请求中特殊字符转义

  public static String requestDecode(String requestBody){
       try {
           return URLDecoder.decode(convertStringForAdd(requestBody), "UTF-8");
       } catch (UnsupportedEncodingException e) {
           log.error("requestBody decode error: {}", e);
       }
       return requestBody;
   }

这种处理方式导致的问题就是如果 JSON 请求参数里面带有 % 就会报以下错误:

针对这种问题其实有两种处理方式:

  • 把对象进行转换成 JSON,如果转换成功就OK,否则就先 UrlEncode,然后再用 & 分离 Key/value。
  • 还有一种方式就是在进行读取 requestBody 之前获取到它的 MediaType

上面两种方式当然就第二种方式更加优雅。下面我们来想一想如何在读取 requestBody 之前获取到 Http 请求的 MediaType 的。

当我们在路由调用 readBody 的时候其实就是调用下面的方法:

org.springframework.cloud.gateway.route.builder.PredicateSpec#readBody

public <T> BooleanSpec readBody(Class<T> inClass, Predicate<T> predicate) {
	return asyncPredicate(getBean(ReadBodyPredicateFactory.class)
			.applyAsync(c -> c.setPredicate(inClass, predicate)));
}

Spring Cloud 中 ReadBodyPredicateFactory 的实现方式如下:

public class ReadBodyPredicateFactory
		extends AbstractRoutePredicateFactory<ReadBodyPredicateFactory.Config> {
	...
	
	@Override
	@SuppressWarnings("unchecked")
	public AsyncPredicate<ServerWebExchange> applyAsync(Config config) {
		return exchange -> {
			Class inClass = config.getInClass();

			Object cachedBody = exchange.getAttribute(CACHE_REQUEST_BODY_OBJECT_KEY);
			Mono<?> modifiedBody;
			// We can only read the body from the request once, once that happens if we
			// try to read the body again an exception will be thrown. The below if/else
			// caches the body object as a request attribute in the ServerWebExchange
			// so if this filter is run more than once (due to more than one route
			// using it) we do not try to read the request body multiple times
			if (cachedBody != null) {
				try {
					boolean test = config.predicate.test(cachedBody);
					exchange.getAttributes().put(TEST_ATTRIBUTE, test);
					return Mono.just(test);
				}
				catch (ClassCastException e) {
					if (log.isDebugEnabled()) {
						log.debug("Predicate test failed because class in predicate "
								+ "does not match the cached body object", e);
					}
				}
				return Mono.just(false);
			}
			else {
				return ServerWebExchangeUtils.cacheRequestBodyAndRequest(exchange,
						(serverHttpRequest) -> ServerRequest
								.create(exchange.mutate().request(serverHttpRequest)
										.build(), messageReaders)
								.bodyToMono(inClass)
								.doOnNext(objectValue -> exchange.getAttributes()
										.put(CACHE_REQUEST_BODY_OBJECT_KEY, objectValue))
								.map(objectValue -> config.getPredicate()
										.test(objectValue)));
			}
		};
	}
	...
}

我们可以看到这里使用了对象 ServerWebExchange ,而这个对象就是 Spring webflux 定义的 Http 请求对象。上面的代码逻辑是判断 exchange 中的属性中是否包含属性为 cachedRequestBodyObjectrequestBody 对象,如果不包含就解析并添加 cachedRequestBodyObjectexchange 。在这里可以看到我们对 ReadBodyPredicateFactory 对象并不可以扩展,所以唯一的方式就是继承这个类,因为在读取 MediaType 的时候参数只有 requestBody:String ,所以我们只有通过 ThreadLocal 来进行参数传递。在真正 PredicateSpec#readBody 获取到 MediaType,就可以很好的解析 requestBody 。下面就是具体的代码实现:

1、GatewayContext.java

GatewayContext 定义网关上下文,保存 MediaType 用于 readBody 时解析。

GatewayContext.java

@Getter
@Setter
public class GatewayContext {

	private MediaType mediaType;

}

2、GatewayContextHolder.java

GatewayContextHolder 通过 ThreadLocal 传递 GatewayContext ,在请求对象解析时使用。

GatewayContextHolder.java

public class GatewayContextHolder {

    private static Logger logger = LoggerFactory.getLogger(GatewayContextHolder.class);

    private static ThreadLocal<GatewayContext> tl = new ThreadLocal<>();

    public static GatewayContext get() {
        if (tl.get() == null) {
            logger.error("gateway context not exist");
            throw new RuntimeException("gateway context is null");
        }
        return tl.get();
    }

    public static void set(GatewayContext sc) {
        if (tl.get() != null) {
            logger.error("gateway context not null");
            tl.remove();
        }
        tl.set(sc);
    }

    public static void cleanUp() {
        try {
            if (tl.get() != null) {
                tl.remove();
            }
        } catch (Exception e) {
            logger.error(e.getMessage(), e);
        }
    }

}

3、CustomReadBodyPredicateFactory.java

CustomReadBodyPredicateFactory 继承 ReadBodyPredicateFactory ,在原有解析 requestBody 的情况下,添加获取 MediaType 的逻辑。

CustomReadBodyPredicateFactory.java

public class CustomReadBodyPredicateFactory extends ReadBodyPredicateFactory {

    protected static final Log log = LogFactory.getLog(CustomReadBodyPredicateFactory.class);

    private static final String TEST_ATTRIBUTE = "read_body_predicate_test_attribute";

    private static final String CACHE_REQUEST_BODY_OBJECT_KEY = "cachedRequestBodyObject";

    private static final List<HttpMessageReader<?>> messageReaders = HandlerStrategies
            .withDefaults().messageReaders();

    public CustomReadBodyPredicateFactory() {
        super();
    }

    @Override
    public AsyncPredicate<ServerWebExchange> applyAsync(ReadBodyPredicateFactory.Config config) {
        return exchange -> {
            Class inClass = config.getInClass();

            Object cachedBody = exchange.getAttribute(CACHE_REQUEST_BODY_OBJECT_KEY);
            
            // 获取 MediaType
            MediaType mediaType = exchange.getRequest().getHeaders().getContentType();
            GatewayContext context = new GatewayContext();
            context.setMediaType(mediaType);
            GatewayContextHolder.set(context);
            // We can only read the body from the request once, once that happens if we
            // try to read the body again an exception will be thrown. The below if/else
            // caches the body object as a request attribute in the ServerWebExchange
            // so if this filter is run more than once (due to more than one route
            // using it) we do not try to read the request body multiple times
            if (cachedBody != null) {
                try {
                    boolean test = config.getPredicate().test(cachedBody);
                    exchange.getAttributes().put(TEST_ATTRIBUTE, test);
                    return Mono.just(test);
                }
                catch (ClassCastException e) {
                    if (log.isDebugEnabled()) {
                        log.debug("Predicate test failed because class in predicate "
                                + "does not match the cached body object", e);
                    }
                }
                return Mono.just(false);
            }
            else {
                return ServerWebExchangeUtils.cacheRequestBodyAndRequest(exchange,
                        (serverHttpRequest) -> ServerRequest
                                .create(exchange.mutate().request(serverHttpRequest)
                                        .build(), messageReaders)
                                .bodyToMono(inClass)
                                .doOnNext(objectValue -> exchange.getAttributes()
                                        .put(CACHE_REQUEST_BODY_OBJECT_KEY, objectValue))
                                .map(objectValue -> config.getPredicate()
                                        .test(objectValue)));
            }
        };
    }

}

4、GatewayBeanFactoryPostProcessor.java

通过 Spring framework 的 BeanDefinitionRegistryPostProcessor 扩展在实例化对象之前,把 readBody 的原有操作类 ReadBodyPredicateFactory 删除,替换成我们自定义类 CustomReadBodyPredicateFactory

GatewayBeanFactoryPostProcessor.java

@Component
public class GatewayBeanFactoryPostProcessor implements BeanDefinitionRegistryPostProcessor {

    @Override
    public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
        // do nothing
    }

    @Override
    public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
        registry.removeBeanDefinition("readBodyPredicateFactory");
        BeanDefinition beanDefinition = BeanDefinitionBuilder.rootBeanDefinition(CustomReadBodyPredicateFactory.class)
                .setScope(BeanDefinition.SCOPE_SINGLETON)
                .setRole(BeanDefinition.ROLE_SUPPORT)
                .getBeanDefinition();
        registry.registerBeanDefinition("readBodyPredicateFactory", beanDefinition);
    }

}

下面就是修改后的自定义路由规则。

public RouteLocatorBuilder.Builder route(RouteLocatorBuilder.Builder builder) {
    return builder.route(r -> r.readBody(String.class, requestBody -> {
        MediaType mediaType = GatewayContextHolder.get().getMediaType();
        // 通过 mediaType 解析 requestBody 然后从解析后的对象获取路由规则
        ...
    );
}        

后面就不会报之前的异常了。


原文:Spring Cloud Gateway 自定义 ReadBodyPredicateFactory 实现动态路由_readbodyroutepredicatefactory-CSDN博客
作者: carl-zhao