fix(security): handle ipv4 unsigned range and deny template fallback

This commit is contained in:
kl
2026-03-03 14:08:39 +08:00
parent 084bdc9012
commit 6bcefedd08
2 changed files with 24 additions and 12 deletions

View File

@@ -49,7 +49,10 @@ public class TrustHostFilter implements Filter {
String url = WebUtils.getSourceUrl(request); String url = WebUtils.getSourceUrl(request);
String host = WebUtils.getHost(url); String host = WebUtils.getHost(url);
if (isNotTrustHost(host)) { if (isNotTrustHost(host)) {
String html = this.notTrustHostHtmlView.replace("${current_host}", host == null ? "UNKNOWN" : host); String currentHost = host == null ? "UNKNOWN" : host;
String html = this.notTrustHostHtmlView == null
? "<html><body>当前预览文件来自不受信任的站点:" + currentHost + "</body></html>"
: this.notTrustHostHtmlView.replace("${current_host}", currentHost);
response.getWriter().write(html); response.getWriter().write(html);
response.getWriter().close(); response.getWriter().close();
} else { } else {
@@ -141,15 +144,15 @@ public class TrustHostFilter implements Filter {
if (parts.length != 2) { if (parts.length != 2) {
return false; return false;
} }
int hostInt = parseLiteralIpv4(host); Long hostInt = parseLiteralIpv4(host);
int networkInt = parseLiteralIpv4(parts[0]); Long networkInt = parseLiteralIpv4(parts[0]);
int prefixLength = Integer.parseInt(parts[1]); int prefixLength = Integer.parseInt(parts[1]);
if (hostInt < 0 || networkInt < 0 || prefixLength < 0 || prefixLength > 32) { if (hostInt == null || networkInt == null || prefixLength < 0 || prefixLength > 32) {
return false; return false;
} }
int mask = prefixLength == 0 ? 0 : -1 << (32 - prefixLength); long mask = prefixLength == 0 ? 0L : (0xFFFFFFFFL << (32 - prefixLength)) & 0xFFFFFFFFL;
return (hostInt & mask) == (networkInt & mask); return (hostInt & mask) == (networkInt & mask);
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
return false; return false;
@@ -159,27 +162,27 @@ public class TrustHostFilter implements Filter {
/** /**
* 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。 * 仅解析字面量 IPv4 地址(不做 DNS 解析),防止 DNS rebinding/TOCTOU 风险。
*/ */
private int parseLiteralIpv4(String input) { private Long parseLiteralIpv4(String input) {
if (input == null || input.trim().isEmpty()) { if (input == null || input.trim().isEmpty()) {
return -1; return null;
} }
String[] parts = input.trim().split("\\."); String[] parts = input.trim().split("\\.");
if (parts.length != 4) { if (parts.length != 4) {
return -1; return null;
} }
int result = 0; long result = 0L;
for (String part : parts) { for (String part : parts) {
if (part.isEmpty() || part.length() > 3) { if (part.isEmpty() || part.length() > 3) {
return -1; return null;
} }
int value; int value;
try { try {
value = Integer.parseInt(part); value = Integer.parseInt(part);
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
return -1; return null;
} }
if (value < 0 || value > 255) { if (value < 0 || value > 255) {
return -1; return null;
} }
result = (result << 8) | value; result = (result << 8) | value;
} }

View File

@@ -34,6 +34,15 @@ public class TrustHostFilterTests {
assert !trustHostFilter.isNotTrustHost("localhost"); assert !trustHostFilter.isNotTrustHost("localhost");
} }
@Test
void shouldSupportHighBitIpv4InCidrMatching() {
ConfigConstants.setTrustHostValue("*");
ConfigConstants.setNotTrustHostValue("200.0.0.0/8");
assert trustHostFilter.isNotTrustHost("200.1.2.3");
assert !trustHostFilter.isNotTrustHost("199.1.2.3");
}
@Test @Test
void shouldDenyWhenHostIsBlankOrNull() { void shouldDenyWhenHostIsBlankOrNull() {
ConfigConstants.setTrustHostValue("*"); ConfigConstants.setTrustHostValue("*");