implement Milestone 8: SQLite Database with injection prevention

This commit is contained in:
2026-01-18 15:18:47 +01:00
parent 2bb083fd7d
commit a94e0d5d63
7 changed files with 1396 additions and 1 deletions

View File

@@ -0,0 +1,598 @@
#include "database_manager.h"
#include <sqlite3.h>
#include <lua.hpp>
#include <filesystem>
#include <algorithm>
#include <cctype>
namespace fs = std::filesystem;
namespace mosis {
// Helper to set a global in the real _G (bypassing any proxy)
static void SetGlobalInRealG(lua_State* L, const char* name) {
// Stack: value to set as global
// Get _G (might be a proxy)
lua_rawgeti(L, LUA_REGISTRYINDEX, LUA_RIDX_GLOBALS);
// Check if it has a metatable with __index (proxy pattern)
if (lua_getmetatable(L, -1)) {
lua_getfield(L, -1, "__index");
if (lua_istable(L, -1)) {
// Found real _G through proxy's __index
// Stack: value, proxy, mt, real_G
lua_pushvalue(L, -4); // Copy value
lua_setfield(L, -2, name); // real_G[name] = value
lua_pop(L, 4); // pop real_G, mt, proxy, original value
return;
}
lua_pop(L, 2); // pop __index, metatable
}
// No proxy, set directly in _G
// Stack: value, _G
lua_pushvalue(L, -2); // Copy value
lua_setfield(L, -2, name); // _G[name] = value
lua_pop(L, 2); // pop _G, original value
}
// ============================================================================
// DatabaseManager
// ============================================================================
DatabaseManager::DatabaseManager(const std::string& app_id,
const std::string& app_root,
const DatabaseLimits& limits)
: m_app_id(app_id)
, m_app_root(app_root)
, m_limits(limits) {
}
DatabaseManager::~DatabaseManager() {
CloseAll();
}
bool DatabaseManager::ValidateName(const std::string& name, std::string& error) {
if (name.empty()) {
error = "Database name cannot be empty";
return false;
}
if (name.length() > 64) {
error = "Database name too long (max 64 characters)";
return false;
}
// Check for path traversal
if (name.find("..") != std::string::npos) {
error = "Database name contains invalid path traversal";
return false;
}
// Check for path separators
if (name.find('/') != std::string::npos || name.find('\\') != std::string::npos) {
error = "Database name cannot contain path separators";
return false;
}
// Only allow alphanumeric, underscore, hyphen
for (char c : name) {
if (!std::isalnum(static_cast<unsigned char>(c)) && c != '_' && c != '-') {
error = "Database name contains invalid characters (only alphanumeric, underscore, hyphen allowed)";
return false;
}
}
return true;
}
std::string DatabaseManager::ResolvePath(const std::string& name) {
fs::path db_dir = fs::path(m_app_root) / "db";
return (db_dir / (name + ".db")).string();
}
std::shared_ptr<DatabaseHandle> DatabaseManager::Open(const std::string& name, std::string& error) {
// Validate name
if (!ValidateName(name, error)) {
return nullptr;
}
// Check if already open
auto it = m_databases.find(name);
if (it != m_databases.end() && it->second->IsOpen()) {
return it->second;
}
// Check max databases limit
if (m_databases.size() >= static_cast<size_t>(m_limits.max_databases_per_app)) {
error = "Maximum number of open databases reached";
return nullptr;
}
// Resolve path and ensure directory exists
std::string db_path = ResolvePath(name);
fs::path parent = fs::path(db_path).parent_path();
std::error_code ec;
fs::create_directories(parent, ec);
if (ec) {
error = "Failed to create database directory: " + ec.message();
return nullptr;
}
// Open SQLite database
sqlite3* db = nullptr;
int rc = sqlite3_open(db_path.c_str(), &db);
if (rc != SQLITE_OK) {
error = "Failed to open database: " + std::string(sqlite3_errmsg(db));
sqlite3_close(db);
return nullptr;
}
// Create handle
auto handle = std::make_shared<DatabaseHandle>(db, db_path, m_limits);
m_databases[name] = handle;
return handle;
}
void DatabaseManager::CloseAll() {
for (auto& [name, handle] : m_databases) {
if (handle) {
handle->Close();
}
}
m_databases.clear();
}
size_t DatabaseManager::GetOpenDatabaseCount() const {
size_t count = 0;
for (const auto& [name, handle] : m_databases) {
if (handle && handle->IsOpen()) {
count++;
}
}
return count;
}
// ============================================================================
// DatabaseHandle
// ============================================================================
DatabaseHandle::DatabaseHandle(sqlite3* db, const std::string& path, const DatabaseLimits& limits)
: m_db(db)
, m_path(path)
, m_limits(limits) {
if (m_db) {
// Set up authorizer
sqlite3_set_authorizer(m_db, Authorizer, this);
// Set busy timeout
sqlite3_busy_timeout(m_db, m_limits.max_query_time_ms);
}
}
DatabaseHandle::~DatabaseHandle() {
Close();
}
void DatabaseHandle::Close() {
if (m_db) {
sqlite3_close(m_db);
m_db = nullptr;
}
}
int DatabaseHandle::Authorizer(void* user_data, int action, const char* arg1,
const char* arg2, const char* arg3, const char* arg4) {
(void)user_data;
(void)arg3;
(void)arg4;
switch (action) {
case SQLITE_ATTACH:
case SQLITE_DETACH:
// Block attaching/detaching databases
return SQLITE_DENY;
case SQLITE_PRAGMA: {
// Allow safe pragmas only
if (arg1) {
std::string pragma(arg1);
// Convert to lowercase for comparison
std::transform(pragma.begin(), pragma.end(), pragma.begin(),
[](unsigned char c) { return std::tolower(c); });
// Whitelist of safe pragmas
if (pragma == "table_info" ||
pragma == "index_list" ||
pragma == "index_info" ||
pragma == "foreign_keys" ||
pragma == "foreign_key_list" ||
pragma == "database_list" ||
pragma == "table_list" ||
pragma == "integrity_check" ||
pragma == "quick_check") {
return SQLITE_OK;
}
// Block all other pragmas
return SQLITE_DENY;
}
return SQLITE_DENY;
}
case SQLITE_FUNCTION: {
// Block dangerous functions
if (arg2) {
std::string func(arg2);
std::transform(func.begin(), func.end(), func.begin(),
[](unsigned char c) { return std::tolower(c); });
if (func == "load_extension") {
return SQLITE_DENY;
}
}
return SQLITE_OK;
}
default:
return SQLITE_OK;
}
}
bool DatabaseHandle::BindParameters(void* stmt_ptr, const std::vector<SqlValue>& params, std::string& error) {
sqlite3_stmt* stmt = static_cast<sqlite3_stmt*>(stmt_ptr);
for (size_t i = 0; i < params.size(); i++) {
int idx = static_cast<int>(i + 1); // SQLite parameters are 1-indexed
int rc = SQLITE_OK;
std::visit([&](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::nullptr_t>) {
rc = sqlite3_bind_null(stmt, idx);
} else if constexpr (std::is_same_v<T, int64_t>) {
rc = sqlite3_bind_int64(stmt, idx, arg);
} else if constexpr (std::is_same_v<T, double>) {
rc = sqlite3_bind_double(stmt, idx, arg);
} else if constexpr (std::is_same_v<T, std::string>) {
rc = sqlite3_bind_text(stmt, idx, arg.c_str(), static_cast<int>(arg.size()), SQLITE_TRANSIENT);
} else if constexpr (std::is_same_v<T, std::vector<uint8_t>>) {
rc = sqlite3_bind_blob(stmt, idx, arg.data(), static_cast<int>(arg.size()), SQLITE_TRANSIENT);
}
}, params[i]);
if (rc != SQLITE_OK) {
error = "Failed to bind parameter " + std::to_string(i) + ": " + sqlite3_errmsg(m_db);
return false;
}
}
return true;
}
bool DatabaseHandle::Execute(const std::string& sql, const std::vector<SqlValue>& params, std::string& error) {
if (!m_db) {
error = "Database not open";
return false;
}
sqlite3_stmt* stmt = nullptr;
int rc = sqlite3_prepare_v2(m_db, sql.c_str(), static_cast<int>(sql.size()), &stmt, nullptr);
if (rc != SQLITE_OK) {
error = "SQL prepare error: " + std::string(sqlite3_errmsg(m_db));
return false;
}
if (!BindParameters(stmt, params, error)) {
sqlite3_finalize(stmt);
return false;
}
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
if (rc != SQLITE_DONE && rc != SQLITE_ROW) {
error = "SQL execution error: " + std::string(sqlite3_errmsg(m_db));
return false;
}
return true;
}
std::optional<SqlResult> DatabaseHandle::Query(const std::string& sql, const std::vector<SqlValue>& params,
std::string& error) {
if (!m_db) {
error = "Database not open";
return std::nullopt;
}
sqlite3_stmt* stmt = nullptr;
int rc = sqlite3_prepare_v2(m_db, sql.c_str(), static_cast<int>(sql.size()), &stmt, nullptr);
if (rc != SQLITE_OK) {
error = "SQL prepare error: " + std::string(sqlite3_errmsg(m_db));
return std::nullopt;
}
if (!BindParameters(stmt, params, error)) {
sqlite3_finalize(stmt);
return std::nullopt;
}
SqlResult result;
int row_count = 0;
while ((rc = sqlite3_step(stmt)) == SQLITE_ROW) {
if (row_count >= m_limits.max_result_rows) {
error = "Result row limit exceeded";
sqlite3_finalize(stmt);
return std::nullopt;
}
int col_count = sqlite3_column_count(stmt);
SqlRow row;
row.reserve(col_count);
for (int i = 0; i < col_count; i++) {
int type = sqlite3_column_type(stmt, i);
switch (type) {
case SQLITE_NULL:
row.push_back(nullptr);
break;
case SQLITE_INTEGER:
row.push_back(sqlite3_column_int64(stmt, i));
break;
case SQLITE_FLOAT:
row.push_back(sqlite3_column_double(stmt, i));
break;
case SQLITE_TEXT: {
const char* text = reinterpret_cast<const char*>(sqlite3_column_text(stmt, i));
int len = sqlite3_column_bytes(stmt, i);
row.push_back(std::string(text, len));
break;
}
case SQLITE_BLOB: {
const uint8_t* data = static_cast<const uint8_t*>(sqlite3_column_blob(stmt, i));
int len = sqlite3_column_bytes(stmt, i);
row.push_back(std::vector<uint8_t>(data, data + len));
break;
}
}
}
result.push_back(std::move(row));
row_count++;
}
sqlite3_finalize(stmt);
if (rc != SQLITE_DONE) {
error = "SQL query error: " + std::string(sqlite3_errmsg(m_db));
return std::nullopt;
}
return result;
}
int64_t DatabaseHandle::GetLastInsertRowId() const {
if (!m_db) return 0;
return sqlite3_last_insert_rowid(m_db);
}
int DatabaseHandle::GetChanges() const {
if (!m_db) return 0;
return sqlite3_changes(m_db);
}
// ============================================================================
// Lua API
// ============================================================================
struct LuaDatabaseHandle {
std::shared_ptr<DatabaseHandle> handle;
};
static int Lua_DatabaseHandle_Execute(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
if (!lh->handle || !lh->handle->IsOpen()) {
lua_pushboolean(L, 0);
lua_pushstring(L, "Database not open");
return 2;
}
const char* sql = luaL_checkstring(L, 2);
// Get parameters from optional table
std::vector<SqlValue> params;
if (lua_gettop(L) >= 3 && lua_istable(L, 3)) {
lua_pushnil(L);
while (lua_next(L, 3) != 0) {
if (lua_isnil(L, -1)) {
params.push_back(nullptr);
} else if (lua_isinteger(L, -1)) {
params.push_back(static_cast<int64_t>(lua_tointeger(L, -1)));
} else if (lua_isnumber(L, -1)) {
params.push_back(lua_tonumber(L, -1));
} else if (lua_isstring(L, -1)) {
size_t len;
const char* str = lua_tolstring(L, -1, &len);
params.push_back(std::string(str, len));
} else if (lua_isboolean(L, -1)) {
params.push_back(static_cast<int64_t>(lua_toboolean(L, -1)));
}
lua_pop(L, 1);
}
}
std::string error;
if (lh->handle->Execute(sql, params, error)) {
lua_pushboolean(L, 1);
return 1;
} else {
lua_pushboolean(L, 0);
lua_pushstring(L, error.c_str());
return 2;
}
}
static int Lua_DatabaseHandle_Query(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
if (!lh->handle || !lh->handle->IsOpen()) {
lua_pushnil(L);
lua_pushstring(L, "Database not open");
return 2;
}
const char* sql = luaL_checkstring(L, 2);
// Get parameters from optional table
std::vector<SqlValue> params;
if (lua_gettop(L) >= 3 && lua_istable(L, 3)) {
lua_pushnil(L);
while (lua_next(L, 3) != 0) {
if (lua_isnil(L, -1)) {
params.push_back(nullptr);
} else if (lua_isinteger(L, -1)) {
params.push_back(static_cast<int64_t>(lua_tointeger(L, -1)));
} else if (lua_isnumber(L, -1)) {
params.push_back(lua_tonumber(L, -1));
} else if (lua_isstring(L, -1)) {
size_t len;
const char* str = lua_tolstring(L, -1, &len);
params.push_back(std::string(str, len));
} else if (lua_isboolean(L, -1)) {
params.push_back(static_cast<int64_t>(lua_toboolean(L, -1)));
}
lua_pop(L, 1);
}
}
std::string error;
auto result = lh->handle->Query(sql, params, error);
if (!result.has_value()) {
lua_pushnil(L);
lua_pushstring(L, error.c_str());
return 2;
}
// Create result table
lua_createtable(L, static_cast<int>(result->size()), 0);
int row_idx = 1;
for (const auto& row : *result) {
lua_createtable(L, static_cast<int>(row.size()), 0);
int col_idx = 1;
for (const auto& val : row) {
std::visit([L](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, std::nullptr_t>) {
lua_pushnil(L);
} else if constexpr (std::is_same_v<T, int64_t>) {
lua_pushinteger(L, arg);
} else if constexpr (std::is_same_v<T, double>) {
lua_pushnumber(L, arg);
} else if constexpr (std::is_same_v<T, std::string>) {
lua_pushlstring(L, arg.c_str(), arg.size());
} else if constexpr (std::is_same_v<T, std::vector<uint8_t>>) {
lua_pushlstring(L, reinterpret_cast<const char*>(arg.data()), arg.size());
}
}, val);
lua_rawseti(L, -2, col_idx++);
}
lua_rawseti(L, -2, row_idx++);
}
return 1;
}
static int Lua_DatabaseHandle_LastInsertId(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
if (!lh->handle) {
lua_pushinteger(L, 0);
return 1;
}
lua_pushinteger(L, lh->handle->GetLastInsertRowId());
return 1;
}
static int Lua_DatabaseHandle_Changes(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
if (!lh->handle) {
lua_pushinteger(L, 0);
return 1;
}
lua_pushinteger(L, lh->handle->GetChanges());
return 1;
}
static int Lua_DatabaseHandle_Close(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
if (lh->handle) {
lh->handle->Close();
}
return 0;
}
static int Lua_DatabaseHandle_GC(lua_State* L) {
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(luaL_checkudata(L, 1, "DatabaseHandle"));
lh->~LuaDatabaseHandle();
return 0;
}
static const luaL_Reg DatabaseHandle_methods[] = {
{"execute", Lua_DatabaseHandle_Execute},
{"query", Lua_DatabaseHandle_Query},
{"lastInsertId", Lua_DatabaseHandle_LastInsertId},
{"changes", Lua_DatabaseHandle_Changes},
{"close", Lua_DatabaseHandle_Close},
{nullptr, nullptr}
};
static int Lua_Database_Open(lua_State* L) {
DatabaseManager* manager = static_cast<DatabaseManager*>(lua_touserdata(L, lua_upvalueindex(1)));
const char* name = luaL_checkstring(L, 1);
std::string error;
auto handle = manager->Open(name, error);
if (!handle) {
lua_pushnil(L);
lua_pushstring(L, error.c_str());
return 2;
}
// Create userdata
LuaDatabaseHandle* lh = static_cast<LuaDatabaseHandle*>(lua_newuserdata(L, sizeof(LuaDatabaseHandle)));
new (lh) LuaDatabaseHandle{handle};
luaL_getmetatable(L, "DatabaseHandle");
lua_setmetatable(L, -2);
return 1;
}
void RegisterDatabaseAPI(lua_State* L, DatabaseManager* manager) {
// Create DatabaseHandle metatable
luaL_newmetatable(L, "DatabaseHandle");
lua_pushvalue(L, -1);
lua_setfield(L, -2, "__index");
lua_pushcfunction(L, Lua_DatabaseHandle_GC);
lua_setfield(L, -2, "__gc");
luaL_setfuncs(L, DatabaseHandle_methods, 0);
lua_pop(L, 1);
// Create database table
lua_newtable(L);
// database.open
lua_pushlightuserdata(L, manager);
lua_pushcclosure(L, Lua_Database_Open, 1);
lua_setfield(L, -2, "open");
// Set as global
SetGlobalInRealG(L, "database");
}
} // namespace mosis

