From c0de3538dcb7ebfeb6dd617ec4743e5b68ee1dd1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 27 Feb 2025 21:16:43 +0100 Subject: [PATCH] Implement insert struct --- server/internal/db/scanner.go | 288 ---------------------------- server/internal/db/scanner_test.go | 298 ----------------------------- server/internal/db/struct_query.go | 99 ++++++++++ 3 files changed, 99 insertions(+), 586 deletions(-) delete mode 100644 server/internal/db/scanner.go delete mode 100644 server/internal/db/scanner_test.go create mode 100644 server/internal/db/struct_query.go diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go deleted file mode 100644 index f50c296..0000000 --- a/server/internal/db/scanner.go +++ /dev/null @@ -1,288 +0,0 @@ -package db - -import ( - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" -) - -// Scanner provides methods for scanning rows into structs -type Scanner struct { - db *sql.DB - dbType DBType -} - -// NewScanner creates a new Scanner instance -func NewScanner(db *sql.DB, dbType DBType) *Scanner { - return &Scanner{ - db: db, - dbType: dbType, - } -} - -// QueryRow executes a query and scans the result into a struct -func (s *Scanner) QueryRow(dest any, q *Query) error { - row := s.db.QueryRow(q.String(), q.Args()...) - - // Handle primitive types - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - elem := v.Elem() - switch elem.Kind() { - case reflect.Struct: - return scanStruct(row, dest) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: - return row.Scan(dest) - default: - return fmt.Errorf("unsupported dest type: %T", dest) - } -} - -// Query executes a query and scans multiple results into a slice of structs -func (s *Scanner) Query(dest any, q *Query) error { - rows, err := s.db.Query(q.String(), q.Args()...) - if err != nil { - return err - } - defer rows.Close() - - return scanStructs(rows, dest) -} - -// scanStruct scans a single row into a struct -func scanStruct(row *sql.Row, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - v = v.Elem() - if v.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a struct") - } - - fields := make([]any, 0, v.NumField()) - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - return row.Scan(fields...) -} - -// scanStructs scans multiple rows into a slice of structs -func scanStructs(rows *sql.Rows, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - sliceVal := v.Elem() - if sliceVal.Kind() != reflect.Slice { - return fmt.Errorf("dest must be a pointer to a slice") - } - - elemType := sliceVal.Type().Elem() - - for rows.Next() { - newElem := reflect.New(elemType).Elem() - fields := make([]any, 0, newElem.NumField()) - - for i := 0; i < newElem.NumField(); i++ { - field := newElem.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - if err := rows.Scan(fields...); err != nil { - return err - } - - sliceVal.Set(reflect.Append(sliceVal, newElem)) - } - - return rows.Err() -} - -// ScannerEx is an extended version of Scanner with more features -type ScannerEx struct { - db *sql.DB - dbType DBType -} - -// NewScannerEx creates a new ScannerEx instance -func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx { - return &ScannerEx{ - db: db, - dbType: dbType, - } -} - -// QueryRow executes a query and scans the result into a struct -func (s *ScannerEx) QueryRow(dest any, q *Query) error { - row := s.db.QueryRow(q.String(), q.Args()...) - - // Get column names - // Note: This is a workaround since sql.Row doesn't expose column names. - // In a real implementation, you'd likely need to execute the query to get columns first. - // For simplicity, we'll infer them from the struct tags. - - return scanStructTags(row, dest) -} - -// Query executes a query and scans multiple results into a slice of structs -func (s *ScannerEx) Query(dest any, q *Query) error { - rows, err := s.db.Query(q.String(), q.Args()...) - if err != nil { - return err - } - defer rows.Close() - - return scanStructsTags(rows, dest) -} - -// getFieldMap builds a map of db column names to struct fields using struct tags -func getFieldMap(t reflect.Type) map[string]int { - fieldMap := make(map[string]int) - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - // Check db tag first - tag := field.Tag.Get("db") - if tag != "" && tag != "-" { - fieldMap[tag] = i - continue - } - - // Check json tag next - tag = field.Tag.Get("json") - if tag != "" && tag != "-" { - // Handle json tag options like omitempty - parts := strings.Split(tag, ",") - fieldMap[parts[0]] = i - continue - } - - // Default to field name with snake_case conversion - fieldMap[toSnakeCase(field.Name)] = i - } - - return fieldMap -} - -var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`) - -// toSnakeCase converts a camelCase string to snake_case -func toSnakeCase(s string) string { - return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}")) -} - -// scanStructTags scans a single row into a struct using field tags -func scanStructTags(row *sql.Row, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - v = v.Elem() - if v.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a struct") - } - - fields := make([]any, 0, v.NumField()) - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - return row.Scan(fields...) -} - -// scanStructsTags scans multiple rows into a slice of structs using field tags -func scanStructsTags(rows *sql.Rows, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - sliceVal := v.Elem() - if sliceVal.Kind() != reflect.Slice { - return fmt.Errorf("dest must be a pointer to a slice") - } - - elemType := sliceVal.Type().Elem() - isPtr := elemType.Kind() == reflect.Ptr - if isPtr { - elemType = elemType.Elem() - } - - if elemType.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a slice of structs") - } - - // Get column names - columns, err := rows.Columns() - if err != nil { - return err - } - - // Build field mapping - fieldMap := getFieldMap(elemType) - - // Prepare values slice for each scan - values := make([]any, len(columns)) - scanFields := make([]any, len(columns)) - for i := range values { - scanFields[i] = &values[i] - } - - for rows.Next() { - // Create a new struct instance - newElem := reflect.New(elemType).Elem() - - // Scan row into values - if err := rows.Scan(scanFields...); err != nil { - return err - } - - // Map values to struct fields - for i, colName := range columns { - if fieldIndex, ok := fieldMap[colName]; ok { - field := newElem.Field(fieldIndex) - if field.CanSet() { - val := reflect.ValueOf(values[i]) - if val.Elem().Kind() == reflect.Interface { - val = val.Elem() - } - if val.Kind() == reflect.Ptr && !val.IsNil() { - field.Set(val.Elem()) - } else if !val.IsNil() { - field.Set(val) - } - } - } - } - - // Append to result slice - if isPtr { - sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) - } else { - sliceVal.Set(reflect.Append(sliceVal, newElem)) - } - } - - return rows.Err() -} diff --git a/server/internal/db/scanner_test.go b/server/internal/db/scanner_test.go deleted file mode 100644 index a1fbc28..0000000 --- a/server/internal/db/scanner_test.go +++ /dev/null @@ -1,298 +0,0 @@ -package db_test - -import ( - "database/sql" - "testing" - "time" - - "lemma/internal/db" -) - -func TestScannerQueryRow(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - // Create a test table - _, err = testDB.TestDB().Exec(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - email TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - active BOOLEAN NOT NULL - ) - `) - if err != nil { - t.Fatalf("Failed to create test table: %v", err) - } - - type User struct { - ID int - Email string - CreatedAt time.Time - Active bool - } - - // Insert test data - now := time.Now().UTC().Truncate(time.Second) - _, err = testDB.TestDB().Exec( - "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", - 1, "test@example.com", now, true, - ) - if err != nil { - t.Fatalf("Failed to insert test data: %v", err) - } - - // Test query row success - t.Run("QueryRow success", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id = "). - Placeholder(1) - - var user User - err := scanner.QueryRow(&user, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if user.ID != 1 { - t.Errorf("Expected ID 1, got %d", user.ID) - } - if user.Email != "test@example.com" { - t.Errorf("Expected Email test@example.com, got %s", user.Email) - } - if !user.CreatedAt.Equal(now) { - t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt) - } - if !user.Active { - t.Errorf("Expected Active true, got %v", user.Active) - } - }) - - // Test query row no results - t.Run("QueryRow no results", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id = "). - Placeholder(999) - - var user User - err := scanner.QueryRow(&user, q) - - if err != sql.ErrNoRows { - t.Errorf("Expected ErrNoRows, got %v", err) - } - }) - - // Test scanning a single value - t.Run("QueryRow single value", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("COUNT(*)").From("users") - - var count int - err := scanner.QueryRow(&count, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if count != 1 { - t.Errorf("Expected count 1, got %d", count) - } - }) -} - -func TestScannerQuery(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - // Create a test table - _, err = testDB.TestDB().Exec(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - email TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - active BOOLEAN NOT NULL - ) - `) - if err != nil { - t.Fatalf("Failed to create test table: %v", err) - } - - type User struct { - ID int - Email string - CreatedAt time.Time - Active bool - } - - // Insert test data - now := time.Now().UTC().Truncate(time.Second) - testUsers := []User{ - {ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true}, - {ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false}, - {ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true}, - } - - for _, user := range testUsers { - _, err = testDB.TestDB().Exec( - "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", - user.ID, user.Email, user.CreatedAt, user.Active, - ) - if err != nil { - t.Fatalf("Failed to insert test data: %v", err) - } - } - - // Test query multiple rows - t.Run("Query multiple rows", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - OrderBy("id ASC") - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != len(testUsers) { - t.Errorf("Expected %d users, got %d", len(testUsers), len(users)) - } - - for i, u := range users { - if u.ID != testUsers[i].ID { - t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID) - } - if u.Email != testUsers[i].Email { - t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email) - } - if !u.CreatedAt.Equal(testUsers[i].CreatedAt) { - t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt) - } - if u.Active != testUsers[i].Active { - t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active) - } - } - }) - - // Test query with filter - t.Run("Query with filter", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("active = "). - Placeholder(true). - OrderBy("id ASC") - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != 2 { - t.Errorf("Expected 2 users, got %d", len(users)) - } - - for _, u := range users { - if !u.Active { - t.Errorf("Expected only active users, got inactive user: %+v", u) - } - } - }) - - // Test query empty result - t.Run("Query empty result", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id > "). - Placeholder(100) - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != 0 { - t.Errorf("Expected 0 users, got %d", len(users)) - } - }) -} - -func TestScanErrors(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("1") - - // Test non-pointer - t.Run("QueryRow non-pointer", func(t *testing.T) { - var user struct{} - err := scanner.QueryRow(user, q) // Passing non-pointer - - if err == nil { - t.Error("Expected error for non-pointer, got nil") - } - }) - - // Test pointer to non-slice for Query - t.Run("Query pointer to non-slice", func(t *testing.T) { - var user struct{} - err := scanner.Query(&user, q) // Passing pointer to struct, not slice - - if err == nil { - t.Error("Expected error for non-slice pointer, got nil") - } - }) - - // Test non-pointer for Query - t.Run("Query non-pointer", func(t *testing.T) { - var users []struct{} - err := scanner.Query(users, q) // Passing non-pointer - - if err == nil { - t.Error("Expected error for non-pointer, got nil") - } - }) -} - -// Mock secrets service for testing -type mockSecretsService struct{} - -func (m *mockSecretsService) Encrypt(plaintext string) (string, error) { - return plaintext, nil -} - -func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) { - return ciphertext, nil -} diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go new file mode 100644 index 0000000..b465b52 --- /dev/null +++ b/server/internal/db/struct_query.go @@ -0,0 +1,99 @@ +package db + +import ( + "fmt" + "reflect" + "strings" + "unicode" +) + +type DBField struct { + Name string + Value any + Type reflect.Type +} + +func StructTagsToFields(s any) ([]DBField, error) { + v := reflect.ValueOf(s) + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, fmt.Errorf("nil pointer provided") + } + + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("provided value is %s, expected struct", v.Kind()) + } + + t := v.Type() + fields := make([]DBField, 0, t.NumField()) + + for i := range t.NumField() { + f := t.Field(i) + + if !f.IsExported() { + continue + } + + tag := f.Tag.Get("db") + if tag == "-" { + continue + } + + if tag == "" { + tag = toSnakeCase(f.Name) + } + + if strings.Contains(tag, "omitempty") && reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) { + continue + } + + fields = append(fields, DBField{ + Name: tag, + Value: v.Field(i).Interface(), + Type: f.Type, + }) + } + return fields, nil +} + +func toSnakeCase(s string) string { + var res string + + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + res += "_" + } + res += string(unicode.ToLower(r)) + } else { + res += string(r) + } + } + return res +} + +func (q *Query) InsertStruct(s any, table string) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + columns := make([]string, 0, len(fields)) + values := make([]any, 0, len(fields)) + + for _, f := range fields { + columns = append(columns, f.Name) + values = append(values, f.Value) + } + + if len(columns) == 0 { + return nil, fmt.Errorf("no columns to insert") + } + + q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) + return q, nil +}