Implement scan struct

This commit is contained in:
2025-03-02 21:54:10 +01:00
parent ccac439465
commit 5fd9755f12

View File

@@ -1,6 +1,7 @@
package db
import (
"database/sql"
"fmt"
"reflect"
"strings"
@@ -11,6 +12,7 @@ type DBField struct {
Name string
Value any
Type reflect.Type
OriginalName string
useDefault bool
encrypted bool
}
@@ -74,6 +76,7 @@ func StructTagsToFields(s any) ([]DBField, error) {
Name: tag,
Value: v.Field(i).Interface(),
Type: f.Type,
OriginalName: f.Name,
useDefault: useDefault,
encrypted: encrypted,
})
@@ -172,3 +175,64 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*
return q, nil
}
func (db *database) ScanStruct(row *sql.Row, dest any) error {
// Get the fields of the destination struct
fields, err := StructTagsToFields(dest)
if err != nil {
return fmt.Errorf("failed to extract struct fields: %w", err)
}
// Create a slice of pointers to hold the scan destinations
scanDest := make([]interface{}, len(fields))
destVal := reflect.ValueOf(dest).Elem()
var fieldsToDecrypt []string
nullStringIndexes := make(map[int]reflect.Value)
for i, field := range fields {
// Find the field in the struct
structField := destVal.FieldByName(field.OriginalName)
if !structField.IsValid() {
return fmt.Errorf("struct field %s not found", field.OriginalName)
}
if field.encrypted {
fieldsToDecrypt = append(fieldsToDecrypt, field.OriginalName)
}
if structField.Kind() == reflect.String {
nullStringIndexes[i] = structField
var ns sql.NullString
scanDest[i] = &ns
} else {
scanDest[i] = structField.Addr().Interface()
}
}
// Scan the row into the destination pointers
if err := row.Scan(scanDest...); err != nil {
return err
}
// Set null strings to nil if they are null
for i, field := range nullStringIndexes {
ns := scanDest[i].(*sql.NullString)
if ns.Valid {
field.SetString(ns.String)
}
}
// Decrypt encrypted fields
for _, fieldName := range fieldsToDecrypt {
field := destVal.FieldByName(fieldName)
decValue, err := db.secretsService.Decrypt(field.Interface().(string))
if err != nil {
return err
}
field.SetString(decValue)
}
return nil
}