fix(security): harden host matching against null and DNS rebinding

This commit is contained in:
kl
2026-03-03 13:51:30 +08:00
parent 1aa477bdf8
commit 084bdc9012
2 changed files with 55 additions and 20 deletions

View File

@@ -4,9 +4,8 @@ import cn.keking.config.ConfigConstants;
import cn.keking.utils.WebUtils;
import java.io.IOException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.regex.Pattern;
@@ -30,6 +29,7 @@ import org.springframework.util.FileCopyUtils;
public class TrustHostFilter implements Filter {
private static final Logger logger = LoggerFactory.getLogger(TrustHostFilter.class);
private final Map<String, Pattern> wildcardPatternCache = new ConcurrentHashMap<>();
private String notTrustHostHtmlView;
@Override
@@ -48,9 +48,8 @@ public class TrustHostFilter implements Filter {
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
String url = WebUtils.getSourceUrl(request);
String host = WebUtils.getHost(url);
assert host != null;
if (isNotTrustHost(host)) {
String html = this.notTrustHostHtmlView.replace("${current_host}", host);
String html = this.notTrustHostHtmlView.replace("${current_host}", host == null ? "UNKNOWN" : host);
response.getWriter().write(html);
response.getWriter().close();
} else {
@@ -59,6 +58,11 @@ public class TrustHostFilter implements Filter {
}
public boolean isNotTrustHost(String host) {
if (host == null || host.trim().isEmpty()) {
logger.warn("主机名为空或无效,拒绝访问");
return true;
}
// 如果配置了黑名单,优先检查黑名单
if (CollectionUtils.isNotEmpty(ConfigConstants.getNotTrustHostSet())) {
return matchAnyPattern(host, ConfigConstants.getNotTrustHostSet());
@@ -81,7 +85,7 @@ public class TrustHostFilter implements Filter {
}
private boolean matchAnyPattern(String host, Set<String> hostPatterns) {
String normalizedHost = host == null ? "" : host.toLowerCase();
String normalizedHost = host.toLowerCase();
for (String hostPattern : hostPatterns) {
if (matchHostPattern(normalizedHost, hostPattern)) {
return true;
@@ -111,8 +115,8 @@ public class TrustHostFilter implements Filter {
}
if (pattern.contains("*")) {
String regex = wildcardToRegex(pattern);
return host.matches(regex);
Pattern compiledPattern = wildcardPatternCache.computeIfAbsent(pattern, key -> Pattern.compile(wildcardToRegex(key)));
return compiledPattern.matcher(host).matches();
}
return host.equals(pattern);
@@ -137,29 +141,49 @@ public class TrustHostFilter implements Filter {
if (parts.length != 2) {
return false;
}
InetAddress hostAddress = InetAddress.getByName(host);
InetAddress networkAddress = InetAddress.getByName(parts[0]);
int hostInt = parseLiteralIpv4(host);
int networkInt = parseLiteralIpv4(parts[0]);
int prefixLength = Integer.parseInt(parts[1]);
if (!(hostAddress instanceof Inet4Address) || !(networkAddress instanceof Inet4Address) || prefixLength < 0 || prefixLength > 32) {
if (hostInt < 0 || networkInt < 0 || prefixLength < 0 || prefixLength > 32) {
return false;
}
int mask = prefixLength == 0 ? 0 : -1 << (32 - prefixLength);
int hostInt = inet4ToInt(hostAddress);
int networkInt = inet4ToInt(networkAddress);
return (hostInt & mask) == (networkInt & mask);
} catch (UnknownHostException | NumberFormatException e) {
} catch (NumberFormatException e) {
return false;
}
}
private int inet4ToInt(InetAddress address) {
byte[] bytes = address.getAddress();
return ((bytes[0] & 0xFF) << 24)
| ((bytes[1] & 0xFF) << 16)
| ((bytes[2] & 0xFF) << 8)
| (bytes[3] & 0xFF);
/**
* 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。
*/
private int parseLiteralIpv4(String input) {
if (input == null || input.trim().isEmpty()) {
return -1;
}
String[] parts = input.trim().split("\\.");
if (parts.length != 4) {
return -1;
}
int result = 0;
for (String part : parts) {
if (part.isEmpty() || part.length() > 3) {
return -1;
}
int value;
try {
value = Integer.parseInt(part);
} catch (NumberFormatException e) {
return -1;
}
if (value < 0 || value > 255) {
return -1;
}
result = (result << 8) | value;
}
return result;
}
@Override

View File

@@ -30,6 +30,17 @@ public class TrustHostFilterTests {
assert trustHostFilter.isNotTrustHost("10.1.2.3");
assert !trustHostFilter.isNotTrustHost("11.1.2.3");
// Ensure hostnames are not matched by CIDR-based not-trust rules (no DNS resolution)
assert !trustHostFilter.isNotTrustHost("localhost");
}
@Test
void shouldDenyWhenHostIsBlankOrNull() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("default");
assert trustHostFilter.isNotTrustHost(null);
assert trustHostFilter.isNotTrustHost(" ");
}
@Test