fix(security): support wildcard/cidr host pattern matching (#710)

* fix(security): support wildcard/cidr host pattern matching

* fix(security): harden host matching against null and DNS rebinding

* fix(security): handle ipv4 unsigned range and deny template fallback

* test(security): verify CIDR matching for IPv4 upper boundary

* fix(security): set UTF-8 deny response and use Locale.ROOT

* fix(security): enforce whitelist with blacklist and harden wildcard rules
This commit is contained in:
kl
2026-03-03 15:26:35 +08:00
committed by GitHub
parent 92ca92bee6
commit 8c3bc81e08
3 changed files with 253 additions and 7 deletions

View File

@@ -4,13 +4,19 @@ import cn.keking.config.ConfigConstants;
import cn.keking.utils.WebUtils;
import java.io.IOException;
import java.util.Map;
import java.util.Locale;
import java.util.concurrent.ConcurrentHashMap;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.regex.Pattern;
import jakarta.servlet.Filter;
import jakarta.servlet.FilterChain;
import jakarta.servlet.FilterConfig;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpServletResponse;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
@@ -25,6 +31,7 @@ import org.springframework.util.FileCopyUtils;
public class TrustHostFilter implements Filter {
private static final Logger logger = LoggerFactory.getLogger(TrustHostFilter.class);
private final Map<String, Pattern> wildcardPatternCache = new ConcurrentHashMap<>();
private String notTrustHostHtmlView;
@Override
@@ -43,9 +50,16 @@ public class TrustHostFilter implements Filter {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
String url = WebUtils.getSourceUrl(request);
String host = WebUtils.getHost(url);
assert host != null;
if (isNotTrustHost(host)) {
String html = this.notTrustHostHtmlView.replace("${current_host}", host);
String currentHost = host == null ? "UNKNOWN" : host;
if (response instanceof HttpServletResponse httpServletResponse) {
httpServletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN);
}
response.setCharacterEncoding(StandardCharsets.UTF_8.name());
response.setContentType("text/html;charset=UTF-8");
String html = this.notTrustHostHtmlView == null
? "<html><head><meta charset=\"utf-8\"></head><body>当前预览文件来自不受信任的站点:" + currentHost + "</body></html>"
: this.notTrustHostHtmlView.replace("${current_host}", currentHost);
response.getWriter().write(html);
response.getWriter().close();
} else {
@@ -54,9 +68,15 @@ public class TrustHostFilter implements Filter {
}
public boolean isNotTrustHost(String host) {
if (host == null || host.trim().isEmpty()) {
logger.warn("主机名为空或无效,拒绝访问");
return true;
}
// 如果配置了黑名单,优先检查黑名单
if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) {
return ConfigConstants.getNotTrustHostSet().contains(host);
if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())
&& matchAnyPattern(host, ConfigConstants.getNotTrustHostSet())) {
return true;
}
// 如果配置了白名单,检查是否在白名单中
@@ -66,7 +86,7 @@ public class TrustHostFilter implements Filter {
logger.debug("允许所有主机访问(通配符模式): {}", host);
return false;
}
return !ConfigConstants.getTrustHostSet().contains(host);
return !matchAnyPattern(host, ConfigConstants.getTrustHostSet());
}
// 安全加固默认拒绝所有未配置的主机防止SSRF攻击
@@ -75,6 +95,136 @@ public class TrustHostFilter implements Filter {
return true;
}
private boolean matchAnyPattern(String host, Set<String> hostPatterns) {
String normalizedHost = host.toLowerCase(Locale.ROOT);
for (String hostPattern : hostPatterns) {
if (matchHostPattern(normalizedHost, hostPattern)) {
return true;
}
}
return false;
}
/**
* 支持三种匹配方式:
* 1. 精确匹配example.com
* 2. 通配符匹配:*.example.com、192.168.*
* 3. IPv4 CIDR192.168.0.0/16
*/
private boolean matchHostPattern(String host, String hostPattern) {
if (hostPattern == null || hostPattern.trim().isEmpty()) {
return false;
}
String pattern = hostPattern.trim().toLowerCase(Locale.ROOT);
if ("*".equals(pattern)) {
return true;
}
if (pattern.contains("/")) {
return matchIpv4Cidr(host, pattern);
}
if (pattern.contains("*")) {
if (isIpv4WildcardPattern(pattern)) {
return matchIpv4Wildcard(host, pattern);
}
Pattern compiledPattern = wildcardPatternCache.computeIfAbsent(pattern, key -> Pattern.compile(wildcardToRegex(key)));
return compiledPattern.matcher(host).matches();
}
return host.equals(pattern);
}
private boolean isIpv4WildcardPattern(String pattern) {
return pattern.matches("^[0-9.*]+$") && pattern.contains(".");
}
private boolean matchIpv4Wildcard(String host, String pattern) {
if (parseLiteralIpv4(host) == null) {
return false;
}
String[] hostParts = host.split("\\.");
String[] patternParts = pattern.split("\\.");
if (hostParts.length != 4 || patternParts.length < 1 || patternParts.length > 4) {
return false;
}
for (int i = 0; i < patternParts.length; i++) {
String p = patternParts[i];
if ("*".equals(p)) {
continue;
}
if (!p.equals(hostParts[i])) {
return false;
}
}
return true;
}
private String wildcardToRegex(String wildcard) {
StringBuilder regexBuilder = new StringBuilder("^");
String[] parts = wildcard.split("\\*", -1);
for (int i = 0; i < parts.length; i++) {
regexBuilder.append(Pattern.quote(parts[i]));
if (i < parts.length - 1) {
regexBuilder.append(".*");
}
}
regexBuilder.append("$");
return regexBuilder.toString();
}
private boolean matchIpv4Cidr(String host, String cidr) {
try {
String[] parts = cidr.split("/");
if (parts.length != 2) {
return false;
}
Long hostInt = parseLiteralIpv4(host);
Long networkInt = parseLiteralIpv4(parts[0]);
int prefixLength = Integer.parseInt(parts[1]);
if (hostInt == null || networkInt == null || prefixLength < 0 || prefixLength > 32) {
return false;
}
long mask = prefixLength == 0 ? 0L : (0xFFFFFFFFL << (32 - prefixLength)) & 0xFFFFFFFFL;
return (hostInt & mask) == (networkInt & mask);
} catch (NumberFormatException e) {
return false;
}
}
/**
* 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。
*/
private Long parseLiteralIpv4(String input) {
if (input == null || input.trim().isEmpty()) {
return null;
}
String[] parts = input.trim().split("\\.");
if (parts.length != 4) {
return null;
}
long result = 0L;
for (String part : parts) {
if (part.isEmpty() || part.length() > 3) {
return null;
}
int value;
try {
value = Integer.parseInt(part);
} catch (NumberFormatException e) {
return null;
}
if (value < 0 || value > 255) {
return null;
}
result = (result << 8) | value;
}
return result;
}
@Override
public void destroy() {

View File

@@ -0,0 +1,92 @@
package cn.keking.web.filter;
import cn.keking.config.ConfigConstants;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
public class TrustHostFilterTests {
private final TrustHostFilter trustHostFilter = new TrustHostFilter();
@AfterEach
void tearDown() {
ConfigConstants.setTrustHostValue("default");
ConfigConstants.setNotTrustHostValue("default");
}
@Test
void shouldBlockWildcardNotTrustHostPattern() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("192.168.*");
assert trustHostFilter.isNotTrustHost("192.168.1.10");
assert !trustHostFilter.isNotTrustHost("8.8.8.8");
assert !trustHostFilter.isNotTrustHost("192.168.evil.com");
}
@Test
void shouldBlockCidrNotTrustHostPattern() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("10.0.0.0/8");
assert trustHostFilter.isNotTrustHost("10.1.2.3");
assert !trustHostFilter.isNotTrustHost("11.1.2.3");
// Ensure hostnames are not matched by CIDR-based not-trust rules (no DNS resolution)
assert !trustHostFilter.isNotTrustHost("localhost");
}
@Test
void shouldSupportHighBitIpv4InCidrMatching() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("200.0.0.0/8");
assert trustHostFilter.isNotTrustHost("200.1.2.3");
assert !trustHostFilter.isNotTrustHost("199.1.2.3");
}
@Test
void shouldSupportIpv4UpperBoundaryCidrMatching() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("255.255.255.255/32");
assert trustHostFilter.isNotTrustHost("255.255.255.255");
assert !trustHostFilter.isNotTrustHost("255.255.255.254");
}
@Test
void shouldDenyWhenHostIsBlankOrNull() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("default");
assert trustHostFilter.isNotTrustHost(null);
assert trustHostFilter.isNotTrustHost(" ");
}
@Test
void shouldAllowWildcardTrustHostPattern() {
ConfigConstants.setTrustHostValue("*.trusted.com");
ConfigConstants.setNotTrustHostValue("default");
assert !trustHostFilter.isNotTrustHost("api.trusted.com");
assert trustHostFilter.isNotTrustHost("api.evil.com");
}
@Test
void shouldKeepBlacklistHigherPriorityThanWhitelist() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("127.0.0.1,10.*");
assert trustHostFilter.isNotTrustHost("127.0.0.1");
assert trustHostFilter.isNotTrustHost("10.1.2.3");
assert !trustHostFilter.isNotTrustHost("8.8.8.8");
}
@Test
void shouldStillEnforceWhitelistWhenBlacklistConfigured() {
ConfigConstants.setTrustHostValue("internal.example.com");
ConfigConstants.setNotTrustHostValue("127.0.0.1");
assert !trustHostFilter.isNotTrustHost("internal.example.com");
assert trustHostFilter.isNotTrustHost("8.8.8.8");
}
}