diff --git a/server/internal/db/db.go b/server/internal/db/db.go index e171158..e4e873b 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -76,12 +76,18 @@ type SystemStore interface { SetSystemSetting(key, value string) error } +type StructScanner interface { + ScanStruct(row *sql.Row, dest interface{}) error + ScanStructs(rows *sql.Rows, dest interface{}) error +} + // Database defines the methods for interacting with the database type Database interface { UserStore WorkspaceStore SessionStore SystemStore + StructScanner Begin() (*sql.Tx, error) Close() error Migrate() error @@ -101,6 +107,7 @@ var ( // Sub-interfaces _ WorkspaceReader = (*database)(nil) _ WorkspaceWriter = (*database)(nil) + _ StructScanner = (*database)(nil) ) var logger logging.Logger diff --git a/server/internal/db/struct_query_test.go b/server/internal/db/struct_query_test.go new file mode 100644 index 0000000..7c28799 --- /dev/null +++ b/server/internal/db/struct_query_test.go @@ -0,0 +1,507 @@ +package db_test + +import ( + "reflect" + "testing" + "time" + + "lemma/internal/db" + "lemma/internal/models" + _ "lemma/internal/testenv" +) + +// TestStructTagsToFields tests the exported StructTagsToFields function +func TestStructTagsToFields(t *testing.T) { + type testStruct struct { + ID int `db:"id"` + Name string `db:"custom_name"` + CreatedAt time.Time `db:"created_at,default"` + Skip string `db:"-"` + Empty string `db:"empty,omitempty"` + Secret string `db:"secret,encrypted"` + NoTag string + } + + tests := []struct { + name string + input interface{} + wantFields int + wantErr bool + }{ + { + name: "valid struct", + input: testStruct{ + ID: 1, + Name: "Test", + CreatedAt: time.Now(), + Skip: "skip me", + Secret: "secret value", + NoTag: "no tag", + }, + wantFields: 5, // ID, Name, CreatedAt, Secret, NoTag (Empty is omitted) + wantErr: false, + }, + { + name: "nil pointer", + input: (*testStruct)(nil), + wantFields: 0, + wantErr: true, + }, + { + name: "non-struct", + input: "not a struct", + wantFields: 0, + wantErr: true, + }, + { + name: "struct pointer", + input: &testStruct{ + ID: 2, + Name: "Test Pointer", + }, + wantFields: 5, // Same fields as above + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fields, err := db.StructTagsToFields(tt.input) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(fields) != tt.wantFields { + t.Errorf("Expected %d fields, got %d", tt.wantFields, len(fields)) + } + + // Check specific field handling for valid struct test + if tt.name == "valid struct" { + // Find fields by name + var idField, nameField, createdAtField, secretField, emptyField, noTagField *db.DBField + for i := range fields { + f := &fields[i] + switch f.Name { + case "id": + idField = f + case "custom_name": + nameField = f + case "created_at": + createdAtField = f + case "secret": + secretField = f + case "empty": + emptyField = f + case "no_tag": + noTagField = f + } + } + + // Check fields exist + if idField == nil { + t.Error("ID field not found") + } + if nameField == nil { + t.Error("Name field not found") + } + if createdAtField == nil { + t.Error("CreatedAt field not found") + } + if secretField == nil { + t.Error("Secret field not found") + } + if noTagField == nil { + t.Error("NoTag field not found") + } + if emptyField != nil { + t.Error("Empty field should be omitted") + } + + // Check original names + if idField != nil && idField.OriginalName != "ID" { + t.Errorf("Expected OriginalName 'ID', got '%s'", idField.OriginalName) + } + if nameField != nil && nameField.OriginalName != "Name" { + t.Errorf("Expected OriginalName 'Name', got '%s'", nameField.OriginalName) + } + } + }) + } +} + +// TestStructQueries tests the struct-based query methods using the test database +func TestStructQueries(t *testing.T) { + // Setup test database + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Define test data + user := &models.User{ + Email: "structquery@example.com", + DisplayName: "Struct Query Test", + PasswordHash: "hashed_password", + Role: models.RoleEditor, + } + + t.Run("InsertStructQuery", func(t *testing.T) { + // Insert user with struct query + createdUser, err := database.CreateUser(user) + if err != nil { + t.Fatalf("Failed to create user with struct query: %v", err) + } + + // Verify user was created with proper values + if createdUser.ID == 0 { + t.Error("Expected non-zero user ID") + } + + if createdUser.Email != user.Email { + t.Errorf("Email = %v, want %v", createdUser.Email, user.Email) + } + + if createdUser.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", createdUser.DisplayName, user.DisplayName) + } + + if createdUser.Role != user.Role { + t.Errorf("Role = %v, want %v", createdUser.Role, user.Role) + } + + // We will use this user for the next test cases + user = createdUser + }) + + t.Run("SelectStructQuery", func(t *testing.T) { + // Get the created user + fetchedUser, err := database.GetUserByID(user.ID) + if err != nil { + t.Fatalf("Failed to get user with struct query: %v", err) + } + + // Verify fetched user matches the original + if fetchedUser.ID != user.ID { + t.Errorf("ID = %v, want %v", fetchedUser.ID, user.ID) + } + + if fetchedUser.Email != user.Email { + t.Errorf("Email = %v, want %v", fetchedUser.Email, user.Email) + } + + if fetchedUser.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", fetchedUser.DisplayName, user.DisplayName) + } + + if fetchedUser.Role != user.Role { + t.Errorf("Role = %v, want %v", fetchedUser.Role, user.Role) + } + }) + + t.Run("UpdateStructQuery", func(t *testing.T) { + // Update the user + user.DisplayName = "Updated Display Name" + user.Role = models.RoleAdmin + + err := database.UpdateUser(user) + if err != nil { + t.Fatalf("Failed to update user with struct query: %v", err) + } + + // Verify update worked + updatedUser, err := database.GetUserByID(user.ID) + if err != nil { + t.Fatalf("Failed to get updated user: %v", err) + } + + if updatedUser.DisplayName != "Updated Display Name" { + t.Errorf("DisplayName = %v, want %v", updatedUser.DisplayName, "Updated Display Name") + } + + if updatedUser.Role != models.RoleAdmin { + t.Errorf("Role = %v, want %v", updatedUser.Role, models.RoleAdmin) + } + }) + + t.Run("ScanStructs", func(t *testing.T) { + // Create another user to test multiple rows + secondUser := &models.User{ + Email: "structquery2@example.com", + DisplayName: "Struct Query Test 2", + PasswordHash: "hashed_password2", + Role: models.RoleViewer, + } + + createdUser2, err := database.CreateUser(secondUser) + if err != nil { + t.Fatalf("Failed to create second user: %v", err) + } + + // Get all users + users, err := database.GetAllUsers() + if err != nil { + t.Fatalf("Failed to get all users: %v", err) + } + + // Verify we have at least the two users we created + if len(users) < 2 { + t.Errorf("Expected at least 2 users, got %d", len(users)) + } + + // Check if both our users are in the result + foundUser1 := false + foundUser2 := false + + for _, u := range users { + if u.ID == user.ID { + foundUser1 = true + if u.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", u.DisplayName, user.DisplayName) + } + } + if u.ID == createdUser2.ID { + foundUser2 = true + if u.DisplayName != secondUser.DisplayName { + t.Errorf("DisplayName = %v, want %v", u.DisplayName, secondUser.DisplayName) + } + } + } + + if !foundUser1 { + t.Errorf("First user (ID: %d) not found in results", user.ID) + } + if !foundUser2 { + t.Errorf("Second user (ID: %d) not found in results", createdUser2.ID) + } + }) + + t.Run("ScanStruct with null values", func(t *testing.T) { + // Test handling of NULL values by creating a workspace with null values + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Null Test Workspace", + // Leave all optional fields as zero values + } + workspace.SetDefaultSettings() // This will set default values + + err := database.CreateWorkspace(workspace) + if err != nil { + t.Fatalf("Failed to create test workspace: %v", err) + } + + // Clear the GitToken to test NULL handling + testDB := database.TestDB() + _, err = testDB.Exec("UPDATE workspaces SET git_token = NULL WHERE id = ?", workspace.ID) + if err != nil { + t.Fatalf("Failed to set git_token to NULL: %v", err) + } + + // Fetch the workspace with NULL field + fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID) + if err != nil { + t.Fatalf("Failed to get workspace with NULL field: %v", err) + } + + // Verify the NULL field is empty + if fetchedWorkspace.GitToken != "" { + t.Errorf("Expected empty GitToken, got '%s'", fetchedWorkspace.GitToken) + } + }) + + t.Run("ScanStructErrors", func(t *testing.T) { + // Test error handling in ScanStruct + testDB := database.TestDB() + + // Attempt to scan too many columns into a struct with fewer fields + row := testDB.QueryRow("SELECT 1, 2, 3") + var singleField struct { + One int `db:"one"` + } + + err := database.ScanStruct(row, &singleField) + if err == nil { + t.Error("Expected error when scanning too many columns, got nil") + } + + // Test scanning into a non-struct + var notAStruct int + row = testDB.QueryRow("SELECT 1") + err = database.ScanStruct(row, ¬AStruct) + if err == nil { + t.Error("Expected error when scanning into non-struct, got nil") + } + + // Test scanning into nil + var nilPtr *struct{} + row = testDB.QueryRow("SELECT 1") + err = database.ScanStruct(row, nilPtr) + if err == nil { + t.Error("Expected error when scanning into nil pointer, got nil") + } + }) +} + +// TestScanStructsErrors tests error handling for ScanStructs +func TestScanStructsErrors(t *testing.T) { + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + testDB := database.TestDB() + + t.Run("ScanStructsWithNilRows", func(t *testing.T) { + var users []*models.User + err := database.ScanStructs(nil, &users) + if err == nil { + t.Error("Expected error with nil rows, got nil") + } + }) + + t.Run("ScanStructsWithNilDest", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var nilSlice *[]*models.User + err = database.ScanStructs(rows, nilSlice) + if err == nil { + t.Error("Expected error with nil destination, got nil") + } + }) + + t.Run("ScanStructsWithNonSlice", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var nonSlice int + err = database.ScanStructs(rows, &nonSlice) + if err == nil { + t.Error("Expected error with non-slice destination, got nil") + } + }) + + t.Run("ScanStructsWithNonStructSlice", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var intSlice []int + err = database.ScanStructs(rows, &intSlice) + if err == nil { + t.Error("Expected error with non-struct slice, got nil") + } + }) +} + +// TestEncryptedFields tests handling of encrypted fields +func TestEncryptedFields(t *testing.T) { + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Create user with workspace that has encrypted token + user, err := database.CreateUser(&models.User{ + Email: "encrypted@example.com", + DisplayName: "Encryption Test", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Create workspace with encrypted field + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Encryption Test", + Theme: "dark", + GitEnabled: true, + GitURL: "https://github.com/user/repo", + GitUser: "username", + GitToken: "secret-token", // This field is encrypted + GitCommitName: "Test User", + GitCommitEmail: "test@example.com", + } + + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("Failed to create test workspace: %v", err) + } + + // Verify our mock secrets service passed the token through unmodified + // In a real application, the token would be encrypted in the database + testDB := database.TestDB() + var rawToken string + err = testDB.QueryRow("SELECT git_token FROM workspaces WHERE id = ?", workspace.ID).Scan(&rawToken) + if err != nil { + t.Fatalf("Failed to query raw token: %v", err) + } + + // With the mock secrets service, encryption is a no-op so the token is stored as-is + if rawToken != "secret-token" { + t.Errorf("Expected raw token 'secret-token', got '%s'", rawToken) + } + + // Verify the fetched workspace has the correct token + fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID) + if err != nil { + t.Fatalf("Failed to get workspace: %v", err) + } + + if fetchedWorkspace.GitToken != "secret-token" { + t.Errorf("Expected GitToken 'secret-token', got '%s'", fetchedWorkspace.GitToken) + } +} + +// Helper function to compare slices of DBFields +func compareDBFields(t *testing.T, got, want []db.DBField) { + t.Helper() + + if len(got) != len(want) { + t.Errorf("Got %d fields, want %d", len(got), len(want)) + return + } + + for i := range got { + if got[i].Name != want[i].Name { + t.Errorf("Field %d name: got %s, want %s", i, got[i].Name, want[i].Name) + } + if got[i].OriginalName != want[i].OriginalName { + t.Errorf("Field %d original name: got %s, want %s", i, got[i].OriginalName, want[i].OriginalName) + } + if !reflect.DeepEqual(got[i].Value, want[i].Value) { + t.Errorf("Field %d value: got %v, want %v", i, got[i].Value, want[i].Value) + } + } +}