View File

@@ -0,0 +1,88 @@
#pragma once
#include <string>
#include <vector>
#include <variant>
#include <optional>
#include <memory>
#include <unordered_map>
struct sqlite3;
struct lua_State;
namespace mosis {
// SQL value types
using SqlValue = std::variant<std::nullptr_t, int64_t, double, std::string, std::vector<uint8_t>>;
using SqlRow = std::vector<SqlValue>;
using SqlResult = std::vector<SqlRow>;
struct DatabaseLimits {
size_t max_database_size = 50 * 1024 * 1024; // 50 MB per database
int max_databases_per_app = 5; // Max open databases
int max_query_time_ms = 5000; // 5 second query timeout
int max_result_rows = 10000; // Max rows returned
};
class DatabaseHandle;
class DatabaseManager {
public:
DatabaseManager(const std::string& app_id,
const std::string& app_root,
const DatabaseLimits& limits = DatabaseLimits{});
~DatabaseManager();
// Database operations
std::shared_ptr<DatabaseHandle> Open(const std::string& name, std::string& error);
void CloseAll();
// Stats
size_t GetOpenDatabaseCount() const;
private:
std::string m_app_id;
std::string m_app_root;
DatabaseLimits m_limits;
std::unordered_map<std::string, std::shared_ptr<DatabaseHandle>> m_databases;
std::string ResolvePath(const std::string& name);
bool ValidateName(const std::string& name, std::string& error);
};
class DatabaseHandle {
public:
DatabaseHandle(sqlite3* db, const std::string& path, const DatabaseLimits& limits);
~DatabaseHandle();
// Execute (INSERT, UPDATE, DELETE, CREATE, etc.)
bool Execute(const std::string& sql, const std::vector<SqlValue>& params, std::string& error);
// Query (SELECT)
std::optional<SqlResult> Query(const std::string& sql, const std::vector<SqlValue>& params,
std::string& error);
// Get last insert rowid
int64_t GetLastInsertRowId() const;
// Get affected rows
int GetChanges() const;
bool IsOpen() const { return m_db != nullptr; }
void Close();
private:
sqlite3* m_db;
std::string m_path;
DatabaseLimits m_limits;
static int Authorizer(void* user_data, int action, const char* arg1,
const char* arg2, const char* arg3, const char* arg4);
bool BindParameters(void* stmt, const std::vector<SqlValue>& params, std::string& error);
};
// Register database.* APIs as globals
void RegisterDatabaseAPI(lua_State* L, DatabaseManager* manager);
} // namespace mosis