diff --git a/SECURITY_CONFIG.md b/SECURITY_CONFIG.md index ee87b300..ad96085b 100644 --- a/SECURITY_CONFIG.md +++ b/SECURITY_CONFIG.md @@ -146,11 +146,15 @@ trust.host = * ### Q4: 如何允许子域名? -目前不支持通配符域名匹配,需要明确列出每个子域名: +已支持通配符域名匹配,可使用 `*.example.com`: ```properties -trust.host = cdn.example.com,api.example.com,storage.example.com +trust.host = *.example.com ``` +说明: +- `*.example.com` 会匹配 `cdn.example.com`、`api.internal.example.com`,但不匹配根域 `example.com` +- 对于 IP 风格通配(如 `192.168.*`、`10.*`),仅匹配字面量 IPv4 地址,不匹配域名 + ## 🚨 安全事件响应 如果发现可疑的预览请求: diff --git a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java index e661844f..53441769 100644 --- a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java +++ b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java @@ -4,13 +4,19 @@ import cn.keking.config.ConfigConstants; import cn.keking.utils.WebUtils; import java.io.IOException; +import java.util.Map; +import java.util.Locale; +import java.util.concurrent.ConcurrentHashMap; import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.regex.Pattern; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; import jakarta.servlet.ServletException; import jakarta.servlet.ServletRequest; import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.collections4.CollectionUtils; import org.slf4j.Logger; @@ -25,6 +31,7 @@ import org.springframework.util.FileCopyUtils; public class TrustHostFilter implements Filter { private static final Logger logger = LoggerFactory.getLogger(TrustHostFilter.class); + private final Map wildcardPatternCache = new ConcurrentHashMap<>(); private String notTrustHostHtmlView; @Override @@ -43,9 +50,16 @@ public class TrustHostFilter implements Filter { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { String url = WebUtils.getSourceUrl(request); String host = WebUtils.getHost(url); - assert host != null; if (isNotTrustHost(host)) { - String html = this.notTrustHostHtmlView.replace("${current_host}", host); + String currentHost = host == null ? "UNKNOWN" : host; + if (response instanceof HttpServletResponse httpServletResponse) { + httpServletResponse.setStatus(HttpServletResponse.SC_FORBIDDEN); + } + response.setCharacterEncoding(StandardCharsets.UTF_8.name()); + response.setContentType("text/html;charset=UTF-8"); + String html = this.notTrustHostHtmlView == null + ? "当前预览文件来自不受信任的站点:" + currentHost + "" + : this.notTrustHostHtmlView.replace("${current_host}", currentHost); response.getWriter().write(html); response.getWriter().close(); } else { @@ -54,9 +68,15 @@ public class TrustHostFilter implements Filter { } public boolean isNotTrustHost(String host) { + if (host == null || host.trim().isEmpty()) { + logger.warn("主机名为空或无效,拒绝访问"); + return true; + } + // 如果配置了黑名单,优先检查黑名单 - if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) { - return ConfigConstants.getNotTrustHostSet().contains(host); + if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet()) + && matchAnyPattern(host, ConfigConstants.getNotTrustHostSet())) { + return true; } // 如果配置了白名单,检查是否在白名单中 @@ -66,7 +86,7 @@ public class TrustHostFilter implements Filter { logger.debug("允许所有主机访问(通配符模式): {}", host); return false; } - return !ConfigConstants.getTrustHostSet().contains(host); + return !matchAnyPattern(host, ConfigConstants.getTrustHostSet()); } // 安全加固:默认拒绝所有未配置的主机(防止SSRF攻击) @@ -75,6 +95,136 @@ public class TrustHostFilter implements Filter { return true; } + private boolean matchAnyPattern(String host, Set hostPatterns) { + String normalizedHost = host.toLowerCase(Locale.ROOT); + for (String hostPattern : hostPatterns) { + if (matchHostPattern(normalizedHost, hostPattern)) { + return true; + } + } + return false; + } + + /** + * 支持三种匹配方式: + * 1. 精确匹配:example.com + * 2. 通配符匹配:*.example.com、192.168.* + * 3. IPv4 CIDR:192.168.0.0/16 + */ + private boolean matchHostPattern(String host, String hostPattern) { + if (hostPattern == null || hostPattern.trim().isEmpty()) { + return false; + } + String pattern = hostPattern.trim().toLowerCase(Locale.ROOT); + + if ("*".equals(pattern)) { + return true; + } + + if (pattern.contains("/")) { + return matchIpv4Cidr(host, pattern); + } + + if (pattern.contains("*")) { + if (isIpv4WildcardPattern(pattern)) { + return matchIpv4Wildcard(host, pattern); + } + Pattern compiledPattern = wildcardPatternCache.computeIfAbsent(pattern, key -> Pattern.compile(wildcardToRegex(key))); + return compiledPattern.matcher(host).matches(); + } + + return host.equals(pattern); + } + + private boolean isIpv4WildcardPattern(String pattern) { + return pattern.matches("^[0-9.*]+$") && pattern.contains("."); + } + + private boolean matchIpv4Wildcard(String host, String pattern) { + if (parseLiteralIpv4(host) == null) { + return false; + } + String[] hostParts = host.split("\\."); + String[] patternParts = pattern.split("\\."); + if (hostParts.length != 4 || patternParts.length < 1 || patternParts.length > 4) { + return false; + } + for (int i = 0; i < patternParts.length; i++) { + String p = patternParts[i]; + if ("*".equals(p)) { + continue; + } + if (!p.equals(hostParts[i])) { + return false; + } + } + return true; + } + + private String wildcardToRegex(String wildcard) { + StringBuilder regexBuilder = new StringBuilder("^"); + String[] parts = wildcard.split("\\*", -1); + for (int i = 0; i < parts.length; i++) { + regexBuilder.append(Pattern.quote(parts[i])); + if (i < parts.length - 1) { + regexBuilder.append(".*"); + } + } + regexBuilder.append("$"); + return regexBuilder.toString(); + } + + private boolean matchIpv4Cidr(String host, String cidr) { + try { + String[] parts = cidr.split("/"); + if (parts.length != 2) { + return false; + } + Long hostInt = parseLiteralIpv4(host); + Long networkInt = parseLiteralIpv4(parts[0]); + int prefixLength = Integer.parseInt(parts[1]); + + if (hostInt == null || networkInt == null || prefixLength < 0 || prefixLength > 32) { + return false; + } + + long mask = prefixLength == 0 ? 0L : (0xFFFFFFFFL << (32 - prefixLength)) & 0xFFFFFFFFL; + return (hostInt & mask) == (networkInt & mask); + } catch (NumberFormatException e) { + return false; + } + } + + /** + * 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。 + */ + private Long parseLiteralIpv4(String input) { + if (input == null || input.trim().isEmpty()) { + return null; + } + String[] parts = input.trim().split("\\."); + if (parts.length != 4) { + return null; + } + long result = 0L; + for (String part : parts) { + if (part.isEmpty() || part.length() > 3) { + return null; + } + int value; + try { + value = Integer.parseInt(part); + } catch (NumberFormatException e) { + return null; + } + if (value < 0 || value > 255) { + return null; + } + result = (result << 8) | value; + } + return result; + } + @Override public void destroy() { diff --git a/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java new file mode 100644 index 00000000..4ae8029b --- /dev/null +++ b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java @@ -0,0 +1,92 @@ +package cn.keking.web.filter; + +import cn.keking.config.ConfigConstants; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class TrustHostFilterTests { + + private final TrustHostFilter trustHostFilter = new TrustHostFilter(); + + @AfterEach + void tearDown() { + ConfigConstants.setTrustHostValue("default"); + ConfigConstants.setNotTrustHostValue("default"); + } + + @Test + void shouldBlockWildcardNotTrustHostPattern() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("192.168.*"); + + assert trustHostFilter.isNotTrustHost("192.168.1.10"); + assert !trustHostFilter.isNotTrustHost("8.8.8.8"); + assert !trustHostFilter.isNotTrustHost("192.168.evil.com"); + } + + @Test + void shouldBlockCidrNotTrustHostPattern() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("10.0.0.0/8"); + + assert trustHostFilter.isNotTrustHost("10.1.2.3"); + assert !trustHostFilter.isNotTrustHost("11.1.2.3"); + // Ensure hostnames are not matched by CIDR-based not-trust rules (no DNS resolution) + assert !trustHostFilter.isNotTrustHost("localhost"); + } + + @Test + void shouldSupportHighBitIpv4InCidrMatching() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("200.0.0.0/8"); + + assert trustHostFilter.isNotTrustHost("200.1.2.3"); + assert !trustHostFilter.isNotTrustHost("199.1.2.3"); + } + + @Test + void shouldSupportIpv4UpperBoundaryCidrMatching() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("255.255.255.255/32"); + + assert trustHostFilter.isNotTrustHost("255.255.255.255"); + assert !trustHostFilter.isNotTrustHost("255.255.255.254"); + } + + @Test + void shouldDenyWhenHostIsBlankOrNull() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("default"); + + assert trustHostFilter.isNotTrustHost(null); + assert trustHostFilter.isNotTrustHost(" "); + } + + @Test + void shouldAllowWildcardTrustHostPattern() { + ConfigConstants.setTrustHostValue("*.trusted.com"); + ConfigConstants.setNotTrustHostValue("default"); + + assert !trustHostFilter.isNotTrustHost("api.trusted.com"); + assert trustHostFilter.isNotTrustHost("api.evil.com"); + } + + @Test + void shouldKeepBlacklistHigherPriorityThanWhitelist() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("127.0.0.1,10.*"); + + assert trustHostFilter.isNotTrustHost("127.0.0.1"); + assert trustHostFilter.isNotTrustHost("10.1.2.3"); + assert !trustHostFilter.isNotTrustHost("8.8.8.8"); + } + + @Test + void shouldStillEnforceWhitelistWhenBlacklistConfigured() { + ConfigConstants.setTrustHostValue("internal.example.com"); + ConfigConstants.setNotTrustHostValue("127.0.0.1"); + + assert !trustHostFilter.isNotTrustHost("internal.example.com"); + assert trustHostFilter.isNotTrustHost("8.8.8.8"); + } +}