diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 04df282..a54ba4a 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -108,6 +108,11 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) { return nil, err } + // Enable foreign keys for this connection + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + return nil, err + } + database := &database{ DB: db, secretsService: secretsService, diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 6ecb3fb..f59b20f 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -49,9 +49,6 @@ var migrations = []Migration{ { Version: 2, SQL: ` - -- Enable foreign key constraints - PRAGMA foreign_keys = ON; - -- Create sessions table for authentication CREATE TABLE IF NOT EXISTS sessions ( id TEXT PRIMARY KEY, diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index 6d1a5fb..0b966c0 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -8,11 +8,6 @@ import ( _ "github.com/mattn/go-sqlite3" ) -type mockSecrets struct{} - -func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil } -func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil } - func TestMigrate(t *testing.T) { database, err := db.NewTestDB(":memory:", &mockSecrets{}) if err != nil { diff --git a/server/internal/db/sessions_test.go b/server/internal/db/sessions_test.go new file mode 100644 index 0000000..f5f87eb --- /dev/null +++ b/server/internal/db/sessions_test.go @@ -0,0 +1,294 @@ +package db_test + +import ( + "strings" + "testing" + "time" + + "novamd/internal/db" + "novamd/internal/models" + + "github.com/google/uuid" +) + +func TestSessionOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + // Create a test user first since sessions need a valid user ID + user, err := database.CreateUser(&models.User{ + Email: "test@example.com", + DisplayName: "Test User", + PasswordHash: "hash", + Role: "editor", + }) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + t.Run("CreateSession", func(t *testing.T) { + testCases := []struct { + name string + session *models.Session + wantErr bool + errContains string + }{ + { + name: "valid session", + session: &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + wantErr: false, + }, + { + name: "invalid user ID", + session: &models.Session{ + ID: uuid.New().String(), + UserID: 99999, // Non-existent user ID + RefreshToken: "invalid-user-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + wantErr: true, + errContains: "FOREIGN KEY constraint failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.CreateSession(tc.session) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify session was stored + stored, err := database.GetSessionByRefreshToken(tc.session.RefreshToken) + if err != nil { + t.Fatalf("failed to retrieve stored session: %v", err) + } + + // Compare fields + if stored.ID != tc.session.ID { + t.Errorf("ID = %v, want %v", stored.ID, tc.session.ID) + } + if stored.UserID != tc.session.UserID { + t.Errorf("UserID = %v, want %v", stored.UserID, tc.session.UserID) + } + if stored.RefreshToken != tc.session.RefreshToken { + t.Errorf("RefreshToken = %v, want %v", stored.RefreshToken, tc.session.RefreshToken) + } + // Compare times within a reasonable threshold + if diff := stored.ExpiresAt.Sub(tc.session.ExpiresAt); diff > time.Second || diff < -time.Second { + t.Errorf("ExpiresAt differs by %v, want difference less than 1s", diff) + } + }) + } + }) + + t.Run("GetSessionByRefreshToken", func(t *testing.T) { + // Create test sessions + validSession := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-get-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + expiredSession := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-token", + ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired + CreatedAt: time.Now().Add(-2 * time.Hour), + } + + if err := database.CreateSession(validSession); err != nil { + t.Fatalf("failed to create valid session: %v", err) + } + if err := database.CreateSession(expiredSession); err != nil { + t.Fatalf("failed to create expired session: %v", err) + } + + testCases := []struct { + name string + refreshToken string + wantErr bool + errContains string + }{ + { + name: "valid token", + refreshToken: "valid-get-token", + wantErr: false, + }, + { + name: "expired token", + refreshToken: "expired-token", + wantErr: true, + errContains: "session not found or expired", + }, + { + name: "non-existent token", + refreshToken: "nonexistent-token", + wantErr: true, + errContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + session, err := database.GetSessionByRefreshToken(tc.refreshToken) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if session.RefreshToken != tc.refreshToken { + t.Errorf("RefreshToken = %v, want %v", session.RefreshToken, tc.refreshToken) + } + }) + } + }) + + t.Run("DeleteSession", func(t *testing.T) { + session := &models.Session{ + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "delete-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + + if err := database.CreateSession(session); err != nil { + t.Fatalf("failed to create session: %v", err) + } + + testCases := []struct { + name string + sessionID string + wantErr bool + errContains string + }{ + { + name: "valid session ID", + sessionID: session.ID, + wantErr: false, + }, + { + name: "non-existent session ID", + sessionID: "nonexistent-id", + wantErr: true, + errContains: "session not found", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.DeleteSession(tc.sessionID) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify session was deleted + _, err = database.GetSessionByRefreshToken(session.RefreshToken) + if err == nil { + t.Error("session still exists after deletion") + } + }) + } + }) + + t.Run("CleanExpiredSessions", func(t *testing.T) { + // Create a mix of valid and expired sessions + sessions := []*models.Session{ + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "valid-clean-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + }, + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-clean-token-1", + ExpiresAt: time.Now().Add(-1 * time.Hour), + CreatedAt: time.Now().Add(-2 * time.Hour), + }, + { + ID: uuid.New().String(), + UserID: user.ID, + RefreshToken: "expired-clean-token-2", + ExpiresAt: time.Now().Add(-2 * time.Hour), + CreatedAt: time.Now().Add(-3 * time.Hour), + }, + } + + for _, s := range sessions { + if err := database.CreateSession(s); err != nil { + t.Fatalf("failed to create session: %v", err) + } + } + + // Clean expired sessions + if err := database.CleanExpiredSessions(); err != nil { + t.Fatalf("failed to clean expired sessions: %v", err) + } + + // Verify valid session still exists + validSession, err := database.GetSessionByRefreshToken("valid-clean-token") + if err != nil { + t.Errorf("valid session was unexpectedly deleted: %v", err) + } + if validSession == nil { + t.Error("valid session was unexpectedly deleted") + } + + // Verify expired sessions were deleted + expiredTokens := []string{"expired-clean-token-1", "expired-clean-token-2"} + for _, token := range expiredTokens { + if _, err := database.GetSessionByRefreshToken(token); err == nil { + t.Errorf("expired session with token %s still exists", token) + } + } + }) +} diff --git a/server/internal/db/system_test.go b/server/internal/db/system_test.go new file mode 100644 index 0000000..2e2ca03 --- /dev/null +++ b/server/internal/db/system_test.go @@ -0,0 +1,213 @@ +package db_test + +import ( + "encoding/base64" + "fmt" + "strings" + "testing" + "time" + + "novamd/internal/db" + "novamd/internal/models" + + "github.com/google/uuid" +) + +func TestSystemOperations(t *testing.T) { + database, err := db.NewTestDB(":memory:", &mockSecrets{}) + if err != nil { + t.Fatalf("failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + + t.Run("GetSystemSettings", func(t *testing.T) { + t.Run("non-existent setting", func(t *testing.T) { + _, err := database.GetSystemSetting("nonexistent-key") + if err == nil { + t.Error("expected error for non-existent key, got nil") + } + }) + + t.Run("existing setting", func(t *testing.T) { + // First set a value + err := database.SetSystemSetting("test-key", "test-value") + if err != nil { + t.Fatalf("failed to set system setting: %v", err) + } + + // Then get it back + value, err := database.GetSystemSetting("test-key") + if err != nil { + t.Fatalf("failed to get system setting: %v", err) + } + + if value != "test-value" { + t.Errorf("got value %q, want %q", value, "test-value") + } + }) + }) + + t.Run("SetSystemSettings", func(t *testing.T) { + testCases := []struct { + name string + key string + value string + wantErr bool + errContains string + }{ + { + name: "new setting", + key: "new-key", + value: "new-value", + wantErr: false, + }, + { + name: "update existing setting", + key: "update-key", + value: "original-value", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := database.SetSystemSetting(tc.key, tc.value) + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + } else if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %v", err, tc.errContains) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify the setting was stored + stored, err := database.GetSystemSetting(tc.key) + if err != nil { + t.Fatalf("failed to retrieve stored setting: %v", err) + } + if stored != tc.value { + t.Errorf("got value %q, want %q", stored, tc.value) + } + + // For the update case, test updating the value + if tc.name == "update existing setting" { + newValue := "updated-value" + err := database.SetSystemSetting(tc.key, newValue) + if err != nil { + t.Fatalf("failed to update setting: %v", err) + } + + stored, err := database.GetSystemSetting(tc.key) + if err != nil { + t.Fatalf("failed to retrieve updated setting: %v", err) + } + if stored != newValue { + t.Errorf("got updated value %q, want %q", stored, newValue) + } + } + }) + } + }) + + t.Run("EnsureJWTSecret", func(t *testing.T) { + // First call should generate a new secret + secret1, err := database.EnsureJWTSecret() + if err != nil { + t.Fatalf("failed to ensure JWT secret: %v", err) + } + + // Verify the secret is a valid base64-encoded string of sufficient length + decoded, err := base64.StdEncoding.DecodeString(secret1) + if err != nil { + t.Errorf("secret is not valid base64: %v", err) + } + if len(decoded) < 32 { + t.Errorf("secret length = %d, want >= 32", len(decoded)) + } + + // Second call should return the same secret + secret2, err := database.EnsureJWTSecret() + if err != nil { + t.Fatalf("failed to get existing JWT secret: %v", err) + } + + if secret2 != secret1 { + t.Errorf("got different secret on second call") + } + }) + + t.Run("GetSystemStats", func(t *testing.T) { + // Create some test users and sessions + users := []*models.User{ + { + Email: "user1@test.com", + DisplayName: "User 1", + PasswordHash: "hash1", + Role: "editor", + }, + { + Email: "user2@test.com", + DisplayName: "User 2", + PasswordHash: "hash2", + Role: "viewer", + }, + } + + for _, u := range users { + createdUser, err := database.CreateUser(u) + if err != nil { + t.Fatalf("failed to create test user: %v", err) + } + + // Create multiple workspaces per user + // Each user has one default workspace + for i := 0; i < 2; i++ { + workspace := &models.Workspace{ + UserID: createdUser.ID, + Name: fmt.Sprintf("Workspace %d", i), + } + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("failed to create test workspace: %v", err) + } + } + + // Create an active session for the first user + if createdUser.Email == "user1@test.com" { + session := &models.Session{ + ID: uuid.New().String(), + UserID: createdUser.ID, + RefreshToken: "test-token", + ExpiresAt: time.Now().Add(24 * time.Hour), + CreatedAt: time.Now(), + } + if err := database.CreateSession(session); err != nil { + t.Fatalf("failed to create test session: %v", err) + } + } + } + + stats, err := database.GetSystemStats() + if err != nil { + t.Fatalf("failed to get system stats: %v", err) + } + + // Verify stats + if stats.TotalUsers != 2 { + t.Errorf("TotalUsers = %d, want 2", stats.TotalUsers) + } + if stats.TotalWorkspaces != 6 { // 2 + 1 default workspace per user + t.Errorf("TotalWorkspaces = %d, want 6", stats.TotalWorkspaces) + } + if stats.ActiveUsers != 1 { // Only user1 has an active session + t.Errorf("ActiveUsers = %d, want 1", stats.ActiveUsers) + } + }) +} diff --git a/server/internal/db/testing.go b/server/internal/db/testdb.go similarity index 100% rename from server/internal/db/testing.go rename to server/internal/db/testdb.go diff --git a/server/internal/db/testutil_test.go b/server/internal/db/testutil_test.go new file mode 100644 index 0000000..103f4b5 --- /dev/null +++ b/server/internal/db/testutil_test.go @@ -0,0 +1,6 @@ +package db_test + +type mockSecrets struct{} + +func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil } +func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil } diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 17cc374..132264b 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -69,7 +69,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e theme, auto_save, show_hidden_files, git_enabled, git_url, git_user, git_token, git_auto_commit, git_commit_msg_template - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index ce39ce5..004dcf5 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -24,7 +24,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { user_id, name, theme, auto_save, show_hidden_files, git_enabled, git_url, git_user, git_token, git_auto_commit, git_commit_msg_template - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,