implement Milestone 9: Network HTTP with SSRF prevention

This commit is contained in:
2026-01-18 15:24:56 +01:00
parent a94e0d5d63
commit c0baa673b8
8 changed files with 1501 additions and 1 deletions

View File

@@ -0,0 +1,388 @@
#include "http_validator.h"
#include <algorithm>
#include <cctype>
#include <regex>
#include <array>
namespace mosis {
HttpValidator::HttpValidator()
: m_domain_restrictions_enabled(false)
{
}
void HttpValidator::SetAllowedDomains(const std::vector<std::string>& domains) {
m_allowed_domains = domains;
m_domain_restrictions_enabled = !domains.empty();
}
void HttpValidator::ClearDomainRestrictions() {
m_allowed_domains.clear();
m_domain_restrictions_enabled = false;
}
std::optional<ParsedUrl> HttpValidator::Validate(const std::string& url, std::string& error) {
// Parse URL
auto parsed = ParseUrl(url);
if (!parsed) {
error = "Invalid URL format";
return std::nullopt;
}
// Must be HTTPS
if (parsed->scheme != "https") {
error = "HTTPS required, got: " + parsed->scheme;
return std::nullopt;
}
// Check for localhost names
if (IsLocalhostName(parsed->host)) {
error = "localhost blocked for security";
return std::nullopt;
}
// Check for metadata hostnames
if (IsMetadataHostname(parsed->host)) {
error = "Cloud metadata hostname blocked for security";
return std::nullopt;
}
// Check if it's an IP address and validate
if (parsed->is_ip_address) {
if (IsBlockedIP(parsed->host)) {
error = "IP address blocked: private, localhost, or metadata endpoint";
return std::nullopt;
}
}
// Check domain whitelist
if (m_domain_restrictions_enabled && !IsDomainAllowed(parsed->host)) {
error = "Domain not in allowed list: " + parsed->host;
return std::nullopt;
}
return parsed;
}
bool HttpValidator::IsIPv4Address(const std::string& host) {
// Simple IPv4 pattern: numbers and dots
if (host.empty()) return false;
int dots = 0;
int num_start = 0;
for (size_t i = 0; i <= host.length(); i++) {
if (i == host.length() || host[i] == '.') {
if (i == (size_t)num_start) return false; // Empty segment
std::string segment = host.substr(num_start, i - num_start);
// Check if segment is a valid number 0-255
if (segment.empty() || segment.length() > 3) return false;
for (char c : segment) {
if (!std::isdigit(static_cast<unsigned char>(c))) return false;
}
int val = std::stoi(segment);
if (val < 0 || val > 255) return false;
if (i < host.length()) {
dots++;
num_start = static_cast<int>(i) + 1;
}
}
}
return dots == 3;
}
bool HttpValidator::IsIPv6Address(const std::string& host) {
// IPv6 addresses in URLs are enclosed in brackets: [::1]
if (host.length() < 2) return false;
if (host.front() == '[' && host.back() == ']') {
return true; // Simplified check - bracket notation means IPv6
}
// Also check for raw IPv6 (contains colons, no dots or limited dots)
int colons = std::count(host.begin(), host.end(), ':');
int dots = std::count(host.begin(), host.end(), '.');
return colons >= 2 && dots <= 3; // IPv6 has multiple colons
}
bool HttpValidator::IsPrivateIPv4(const std::string& ip) {
// Parse IPv4 octets
std::array<int, 4> octets{};
if (sscanf(ip.c_str(), "%d.%d.%d.%d", &octets[0], &octets[1], &octets[2], &octets[3]) != 4) {
return false;
}
// 0.0.0.0 - all interfaces
if (octets[0] == 0 && octets[1] == 0 && octets[2] == 0 && octets[3] == 0) {
return true;
}
// 127.0.0.0/8 - loopback
if (octets[0] == 127) {
return true;
}
// 10.0.0.0/8 - private Class A
if (octets[0] == 10) {
return true;
}
// 172.16.0.0/12 - private Class B (172.16.0.0 - 172.31.255.255)
if (octets[0] == 172 && octets[1] >= 16 && octets[1] <= 31) {
return true;
}
// 192.168.0.0/16 - private Class C
if (octets[0] == 192 && octets[1] == 168) {
return true;
}
// 169.254.0.0/16 - link-local
if (octets[0] == 169 && octets[1] == 254) {
return true;
}
return false;
}
bool HttpValidator::IsPrivateIPv6(const std::string& ip) {
std::string addr = ip;
// Remove brackets if present
if (!addr.empty() && addr.front() == '[') addr = addr.substr(1);
if (!addr.empty() && addr.back() == ']') addr.pop_back();
// Convert to lowercase for comparison
std::transform(addr.begin(), addr.end(), addr.begin(),
[](unsigned char c) { return std::tolower(c); });
// ::1 - loopback
if (addr == "::1" || addr == "0:0:0:0:0:0:0:1") {
return true;
}
// :: - unspecified (equivalent to 0.0.0.0)
if (addr == "::" || addr == "0:0:0:0:0:0:0:0") {
return true;
}
// fc00::/7 - unique local addresses (fc00:: to fdff::)
if (addr.length() >= 2) {
char first = addr[0];
char second = addr.length() > 1 ? addr[1] : '0';
if (first == 'f' && (second == 'c' || second == 'd')) {
return true;
}
}
// fe80::/10 - link-local
if (addr.rfind("fe80:", 0) == 0 || addr.rfind("fe8", 0) == 0 ||
addr.rfind("fe9", 0) == 0 || addr.rfind("fea", 0) == 0 ||
addr.rfind("feb", 0) == 0) {
return true;
}
return false;
}
bool HttpValidator::IsLocalhostIP(const std::string& host) {
// IPv4 localhost
if (IsIPv4Address(host)) {
std::array<int, 4> octets{};
if (sscanf(host.c_str(), "%d.%d.%d.%d", &octets[0], &octets[1], &octets[2], &octets[3]) == 4) {
return octets[0] == 127;
}
}
// IPv6 localhost
std::string addr = host;
if (!addr.empty() && addr.front() == '[') addr = addr.substr(1);
if (!addr.empty() && addr.back() == ']') addr.pop_back();
std::transform(addr.begin(), addr.end(), addr.begin(),
[](unsigned char c) { return std::tolower(c); });
return addr == "::1" || addr == "0:0:0:0:0:0:0:1";
}
bool HttpValidator::IsMetadataIP(const std::string& host) {
// AWS/Azure/GCP metadata endpoint
if (host == "169.254.169.254") {
return true;
}
// GCP alternate
if (host == "metadata.google.internal") {
return true;
}
return false;
}
bool HttpValidator::IsBlockedIP(const std::string& host) {
if (IsIPv4Address(host)) {
return IsPrivateIPv4(host) || IsMetadataIP(host);
}
if (IsIPv6Address(host)) {
return IsPrivateIPv6(host);
}
return false;
}
bool HttpValidator::IsDomainAllowed(const std::string& host) {
if (!m_domain_restrictions_enabled) {
return true;
}
std::string lower_host = host;
std::transform(lower_host.begin(), lower_host.end(), lower_host.begin(),
[](unsigned char c) { return std::tolower(c); });
for (const auto& domain : m_allowed_domains) {
std::string lower_domain = domain;
std::transform(lower_domain.begin(), lower_domain.end(), lower_domain.begin(),
[](unsigned char c) { return std::tolower(c); });
// Exact match
if (lower_host == lower_domain) {
return true;
}
// Subdomain match (e.g., "api.example.com" matches "example.com")
if (lower_host.length() > lower_domain.length()) {
size_t pos = lower_host.length() - lower_domain.length();
if (lower_host[pos - 1] == '.' &&
lower_host.substr(pos) == lower_domain) {
return true;
}
}
}
return false;
}
bool HttpValidator::IsLocalhostName(const std::string& host) {
std::string lower = host;
std::transform(lower.begin(), lower.end(), lower.begin(),
[](unsigned char c) { return std::tolower(c); });
// Common localhost names
if (lower == "localhost") return true;
if (lower == "localhost.localdomain") return true;
// Ends with .localhost
if (lower.length() > 10 && lower.substr(lower.length() - 10) == ".localhost") {
return true;
}
return false;
}
bool HttpValidator::IsMetadataHostname(const std::string& host) {
std::string lower = host;
std::transform(lower.begin(), lower.end(), lower.begin(),
[](unsigned char c) { return std::tolower(c); });
// GCP metadata
if (lower == "metadata.google.internal") return true;
if (lower == "metadata") return true;
// Azure metadata
if (lower == "metadata.azure.internal") return true;
return false;
}
std::optional<ParsedUrl> HttpValidator::ParseUrl(const std::string& url) {
ParsedUrl result;
result.port = 443; // Default HTTPS port
result.is_ip_address = false;
// Find scheme
size_t scheme_end = url.find("://");
if (scheme_end == std::string::npos) {
return std::nullopt;
}
result.scheme = url.substr(0, scheme_end);
std::transform(result.scheme.begin(), result.scheme.end(), result.scheme.begin(),
[](unsigned char c) { return std::tolower(c); });
// Start of authority
size_t auth_start = scheme_end + 3;
if (auth_start >= url.length()) {
return std::nullopt;
}
// Find end of authority (path starts with /)
size_t path_start = url.find('/', auth_start);
std::string authority;
if (path_start == std::string::npos) {
authority = url.substr(auth_start);
result.path = "/";
} else {
authority = url.substr(auth_start, path_start - auth_start);
// Find query string
size_t query_start = url.find('?', path_start);
if (query_start != std::string::npos) {
result.path = url.substr(path_start, query_start - path_start);
result.query = url.substr(query_start);
} else {
result.path = url.substr(path_start);
}
}
if (authority.empty()) {
return std::nullopt;
}
// Parse authority for host and port
// Handle IPv6 addresses in brackets
if (authority[0] == '[') {
size_t bracket_end = authority.find(']');
if (bracket_end == std::string::npos) {
return std::nullopt; // Malformed IPv6
}
result.host = authority.substr(0, bracket_end + 1);
result.is_ip_address = true;
// Check for port after bracket
if (bracket_end + 1 < authority.length()) {
if (authority[bracket_end + 1] == ':') {
std::string port_str = authority.substr(bracket_end + 2);
try {
result.port = static_cast<uint16_t>(std::stoi(port_str));
} catch (...) {
return std::nullopt;
}
}
}
} else {
// Regular host or IPv4
size_t port_pos = authority.rfind(':');
if (port_pos != std::string::npos) {
result.host = authority.substr(0, port_pos);
std::string port_str = authority.substr(port_pos + 1);
try {
result.port = static_cast<uint16_t>(std::stoi(port_str));
} catch (...) {
return std::nullopt;
}
} else {
result.host = authority;
}
// Check if it's an IP address
result.is_ip_address = IsIPv4Address(result.host) || IsIPv6Address(result.host);
}
// Default port based on scheme
if (result.scheme == "https" && result.port == 0) {
result.port = 443;
} else if (result.scheme == "http" && result.port == 0) {
result.port = 80;
}
return result;
}
} // namespace mosis

