From 084bdc9012c0e4609094e4e4fe9ec37b4f3a18f6 Mon Sep 17 00:00:00 2001 From: kl Date: Tue, 3 Mar 2026 13:51:30 +0800 Subject: [PATCH] fix(security): harden host matching against null and DNS rebinding --- .../cn/keking/web/filter/TrustHostFilter.java | 64 +++++++++++++------ .../web/filter/TrustHostFilterTests.java | 11 ++++ 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java index 00e75815..ca27344a 100644 --- a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java +++ b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java @@ -4,9 +4,8 @@ import cn.keking.config.ConfigConstants; import cn.keking.utils.WebUtils; import java.io.IOException; -import java.net.Inet4Address; -import java.net.InetAddress; -import java.net.UnknownHostException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.nio.charset.StandardCharsets; import java.util.Set; import java.util.regex.Pattern; @@ -30,6 +29,7 @@ import org.springframework.util.FileCopyUtils; public class TrustHostFilter implements Filter { private static final Logger logger = LoggerFactory.getLogger(TrustHostFilter.class); + private final Map wildcardPatternCache = new ConcurrentHashMap<>(); private String notTrustHostHtmlView; @Override @@ -48,9 +48,8 @@ public class TrustHostFilter implements Filter { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { String url = WebUtils.getSourceUrl(request); String host = WebUtils.getHost(url); - assert host != null; if (isNotTrustHost(host)) { - String html = this.notTrustHostHtmlView.replace("${current_host}", host); + String html = this.notTrustHostHtmlView.replace("${current_host}", host == null ? "UNKNOWN" : host); response.getWriter().write(html); response.getWriter().close(); } else { @@ -59,6 +58,11 @@ public class TrustHostFilter implements Filter { } public boolean isNotTrustHost(String host) { + if (host == null || host.trim().isEmpty()) { + logger.warn("主机名为空或无效,拒绝访问"); + return true; + } + // 如果配置了黑名单,优先检查黑名单 if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) { return matchAnyPattern(host, ConfigConstants.getNotTrustHostSet()); @@ -81,7 +85,7 @@ public class TrustHostFilter implements Filter { } private boolean matchAnyPattern(String host, Set hostPatterns) { - String normalizedHost = host == null ? "" : host.toLowerCase(); + String normalizedHost = host.toLowerCase(); for (String hostPattern : hostPatterns) { if (matchHostPattern(normalizedHost, hostPattern)) { return true; @@ -111,8 +115,8 @@ public class TrustHostFilter implements Filter { } if (pattern.contains("*")) { - String regex = wildcardToRegex(pattern); - return host.matches(regex); + Pattern compiledPattern = wildcardPatternCache.computeIfAbsent(pattern, key -> Pattern.compile(wildcardToRegex(key))); + return compiledPattern.matcher(host).matches(); } return host.equals(pattern); @@ -137,29 +141,49 @@ public class TrustHostFilter implements Filter { if (parts.length != 2) { return false; } - InetAddress hostAddress = InetAddress.getByName(host); - InetAddress networkAddress = InetAddress.getByName(parts[0]); + int hostInt = parseLiteralIpv4(host); + int networkInt = parseLiteralIpv4(parts[0]); int prefixLength = Integer.parseInt(parts[1]); - if (!(hostAddress instanceof Inet4Address) || !(networkAddress instanceof Inet4Address) || prefixLength < 0 || prefixLength > 32) { + if (hostInt < 0 || networkInt < 0 || prefixLength < 0 || prefixLength > 32) { return false; } int mask = prefixLength == 0 ? 0 : -1 << (32 - prefixLength); - int hostInt = inet4ToInt(hostAddress); - int networkInt = inet4ToInt(networkAddress); return (hostInt & mask) == (networkInt & mask); - } catch (UnknownHostException | NumberFormatException e) { + } catch (NumberFormatException e) { return false; } } - private int inet4ToInt(InetAddress address) { - byte[] bytes = address.getAddress(); - return ((bytes[0] & 0xFF) << 24) - | ((bytes[1] & 0xFF) << 16) - | ((bytes[2] & 0xFF) << 8) - | (bytes[3] & 0xFF); + /** + * 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。 + */ + private int parseLiteralIpv4(String input) { + if (input == null || input.trim().isEmpty()) { + return -1; + } + String[] parts = input.trim().split("\\."); + if (parts.length != 4) { + return -1; + } + int result = 0; + for (String part : parts) { + if (part.isEmpty() || part.length() > 3) { + return -1; + } + int value; + try { + value = Integer.parseInt(part); + } catch (NumberFormatException e) { + return -1; + } + if (value < 0 || value > 255) { + return -1; + } + result = (result << 8) | value; + } + return result; } @Override diff --git a/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java index 5b97e94b..30eed014 100644 --- a/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java +++ b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java @@ -30,6 +30,17 @@ public class TrustHostFilterTests { assert trustHostFilter.isNotTrustHost("10.1.2.3"); assert !trustHostFilter.isNotTrustHost("11.1.2.3"); + // Ensure hostnames are not matched by CIDR-based not-trust rules (no DNS resolution) + assert !trustHostFilter.isNotTrustHost("localhost"); + } + + @Test + void shouldDenyWhenHostIsBlankOrNull() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("default"); + + assert trustHostFilter.isNotTrustHost(null); + assert trustHostFilter.isNotTrustHost(" "); } @Test