mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Implement insert struct
This commit is contained in:
@@ -1,288 +0,0 @@
|
|||||||
package db
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Scanner provides methods for scanning rows into structs
|
|
||||||
type Scanner struct {
|
|
||||||
db *sql.DB
|
|
||||||
dbType DBType
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewScanner creates a new Scanner instance
|
|
||||||
func NewScanner(db *sql.DB, dbType DBType) *Scanner {
|
|
||||||
return &Scanner{
|
|
||||||
db: db,
|
|
||||||
dbType: dbType,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRow executes a query and scans the result into a struct
|
|
||||||
func (s *Scanner) QueryRow(dest any, q *Query) error {
|
|
||||||
row := s.db.QueryRow(q.String(), q.Args()...)
|
|
||||||
|
|
||||||
// Handle primitive types
|
|
||||||
v := reflect.ValueOf(dest)
|
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("dest must be a pointer")
|
|
||||||
}
|
|
||||||
|
|
||||||
elem := v.Elem()
|
|
||||||
switch elem.Kind() {
|
|
||||||
case reflect.Struct:
|
|
||||||
return scanStruct(row, dest)
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
|
||||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
|
||||||
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
|
|
||||||
return row.Scan(dest)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported dest type: %T", dest)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes a query and scans multiple results into a slice of structs
|
|
||||||
func (s *Scanner) Query(dest any, q *Query) error {
|
|
||||||
rows, err := s.db.Query(q.String(), q.Args()...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
return scanStructs(rows, dest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanStruct scans a single row into a struct
|
|
||||||
func scanStruct(row *sql.Row, dest any) error {
|
|
||||||
v := reflect.ValueOf(dest)
|
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("dest must be a pointer")
|
|
||||||
}
|
|
||||||
v = v.Elem()
|
|
||||||
if v.Kind() != reflect.Struct {
|
|
||||||
return fmt.Errorf("dest must be a pointer to a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := make([]any, 0, v.NumField())
|
|
||||||
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
if field.CanSet() {
|
|
||||||
fields = append(fields, field.Addr().Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return row.Scan(fields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanStructs scans multiple rows into a slice of structs
|
|
||||||
func scanStructs(rows *sql.Rows, dest any) error {
|
|
||||||
v := reflect.ValueOf(dest)
|
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("dest must be a pointer")
|
|
||||||
}
|
|
||||||
|
|
||||||
sliceVal := v.Elem()
|
|
||||||
if sliceVal.Kind() != reflect.Slice {
|
|
||||||
return fmt.Errorf("dest must be a pointer to a slice")
|
|
||||||
}
|
|
||||||
|
|
||||||
elemType := sliceVal.Type().Elem()
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
newElem := reflect.New(elemType).Elem()
|
|
||||||
fields := make([]any, 0, newElem.NumField())
|
|
||||||
|
|
||||||
for i := 0; i < newElem.NumField(); i++ {
|
|
||||||
field := newElem.Field(i)
|
|
||||||
if field.CanSet() {
|
|
||||||
fields = append(fields, field.Addr().Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := rows.Scan(fields...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
sliceVal.Set(reflect.Append(sliceVal, newElem))
|
|
||||||
}
|
|
||||||
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScannerEx is an extended version of Scanner with more features
|
|
||||||
type ScannerEx struct {
|
|
||||||
db *sql.DB
|
|
||||||
dbType DBType
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewScannerEx creates a new ScannerEx instance
|
|
||||||
func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx {
|
|
||||||
return &ScannerEx{
|
|
||||||
db: db,
|
|
||||||
dbType: dbType,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRow executes a query and scans the result into a struct
|
|
||||||
func (s *ScannerEx) QueryRow(dest any, q *Query) error {
|
|
||||||
row := s.db.QueryRow(q.String(), q.Args()...)
|
|
||||||
|
|
||||||
// Get column names
|
|
||||||
// Note: This is a workaround since sql.Row doesn't expose column names.
|
|
||||||
// In a real implementation, you'd likely need to execute the query to get columns first.
|
|
||||||
// For simplicity, we'll infer them from the struct tags.
|
|
||||||
|
|
||||||
return scanStructTags(row, dest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes a query and scans multiple results into a slice of structs
|
|
||||||
func (s *ScannerEx) Query(dest any, q *Query) error {
|
|
||||||
rows, err := s.db.Query(q.String(), q.Args()...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
return scanStructsTags(rows, dest)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getFieldMap builds a map of db column names to struct fields using struct tags
|
|
||||||
func getFieldMap(t reflect.Type) map[string]int {
|
|
||||||
fieldMap := make(map[string]int)
|
|
||||||
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
field := t.Field(i)
|
|
||||||
|
|
||||||
// Check db tag first
|
|
||||||
tag := field.Tag.Get("db")
|
|
||||||
if tag != "" && tag != "-" {
|
|
||||||
fieldMap[tag] = i
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check json tag next
|
|
||||||
tag = field.Tag.Get("json")
|
|
||||||
if tag != "" && tag != "-" {
|
|
||||||
// Handle json tag options like omitempty
|
|
||||||
parts := strings.Split(tag, ",")
|
|
||||||
fieldMap[parts[0]] = i
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default to field name with snake_case conversion
|
|
||||||
fieldMap[toSnakeCase(field.Name)] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
return fieldMap
|
|
||||||
}
|
|
||||||
|
|
||||||
var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
|
||||||
|
|
||||||
// toSnakeCase converts a camelCase string to snake_case
|
|
||||||
func toSnakeCase(s string) string {
|
|
||||||
return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanStructTags scans a single row into a struct using field tags
|
|
||||||
func scanStructTags(row *sql.Row, dest any) error {
|
|
||||||
v := reflect.ValueOf(dest)
|
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("dest must be a pointer")
|
|
||||||
}
|
|
||||||
v = v.Elem()
|
|
||||||
if v.Kind() != reflect.Struct {
|
|
||||||
return fmt.Errorf("dest must be a pointer to a struct")
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := make([]any, 0, v.NumField())
|
|
||||||
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
field := v.Field(i)
|
|
||||||
if field.CanSet() {
|
|
||||||
fields = append(fields, field.Addr().Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return row.Scan(fields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanStructsTags scans multiple rows into a slice of structs using field tags
|
|
||||||
func scanStructsTags(rows *sql.Rows, dest any) error {
|
|
||||||
v := reflect.ValueOf(dest)
|
|
||||||
if v.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("dest must be a pointer")
|
|
||||||
}
|
|
||||||
|
|
||||||
sliceVal := v.Elem()
|
|
||||||
if sliceVal.Kind() != reflect.Slice {
|
|
||||||
return fmt.Errorf("dest must be a pointer to a slice")
|
|
||||||
}
|
|
||||||
|
|
||||||
elemType := sliceVal.Type().Elem()
|
|
||||||
isPtr := elemType.Kind() == reflect.Ptr
|
|
||||||
if isPtr {
|
|
||||||
elemType = elemType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if elemType.Kind() != reflect.Struct {
|
|
||||||
return fmt.Errorf("dest must be a pointer to a slice of structs")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get column names
|
|
||||||
columns, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build field mapping
|
|
||||||
fieldMap := getFieldMap(elemType)
|
|
||||||
|
|
||||||
// Prepare values slice for each scan
|
|
||||||
values := make([]any, len(columns))
|
|
||||||
scanFields := make([]any, len(columns))
|
|
||||||
for i := range values {
|
|
||||||
scanFields[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
// Create a new struct instance
|
|
||||||
newElem := reflect.New(elemType).Elem()
|
|
||||||
|
|
||||||
// Scan row into values
|
|
||||||
if err := rows.Scan(scanFields...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Map values to struct fields
|
|
||||||
for i, colName := range columns {
|
|
||||||
if fieldIndex, ok := fieldMap[colName]; ok {
|
|
||||||
field := newElem.Field(fieldIndex)
|
|
||||||
if field.CanSet() {
|
|
||||||
val := reflect.ValueOf(values[i])
|
|
||||||
if val.Elem().Kind() == reflect.Interface {
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
if val.Kind() == reflect.Ptr && !val.IsNil() {
|
|
||||||
field.Set(val.Elem())
|
|
||||||
} else if !val.IsNil() {
|
|
||||||
field.Set(val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append to result slice
|
|
||||||
if isPtr {
|
|
||||||
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
|
|
||||||
} else {
|
|
||||||
sliceVal.Set(reflect.Append(sliceVal, newElem))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
@@ -1,298 +0,0 @@
|
|||||||
package db_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"lemma/internal/db"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestScannerQueryRow(t *testing.T) {
|
|
||||||
mockSecrets := &mockSecretsService{}
|
|
||||||
testDB, err := db.NewTestDB(mockSecrets)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
|
||||||
}
|
|
||||||
defer testDB.Close()
|
|
||||||
|
|
||||||
// Create a test table
|
|
||||||
_, err = testDB.TestDB().Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
email TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP NOT NULL,
|
|
||||||
active BOOLEAN NOT NULL
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int
|
|
||||||
Email string
|
|
||||||
CreatedAt time.Time
|
|
||||||
Active bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert test data
|
|
||||||
now := time.Now().UTC().Truncate(time.Second)
|
|
||||||
_, err = testDB.TestDB().Exec(
|
|
||||||
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
|
|
||||||
1, "test@example.com", now, true,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to insert test data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test query row success
|
|
||||||
t.Run("QueryRow success", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("id", "email", "created_at", "active").
|
|
||||||
From("users").
|
|
||||||
Where("id = ").
|
|
||||||
Placeholder(1)
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err := scanner.QueryRow(&user, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.ID != 1 {
|
|
||||||
t.Errorf("Expected ID 1, got %d", user.ID)
|
|
||||||
}
|
|
||||||
if user.Email != "test@example.com" {
|
|
||||||
t.Errorf("Expected Email test@example.com, got %s", user.Email)
|
|
||||||
}
|
|
||||||
if !user.CreatedAt.Equal(now) {
|
|
||||||
t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt)
|
|
||||||
}
|
|
||||||
if !user.Active {
|
|
||||||
t.Errorf("Expected Active true, got %v", user.Active)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test query row no results
|
|
||||||
t.Run("QueryRow no results", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("id", "email", "created_at", "active").
|
|
||||||
From("users").
|
|
||||||
Where("id = ").
|
|
||||||
Placeholder(999)
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err := scanner.QueryRow(&user, q)
|
|
||||||
|
|
||||||
if err != sql.ErrNoRows {
|
|
||||||
t.Errorf("Expected ErrNoRows, got %v", err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test scanning a single value
|
|
||||||
t.Run("QueryRow single value", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("COUNT(*)").From("users")
|
|
||||||
|
|
||||||
var count int
|
|
||||||
err := scanner.QueryRow(&count, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if count != 1 {
|
|
||||||
t.Errorf("Expected count 1, got %d", count)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestScannerQuery(t *testing.T) {
|
|
||||||
mockSecrets := &mockSecretsService{}
|
|
||||||
testDB, err := db.NewTestDB(mockSecrets)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
|
||||||
}
|
|
||||||
defer testDB.Close()
|
|
||||||
|
|
||||||
// Create a test table
|
|
||||||
_, err = testDB.TestDB().Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY,
|
|
||||||
email TEXT NOT NULL,
|
|
||||||
created_at TIMESTAMP NOT NULL,
|
|
||||||
active BOOLEAN NOT NULL
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int
|
|
||||||
Email string
|
|
||||||
CreatedAt time.Time
|
|
||||||
Active bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert test data
|
|
||||||
now := time.Now().UTC().Truncate(time.Second)
|
|
||||||
testUsers := []User{
|
|
||||||
{ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true},
|
|
||||||
{ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false},
|
|
||||||
{ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, user := range testUsers {
|
|
||||||
_, err = testDB.TestDB().Exec(
|
|
||||||
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
|
|
||||||
user.ID, user.Email, user.CreatedAt, user.Active,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to insert test data: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test query multiple rows
|
|
||||||
t.Run("Query multiple rows", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("id", "email", "created_at", "active").
|
|
||||||
From("users").
|
|
||||||
OrderBy("id ASC")
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
err := scanner.Query(&users, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != len(testUsers) {
|
|
||||||
t.Errorf("Expected %d users, got %d", len(testUsers), len(users))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, u := range users {
|
|
||||||
if u.ID != testUsers[i].ID {
|
|
||||||
t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID)
|
|
||||||
}
|
|
||||||
if u.Email != testUsers[i].Email {
|
|
||||||
t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email)
|
|
||||||
}
|
|
||||||
if !u.CreatedAt.Equal(testUsers[i].CreatedAt) {
|
|
||||||
t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt)
|
|
||||||
}
|
|
||||||
if u.Active != testUsers[i].Active {
|
|
||||||
t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test query with filter
|
|
||||||
t.Run("Query with filter", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("id", "email", "created_at", "active").
|
|
||||||
From("users").
|
|
||||||
Where("active = ").
|
|
||||||
Placeholder(true).
|
|
||||||
OrderBy("id ASC")
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
err := scanner.Query(&users, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != 2 {
|
|
||||||
t.Errorf("Expected 2 users, got %d", len(users))
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, u := range users {
|
|
||||||
if !u.Active {
|
|
||||||
t.Errorf("Expected only active users, got inactive user: %+v", u)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test query empty result
|
|
||||||
t.Run("Query empty result", func(t *testing.T) {
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("id", "email", "created_at", "active").
|
|
||||||
From("users").
|
|
||||||
Where("id > ").
|
|
||||||
Placeholder(100)
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
err := scanner.Query(&users, q)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Expected no error, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != 0 {
|
|
||||||
t.Errorf("Expected 0 users, got %d", len(users))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestScanErrors(t *testing.T) {
|
|
||||||
mockSecrets := &mockSecretsService{}
|
|
||||||
testDB, err := db.NewTestDB(mockSecrets)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create test database: %v", err)
|
|
||||||
}
|
|
||||||
defer testDB.Close()
|
|
||||||
|
|
||||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
|
||||||
q := db.NewQuery(db.DBTypeSQLite)
|
|
||||||
q.Select("1")
|
|
||||||
|
|
||||||
// Test non-pointer
|
|
||||||
t.Run("QueryRow non-pointer", func(t *testing.T) {
|
|
||||||
var user struct{}
|
|
||||||
err := scanner.QueryRow(user, q) // Passing non-pointer
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for non-pointer, got nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test pointer to non-slice for Query
|
|
||||||
t.Run("Query pointer to non-slice", func(t *testing.T) {
|
|
||||||
var user struct{}
|
|
||||||
err := scanner.Query(&user, q) // Passing pointer to struct, not slice
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for non-slice pointer, got nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Test non-pointer for Query
|
|
||||||
t.Run("Query non-pointer", func(t *testing.T) {
|
|
||||||
var users []struct{}
|
|
||||||
err := scanner.Query(users, q) // Passing non-pointer
|
|
||||||
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error for non-pointer, got nil")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mock secrets service for testing
|
|
||||||
type mockSecretsService struct{}
|
|
||||||
|
|
||||||
func (m *mockSecretsService) Encrypt(plaintext string) (string, error) {
|
|
||||||
return plaintext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) {
|
|
||||||
return ciphertext, nil
|
|
||||||
}
|
|
||||||
99
server/internal/db/struct_query.go
Normal file
99
server/internal/db/struct_query.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DBField struct {
|
||||||
|
Name string
|
||||||
|
Value any
|
||||||
|
Type reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func StructTagsToFields(s any) ([]DBField, error) {
|
||||||
|
v := reflect.ValueOf(s)
|
||||||
|
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
if v.IsNil() {
|
||||||
|
return nil, fmt.Errorf("nil pointer provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if v.Kind() != reflect.Struct {
|
||||||
|
return nil, fmt.Errorf("provided value is %s, expected struct", v.Kind())
|
||||||
|
}
|
||||||
|
|
||||||
|
t := v.Type()
|
||||||
|
fields := make([]DBField, 0, t.NumField())
|
||||||
|
|
||||||
|
for i := range t.NumField() {
|
||||||
|
f := t.Field(i)
|
||||||
|
|
||||||
|
if !f.IsExported() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tag := f.Tag.Get("db")
|
||||||
|
if tag == "-" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if tag == "" {
|
||||||
|
tag = toSnakeCase(f.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(tag, "omitempty") && reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fields = append(fields, DBField{
|
||||||
|
Name: tag,
|
||||||
|
Value: v.Field(i).Interface(),
|
||||||
|
Type: f.Type,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return fields, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toSnakeCase(s string) string {
|
||||||
|
var res string
|
||||||
|
|
||||||
|
for i, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
if i > 0 {
|
||||||
|
res += "_"
|
||||||
|
}
|
||||||
|
res += string(unicode.ToLower(r))
|
||||||
|
} else {
|
||||||
|
res += string(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) InsertStruct(s any, table string) (*Query, error) {
|
||||||
|
fields, err := StructTagsToFields(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
columns := make([]string, 0, len(fields))
|
||||||
|
values := make([]any, 0, len(fields))
|
||||||
|
|
||||||
|
for _, f := range fields {
|
||||||
|
columns = append(columns, f.Name)
|
||||||
|
values = append(values, f.Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) == 0 {
|
||||||
|
return nil, fmt.Errorf("no columns to insert")
|
||||||
|
}
|
||||||
|
|
||||||
|
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
|
||||||
|
return q, nil
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user