View File

@@ -0,0 +1,55 @@
#pragma once
#include <string>
#include <vector>
#include <optional>
#include <cstdint>
namespace mosis {
struct ParsedUrl {
std::string scheme; // "https"
std::string host; // "api.example.com" or "192.0.2.1"
uint16_t port; // 443
std::string path; // "/api/data"
std::string query; // "?key=value"
bool is_ip_address; // true if host is IP literal
};
class HttpValidator {
public:
HttpValidator();
// Set allowed domains (from app manifest)
void SetAllowedDomains(const std::vector<std::string>& domains);
// Clear domain restrictions (for testing)
void ClearDomainRestrictions();
// Validate URL
// Returns parsed URL on success, sets error on failure
std::optional<ParsedUrl> Validate(const std::string& url, std::string& error);
private:
std::vector<std::string> m_allowed_domains;
bool m_domain_restrictions_enabled;
// IP address validation
bool IsIPv4Address(const std::string& host);
bool IsIPv6Address(const std::string& host);
bool IsPrivateIPv4(const std::string& ip);
bool IsPrivateIPv6(const std::string& ip);
bool IsLocalhostIP(const std::string& host);
bool IsMetadataIP(const std::string& host);
bool IsBlockedIP(const std::string& host);
// Domain validation
bool IsDomainAllowed(const std::string& host);
bool IsLocalhostName(const std::string& host);
bool IsMetadataHostname(const std::string& host);
// URL parsing
std::optional<ParsedUrl> ParseUrl(const std::string& url);
};
} // namespace mosis

