diff --git a/portal/cmd/server/main.go b/portal/cmd/server/main.go index 5bd1fb2..62d6b68 100644 --- a/portal/cmd/server/main.go +++ b/portal/cmd/server/main.go @@ -23,17 +23,20 @@ func main() { } // Initialize database - db, err := database.Open(cfg.DatabasePath) + sqlDB, err := database.Open(cfg.DatabasePath) if err != nil { log.Fatalf("failed to open database: %v", err) } - defer db.Close() + defer sqlDB.Close() // Run migrations - if err := database.Migrate(db); err != nil { + if err := database.Migrate(sqlDB); err != nil { log.Fatalf("failed to run migrations: %v", err) } + // Wrap database with business logic + db := database.NewDB(sqlDB) + // Create router router := api.NewRouter(cfg, db) diff --git a/portal/go.mod b/portal/go.mod index 4e1087c..3412ec9 100644 --- a/portal/go.mod +++ b/portal/go.mod @@ -6,6 +6,7 @@ require ( github.com/go-chi/chi/v5 v5.0.12 github.com/go-playground/validator/v10 v10.19.0 github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/google/uuid v1.6.0 golang.org/x/crypto v0.21.0 golang.org/x/oauth2 v0.18.0 modernc.org/sqlite v1.29.5 diff --git a/portal/internal/api/handlers/auth.go b/portal/internal/api/handlers/auth.go new file mode 100644 index 0000000..30188e6 --- /dev/null +++ b/portal/internal/api/handlers/auth.go @@ -0,0 +1,264 @@ +package handlers + +import ( + "encoding/json" + "log" + "net/http" + "time" + + "github.com/omixlab/mosis-portal/internal/api/middleware" + "github.com/omixlab/mosis-portal/internal/auth" + "github.com/omixlab/mosis-portal/internal/database" +) + +// AuthHandler handles authentication endpoints +type AuthHandler struct { + oauthManager *auth.OAuthManager + jwtManager *auth.JWTManager + db *database.DB +} + +// NewAuthHandler creates a new auth handler +func NewAuthHandler(oauthManager *auth.OAuthManager, jwtManager *auth.JWTManager, db *database.DB) *AuthHandler { + return &AuthHandler{ + oauthManager: oauthManager, + jwtManager: jwtManager, + db: db, + } +} + +// OAuthStart initiates OAuth flow +func (h *AuthHandler) OAuthStart(provider auth.OAuthProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + state, err := auth.GenerateState() + if err != nil { + http.Error(w, "Failed to generate state", http.StatusInternalServerError) + return + } + + // Store state in cookie for verification + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: state, + Path: "/", + MaxAge: 300, // 5 minutes + HttpOnly: true, + Secure: r.TLS != nil, + SameSite: http.SameSiteLaxMode, + }) + + authURL, err := h.oauthManager.GetAuthURL(provider, state) + if err != nil { + http.Error(w, "OAuth not configured: "+err.Error(), http.StatusNotImplemented) + return + } + + http.Redirect(w, r, authURL, http.StatusTemporaryRedirect) + } +} + +// OAuthCallback handles OAuth callback +func (h *AuthHandler) OAuthCallback(provider auth.OAuthProvider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Verify state + stateCookie, err := r.Cookie("oauth_state") + if err != nil || stateCookie.Value != r.URL.Query().Get("state") { + http.Error(w, "Invalid state", http.StatusBadRequest) + return + } + + // Clear state cookie + http.SetCookie(w, &http.Cookie{ + Name: "oauth_state", + Value: "", + Path: "/", + MaxAge: -1, + }) + + // Check for error from provider + if errMsg := r.URL.Query().Get("error"); errMsg != "" { + http.Error(w, "OAuth error: "+errMsg, http.StatusBadRequest) + return + } + + code := r.URL.Query().Get("code") + if code == "" { + http.Error(w, "No code provided", http.StatusBadRequest) + return + } + + // Exchange code for user info + oauthUser, err := h.oauthManager.Exchange(r.Context(), provider, code) + if err != nil { + log.Printf("OAuth exchange failed: %v", err) + http.Error(w, "Authentication failed", http.StatusInternalServerError) + return + } + + if oauthUser.Email == "" { + http.Error(w, "Email not available from OAuth provider", http.StatusBadRequest) + return + } + + // Find or create developer + developer, err := h.db.FindOrCreateDeveloper(r.Context(), &database.Developer{ + Email: oauthUser.Email, + Name: oauthUser.Name, + OAuthProvider: string(oauthUser.Provider), + OAuthID: oauthUser.ID, + AvatarURL: oauthUser.Avatar, + }) + if err != nil { + log.Printf("Failed to find/create developer: %v", err) + http.Error(w, "Failed to create account", http.StatusInternalServerError) + return + } + + // Generate tokens + tokenPair, err := h.jwtManager.GenerateTokenPair(developer.ID, developer.Email) + if err != nil { + log.Printf("Failed to generate tokens: %v", err) + http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) + return + } + + // Set refresh token in HttpOnly cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: tokenPair.RefreshToken, + Path: "/", + MaxAge: 30 * 24 * 60 * 60, // 30 days + HttpOnly: true, + Secure: r.TLS != nil, + SameSite: http.SameSiteStrictMode, + }) + + // Log the authentication + h.db.LogAudit(r.Context(), developer.ID, "oauth_login", r.RemoteAddr, r.UserAgent(), true, "") + + // Return access token + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": tokenPair.AccessToken, + "token_type": tokenPair.TokenType, + "expires_in": tokenPair.ExpiresIn, + "developer": map[string]interface{}{ + "id": developer.ID, + "email": developer.Email, + "name": developer.Name, + "avatar_url": developer.AvatarURL, + }, + }) + } +} + +// Refresh refreshes the access token using refresh token +func (h *AuthHandler) Refresh(w http.ResponseWriter, r *http.Request) { + // Get refresh token from cookie or body + var refreshToken string + + cookie, err := r.Cookie("refresh_token") + if err == nil { + refreshToken = cookie.Value + } + + // If not in cookie, try request body + if refreshToken == "" { + var body struct { + RefreshToken string `json:"refresh_token"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err == nil { + refreshToken = body.RefreshToken + } + } + + if refreshToken == "" { + http.Error(w, "Refresh token required", http.StatusBadRequest) + return + } + + // Validate refresh token + claims, err := h.jwtManager.ValidateRefreshToken(refreshToken) + if err != nil { + // Clear invalid cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: "", + Path: "/", + MaxAge: -1, + }) + http.Error(w, "Invalid refresh token", http.StatusUnauthorized) + return + } + + // Generate new token pair + tokenPair, err := h.jwtManager.GenerateTokenPair(claims.Subject, claims.Email) + if err != nil { + http.Error(w, "Failed to generate tokens", http.StatusInternalServerError) + return + } + + // Set new refresh token in cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: tokenPair.RefreshToken, + Path: "/", + MaxAge: 30 * 24 * 60 * 60, + HttpOnly: true, + Secure: r.TLS != nil, + SameSite: http.SameSiteStrictMode, + }) + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "access_token": tokenPair.AccessToken, + "token_type": tokenPair.TokenType, + "expires_in": tokenPair.ExpiresIn, + }) +} + +// Logout invalidates the current session +func (h *AuthHandler) Logout(w http.ResponseWriter, r *http.Request) { + // Clear refresh token cookie + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: "", + Path: "/", + MaxAge: -1, + HttpOnly: true, + Secure: r.TLS != nil, + SameSite: http.SameSiteStrictMode, + }) + + // Log the logout if authenticated + developerID := middleware.GetDeveloperID(r.Context()) + if developerID != "" { + h.db.LogAudit(r.Context(), developerID, "logout", r.RemoteAddr, r.UserAgent(), true, "") + } + + w.WriteHeader(http.StatusNoContent) +} + +// Me returns the current user's information +func (h *AuthHandler) Me(w http.ResponseWriter, r *http.Request) { + developerID := middleware.GetDeveloperID(r.Context()) + if developerID == "" { + http.Error(w, "Not authenticated", http.StatusUnauthorized) + return + } + + developer, err := h.db.GetDeveloper(r.Context(), developerID) + if err != nil { + http.Error(w, "Developer not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "id": developer.ID, + "email": developer.Email, + "name": developer.Name, + "avatar_url": developer.AvatarURL, + "created_at": developer.CreatedAt.Format(time.RFC3339), + }) +} diff --git a/portal/internal/api/middleware/auth.go b/portal/internal/api/middleware/auth.go new file mode 100644 index 0000000..700c550 --- /dev/null +++ b/portal/internal/api/middleware/auth.go @@ -0,0 +1,127 @@ +// Package middleware provides HTTP middleware for the API +package middleware + +import ( + "context" + "net/http" + "strings" + + "github.com/omixlab/mosis-portal/internal/auth" + "github.com/omixlab/mosis-portal/internal/database" +) + +type contextKey string + +const ( + DeveloperContextKey contextKey = "developer" + ClaimsContextKey contextKey = "claims" +) + +// AuthMiddleware handles JWT and API key authentication +type AuthMiddleware struct { + jwtManager *auth.JWTManager + db *database.DB +} + +// NewAuthMiddleware creates a new auth middleware +func NewAuthMiddleware(jwtManager *auth.JWTManager, db *database.DB) *AuthMiddleware { + return &AuthMiddleware{ + jwtManager: jwtManager, + db: db, + } +} + +// RequireAuth requires a valid JWT or API key +func (m *AuthMiddleware) RequireAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try Bearer token first + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + token := strings.TrimPrefix(authHeader, "Bearer ") + claims, err := m.jwtManager.ValidateAccessToken(token) + if err != nil { + http.Error(w, "Invalid or expired token", http.StatusUnauthorized) + return + } + + // Add claims to context + ctx := context.WithValue(r.Context(), ClaimsContextKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + // Try API key + apiKey := r.Header.Get("X-API-Key") + if apiKey != "" { + developer, err := m.db.ValidateAPIKey(r.Context(), apiKey) + if err != nil { + http.Error(w, "Invalid API key", http.StatusUnauthorized) + return + } + + // Add developer to context + ctx := context.WithValue(r.Context(), DeveloperContextKey, developer) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + + http.Error(w, "Authorization required", http.StatusUnauthorized) + }) +} + +// OptionalAuth adds developer info to context if authenticated, but doesn't require it +func (m *AuthMiddleware) OptionalAuth(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Try Bearer token + authHeader := r.Header.Get("Authorization") + if strings.HasPrefix(authHeader, "Bearer ") { + token := strings.TrimPrefix(authHeader, "Bearer ") + claims, err := m.jwtManager.ValidateAccessToken(token) + if err == nil { + ctx := context.WithValue(r.Context(), ClaimsContextKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + // Try API key + apiKey := r.Header.Get("X-API-Key") + if apiKey != "" { + developer, err := m.db.ValidateAPIKey(r.Context(), apiKey) + if err == nil { + ctx := context.WithValue(r.Context(), DeveloperContextKey, developer) + next.ServeHTTP(w, r.WithContext(ctx)) + return + } + } + + // Continue without authentication + next.ServeHTTP(w, r) + }) +} + +// GetClaims retrieves JWT claims from context +func GetClaims(ctx context.Context) *auth.Claims { + claims, ok := ctx.Value(ClaimsContextKey).(*auth.Claims) + if !ok { + return nil + } + return claims +} + +// GetDeveloperID retrieves the developer ID from context (from JWT or API key) +func GetDeveloperID(ctx context.Context) string { + // First check JWT claims + claims := GetClaims(ctx) + if claims != nil { + return claims.Subject + } + + // Then check developer from API key + developer, ok := ctx.Value(DeveloperContextKey).(*database.Developer) + if ok && developer != nil { + return developer.ID + } + + return "" +} diff --git a/portal/internal/api/router.go b/portal/internal/api/router.go index 6ee73c7..032e84d 100644 --- a/portal/internal/api/router.go +++ b/portal/internal/api/router.go @@ -2,24 +2,36 @@ package api import ( - "database/sql" "net/http" "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" + chimw "github.com/go-chi/chi/v5/middleware" "github.com/omixlab/mosis-portal/internal/api/handlers" + "github.com/omixlab/mosis-portal/internal/api/middleware" + "github.com/omixlab/mosis-portal/internal/auth" "github.com/omixlab/mosis-portal/internal/config" + "github.com/omixlab/mosis-portal/internal/database" ) // NewRouter creates and configures the HTTP router -func NewRouter(cfg *config.Config, db *sql.DB) http.Handler { +func NewRouter(cfg *config.Config, db *database.DB) http.Handler { r := chi.NewRouter() // Middleware - r.Use(middleware.Logger) - r.Use(middleware.Recoverer) - r.Use(middleware.RealIP) - r.Use(middleware.RequestID) + r.Use(chimw.Logger) + r.Use(chimw.Recoverer) + r.Use(chimw.RealIP) + r.Use(chimw.RequestID) + + // Initialize auth components + jwtManager := auth.NewJWTManager(cfg.JWTSecret) + oauthManager := auth.NewOAuthManager( + cfg.BaseURL, + cfg.GitHubClientID, cfg.GitHubClientSecret, + cfg.GoogleClientID, cfg.GoogleClientSecret, + ) + authMiddleware := middleware.NewAuthMiddleware(jwtManager, db) + authHandler := handlers.NewAuthHandler(oauthManager, jwtManager, db) // Health check r.Get("/health", func(w http.ResponseWriter, r *http.Request) { @@ -29,47 +41,57 @@ func NewRouter(cfg *config.Config, db *sql.DB) http.Handler { // API v1 r.Route("/v1", func(r chi.Router) { - // Auth routes + // Auth routes (public) r.Route("/auth", func(r chi.Router) { - r.Post("/oauth/github", handlers.NotImplemented) - r.Get("/oauth/github/callback", handlers.NotImplemented) - r.Post("/oauth/google", handlers.NotImplemented) - r.Get("/oauth/google/callback", handlers.NotImplemented) - r.Post("/refresh", handlers.NotImplemented) - r.Post("/logout", handlers.NotImplemented) - r.Get("/me", handlers.NotImplemented) + // OAuth - use GET for initiating (redirect based) + r.Get("/oauth/github", authHandler.OAuthStart(auth.ProviderGitHub)) + r.Get("/oauth/github/callback", authHandler.OAuthCallback(auth.ProviderGitHub)) + r.Get("/oauth/google", authHandler.OAuthStart(auth.ProviderGoogle)) + r.Get("/oauth/google/callback", authHandler.OAuthCallback(auth.ProviderGoogle)) + + // Token management + r.Post("/refresh", authHandler.Refresh) + r.Post("/logout", authHandler.Logout) + + // Current user (requires auth) + r.With(authMiddleware.RequireAuth).Get("/me", authHandler.Me) }) - // Developer apps - r.Route("/apps", func(r chi.Router) { - r.Get("/", handlers.NotImplemented) - r.Post("/", handlers.NotImplemented) - r.Get("/{appID}", handlers.NotImplemented) - r.Patch("/{appID}", handlers.NotImplemented) - r.Delete("/{appID}", handlers.NotImplemented) + // Protected developer routes + r.Group(func(r chi.Router) { + r.Use(authMiddleware.RequireAuth) - // Versions - r.Route("/{appID}/versions", func(r chi.Router) { + // Developer apps + r.Route("/apps", func(r chi.Router) { r.Get("/", handlers.NotImplemented) r.Post("/", handlers.NotImplemented) - r.Get("/{versionID}", handlers.NotImplemented) - r.Post("/{versionID}/submit", handlers.NotImplemented) - r.Post("/{versionID}/publish", handlers.NotImplemented) + r.Get("/{appID}", handlers.NotImplemented) + r.Patch("/{appID}", handlers.NotImplemented) + r.Delete("/{appID}", handlers.NotImplemented) + + // Versions + r.Route("/{appID}/versions", func(r chi.Router) { + r.Get("/", handlers.NotImplemented) + r.Post("/", handlers.NotImplemented) + r.Get("/{versionID}", handlers.NotImplemented) + r.Post("/{versionID}/submit", handlers.NotImplemented) + r.Post("/{versionID}/publish", handlers.NotImplemented) + }) }) - }) - // API Keys - r.Route("/api-keys", func(r chi.Router) { - r.Get("/", handlers.NotImplemented) - r.Post("/", handlers.NotImplemented) - r.Delete("/{keyID}", handlers.NotImplemented) - }) + // API Keys + r.Route("/api-keys", func(r chi.Router) { + r.Get("/", handlers.NotImplemented) + r.Post("/", handlers.NotImplemented) + r.Delete("/{keyID}", handlers.NotImplemented) + }) - // Signing Keys - r.Route("/signing-keys", func(r chi.Router) { - r.Get("/", handlers.NotImplemented) - r.Post("/", handlers.NotImplemented) - r.Delete("/{keyID}", handlers.NotImplemented) + // Signing Keys + r.Route("/signing-keys", func(r chi.Router) { + r.Get("/", handlers.NotImplemented) + r.Post("/", handlers.NotImplemented) + r.Delete("/{keyID}", handlers.NotImplemented) + }) }) // Public store endpoints @@ -80,15 +102,16 @@ func NewRouter(cfg *config.Config, db *sql.DB) http.Handler { r.Get("/apps/updates", handlers.NotImplemented) }) - // Telemetry + // Telemetry (API key auth preferred, but can work without for initial setup) r.Route("/telemetry", func(r chi.Router) { r.Post("/events", handlers.NotImplemented) r.Post("/crash", handlers.NotImplemented) }) }) - // Admin routes (htmx UI) + // Admin routes (htmx UI) - requires auth r.Route("/admin", func(r chi.Router) { + r.Use(authMiddleware.RequireAuth) r.Get("/", handlers.NotImplemented) r.Get("/review-queue", handlers.NotImplemented) r.Get("/review/{versionID}", handlers.NotImplemented) diff --git a/portal/internal/auth/jwt.go b/portal/internal/auth/jwt.go new file mode 100644 index 0000000..65114cc --- /dev/null +++ b/portal/internal/auth/jwt.go @@ -0,0 +1,145 @@ +// Package auth provides authentication and authorization functionality +package auth + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "time" + + "github.com/golang-jwt/jwt/v5" +) + +var ( + ErrInvalidToken = errors.New("invalid token") + ErrExpiredToken = errors.New("token has expired") +) + +// TokenPair contains access and refresh tokens +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` +} + +// Claims represents JWT claims for access tokens +type Claims struct { + jwt.RegisteredClaims + Type string `json:"type"` + Email string `json:"email,omitempty"` +} + +// JWTManager handles JWT token operations +type JWTManager struct { + secretKey []byte + accessTokenExpiry time.Duration + refreshTokenExpiry time.Duration +} + +// NewJWTManager creates a new JWT manager +func NewJWTManager(secret string) *JWTManager { + return &JWTManager{ + secretKey: []byte(secret), + accessTokenExpiry: time.Hour, + refreshTokenExpiry: 30 * 24 * time.Hour, + } +} + +// GenerateTokenPair creates a new access/refresh token pair +func (m *JWTManager) GenerateTokenPair(developerID, email string) (*TokenPair, error) { + accessToken, err := m.generateToken(developerID, email, "access", m.accessTokenExpiry) + if err != nil { + return nil, fmt.Errorf("generate access token: %w", err) + } + + refreshToken, err := m.generateToken(developerID, email, "refresh", m.refreshTokenExpiry) + if err != nil { + return nil, fmt.Errorf("generate refresh token: %w", err) + } + + return &TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + TokenType: "Bearer", + ExpiresIn: int64(m.accessTokenExpiry.Seconds()), + }, nil +} + +func (m *JWTManager) generateToken(subject, email, tokenType string, expiry time.Duration) (string, error) { + now := time.Now() + claims := Claims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: subject, + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(expiry)), + Issuer: "mosis-portal", + }, + Type: tokenType, + Email: email, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString(m.secretKey) +} + +// ValidateAccessToken validates an access token and returns the claims +func (m *JWTManager) ValidateAccessToken(tokenString string) (*Claims, error) { + return m.validateToken(tokenString, "access") +} + +// ValidateRefreshToken validates a refresh token and returns the claims +func (m *JWTManager) ValidateRefreshToken(tokenString string) (*Claims, error) { + return m.validateToken(tokenString, "refresh") +} + +func (m *JWTManager) validateToken(tokenString, expectedType string) (*Claims, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return m.secretKey, nil + }) + + if err != nil { + if errors.Is(err, jwt.ErrTokenExpired) { + return nil, ErrExpiredToken + } + return nil, ErrInvalidToken + } + + claims, ok := token.Claims.(*Claims) + if !ok || !token.Valid { + return nil, ErrInvalidToken + } + + if claims.Type != expectedType { + return nil, ErrInvalidToken + } + + return claims, nil +} + +// GenerateAPIKey generates a new API key with the given prefix +func GenerateAPIKey(prefix string) (string, error) { + // Generate 32 random bytes + bytes := make([]byte, 32) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + + // Encode as base64url (no padding) + encoded := base64.RawURLEncoding.EncodeToString(bytes) + + return fmt.Sprintf("%s_%s", prefix, encoded), nil +} + +// GenerateState generates a random state for OAuth +func GenerateState() (string, error) { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(bytes), nil +} diff --git a/portal/internal/auth/oauth.go b/portal/internal/auth/oauth.go new file mode 100644 index 0000000..b00db5b --- /dev/null +++ b/portal/internal/auth/oauth.go @@ -0,0 +1,238 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/google" +) + +// OAuthProvider represents an OAuth2 provider +type OAuthProvider string + +const ( + ProviderGitHub OAuthProvider = "github" + ProviderGoogle OAuthProvider = "google" +) + +// OAuthUser contains user information from OAuth provider +type OAuthUser struct { + Provider OAuthProvider + ID string + Email string + Name string + Avatar string +} + +// OAuthManager handles OAuth2 authentication +type OAuthManager struct { + githubConfig *oauth2.Config + googleConfig *oauth2.Config +} + +// NewOAuthManager creates a new OAuth manager +func NewOAuthManager(baseURL, githubClientID, githubClientSecret, googleClientID, googleClientSecret string) *OAuthManager { + m := &OAuthManager{} + + if githubClientID != "" && githubClientSecret != "" { + m.githubConfig = &oauth2.Config{ + ClientID: githubClientID, + ClientSecret: githubClientSecret, + Endpoint: github.Endpoint, + Scopes: []string{"read:user", "user:email"}, + RedirectURL: baseURL + "/v1/auth/oauth/github/callback", + } + } + + if googleClientID != "" && googleClientSecret != "" { + m.googleConfig = &oauth2.Config{ + ClientID: googleClientID, + ClientSecret: googleClientSecret, + Endpoint: google.Endpoint, + Scopes: []string{"openid", "email", "profile"}, + RedirectURL: baseURL + "/v1/auth/oauth/google/callback", + } + } + + return m +} + +// GetAuthURL returns the OAuth authorization URL for the given provider +func (m *OAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) { + config, err := m.getConfig(provider) + if err != nil { + return "", err + } + return config.AuthCodeURL(state, oauth2.AccessTypeOffline), nil +} + +// Exchange exchanges an authorization code for user information +func (m *OAuthManager) Exchange(ctx context.Context, provider OAuthProvider, code string) (*OAuthUser, error) { + config, err := m.getConfig(provider) + if err != nil { + return nil, err + } + + token, err := config.Exchange(ctx, code) + if err != nil { + return nil, fmt.Errorf("exchange code: %w", err) + } + + return m.fetchUserInfo(ctx, provider, token) +} + +func (m *OAuthManager) getConfig(provider OAuthProvider) (*oauth2.Config, error) { + switch provider { + case ProviderGitHub: + if m.githubConfig == nil { + return nil, fmt.Errorf("github oauth not configured") + } + return m.githubConfig, nil + case ProviderGoogle: + if m.googleConfig == nil { + return nil, fmt.Errorf("google oauth not configured") + } + return m.googleConfig, nil + default: + return nil, fmt.Errorf("unknown provider: %s", provider) + } +} + +func (m *OAuthManager) fetchUserInfo(ctx context.Context, provider OAuthProvider, token *oauth2.Token) (*OAuthUser, error) { + switch provider { + case ProviderGitHub: + return m.fetchGitHubUser(ctx, token) + case ProviderGoogle: + return m.fetchGoogleUser(ctx, token) + default: + return nil, fmt.Errorf("unknown provider: %s", provider) + } +} + +func (m *OAuthManager) fetchGitHubUser(ctx context.Context, token *oauth2.Token) (*OAuthUser, error) { + client := m.githubConfig.Client(ctx, token) + + // Fetch user info + resp, err := client.Get("https://api.github.com/user") + if err != nil { + return nil, fmt.Errorf("fetch user: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("github api returned %d", resp.StatusCode) + } + + var ghUser struct { + ID int64 `json:"id"` + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` + AvatarURL string `json:"avatar_url"` + } + + if err := json.NewDecoder(resp.Body).Decode(&ghUser); err != nil { + return nil, fmt.Errorf("decode user: %w", err) + } + + // If email not public, fetch from emails endpoint + email := ghUser.Email + if email == "" { + email, _ = m.fetchGitHubEmail(ctx, client) + } + + name := ghUser.Name + if name == "" { + name = ghUser.Login + } + + return &OAuthUser{ + Provider: ProviderGitHub, + ID: fmt.Sprintf("%d", ghUser.ID), + Email: email, + Name: name, + Avatar: ghUser.AvatarURL, + }, nil +} + +func (m *OAuthManager) fetchGitHubEmail(ctx context.Context, client *http.Client) (string, error) { + resp, err := client.Get("https://api.github.com/user/emails") + if err != nil { + return "", err + } + defer resp.Body.Close() + + var emails []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + Verified bool `json:"verified"` + } + + if err := json.NewDecoder(resp.Body).Decode(&emails); err != nil { + return "", err + } + + // Find primary verified email + for _, e := range emails { + if e.Primary && e.Verified { + return e.Email, nil + } + } + + // Fall back to any verified email + for _, e := range emails { + if e.Verified { + return e.Email, nil + } + } + + return "", nil +} + +func (m *OAuthManager) fetchGoogleUser(ctx context.Context, token *oauth2.Token) (*OAuthUser, error) { + client := m.googleConfig.Client(ctx, token) + + resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") + if err != nil { + return nil, fmt.Errorf("fetch user: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("google api returned %d", resp.StatusCode) + } + + var gUser struct { + ID string `json:"id"` + Email string `json:"email"` + VerifiedEmail bool `json:"verified_email"` + Name string `json:"name"` + Picture string `json:"picture"` + } + + if err := json.NewDecoder(resp.Body).Decode(&gUser); err != nil { + return nil, fmt.Errorf("decode user: %w", err) + } + + return &OAuthUser{ + Provider: ProviderGoogle, + ID: gUser.ID, + Email: gUser.Email, + Name: gUser.Name, + Avatar: gUser.Picture, + }, nil +} + +// IsGitHubConfigured returns true if GitHub OAuth is configured +func (m *OAuthManager) IsGitHubConfigured() bool { + return m.githubConfig != nil +} + +// IsGoogleConfigured returns true if Google OAuth is configured +func (m *OAuthManager) IsGoogleConfigured() bool { + return m.googleConfig != nil +} diff --git a/portal/internal/database/database.go b/portal/internal/database/database.go index 2dc430d..f43398b 100644 --- a/portal/internal/database/database.go +++ b/portal/internal/database/database.go @@ -2,14 +2,42 @@ package database import ( + "context" "database/sql" "fmt" "os" "path/filepath" + "time" + "github.com/google/uuid" + "golang.org/x/crypto/bcrypt" _ "modernc.org/sqlite" ) +// DB wraps the database connection with business logic +type DB struct { + *sql.DB +} + +// Developer represents a developer account +type Developer struct { + ID string + Email string + Name string + PasswordHash string + OAuthProvider string + OAuthID string + AvatarURL string + Verified bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// NewDB creates a new DB wrapper +func NewDB(db *sql.DB) *DB { + return &DB{db} +} + // Open opens the SQLite database with WAL mode enabled func Open(path string) (*sql.DB, error) { // Ensure directory exists @@ -201,3 +229,134 @@ CREATE INDEX IF NOT EXISTS idx_crashes_app ON crash_reports(app_id, timestamp); CREATE INDEX IF NOT EXISTS idx_audit_developer ON audit_logs(developer_id); CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_logs(created_at); ` + +// FindOrCreateDeveloper finds an existing developer by email or creates a new one +func (db *DB) FindOrCreateDeveloper(ctx context.Context, dev *Developer) (*Developer, error) { + // First try to find by email + existing, err := db.GetDeveloperByEmail(ctx, dev.Email) + if err == nil { + // Update OAuth info if changed + if dev.OAuthProvider != "" && (existing.OAuthProvider != dev.OAuthProvider || existing.OAuthID != dev.OAuthID) { + _, err := db.ExecContext(ctx, ` + UPDATE developers SET oauth_provider = ?, oauth_id = ?, updated_at = datetime('now') + WHERE id = ? + `, dev.OAuthProvider, dev.OAuthID, existing.ID) + if err != nil { + return nil, fmt.Errorf("update oauth: %w", err) + } + existing.OAuthProvider = dev.OAuthProvider + existing.OAuthID = dev.OAuthID + } + return existing, nil + } + + // Create new developer + dev.ID = uuid.New().String() + _, err = db.ExecContext(ctx, ` + INSERT INTO developers (id, email, name, oauth_provider, oauth_id, verified) + VALUES (?, ?, ?, ?, ?, 1) + `, dev.ID, dev.Email, dev.Name, dev.OAuthProvider, dev.OAuthID) + if err != nil { + return nil, fmt.Errorf("create developer: %w", err) + } + + dev.Verified = true + dev.CreatedAt = time.Now() + dev.UpdatedAt = dev.CreatedAt + return dev, nil +} + +// GetDeveloper retrieves a developer by ID +func (db *DB) GetDeveloper(ctx context.Context, id string) (*Developer, error) { + row := db.QueryRowContext(ctx, ` + SELECT id, email, name, password_hash, oauth_provider, oauth_id, verified, created_at, updated_at + FROM developers WHERE id = ? + `, id) + + return scanDeveloper(row) +} + +// GetDeveloperByEmail retrieves a developer by email +func (db *DB) GetDeveloperByEmail(ctx context.Context, email string) (*Developer, error) { + row := db.QueryRowContext(ctx, ` + SELECT id, email, name, password_hash, oauth_provider, oauth_id, verified, created_at, updated_at + FROM developers WHERE email = ? + `, email) + + return scanDeveloper(row) +} + +func scanDeveloper(row *sql.Row) (*Developer, error) { + var dev Developer + var passwordHash, oauthProvider, oauthID sql.NullString + var createdAt, updatedAt string + + err := row.Scan(&dev.ID, &dev.Email, &dev.Name, &passwordHash, &oauthProvider, &oauthID, &dev.Verified, &createdAt, &updatedAt) + if err != nil { + return nil, err + } + + dev.PasswordHash = passwordHash.String + dev.OAuthProvider = oauthProvider.String + dev.OAuthID = oauthID.String + dev.CreatedAt, _ = time.Parse("2006-01-02 15:04:05", createdAt) + dev.UpdatedAt, _ = time.Parse("2006-01-02 15:04:05", updatedAt) + + return &dev, nil +} + +// ValidateAPIKey validates an API key and returns the associated developer +func (db *DB) ValidateAPIKey(ctx context.Context, key string) (*Developer, error) { + // Extract prefix (first 15 chars: mk_live_xxxxxxx) + if len(key) < 15 { + return nil, fmt.Errorf("invalid key format") + } + prefix := key[:15] + + // Find key by prefix + row := db.QueryRowContext(ctx, ` + SELECT k.key_hash, k.developer_id, k.expires_at + FROM api_keys k + WHERE k.key_prefix = ? + `, prefix) + + var keyHash, developerID string + var expiresAt sql.NullString + if err := row.Scan(&keyHash, &developerID, &expiresAt); err != nil { + return nil, fmt.Errorf("key not found") + } + + // Check expiration + if expiresAt.Valid { + expiry, err := time.Parse("2006-01-02 15:04:05", expiresAt.String) + if err == nil && time.Now().After(expiry) { + return nil, fmt.Errorf("key expired") + } + } + + // Verify key hash + if err := bcrypt.CompareHashAndPassword([]byte(keyHash), []byte(key)); err != nil { + return nil, fmt.Errorf("invalid key") + } + + // Update last used + db.ExecContext(ctx, `UPDATE api_keys SET last_used_at = datetime('now') WHERE key_prefix = ?`, prefix) + + // Get developer + return db.GetDeveloper(ctx, developerID) +} + +// LogAudit logs an audit event +func (db *DB) LogAudit(ctx context.Context, developerID, action, ipAddress, userAgent string, success bool, failureReason string) { + details := "" + if !success { + details = fmt.Sprintf(`{"success":false,"reason":"%s"}`, failureReason) + } else { + details = `{"success":true}` + } + + db.ExecContext(ctx, ` + INSERT INTO audit_logs (developer_id, action, details, ip_address, user_agent) + VALUES (?, ?, ?, ?, ?) + `, developerID, action, details, ipAddress, userAgent) +}