Add query and scanner tests

This commit is contained in:
2025-02-24 21:38:52 +01:00
parent 7cbe6fd272
commit 96284c3dbd
4 changed files with 1217 additions and 4 deletions

View File

@@ -24,6 +24,7 @@ type Query struct {
hasWhere bool hasWhere bool
hasOrderBy bool hasOrderBy bool
hasGroupBy bool hasGroupBy bool
hasHaving bool
hasLimit bool hasLimit bool
hasOffset bool hasOffset bool
isInParens bool isInParens bool
@@ -130,6 +131,18 @@ func (q *Query) GroupBy(columns ...string) *Query {
return q return q
} }
// Having adds a HAVING clause for filtering groups
func (q *Query) Having(condition string) *Query {
if !q.hasHaving {
q.Write(" HAVING ")
q.hasHaving = true
} else {
q.Write(" AND ")
}
q.Write(condition)
return q
}
// Limit adds a LIMIT clause // Limit adds a LIMIT clause
func (q *Query) Limit(limit int) *Query { func (q *Query) Limit(limit int) *Query {
if !q.hasLimit { if !q.hasLimit {
@@ -195,7 +208,12 @@ func (q *Query) Delete() *Query {
// StartGroup starts a parenthetical group // StartGroup starts a parenthetical group
func (q *Query) StartGroup() *Query { func (q *Query) StartGroup() *Query {
q.Write(" (") if q.hasWhere {
q.Write(" AND (")
} else {
q.Write(" WHERE (")
q.hasWhere = true
}
q.parensDepth++ q.parensDepth++
return q return q
} }

View File

@@ -0,0 +1,704 @@
package db_test
import (
"reflect"
"testing"
"lemma/internal/db"
)
func TestNewQuery(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
}{
{
name: "SQLite query",
dbType: db.DBTypeSQLite,
},
{
name: "Postgres query",
dbType: db.DBTypePostgres,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
// Test that a new query is empty
if q.String() != "" {
t.Errorf("NewQuery() should return empty string, got %q", q.String())
}
if len(q.Args()) != 0 {
t.Errorf("NewQuery() should return empty args, got %v", q.Args())
}
// Test placeholder behavior - SQLite uses ? and Postgres uses $1
q.Write("test").Placeholder(1)
expectedPlaceholder := "?"
if tt.dbType == db.DBTypePostgres {
expectedPlaceholder = "$1"
}
if q.String() != "test"+expectedPlaceholder {
t.Errorf("Expected placeholder format %q for %s, got %q",
"test"+expectedPlaceholder, tt.name, q.String())
}
})
}
}
func TestBasicQueryBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Simple select SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users")
},
wantSQL: "SELECT id, name FROM users",
wantArgs: []interface{}{},
},
{
name: "Simple select Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users")
},
wantSQL: "SELECT id, name FROM users",
wantArgs: []interface{}{},
},
{
name: "Select with where SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT id, name FROM users WHERE id = ?",
wantArgs: []interface{}{1},
},
{
name: "Select with where Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT id, name FROM users WHERE id = $1",
wantArgs: []interface{}{1},
},
{
name: "Multiple where conditions SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("role = ").Placeholder("admin")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?",
wantArgs: []interface{}{true, "admin"},
},
{
name: "Multiple where conditions Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("role = ").Placeholder("admin")
},
wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2",
wantArgs: []interface{}{true, "admin"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestPlaceholders(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Single placeholder SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
},
wantSQL: "SELECT * FROM users WHERE id = ?",
wantArgs: []interface{}{42},
},
{
name: "Single placeholder Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
},
wantSQL: "SELECT * FROM users WHERE id = $1",
wantArgs: []interface{}{42},
},
{
name: "Multiple placeholders SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").
Placeholder(42).
Write(" AND name = ").
Placeholder("John")
},
wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?",
wantArgs: []interface{}{42, "John"},
},
{
name: "Multiple placeholders Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").
Placeholder(42).
Write(" AND name = ").
Placeholder("John")
},
wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2",
wantArgs: []interface{}{42, "John"},
},
{
name: "Placeholders for IN SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id IN (").
Placeholders(3).
Write(")").
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
wantArgs: []interface{}{1, 2, 3},
},
{
name: "Placeholders for IN Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id IN (").
Placeholders(3).
Write(")").
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)",
wantArgs: []interface{}{1, 2, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestWhereClauseBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Simple where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT * FROM users WHERE id = ?",
wantArgs: []interface{}{1},
},
{
name: "Where with And",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("id = ").Placeholder(1).
And("active = ").Placeholder(true)
},
wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?",
wantArgs: []interface{}{1, true},
},
{
name: "Where with Or",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("id = ").Placeholder(1).
Or("id = ").Placeholder(2)
},
wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?",
wantArgs: []interface{}{1, 2},
},
{
name: "Where with parentheses",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("(").
Write("id = ").Placeholder(1).
Or("id = ").Placeholder(2).
Write(")")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
wantArgs: []interface{}{true, 1, 2},
},
{
name: "Where with StartGroup and EndGroup",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
Write(" AND (").
Write("id = ").Placeholder(1).
Or("id = ").Placeholder(2).
Write(")")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
wantArgs: []interface{}{true, 1, 2},
},
{
name: "Where with nested groups",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("(").
Write("active = ").Placeholder(true).
Or("role = ").Placeholder("admin").
Write(")").
And("created_at > ").Placeholder("2020-01-01")
},
wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?",
wantArgs: []interface{}{true, "admin", "2020-01-01"},
},
{
name: "WhereIn",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
WhereIn("id", 3).
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
wantArgs: []interface{}{1, 2, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestJoinClauseBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Inner join",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id",
wantArgs: []interface{}{},
},
{
name: "Left join",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id",
wantArgs: []interface{}{},
},
{
name: "Multiple joins",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name", "s.role").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
Join(db.LeftJoin, "settings s", "s.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id",
wantArgs: []interface{}{},
},
{
name: "Join with where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
Where("u.active = ").Placeholder(true)
},
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?",
wantArgs: []interface{}{true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestOrderLimitOffset(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Order by",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").OrderBy("name ASC")
},
wantSQL: "SELECT * FROM users ORDER BY name ASC",
wantArgs: []interface{}{},
},
{
name: "Order by multiple columns",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC")
},
wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC",
wantArgs: []interface{}{},
},
{
name: "Limit",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Limit(10)
},
wantSQL: "SELECT * FROM users LIMIT 10",
wantArgs: []interface{}{},
},
{
name: "Limit and offset",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Limit(10).Offset(20)
},
wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20",
wantArgs: []interface{}{},
},
{
name: "Complete query with all clauses",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").
From("users").
Where("active = ").Placeholder(true).
OrderBy("name ASC").
Limit(10).
Offset(20)
},
wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20",
wantArgs: []interface{}{true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestInsertUpdateDelete(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Insert SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John", "john@example.com")
},
wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)",
wantArgs: []interface{}{"John", "john@example.com"},
},
{
name: "Insert Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John", "john@example.com")
},
wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)",
wantArgs: []interface{}{"John", "john@example.com"},
},
{
name: "Update SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("John").
Set("email").Placeholder("john@example.com").
Where("id = ").Placeholder(1)
},
wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?",
wantArgs: []interface{}{"John", "john@example.com", 1},
},
{
name: "Update Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("John").
Set("email").Placeholder("john@example.com").
Where("id = ").Placeholder(1)
},
wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3",
wantArgs: []interface{}{"John", "john@example.com", 1},
},
{
name: "Delete SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Delete().From("users").Where("id = ").Placeholder(1)
},
wantSQL: "DELETE FROM users WHERE id = ?",
wantArgs: []interface{}{1},
},
{
name: "Delete Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Delete().From("users").Where("id = ").Placeholder(1)
},
wantSQL: "DELETE FROM users WHERE id = $1",
wantArgs: []interface{}{1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestHavingClause(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Simple having",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "COUNT(*) as count").
From("employees").
GroupBy("department").
Having("count > ").Placeholder(5)
},
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?",
wantArgs: []interface{}{5},
},
{
name: "Having with multiple conditions",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "AVG(salary) as avg_salary").
From("employees").
GroupBy("department").
Having("avg_salary > ").Placeholder(50000).
And("COUNT(*) > ").Placeholder(3)
},
wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?",
wantArgs: []interface{}{50000, 3},
},
{
name: "Having with postgres placeholders",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "COUNT(*) as count").
From("employees").
GroupBy("department").
Having("count > ").Placeholder(5)
},
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1",
wantArgs: []interface{}{5},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestComplexQueries(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []interface{}
}{
{
name: "Complex select with join and where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count").
From("users u").
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id").
Where("u.active = ").Placeholder(true).
GroupBy("u.id", "u.name").
Having("COUNT(w.id) > ").Placeholder(0).
OrderBy("workspace_count DESC").
Limit(10)
},
wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10",
wantArgs: []interface{}{true, 0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType)
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}

