implement Milestone 10: WebSocket with connection limits and SSRF prevention

This commit is contained in:
2026-01-18 15:30:13 +01:00
parent c0baa673b8
commit 0c19247838
7 changed files with 1252 additions and 6 deletions

View File

@@ -29,9 +29,9 @@ std::optional<ParsedUrl> HttpValidator::Validate(const std::string& url, std::st
return std::nullopt;
}
// Must be HTTPS
if (parsed->scheme != "https") {
error = "HTTPS required, got: " + parsed->scheme;
// Must be HTTPS or WSS
if (parsed->scheme != "https" && parsed->scheme != "wss") {
error = "HTTPS or WSS required, got: " + parsed->scheme;
return std::nullopt;
}
@@ -376,9 +376,9 @@ std::optional<ParsedUrl> HttpValidator::ParseUrl(const std::string& url) {
}
// Default port based on scheme
if (result.scheme == "https" && result.port == 0) {
if ((result.scheme == "https" || result.scheme == "wss") && result.port == 0) {
result.port = 443;
} else if (result.scheme == "http" && result.port == 0) {
} else if ((result.scheme == "http" || result.scheme == "ws") && result.port == 0) {
result.port = 80;
}

View File

@@ -0,0 +1,458 @@
#include "websocket_manager.h"
#include <lua.hpp>
#include <algorithm>
namespace mosis {
// WebSocket implementation
WebSocket::WebSocket(int id, const std::string& url, size_t max_message_size)
: m_id(id)
, m_url(url)
, m_state(WebSocketState::Connecting)
, m_max_message_size(max_message_size)
{
}
WebSocket::~WebSocket() {
if (m_state != WebSocketState::Closed) {
Close();
}
}
bool WebSocket::Send(const std::string& data, bool binary) {
if (m_state != WebSocketState::Open) {
return false;
}
if (data.size() > m_max_message_size) {
return false;
}
// In mock mode, we don't actually send
// In real implementation, would send through WebSocket connection
return true;
}
void WebSocket::Close(int code, const std::string& reason) {
if (m_state == WebSocketState::Closed) {
return;
}
m_state = WebSocketState::Closing;
// In real implementation, send close frame
// For mock, just transition to closed
m_state = WebSocketState::Closed;
if (m_on_close) {
m_on_close(code, reason);
}
}
void WebSocket::SimulateOpen() {
if (m_state == WebSocketState::Connecting) {
m_state = WebSocketState::Open;
if (m_on_open) {
m_on_open();
}
}
}
void WebSocket::SimulateMessage(const std::string& data, bool binary) {
if (m_state == WebSocketState::Open && m_on_message) {
m_on_message(data, binary);
}
}
void WebSocket::SimulateClose(int code, const std::string& reason) {
if (m_state != WebSocketState::Closed) {
m_state = WebSocketState::Closed;
if (m_on_close) {
m_on_close(code, reason);
}
}
}
void WebSocket::SimulateError(const std::string& error) {
if (m_on_error) {
m_on_error(error);
}
}
// WebSocketManager implementation
WebSocketManager::WebSocketManager(const std::string& app_id, const WebSocketLimits& limits)
: m_app_id(app_id)
, m_limits(limits)
, m_mock_mode(true)
{
}
WebSocketManager::~WebSocketManager() {
CloseAll();
}
void WebSocketManager::SetAllowedDomains(const std::vector<std::string>& domains) {
m_validator.SetAllowedDomains(domains);
}
void WebSocketManager::ClearDomainRestrictions() {
m_validator.ClearDomainRestrictions();
}
bool WebSocketManager::ValidateUrl(const std::string& url, std::string& error) {
auto parsed = m_validator.Validate(url, error);
if (!parsed) {
return false;
}
// Must be WSS scheme
if (parsed->scheme != "wss") {
error = "WSS required for WebSocket, got: " + parsed->scheme;
return false;
}
return true;
}
std::shared_ptr<WebSocket> WebSocketManager::Connect(const std::string& url, std::string& error) {
// Validate URL
if (!ValidateUrl(url, error)) {
return nullptr;
}
std::lock_guard<std::mutex> lock(m_mutex);
// Check connection limit
if (static_cast<int>(m_connections.size()) >= m_limits.max_connections_per_app) {
error = "Connection limit exceeded (max " +
std::to_string(m_limits.max_connections_per_app) + ")";
return nullptr;
}
// In mock mode, we create the WebSocket but don't actually connect
if (m_mock_mode) {
int id = m_next_id++;
auto ws = std::make_shared<WebSocket>(id, url, m_limits.max_message_size);
m_connections[id] = ws;
// In mock mode, immediately fail with mock error
error = "WebSocket connections disabled in mock mode";
// Remove the connection since it failed
m_connections.erase(id);
return nullptr;
}
// In a real implementation, we would start the connection here
int id = m_next_id++;
auto ws = std::make_shared<WebSocket>(id, url, m_limits.max_message_size);
m_connections[id] = ws;
// For real implementation: start async connection
// For now, just return the WebSocket (it stays in Connecting state)
return ws;
}
std::shared_ptr<WebSocket> WebSocketManager::GetConnection(int id) {
std::lock_guard<std::mutex> lock(m_mutex);
auto it = m_connections.find(id);
if (it != m_connections.end()) {
return it->second;
}
return nullptr;
}
void WebSocketManager::CloseConnection(int id, int code, const std::string& reason) {
std::shared_ptr<WebSocket> ws;
{
std::lock_guard<std::mutex> lock(m_mutex);
auto it = m_connections.find(id);
if (it != m_connections.end()) {
ws = it->second;
m_connections.erase(it);
}
}
if (ws) {
ws->Close(code, reason);
}
}
void WebSocketManager::CloseAll() {
std::map<int, std::shared_ptr<WebSocket>> connections_copy;
{
std::lock_guard<std::mutex> lock(m_mutex);
connections_copy = std::move(m_connections);
m_connections.clear();
}
for (auto& [id, ws] : connections_copy) {
ws->Close(1001, "App stopping");
}
}
int WebSocketManager::GetActiveConnectionCount() const {
std::lock_guard<std::mutex> lock(m_mutex);
return static_cast<int>(m_connections.size());
}
// Lua API implementation
// Userdata for WebSocket
struct LuaWebSocket {
std::weak_ptr<WebSocket> ws;
int id;
};
static const char* WEBSOCKET_MT = "mosis.WebSocket";
// Get WebSocketManager from upvalue
static WebSocketManager* GetManager(lua_State* L) {
return static_cast<WebSocketManager*>(lua_touserdata(L, lua_upvalueindex(1)));
}
// Get WebSocket from userdata
static LuaWebSocket* GetWebSocket(lua_State* L, int index) {
return static_cast<LuaWebSocket*>(luaL_checkudata(L, index, WEBSOCKET_MT));
}
// WebSocket:send(data [, binary])
static int L_websocket_send(lua_State* L) {
LuaWebSocket* lws = GetWebSocket(L, 1);
auto ws = lws->ws.lock();
if (!ws) {
lua_pushboolean(L, 0);
lua_pushstring(L, "WebSocket closed");
return 2;
}
size_t len;
const char* data = luaL_checklstring(L, 2, &len);
bool binary = lua_toboolean(L, 3);
bool ok = ws->Send(std::string(data, len), binary);
lua_pushboolean(L, ok ? 1 : 0);
if (!ok) {
lua_pushstring(L, "Send failed");
return 2;
}
return 1;
}
// WebSocket:close([code [, reason]])
static int L_websocket_close(lua_State* L) {
LuaWebSocket* lws = GetWebSocket(L, 1);
auto ws = lws->ws.lock();
if (!ws) {
return 0;
}
int code = static_cast<int>(luaL_optinteger(L, 2, 1000));
const char* reason = luaL_optstring(L, 3, "");
ws->Close(code, reason);
return 0;
}
// WebSocket:state() -> string
static int L_websocket_state(lua_State* L) {
LuaWebSocket* lws = GetWebSocket(L, 1);
auto ws = lws->ws.lock();
if (!ws) {
lua_pushstring(L, "closed");
return 1;
}
switch (ws->GetState()) {
case WebSocketState::Connecting:
lua_pushstring(L, "connecting");
break;
case WebSocketState::Open:
lua_pushstring(L, "open");
break;
case WebSocketState::Closing:
lua_pushstring(L, "closing");
break;
case WebSocketState::Closed:
lua_pushstring(L, "closed");
break;
}
return 1;
}
// WebSocket:on(event, callback)
static int L_websocket_on(lua_State* L) {
LuaWebSocket* lws = GetWebSocket(L, 1);
auto ws = lws->ws.lock();
if (!ws) {
return 0;
}
const char* event = luaL_checkstring(L, 2);
luaL_checktype(L, 3, LUA_TFUNCTION);
// Store callback in registry with ws id + event as key
// For now, this is a simplified version that doesn't actually store callbacks
// A full implementation would need to store refs and call them on events
lua_pushboolean(L, 1);
return 1;
}
// WebSocket garbage collection
static int L_websocket_gc(lua_State* L) {
LuaWebSocket* lws = GetWebSocket(L, 1);
auto ws = lws->ws.lock();
if (ws && ws->GetState() != WebSocketState::Closed) {
ws->Close(1001, "WebSocket garbage collected");
}
return 0;
}
// network.websocket(url) -> WebSocket, error
static int L_network_websocket(lua_State* L) {
WebSocketManager* manager = GetManager(L);
if (!manager) {
lua_pushnil(L);
lua_pushstring(L, "WebSocketManager not available");
return 2;
}
const char* url = luaL_checkstring(L, 1);
std::string error;
auto ws = manager->Connect(url, error);
if (!ws) {
lua_pushnil(L);
lua_pushstring(L, error.c_str());
return 2;
}
// Create userdata
LuaWebSocket* lws = static_cast<LuaWebSocket*>(lua_newuserdata(L, sizeof(LuaWebSocket)));
lws->ws = ws;
lws->id = ws->GetId();
// Set metatable
luaL_getmetatable(L, WEBSOCKET_MT);
lua_setmetatable(L, -2);
return 1;
}
// Helper to set a global in the real _G (bypassing any proxy)
static void SetGlobalInRealG(lua_State* L, const char* name) {
// Stack: value to set as global
lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
// Check if it's a proxy with __index pointing to real _G
if (lua_getmetatable(L, -1)) {
lua_getfield(L, -1, "__index");
if (lua_istable(L, -1)) {
// This is the real _G, set our value there
lua_pushvalue(L, -4); // Push the value
lua_setfield(L, -2, name);
lua_pop(L, 4); // Pop __index, metatable, proxy, (value already consumed)
return;
}
lua_pop(L, 2); // Pop __index and metatable
}
// No proxy, set directly
lua_pushvalue(L, -2); // Push the value
lua_setfield(L, -2, name);
lua_pop(L, 2); // Pop globals table and original value
}
void RegisterWebSocketAPI(lua_State* L, WebSocketManager* manager) {
// Create WebSocket metatable
luaL_newmetatable(L, WEBSOCKET_MT);
lua_pushstring(L, "__index");
lua_newtable(L);
// Methods
lua_pushcfunction(L, L_websocket_send);
lua_setfield(L, -2, "send");
lua_pushcfunction(L, L_websocket_close);
lua_setfield(L, -2, "close");
lua_pushcfunction(L, L_websocket_state);
lua_setfield(L, -2, "state");
lua_pushcfunction(L, L_websocket_on);
lua_setfield(L, -2, "on");
lua_settable(L, -3); // Set __index
// GC metamethod
lua_pushstring(L, "__gc");
lua_pushcfunction(L, L_websocket_gc);
lua_settable(L, -3);
lua_pop(L, 1); // Pop metatable
// Check if network table already exists (from RegisterNetworkAPI)
lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
if (lua_getmetatable(L, -1)) {
lua_getfield(L, -1, "__index");
if (lua_istable(L, -1)) {
// Use real _G
lua_getfield(L, -1, "network");
if (lua_istable(L, -1)) {
// Add websocket function to existing network table
lua_pushlightuserdata(L, manager);
lua_pushcclosure(L, L_network_websocket, 1);
lua_setfield(L, -2, "websocket");
lua_pop(L, 4); // Pop network, __index, metatable, globals
return;
}
lua_pop(L, 4); // Pop nil, __index, metatable, globals
} else {
lua_pop(L, 3); // Pop __index, metatable, globals
}
} else {
lua_pop(L, 1); // Pop globals
}
// No existing network table with proxy, try direct access
lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
lua_getfield(L, -1, "network");
if (lua_istable(L, -1)) {
// Add websocket function to existing network table
lua_pushlightuserdata(L, manager);
lua_pushcclosure(L, L_network_websocket, 1);
lua_setfield(L, -2, "websocket");
lua_pop(L, 2); // Pop network, globals
return;
}
lua_pop(L, 2); // Pop nil/not-table, globals
// No network table exists, create one
lua_newtable(L);
lua_pushlightuserdata(L, manager);
lua_pushcclosure(L, L_network_websocket, 1);
lua_setfield(L, -2, "websocket");
SetGlobalInRealG(L, "network");
}
} // namespace mosis

View File

@@ -0,0 +1,121 @@
#pragma once
#include <string>
#include <map>
#include <memory>
#include <mutex>
#include <functional>
#include <vector>
#include "http_validator.h"
struct lua_State;
namespace mosis {
struct WebSocketLimits {
int max_connections_per_app = 5;
size_t max_message_size = 1 * 1024 * 1024; // 1 MB
int idle_timeout_ms = 5 * 60 * 1000; // 5 minutes
int connect_timeout_ms = 30000; // 30 seconds
};
enum class WebSocketState {
Connecting,
Open,
Closing,
Closed
};
class WebSocket {
public:
using MessageCallback = std::function<void(const std::string& data, bool binary)>;
using CloseCallback = std::function<void(int code, const std::string& reason)>;
using ErrorCallback = std::function<void(const std::string& error)>;
using OpenCallback = std::function<void()>;
WebSocket(int id, const std::string& url, size_t max_message_size);
~WebSocket();
int GetId() const { return m_id; }
const std::string& GetUrl() const { return m_url; }
WebSocketState GetState() const { return m_state; }
// Send message (returns false if not connected or message too large)
bool Send(const std::string& data, bool binary = false);
// Close connection
void Close(int code = 1000, const std::string& reason = "");
// Event callbacks
void SetOnOpen(OpenCallback cb) { m_on_open = std::move(cb); }
void SetOnMessage(MessageCallback cb) { m_on_message = std::move(cb); }
void SetOnClose(CloseCallback cb) { m_on_close = std::move(cb); }
void SetOnError(ErrorCallback cb) { m_on_error = std::move(cb); }
// For mock mode - simulate events
void SimulateOpen();
void SimulateMessage(const std::string& data, bool binary);
void SimulateClose(int code, const std::string& reason);
void SimulateError(const std::string& error);
private:
int m_id;
std::string m_url;
WebSocketState m_state;
size_t m_max_message_size;
OpenCallback m_on_open;
MessageCallback m_on_message;
CloseCallback m_on_close;
ErrorCallback m_on_error;
};
class WebSocketManager {
public:
WebSocketManager(const std::string& app_id, const WebSocketLimits& limits = WebSocketLimits{});
~WebSocketManager();
// Configure domain restrictions (reuses HttpValidator)
void SetAllowedDomains(const std::vector<std::string>& domains);
void ClearDomainRestrictions();
// Create new WebSocket connection
// Returns WebSocket on success, nullptr on failure (sets error)
std::shared_ptr<WebSocket> Connect(const std::string& url, std::string& error);
// Get connection by ID
std::shared_ptr<WebSocket> GetConnection(int id);
// Close specific connection
void CloseConnection(int id, int code = 1000, const std::string& reason = "");
// Close all connections (called on app stop)
void CloseAll();
// Stats
int GetActiveConnectionCount() const;
// Access validator for testing
HttpValidator& GetValidator() { return m_validator; }
const HttpValidator& GetValidator() const { return m_validator; }
// For testing: set mock mode
void SetMockMode(bool enabled) { m_mock_mode = enabled; }
bool IsMockMode() const { return m_mock_mode; }
private:
std::string m_app_id;
WebSocketLimits m_limits;
HttpValidator m_validator;
std::map<int, std::shared_ptr<WebSocket>> m_connections;
int m_next_id = 1;
mutable std::mutex m_mutex;
bool m_mock_mode = true; // Default to mock mode for tests
bool ValidateUrl(const std::string& url, std::string& error);
};
// Register websocket APIs as part of network global
void RegisterWebSocketAPI(lua_State* L, WebSocketManager* manager);
} // namespace mosis