View File

@@ -0,0 +1,249 @@
#include "network_manager.h"
#include <lua.hpp>
#include <algorithm>
namespace mosis {
NetworkManager::NetworkManager(const std::string& app_id, const NetworkLimits& limits)
: m_app_id(app_id)
, m_limits(limits)
, m_mock_mode(true)
{
}
NetworkManager::~NetworkManager() {
}
void NetworkManager::SetAllowedDomains(const std::vector<std::string>& domains) {
m_validator.SetAllowedDomains(domains);
}
void NetworkManager::ClearDomainRestrictions() {
m_validator.ClearDomainRestrictions();
}
bool NetworkManager::ValidateRequest(const HttpRequest& request, std::string& error) {
// Validate URL
auto parsed = m_validator.Validate(request.url, error);
if (!parsed) {
return false;
}
// Validate method
std::string method = request.method;
std::transform(method.begin(), method.end(), method.begin(),
[](unsigned char c) { return std::toupper(c); });
static const std::vector<std::string> allowed_methods = {
"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"
};
bool method_valid = false;
for (const auto& m : allowed_methods) {
if (method == m) {
method_valid = true;
break;
}
}
if (!method_valid) {
error = "Invalid HTTP method: " + request.method;
return false;
}
// Validate request body size
if (request.body.size() > m_limits.max_request_body) {
error = "Request body too large: " + std::to_string(request.body.size()) +
" bytes (max " + std::to_string(m_limits.max_request_body) + ")";
return false;
}
// Validate timeout
if (request.timeout_ms > m_limits.max_timeout_ms) {
error = "Timeout too large: " + std::to_string(request.timeout_ms) +
"ms (max " + std::to_string(m_limits.max_timeout_ms) + "ms)";
return false;
}
// Check concurrent request limit
if (m_active_requests.load() >= m_limits.max_concurrent_requests) {
error = "Too many concurrent requests (max " +
std::to_string(m_limits.max_concurrent_requests) + ")";
return false;
}
return true;
}
HttpResponse NetworkManager::Request(const HttpRequest& request, std::string& error) {
HttpResponse response;
// Validate the request
if (!ValidateRequest(request, error)) {
response.error = error;
return response;
}
// In mock mode, we don't actually make network calls
// This is for testing the validation logic
if (m_mock_mode) {
error = "Network requests disabled in mock mode";
response.error = error;
return response;
}
// Track active requests
m_active_requests++;
// In a real implementation, we would make the HTTP request here
// For now, just return an error indicating no network implementation
error = "Network requests not implemented on this platform";
response.error = error;
m_active_requests--;
return response;
}
int NetworkManager::GetActiveRequestCount() const {
return m_active_requests.load();
}
// Lua API implementation
// Get NetworkManager from upvalue
static NetworkManager* GetManager(lua_State* L) {
return static_cast<NetworkManager*>(lua_touserdata(L, lua_upvalueindex(1)));
}
// network.request(options) -> response, error
static int L_network_request(lua_State* L) {
NetworkManager* manager = GetManager(L);
if (!manager) {
lua_pushnil(L);
lua_pushstring(L, "NetworkManager not available");
return 2;
}
// Expect table argument
luaL_checktype(L, 1, LUA_TTABLE);
HttpRequest request;
// Get URL (required)
lua_getfield(L, 1, "url");
if (!lua_isstring(L, -1)) {
lua_pushnil(L);
lua_pushstring(L, "url is required and must be a string");
return 2;
}
request.url = lua_tostring(L, -1);
lua_pop(L, 1);
// Get method (optional, default GET)
lua_getfield(L, 1, "method");
if (lua_isstring(L, -1)) {
request.method = lua_tostring(L, -1);
}
lua_pop(L, 1);
// Get headers (optional)
lua_getfield(L, 1, "headers");
if (lua_istable(L, -1)) {
lua_pushnil(L);
while (lua_next(L, -2) != 0) {
if (lua_isstring(L, -2) && lua_isstring(L, -1)) {
request.headers[lua_tostring(L, -2)] = lua_tostring(L, -1);
}
lua_pop(L, 1);
}
}
lua_pop(L, 1);
// Get body (optional)
lua_getfield(L, 1, "body");
if (lua_isstring(L, -1)) {
size_t len;
const char* body = lua_tolstring(L, -1, &len);
request.body = std::string(body, len);
}
lua_pop(L, 1);
// Get timeout (optional)
lua_getfield(L, 1, "timeout");
if (lua_isnumber(L, -1)) {
request.timeout_ms = static_cast<int>(lua_tointeger(L, -1));
}
lua_pop(L, 1);
// Make request
std::string error;
HttpResponse response = manager->Request(request, error);
if (!error.empty()) {
lua_pushnil(L);
lua_pushstring(L, error.c_str());
return 2;
}
// Return response as table
lua_newtable(L);
lua_pushinteger(L, response.status_code);
lua_setfield(L, -2, "status");
lua_pushstring(L, response.body.c_str());
lua_setfield(L, -2, "body");
// Headers table
lua_newtable(L);
for (const auto& [key, value] : response.headers) {
lua_pushstring(L, value.c_str());
lua_setfield(L, -2, key.c_str());
}
lua_setfield(L, -2, "headers");
if (!response.error.empty()) {
lua_pushstring(L, response.error.c_str());
lua_setfield(L, -2, "error");
}
return 1; // Return response table
}
// Helper to set a global in the real _G (bypassing any proxy)
static void SetGlobalInRealG(lua_State* L, const char* name) {
// Stack: value to set as global
lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
// Check if it's a proxy with __index pointing to real _G
if (lua_getmetatable(L, -1)) {
lua_getfield(L, -1, "__index");
if (lua_istable(L, -1)) {
// This is the real _G, set our value there
lua_pushvalue(L, -4); // Push the value
lua_setfield(L, -2, name);
lua_pop(L, 4); // Pop __index, metatable, proxy, (value already consumed)
return;
}
lua_pop(L, 2); // Pop __index and metatable
}
// No proxy, set directly
lua_pushvalue(L, -2); // Push the value
lua_setfield(L, -2, name);
lua_pop(L, 2); // Pop globals table and original value
}
void RegisterNetworkAPI(lua_State* L, NetworkManager* manager) {
// Create network table
lua_newtable(L);
// Add request function with manager as upvalue
lua_pushlightuserdata(L, manager);
lua_pushcclosure(L, L_network_request, 1);
lua_setfield(L, -2, "request");
// Set as global
SetGlobalInRealG(L, "network");
}
} // namespace mosis

