implement Milestone 10: WebSocket with connection limits and SSRF prevention

This commit is contained in:
2026-01-18 15:30:13 +01:00
parent c0baa673b8
commit 0c19247838
7 changed files with 1252 additions and 6 deletions

View File

@@ -23,6 +23,7 @@ add_library(mosis-sandbox STATIC
../src/main/cpp/sandbox/database_manager.cpp
../src/main/cpp/sandbox/http_validator.cpp
../src/main/cpp/sandbox/network_manager.cpp
../src/main/cpp/sandbox/websocket_manager.cpp
)
target_include_directories(mosis-sandbox PUBLIC
../src/main/cpp/sandbox

View File

@@ -13,6 +13,7 @@
#include "database_manager.h"
#include "http_validator.h"
#include "network_manager.h"
#include "websocket_manager.h"
#include <filesystem>
#include <fstream>
#include <sstream>
@@ -1667,6 +1668,177 @@ bool Test_NetworkLuaIntegration(std::string& error_msg) {
return true;
}
//=============================================================================
// MILESTONE 10: WebSocket
//=============================================================================
bool Test_WebSocketUrlValidation(std::string& error_msg) {
mosis::WebSocketManager manager("test.app");
manager.ClearDomainRestrictions();
std::string err;
// WSS should be allowed (will fail in mock mode but validation passes)
auto ws = manager.Connect("wss://example.com/socket", err);
EXPECT_TRUE(err.find("mock") != std::string::npos ||
err.find("disabled") != std::string::npos);
// WS (plain) should be blocked at validation
ws = manager.Connect("ws://example.com/socket", err);
EXPECT_TRUE(ws == nullptr);
EXPECT_TRUE(err.find("WSS") != std::string::npos ||
err.find("HTTPS") != std::string::npos ||
err.find("required") != std::string::npos);
return true;
}
bool Test_WebSocketConnectionLimits(std::string& error_msg) {
mosis::WebSocketLimits limits;
limits.max_connections_per_app = 2;
mosis::WebSocketManager manager("test.app", limits);
manager.ClearDomainRestrictions();
manager.SetMockMode(false); // Disable mock to test connection tracking
std::string err;
// Create max connections
auto ws1 = manager.Connect("wss://example.com/socket1", err);
auto ws2 = manager.Connect("wss://example.com/socket2", err);
// Third should fail
auto ws3 = manager.Connect("wss://example.com/socket3", err);
EXPECT_TRUE(ws3 == nullptr);
EXPECT_TRUE(err.find("limit") != std::string::npos ||
err.find("Connection") != std::string::npos);
return true;
}
bool Test_WebSocketBlocksPrivateIP(std::string& error_msg) {
mosis::WebSocketManager manager("test.app");
manager.ClearDomainRestrictions();
std::string err;
std::vector<std::string> private_urls = {
"wss://127.0.0.1/socket",
"wss://localhost/socket",
"wss://10.0.0.1/socket",
"wss://192.168.1.1/socket",
"wss://169.254.169.254/socket"
};
for (const auto& url : private_urls) {
auto ws = manager.Connect(url, err);
EXPECT_TRUE(ws == nullptr);
}
return true;
}
bool Test_WebSocketDomainWhitelist(std::string& error_msg) {
mosis::WebSocketManager manager("test.app");
manager.SetAllowedDomains({"api.example.com"});
std::string err;
// Allowed domain - should pass validation (may fail in mock mode for other reasons)
auto ws1 = manager.Connect("wss://api.example.com/socket", err);
bool allowed_passed = (ws1 != nullptr) ||
(err.find("mock") != std::string::npos) ||
(err.find("disabled") != std::string::npos);
EXPECT_TRUE(allowed_passed);
// Disallowed domain - should fail validation
auto ws2 = manager.Connect("wss://evil.com/socket", err);
EXPECT_TRUE(ws2 == nullptr);
EXPECT_TRUE(err.find("allowed") != std::string::npos ||
err.find("whitelist") != std::string::npos ||
err.find("Domain") != std::string::npos);
return true;
}
bool Test_WebSocketMessageLimits(std::string& error_msg) {
// Create a WebSocket directly to test send limits
mosis::WebSocket ws(1, "wss://example.com/socket", 1024); // 1 KB limit
// WebSocket is in Connecting state, so send should fail
std::string small_message(512, 'X');
bool send_result = ws.Send(small_message);
EXPECT_FALSE(send_result); // Not connected
// Simulate open
ws.SimulateOpen();
// Now small message should work
send_result = ws.Send(small_message);
EXPECT_TRUE(send_result);
// Large message should fail
std::string large_message(2048, 'X'); // 2 KB
send_result = ws.Send(large_message);
EXPECT_FALSE(send_result);
return true;
}
bool Test_WebSocketCloseAll(std::string& error_msg) {
mosis::WebSocketManager manager("test.app");
manager.ClearDomainRestrictions();
manager.SetMockMode(false); // Disable mock to track connections
std::string err;
// Create some connections
manager.Connect("wss://example.com/socket1", err);
manager.Connect("wss://example.com/socket2", err);
EXPECT_TRUE(manager.GetActiveConnectionCount() == 2);
// Close all
manager.CloseAll();
// Should have no active connections
EXPECT_TRUE(manager.GetActiveConnectionCount() == 0);
return true;
}
bool Test_WebSocketLuaIntegration(std::string& error_msg) {
SandboxContext ctx = TestContext();
LuaSandbox sandbox(ctx);
mosis::WebSocketManager manager("test.app");
manager.ClearDomainRestrictions();
mosis::RegisterWebSocketAPI(sandbox.GetState(), &manager);
std::string script = R"lua(
-- Test that network.websocket exists
if not network then
error("network global not found")
end
if not network.websocket then
error("network.websocket not found")
end
-- Test validation rejection (private IP)
local ws, err = network.websocket("wss://127.0.0.1/socket")
if ws then
error("expected private IP to be blocked")
end
)lua";
bool ok = sandbox.LoadString(script, "websocket_test");
if (!ok) {
error_msg = "Lua test failed: " + sandbox.GetLastError();
return false;
}
return true;
}
//=============================================================================
// MAIN
//=============================================================================
@@ -1799,6 +1971,15 @@ int main(int argc, char* argv[]) {
harness.AddTest("NetworkRequestLimits", Test_NetworkRequestLimits);
harness.AddTest("NetworkLuaIntegration", Test_NetworkLuaIntegration);
// Milestone 10: WebSocket
harness.AddTest("WebSocketUrlValidation", Test_WebSocketUrlValidation);
harness.AddTest("WebSocketConnectionLimits", Test_WebSocketConnectionLimits);
harness.AddTest("WebSocketBlocksPrivateIP", Test_WebSocketBlocksPrivateIP);
harness.AddTest("WebSocketDomainWhitelist", Test_WebSocketDomainWhitelist);
harness.AddTest("WebSocketMessageLimits", Test_WebSocketMessageLimits);
harness.AddTest("WebSocketCloseAll", Test_WebSocketCloseAll);
harness.AddTest("WebSocketLuaIntegration", Test_WebSocketLuaIntegration);
// Run tests
auto results = harness.Run(filter);