Add query and scanner tests

This commit is contained in:
2025-02-24 21:38:52 +01:00
parent 7cbe6fd272
commit 96284c3dbd
4 changed files with 1217 additions and 4 deletions

View File

@@ -4,6 +4,8 @@ import (
"database/sql"
"fmt"
"reflect"
"regexp"
"strings"
)
// Scanner provides methods for scanning rows into structs
@@ -23,7 +25,24 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner {
// QueryRow executes a query and scans the result into a struct
func (s *Scanner) QueryRow(dest interface{}, q *Query) error {
row := s.db.QueryRow(q.String(), q.Args()...)
return scanStruct(row, dest)
// Handle primitive types
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
elem := v.Elem()
switch elem.Kind() {
case reflect.Struct:
return scanStruct(row, dest)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
return row.Scan(dest)
default:
return fmt.Errorf("unsupported dest type: %T", dest)
}
}
// Query executes a query and scans multiple results into a slice of structs
@@ -41,7 +60,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error {
func scanStruct(row *sql.Row, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer to a struct")
return fmt.Errorf("dest must be a pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
@@ -64,8 +83,9 @@ func scanStruct(row *sql.Row, dest interface{}) error {
func scanStructs(rows *sql.Rows, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer to a slice")
return fmt.Errorf("dest must be a pointer")
}
sliceVal := v.Elem()
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to a slice")
@@ -93,3 +113,176 @@ func scanStructs(rows *sql.Rows, dest interface{}) error {
return rows.Err()
}
// ScannerEx is an extended version of Scanner with more features
type ScannerEx struct {
db *sql.DB
dbType DBType
}
// NewScannerEx creates a new ScannerEx instance
func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx {
return &ScannerEx{
db: db,
dbType: dbType,
}
}
// QueryRow executes a query and scans the result into a struct
func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error {
row := s.db.QueryRow(q.String(), q.Args()...)
// Get column names
// Note: This is a workaround since sql.Row doesn't expose column names.
// In a real implementation, you'd likely need to execute the query to get columns first.
// For simplicity, we'll infer them from the struct tags.
return scanStructTags(row, dest)
}
// Query executes a query and scans multiple results into a slice of structs
func (s *ScannerEx) Query(dest interface{}, q *Query) error {
rows, err := s.db.Query(q.String(), q.Args()...)
if err != nil {
return err
}
defer rows.Close()
return scanStructsTags(rows, dest)
}
// getFieldMap builds a map of db column names to struct fields using struct tags
func getFieldMap(t reflect.Type) map[string]int {
fieldMap := make(map[string]int)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Check db tag first
tag := field.Tag.Get("db")
if tag != "" && tag != "-" {
fieldMap[tag] = i
continue
}
// Check json tag next
tag = field.Tag.Get("json")
if tag != "" && tag != "-" {
// Handle json tag options like omitempty
parts := strings.Split(tag, ",")
fieldMap[parts[0]] = i
continue
}
// Default to field name with snake_case conversion
fieldMap[toSnakeCase(field.Name)] = i
}
return fieldMap
}
var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`)
// toSnakeCase converts a camelCase string to snake_case
func toSnakeCase(s string) string {
return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}"))
}
// scanStructTags scans a single row into a struct using field tags
func scanStructTags(row *sql.Row, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to a struct")
}
fields := make([]interface{}, 0, v.NumField())
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanSet() {
fields = append(fields, field.Addr().Interface())
}
}
return row.Scan(fields...)
}
// scanStructsTags scans multiple rows into a slice of structs using field tags
func scanStructsTags(rows *sql.Rows, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
sliceVal := v.Elem()
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to a slice")
}
elemType := sliceVal.Type().Elem()
isPtr := elemType.Kind() == reflect.Ptr
if isPtr {
elemType = elemType.Elem()
}
if elemType.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to a slice of structs")
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return err
}
// Build field mapping
fieldMap := getFieldMap(elemType)
// Prepare values slice for each scan
values := make([]interface{}, len(columns))
scanFields := make([]interface{}, len(columns))
for i := range values {
scanFields[i] = &values[i]
}
for rows.Next() {
// Create a new struct instance
newElem := reflect.New(elemType).Elem()
// Scan row into values
if err := rows.Scan(scanFields...); err != nil {
return err
}
// Map values to struct fields
for i, colName := range columns {
if fieldIndex, ok := fieldMap[colName]; ok {
field := newElem.Field(fieldIndex)
if field.CanSet() {
val := reflect.ValueOf(values[i])
if val.Elem().Kind() == reflect.Interface {
val = val.Elem()
}
if val.Kind() == reflect.Ptr && !val.IsNil() {
field.Set(val.Elem())
} else if !val.IsNil() {
field.Set(val)
}
}
}
}
// Append to result slice
if isPtr {
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
} else {
sliceVal.Set(reflect.Append(sliceVal, newElem))
}
}
return rows.Err()
}