mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-06 07:54:22 +00:00
Implement session and system tests
This commit is contained in:
@@ -108,6 +108,11 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Enable foreign keys for this connection
|
||||
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
|
||||
@@ -49,9 +49,6 @@ var migrations = []Migration{
|
||||
{
|
||||
Version: 2,
|
||||
SQL: `
|
||||
-- Enable foreign key constraints
|
||||
PRAGMA foreign_keys = ON;
|
||||
|
||||
-- Create sessions table for authentication
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
|
||||
@@ -8,11 +8,6 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type mockSecrets struct{}
|
||||
|
||||
func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil }
|
||||
func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil }
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
|
||||
294
server/internal/db/sessions_test.go
Normal file
294
server/internal/db/sessions_test.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSessionOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Create a test user first since sessions need a valid user ID
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "test@example.com",
|
||||
DisplayName: "Test User",
|
||||
PasswordHash: "hash",
|
||||
Role: "editor",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
t.Run("CreateSession", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
session *models.Session
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session",
|
||||
session: &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid user ID",
|
||||
session: &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: 99999, // Non-existent user ID
|
||||
RefreshToken: "invalid-user-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
wantErr: true,
|
||||
errContains: "FOREIGN KEY constraint failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.CreateSession(tc.session)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was stored
|
||||
stored, err := database.GetSessionByRefreshToken(tc.session.RefreshToken)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve stored session: %v", err)
|
||||
}
|
||||
|
||||
// Compare fields
|
||||
if stored.ID != tc.session.ID {
|
||||
t.Errorf("ID = %v, want %v", stored.ID, tc.session.ID)
|
||||
}
|
||||
if stored.UserID != tc.session.UserID {
|
||||
t.Errorf("UserID = %v, want %v", stored.UserID, tc.session.UserID)
|
||||
}
|
||||
if stored.RefreshToken != tc.session.RefreshToken {
|
||||
t.Errorf("RefreshToken = %v, want %v", stored.RefreshToken, tc.session.RefreshToken)
|
||||
}
|
||||
// Compare times within a reasonable threshold
|
||||
if diff := stored.ExpiresAt.Sub(tc.session.ExpiresAt); diff > time.Second || diff < -time.Second {
|
||||
t.Errorf("ExpiresAt differs by %v, want difference less than 1s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSessionByRefreshToken", func(t *testing.T) {
|
||||
// Create test sessions
|
||||
validSession := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-get-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
expiredSession := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-token",
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour), // Expired
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
}
|
||||
|
||||
if err := database.CreateSession(validSession); err != nil {
|
||||
t.Fatalf("failed to create valid session: %v", err)
|
||||
}
|
||||
if err := database.CreateSession(expiredSession); err != nil {
|
||||
t.Fatalf("failed to create expired session: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
refreshToken string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid token",
|
||||
refreshToken: "valid-get-token",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "expired token",
|
||||
refreshToken: "expired-token",
|
||||
wantErr: true,
|
||||
errContains: "session not found or expired",
|
||||
},
|
||||
{
|
||||
name: "non-existent token",
|
||||
refreshToken: "nonexistent-token",
|
||||
wantErr: true,
|
||||
errContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := database.GetSessionByRefreshToken(tc.refreshToken)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if session.RefreshToken != tc.refreshToken {
|
||||
t.Errorf("RefreshToken = %v, want %v", session.RefreshToken, tc.refreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("DeleteSession", func(t *testing.T) {
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "delete-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := database.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
sessionID string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid session ID",
|
||||
sessionID: session.ID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent session ID",
|
||||
sessionID: "nonexistent-id",
|
||||
wantErr: true,
|
||||
errContains: "session not found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.DeleteSession(tc.sessionID)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if tc.errContains != "" && !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify session was deleted
|
||||
_, err = database.GetSessionByRefreshToken(session.RefreshToken)
|
||||
if err == nil {
|
||||
t.Error("session still exists after deletion")
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("CleanExpiredSessions", func(t *testing.T) {
|
||||
// Create a mix of valid and expired sessions
|
||||
sessions := []*models.Session{
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "valid-clean-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
},
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-clean-token-1",
|
||||
ExpiresAt: time.Now().Add(-1 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-2 * time.Hour),
|
||||
},
|
||||
{
|
||||
ID: uuid.New().String(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: "expired-clean-token-2",
|
||||
ExpiresAt: time.Now().Add(-2 * time.Hour),
|
||||
CreatedAt: time.Now().Add(-3 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
if err := database.CreateSession(s); err != nil {
|
||||
t.Fatalf("failed to create session: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean expired sessions
|
||||
if err := database.CleanExpiredSessions(); err != nil {
|
||||
t.Fatalf("failed to clean expired sessions: %v", err)
|
||||
}
|
||||
|
||||
// Verify valid session still exists
|
||||
validSession, err := database.GetSessionByRefreshToken("valid-clean-token")
|
||||
if err != nil {
|
||||
t.Errorf("valid session was unexpectedly deleted: %v", err)
|
||||
}
|
||||
if validSession == nil {
|
||||
t.Error("valid session was unexpectedly deleted")
|
||||
}
|
||||
|
||||
// Verify expired sessions were deleted
|
||||
expiredTokens := []string{"expired-clean-token-1", "expired-clean-token-2"}
|
||||
for _, token := range expiredTokens {
|
||||
if _, err := database.GetSessionByRefreshToken(token); err == nil {
|
||||
t.Errorf("expired session with token %s still exists", token)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
213
server/internal/db/system_test.go
Normal file
213
server/internal/db/system_test.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"novamd/internal/db"
|
||||
"novamd/internal/models"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSystemOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
t.Run("GetSystemSettings", func(t *testing.T) {
|
||||
t.Run("non-existent setting", func(t *testing.T) {
|
||||
_, err := database.GetSystemSetting("nonexistent-key")
|
||||
if err == nil {
|
||||
t.Error("expected error for non-existent key, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("existing setting", func(t *testing.T) {
|
||||
// First set a value
|
||||
err := database.SetSystemSetting("test-key", "test-value")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to set system setting: %v", err)
|
||||
}
|
||||
|
||||
// Then get it back
|
||||
value, err := database.GetSystemSetting("test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get system setting: %v", err)
|
||||
}
|
||||
|
||||
if value != "test-value" {
|
||||
t.Errorf("got value %q, want %q", value, "test-value")
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("SetSystemSettings", func(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
value string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "new setting",
|
||||
key: "new-key",
|
||||
value: "new-value",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "update existing setting",
|
||||
key: "update-key",
|
||||
value: "original-value",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := database.SetSystemSetting(tc.key, tc.value)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
} else if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %v", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Verify the setting was stored
|
||||
stored, err := database.GetSystemSetting(tc.key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve stored setting: %v", err)
|
||||
}
|
||||
if stored != tc.value {
|
||||
t.Errorf("got value %q, want %q", stored, tc.value)
|
||||
}
|
||||
|
||||
// For the update case, test updating the value
|
||||
if tc.name == "update existing setting" {
|
||||
newValue := "updated-value"
|
||||
err := database.SetSystemSetting(tc.key, newValue)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to update setting: %v", err)
|
||||
}
|
||||
|
||||
stored, err := database.GetSystemSetting(tc.key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to retrieve updated setting: %v", err)
|
||||
}
|
||||
if stored != newValue {
|
||||
t.Errorf("got updated value %q, want %q", stored, newValue)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("EnsureJWTSecret", func(t *testing.T) {
|
||||
// First call should generate a new secret
|
||||
secret1, err := database.EnsureJWTSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to ensure JWT secret: %v", err)
|
||||
}
|
||||
|
||||
// Verify the secret is a valid base64-encoded string of sufficient length
|
||||
decoded, err := base64.StdEncoding.DecodeString(secret1)
|
||||
if err != nil {
|
||||
t.Errorf("secret is not valid base64: %v", err)
|
||||
}
|
||||
if len(decoded) < 32 {
|
||||
t.Errorf("secret length = %d, want >= 32", len(decoded))
|
||||
}
|
||||
|
||||
// Second call should return the same secret
|
||||
secret2, err := database.EnsureJWTSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get existing JWT secret: %v", err)
|
||||
}
|
||||
|
||||
if secret2 != secret1 {
|
||||
t.Errorf("got different secret on second call")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetSystemStats", func(t *testing.T) {
|
||||
// Create some test users and sessions
|
||||
users := []*models.User{
|
||||
{
|
||||
Email: "user1@test.com",
|
||||
DisplayName: "User 1",
|
||||
PasswordHash: "hash1",
|
||||
Role: "editor",
|
||||
},
|
||||
{
|
||||
Email: "user2@test.com",
|
||||
DisplayName: "User 2",
|
||||
PasswordHash: "hash2",
|
||||
Role: "viewer",
|
||||
},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
createdUser, err := database.CreateUser(u)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Create multiple workspaces per user
|
||||
// Each user has one default workspace
|
||||
for i := 0; i < 2; i++ {
|
||||
workspace := &models.Workspace{
|
||||
UserID: createdUser.ID,
|
||||
Name: fmt.Sprintf("Workspace %d", i),
|
||||
}
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("failed to create test workspace: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create an active session for the first user
|
||||
if createdUser.Email == "user1@test.com" {
|
||||
session := &models.Session{
|
||||
ID: uuid.New().String(),
|
||||
UserID: createdUser.ID,
|
||||
RefreshToken: "test-token",
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := database.CreateSession(session); err != nil {
|
||||
t.Fatalf("failed to create test session: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := database.GetSystemStats()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get system stats: %v", err)
|
||||
}
|
||||
|
||||
// Verify stats
|
||||
if stats.TotalUsers != 2 {
|
||||
t.Errorf("TotalUsers = %d, want 2", stats.TotalUsers)
|
||||
}
|
||||
if stats.TotalWorkspaces != 6 { // 2 + 1 default workspace per user
|
||||
t.Errorf("TotalWorkspaces = %d, want 6", stats.TotalWorkspaces)
|
||||
}
|
||||
if stats.ActiveUsers != 1 { // Only user1 has an active session
|
||||
t.Errorf("ActiveUsers = %d, want 1", stats.ActiveUsers)
|
||||
}
|
||||
})
|
||||
}
|
||||
6
server/internal/db/testutil_test.go
Normal file
6
server/internal/db/testutil_test.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package db_test
|
||||
|
||||
type mockSecrets struct{}
|
||||
|
||||
func (m *mockSecrets) Encrypt(s string) (string, error) { return s, nil }
|
||||
func (m *mockSecrets) Decrypt(s string) (string, error) { return s, nil }
|
||||
@@ -69,7 +69,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name,
|
||||
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
|
||||
|
||||
@@ -24,7 +24,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
||||
user_id, name, theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
|
||||
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,
|
||||
|
||||
Reference in New Issue
Block a user