diff --git a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java index e661844f..00e75815 100644 --- a/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java +++ b/server/src/main/java/cn/keking/web/filter/TrustHostFilter.java @@ -4,7 +4,12 @@ import cn.keking.config.ConfigConstants; import cn.keking.utils.WebUtils; import java.io.IOException; +import java.net.Inet4Address; +import java.net.InetAddress; +import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; +import java.util.Set; +import java.util.regex.Pattern; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; @@ -56,7 +61,7 @@ public class TrustHostFilter implements Filter { public boolean isNotTrustHost(String host) { // 如果配置了黑名单,优先检查黑名单 if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) { - return ConfigConstants.getNotTrustHostSet().contains(host); + return matchAnyPattern(host, ConfigConstants.getNotTrustHostSet()); } // 如果配置了白名单,检查是否在白名单中 @@ -66,7 +71,7 @@ public class TrustHostFilter implements Filter { logger.debug("允许所有主机访问(通配符模式): {}", host); return false; } - return !ConfigConstants.getTrustHostSet().contains(host); + return !matchAnyPattern(host, ConfigConstants.getTrustHostSet()); } // 安全加固:默认拒绝所有未配置的主机(防止SSRF攻击) @@ -75,6 +80,88 @@ public class TrustHostFilter implements Filter { return true; } + private boolean matchAnyPattern(String host, Set hostPatterns) { + String normalizedHost = host == null ? "" : host.toLowerCase(); + for (String hostPattern : hostPatterns) { + if (matchHostPattern(normalizedHost, hostPattern)) { + return true; + } + } + return false; + } + + /** + * 支持三种匹配方式: + * 1. 精确匹配:example.com + * 2. 通配符匹配:*.example.com、192.168.* + * 3. IPv4 CIDR:192.168.0.0/16 + */ + private boolean matchHostPattern(String host, String hostPattern) { + if (hostPattern == null || hostPattern.trim().isEmpty()) { + return false; + } + String pattern = hostPattern.trim().toLowerCase(); + + if ("*".equals(pattern)) { + return true; + } + + if (pattern.contains("/")) { + return matchIpv4Cidr(host, pattern); + } + + if (pattern.contains("*")) { + String regex = wildcardToRegex(pattern); + return host.matches(regex); + } + + return host.equals(pattern); + } + + private String wildcardToRegex(String wildcard) { + StringBuilder regexBuilder = new StringBuilder("^"); + String[] parts = wildcard.split("\\*", -1); + for (int i = 0; i < parts.length; i++) { + regexBuilder.append(Pattern.quote(parts[i])); + if (i < parts.length - 1) { + regexBuilder.append(".*"); + } + } + regexBuilder.append("$"); + return regexBuilder.toString(); + } + + private boolean matchIpv4Cidr(String host, String cidr) { + try { + String[] parts = cidr.split("/"); + if (parts.length != 2) { + return false; + } + InetAddress hostAddress = InetAddress.getByName(host); + InetAddress networkAddress = InetAddress.getByName(parts[0]); + int prefixLength = Integer.parseInt(parts[1]); + + if (!(hostAddress instanceof Inet4Address) || !(networkAddress instanceof Inet4Address) || prefixLength < 0 || prefixLength > 32) { + return false; + } + + int mask = prefixLength == 0 ? 0 : -1 << (32 - prefixLength); + int hostInt = inet4ToInt(hostAddress); + int networkInt = inet4ToInt(networkAddress); + return (hostInt & mask) == (networkInt & mask); + } catch (UnknownHostException | NumberFormatException e) { + return false; + } + } + + private int inet4ToInt(InetAddress address) { + byte[] bytes = address.getAddress(); + return ((bytes[0] & 0xFF) << 24) + | ((bytes[1] & 0xFF) << 16) + | ((bytes[2] & 0xFF) << 8) + | (bytes[3] & 0xFF); + } + @Override public void destroy() { diff --git a/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java new file mode 100644 index 00000000..5b97e94b --- /dev/null +++ b/server/src/test/java/cn/keking/web/filter/TrustHostFilterTests.java @@ -0,0 +1,53 @@ +package cn.keking.web.filter; + +import cn.keking.config.ConfigConstants; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +public class TrustHostFilterTests { + + private final TrustHostFilter trustHostFilter = new TrustHostFilter(); + + @AfterEach + void tearDown() { + ConfigConstants.setTrustHostValue("default"); + ConfigConstants.setNotTrustHostValue("default"); + } + + @Test + void shouldBlockWildcardNotTrustHostPattern() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("192.168.*"); + + assert trustHostFilter.isNotTrustHost("192.168.1.10"); + assert !trustHostFilter.isNotTrustHost("8.8.8.8"); + } + + @Test + void shouldBlockCidrNotTrustHostPattern() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("10.0.0.0/8"); + + assert trustHostFilter.isNotTrustHost("10.1.2.3"); + assert !trustHostFilter.isNotTrustHost("11.1.2.3"); + } + + @Test + void shouldAllowWildcardTrustHostPattern() { + ConfigConstants.setTrustHostValue("*.trusted.com"); + ConfigConstants.setNotTrustHostValue("default"); + + assert !trustHostFilter.isNotTrustHost("api.trusted.com"); + assert trustHostFilter.isNotTrustHost("api.evil.com"); + } + + @Test + void shouldKeepBlacklistHigherPriorityThanWhitelist() { + ConfigConstants.setTrustHostValue("*"); + ConfigConstants.setNotTrustHostValue("127.0.0.1,10.*"); + + assert trustHostFilter.isNotTrustHost("127.0.0.1"); + assert trustHostFilter.isNotTrustHost("10.1.2.3"); + assert !trustHostFilter.isNotTrustHost("8.8.8.8"); + } +}