View File

@@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect" "reflect"
"regexp"
"strings"
) )
// Scanner provides methods for scanning rows into structs // Scanner provides methods for scanning rows into structs
@@ -23,7 +25,24 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner {
// QueryRow executes a query and scans the result into a struct // QueryRow executes a query and scans the result into a struct
func (s *Scanner) QueryRow(dest interface{}, q *Query) error { func (s *Scanner) QueryRow(dest interface{}, q *Query) error {
row := s.db.QueryRow(q.String(), q.Args()...) row := s.db.QueryRow(q.String(), q.Args()...)
return scanStruct(row, dest)
// Handle primitive types
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
elem := v.Elem()
switch elem.Kind() {
case reflect.Struct:
return scanStruct(row, dest)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
return row.Scan(dest)
default:
return fmt.Errorf("unsupported dest type: %T", dest)
}
} }
// Query executes a query and scans multiple results into a slice of structs // Query executes a query and scans multiple results into a slice of structs
@@ -41,7 +60,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error {
func scanStruct(row *sql.Row, dest interface{}) error { func scanStruct(row *sql.Row, dest interface{}) error {
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer to a struct") return fmt.Errorf("dest must be a pointer")
} }
v = v.Elem() v = v.Elem()
if v.Kind() != reflect.Struct { if v.Kind() != reflect.Struct {
@@ -64,8 +83,9 @@ func scanStruct(row *sql.Row, dest interface{}) error {
func scanStructs(rows *sql.Rows, dest interface{}) error { func scanStructs(rows *sql.Rows, dest interface{}) error {
v := reflect.ValueOf(dest) v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr { if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer to a slice") return fmt.Errorf("dest must be a pointer")
} }
sliceVal := v.Elem() sliceVal := v.Elem()
if sliceVal.Kind() != reflect.Slice { if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to a slice") return fmt.Errorf("dest must be a pointer to a slice")
@@ -93,3 +113,176 @@ func scanStructs(rows *sql.Rows, dest interface{}) error {
return rows.Err() return rows.Err()
} }
// ScannerEx is an extended version of Scanner with more features
type ScannerEx struct {
db *sql.DB
dbType DBType
}
// NewScannerEx creates a new ScannerEx instance
func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx {
return &ScannerEx{
db: db,
dbType: dbType,
}
}
// QueryRow executes a query and scans the result into a struct
func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error {
row := s.db.QueryRow(q.String(), q.Args()...)
// Get column names
// Note: This is a workaround since sql.Row doesn't expose column names.
// In a real implementation, you'd likely need to execute the query to get columns first.
// For simplicity, we'll infer them from the struct tags.
return scanStructTags(row, dest)
}
// Query executes a query and scans multiple results into a slice of structs
func (s *ScannerEx) Query(dest interface{}, q *Query) error {
rows, err := s.db.Query(q.String(), q.Args()...)
if err != nil {
return err
}
defer rows.Close()
return scanStructsTags(rows, dest)
}
// getFieldMap builds a map of db column names to struct fields using struct tags
func getFieldMap(t reflect.Type) map[string]int {
fieldMap := make(map[string]int)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
// Check db tag first
tag := field.Tag.Get("db")
if tag != "" && tag != "-" {
fieldMap[tag] = i
continue
}
// Check json tag next
tag = field.Tag.Get("json")
if tag != "" && tag != "-" {
// Handle json tag options like omitempty
parts := strings.Split(tag, ",")
fieldMap[parts[0]] = i
continue
}
// Default to field name with snake_case conversion
fieldMap[toSnakeCase(field.Name)] = i
}
return fieldMap
}
var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`)
// toSnakeCase converts a camelCase string to snake_case
func toSnakeCase(s string) string {
return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}"))
}
// scanStructTags scans a single row into a struct using field tags
func scanStructTags(row *sql.Row, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to a struct")
}
fields := make([]interface{}, 0, v.NumField())
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if field.CanSet() {
fields = append(fields, field.Addr().Interface())
}
}
return row.Scan(fields...)
}
// scanStructsTags scans multiple rows into a slice of structs using field tags
func scanStructsTags(rows *sql.Rows, dest interface{}) error {
v := reflect.ValueOf(dest)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("dest must be a pointer")
}
sliceVal := v.Elem()
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("dest must be a pointer to a slice")
}
elemType := sliceVal.Type().Elem()
isPtr := elemType.Kind() == reflect.Ptr
if isPtr {
elemType = elemType.Elem()
}
if elemType.Kind() != reflect.Struct {
return fmt.Errorf("dest must be a pointer to a slice of structs")
}
// Get column names
columns, err := rows.Columns()
if err != nil {
return err
}
// Build field mapping
fieldMap := getFieldMap(elemType)
// Prepare values slice for each scan
values := make([]interface{}, len(columns))
scanFields := make([]interface{}, len(columns))
for i := range values {
scanFields[i] = &values[i]
}
for rows.Next() {
// Create a new struct instance
newElem := reflect.New(elemType).Elem()
// Scan row into values
if err := rows.Scan(scanFields...); err != nil {
return err
}
// Map values to struct fields
for i, colName := range columns {
if fieldIndex, ok := fieldMap[colName]; ok {
field := newElem.Field(fieldIndex)
if field.CanSet() {
val := reflect.ValueOf(values[i])
if val.Elem().Kind() == reflect.Interface {
val = val.Elem()
}
if val.Kind() == reflect.Ptr && !val.IsNil() {
field.Set(val.Elem())
} else if !val.IsNil() {
field.Set(val)
}
}
}
}
// Append to result slice
if isPtr {
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
} else {
sliceVal.Set(reflect.Append(sliceVal, newElem))
}
}
return rows.Err()
}

View File

@@ -0,0 +1,298 @@
package db_test
import (
"database/sql"
"testing"
"time"
"lemma/internal/db"
)
func TestScannerQueryRow(t *testing.T) {
mockSecrets := &mockSecretsService{}
testDB, err := db.NewTestDB(mockSecrets)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer testDB.Close()
// Create a test table
_, err = testDB.TestDB().Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY,
email TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
active BOOLEAN NOT NULL
)
`)
if err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
type User struct {
ID int
Email string
CreatedAt time.Time
Active bool
}
// Insert test data
now := time.Now().UTC().Truncate(time.Second)
_, err = testDB.TestDB().Exec(
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
1, "test@example.com", now, true,
)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}
// Test query row success
t.Run("QueryRow success", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("id", "email", "created_at", "active").
From("users").
Where("id = ").
Placeholder(1)
var user User
err := scanner.QueryRow(&user, q)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if user.ID != 1 {
t.Errorf("Expected ID 1, got %d", user.ID)
}
if user.Email != "test@example.com" {
t.Errorf("Expected Email test@example.com, got %s", user.Email)
}
if !user.CreatedAt.Equal(now) {
t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt)
}
if !user.Active {
t.Errorf("Expected Active true, got %v", user.Active)
}
})
// Test query row no results
t.Run("QueryRow no results", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("id", "email", "created_at", "active").
From("users").
Where("id = ").
Placeholder(999)
var user User
err := scanner.QueryRow(&user, q)
if err != sql.ErrNoRows {
t.Errorf("Expected ErrNoRows, got %v", err)
}
})
// Test scanning a single value
t.Run("QueryRow single value", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("COUNT(*)").From("users")
var count int
err := scanner.QueryRow(&count, q)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if count != 1 {
t.Errorf("Expected count 1, got %d", count)
}
})
}
func TestScannerQuery(t *testing.T) {
mockSecrets := &mockSecretsService{}
testDB, err := db.NewTestDB(mockSecrets)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer testDB.Close()
// Create a test table
_, err = testDB.TestDB().Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY,
email TEXT NOT NULL,
created_at TIMESTAMP NOT NULL,
active BOOLEAN NOT NULL
)
`)
if err != nil {
t.Fatalf("Failed to create test table: %v", err)
}
type User struct {
ID int
Email string
CreatedAt time.Time
Active bool
}
// Insert test data
now := time.Now().UTC().Truncate(time.Second)
testUsers := []User{
{ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true},
{ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false},
{ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true},
}
for _, user := range testUsers {
_, err = testDB.TestDB().Exec(
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
user.ID, user.Email, user.CreatedAt, user.Active,
)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}
}
// Test query multiple rows
t.Run("Query multiple rows", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("id", "email", "created_at", "active").
From("users").
OrderBy("id ASC")
var users []User
err := scanner.Query(&users, q)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(users) != len(testUsers) {
t.Errorf("Expected %d users, got %d", len(testUsers), len(users))
}
for i, u := range users {
if u.ID != testUsers[i].ID {
t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID)
}
if u.Email != testUsers[i].Email {
t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email)
}
if !u.CreatedAt.Equal(testUsers[i].CreatedAt) {
t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt)
}
if u.Active != testUsers[i].Active {
t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active)
}
}
})
// Test query with filter
t.Run("Query with filter", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("id", "email", "created_at", "active").
From("users").
Where("active = ").
Placeholder(true).
OrderBy("id ASC")
var users []User
err := scanner.Query(&users, q)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(users) != 2 {
t.Errorf("Expected 2 users, got %d", len(users))
}
for _, u := range users {
if !u.Active {
t.Errorf("Expected only active users, got inactive user: %+v", u)
}
}
})
// Test query empty result
t.Run("Query empty result", func(t *testing.T) {
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("id", "email", "created_at", "active").
From("users").
Where("id > ").
Placeholder(100)
var users []User
err := scanner.Query(&users, q)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if len(users) != 0 {
t.Errorf("Expected 0 users, got %d", len(users))
}
})
}
func TestScanErrors(t *testing.T) {
mockSecrets := &mockSecretsService{}
testDB, err := db.NewTestDB(mockSecrets)
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer testDB.Close()
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
q := db.NewQuery(db.DBTypeSQLite)
q.Select("1")
// Test non-pointer
t.Run("QueryRow non-pointer", func(t *testing.T) {
var user struct{}
err := scanner.QueryRow(user, q) // Passing non-pointer
if err == nil {
t.Error("Expected error for non-pointer, got nil")
}
})
// Test pointer to non-slice for Query
t.Run("Query pointer to non-slice", func(t *testing.T) {
var user struct{}
err := scanner.Query(&user, q) // Passing pointer to struct, not slice
if err == nil {
t.Error("Expected error for non-slice pointer, got nil")
}
})
// Test non-pointer for Query
t.Run("Query non-pointer", func(t *testing.T) {
var users []struct{}
err := scanner.Query(users, q) // Passing non-pointer
if err == nil {
t.Error("Expected error for non-pointer, got nil")
}
})
}
// Mock secrets service for testing
type mockSecretsService struct{}
func (m *mockSecretsService) Encrypt(plaintext string) (string, error) {
return plaintext, nil
}
func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) {
return ciphertext, nil
}