From 5fd9755f120293d90f352dbd654c40e66fda4500 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 2 Mar 2025 21:54:10 +0100 Subject: [PATCH] Implement scan struct --- server/internal/db/struct_query.go | 84 ++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 10 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index efa6321..3484307 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "reflect" "strings" @@ -8,11 +9,12 @@ import ( ) type DBField struct { - Name string - Value any - Type reflect.Type - useDefault bool - encrypted bool + Name string + Value any + Type reflect.Type + OriginalName string + useDefault bool + encrypted bool } func StructTagsToFields(s any) ([]DBField, error) { @@ -71,11 +73,12 @@ func StructTagsToFields(s any) ([]DBField, error) { } fields = append(fields, DBField{ - Name: tag, - Value: v.Field(i).Interface(), - Type: f.Type, - useDefault: useDefault, - encrypted: encrypted, + Name: tag, + Value: v.Field(i).Interface(), + Type: f.Type, + OriginalName: f.Name, + useDefault: useDefault, + encrypted: encrypted, }) } return fields, nil @@ -172,3 +175,64 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* return q, nil } + +func (db *database) ScanStruct(row *sql.Row, dest any) error { + // Get the fields of the destination struct + fields, err := StructTagsToFields(dest) + if err != nil { + return fmt.Errorf("failed to extract struct fields: %w", err) + } + + // Create a slice of pointers to hold the scan destinations + scanDest := make([]interface{}, len(fields)) + destVal := reflect.ValueOf(dest).Elem() + + var fieldsToDecrypt []string + nullStringIndexes := make(map[int]reflect.Value) + + for i, field := range fields { + // Find the field in the struct + structField := destVal.FieldByName(field.OriginalName) + if !structField.IsValid() { + return fmt.Errorf("struct field %s not found", field.OriginalName) + } + + if field.encrypted { + fieldsToDecrypt = append(fieldsToDecrypt, field.OriginalName) + } + + if structField.Kind() == reflect.String { + nullStringIndexes[i] = structField + + var ns sql.NullString + scanDest[i] = &ns + } else { + scanDest[i] = structField.Addr().Interface() + } + } + + // Scan the row into the destination pointers + if err := row.Scan(scanDest...); err != nil { + return err + } + + // Set null strings to nil if they are null + for i, field := range nullStringIndexes { + ns := scanDest[i].(*sql.NullString) + if ns.Valid { + field.SetString(ns.String) + } + } + + // Decrypt encrypted fields + for _, fieldName := range fieldsToDecrypt { + field := destVal.FieldByName(fieldName) + decValue, err := db.secretsService.Decrypt(field.Interface().(string)) + if err != nil { + return err + } + field.SetString(decValue) + } + + return nil +}