View File

@@ -0,0 +1,76 @@
#pragma once
#include <string>
#include <vector>
#include <map>
#include <mutex>
#include <atomic>
#include "http_validator.h"
struct lua_State;
namespace mosis {
struct HttpRequest {
std::string url;
std::string method = "GET";
std::map<std::string, std::string> headers;
std::string body;
int timeout_ms = 30000;
};
struct HttpResponse {
int status_code = 0;
std::map<std::string, std::string> headers;
std::string body;
std::string error;
};
struct NetworkLimits {
size_t max_request_body = 10 * 1024 * 1024; // 10 MB
size_t max_response_body = 50 * 1024 * 1024; // 50 MB
int max_timeout_ms = 60000; // 60 seconds
int max_concurrent_requests = 6;
int default_timeout_ms = 30000;
};
class NetworkManager {
public:
NetworkManager(const std::string& app_id, const NetworkLimits& limits = NetworkLimits{});
~NetworkManager();
// Configure domain restrictions
void SetAllowedDomains(const std::vector<std::string>& domains);
void ClearDomainRestrictions();
// Synchronous request
// In test mode, validates but doesn't actually make network calls
HttpResponse Request(const HttpRequest& request, std::string& error);
// Stats
int GetActiveRequestCount() const;
// Access validator for testing
HttpValidator& GetValidator() { return m_validator; }
const HttpValidator& GetValidator() const { return m_validator; }
// For testing: set mock mode (no actual network calls)
void SetMockMode(bool enabled) { m_mock_mode = enabled; }
bool IsMockMode() const { return m_mock_mode; }
private:
std::string m_app_id;
NetworkLimits m_limits;
HttpValidator m_validator;
std::atomic<int> m_active_requests{0};
std::mutex m_mutex;
bool m_mock_mode = true; // Default to mock mode for tests
// Validate request before sending
bool ValidateRequest(const HttpRequest& request, std::string& error);
};
// Register network.* APIs as globals
void RegisterNetworkAPI(lua_State* L, NetworkManager* manager);
} // namespace mosis