mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Add query and scanner tests
This commit is contained in:
@@ -24,6 +24,7 @@ type Query struct {
|
||||
hasWhere bool
|
||||
hasOrderBy bool
|
||||
hasGroupBy bool
|
||||
hasHaving bool
|
||||
hasLimit bool
|
||||
hasOffset bool
|
||||
isInParens bool
|
||||
@@ -130,6 +131,18 @@ func (q *Query) GroupBy(columns ...string) *Query {
|
||||
return q
|
||||
}
|
||||
|
||||
// Having adds a HAVING clause for filtering groups
|
||||
func (q *Query) Having(condition string) *Query {
|
||||
if !q.hasHaving {
|
||||
q.Write(" HAVING ")
|
||||
q.hasHaving = true
|
||||
} else {
|
||||
q.Write(" AND ")
|
||||
}
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit adds a LIMIT clause
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
if !q.hasLimit {
|
||||
@@ -195,7 +208,12 @@ func (q *Query) Delete() *Query {
|
||||
|
||||
// StartGroup starts a parenthetical group
|
||||
func (q *Query) StartGroup() *Query {
|
||||
q.Write(" (")
|
||||
if q.hasWhere {
|
||||
q.Write(" AND (")
|
||||
} else {
|
||||
q.Write(" WHERE (")
|
||||
q.hasWhere = true
|
||||
}
|
||||
q.parensDepth++
|
||||
return q
|
||||
}
|
||||
|
||||
704
server/internal/db/query_test.go
Normal file
704
server/internal/db/query_test.go
Normal file
@@ -0,0 +1,704 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"lemma/internal/db"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
}{
|
||||
{
|
||||
name: "SQLite query",
|
||||
dbType: db.DBTypeSQLite,
|
||||
},
|
||||
{
|
||||
name: "Postgres query",
|
||||
dbType: db.DBTypePostgres,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
|
||||
// Test that a new query is empty
|
||||
if q.String() != "" {
|
||||
t.Errorf("NewQuery() should return empty string, got %q", q.String())
|
||||
}
|
||||
if len(q.Args()) != 0 {
|
||||
t.Errorf("NewQuery() should return empty args, got %v", q.Args())
|
||||
}
|
||||
|
||||
// Test placeholder behavior - SQLite uses ? and Postgres uses $1
|
||||
q.Write("test").Placeholder(1)
|
||||
|
||||
expectedPlaceholder := "?"
|
||||
if tt.dbType == db.DBTypePostgres {
|
||||
expectedPlaceholder = "$1"
|
||||
}
|
||||
|
||||
if q.String() != "test"+expectedPlaceholder {
|
||||
t.Errorf("Expected placeholder format %q for %s, got %q",
|
||||
"test"+expectedPlaceholder, tt.name, q.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicQueryBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Simple select SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users")
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Simple select Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users")
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Select with where SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users WHERE id = ?",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "Select with where Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users WHERE id = $1",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "Multiple where conditions SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("role = ").Placeholder("admin")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?",
|
||||
wantArgs: []interface{}{true, "admin"},
|
||||
},
|
||||
{
|
||||
name: "Multiple where conditions Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("role = ").Placeholder("admin")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2",
|
||||
wantArgs: []interface{}{true, "admin"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Single placeholder SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ?",
|
||||
wantArgs: []interface{}{42},
|
||||
},
|
||||
{
|
||||
name: "Single placeholder Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = $1",
|
||||
wantArgs: []interface{}{42},
|
||||
},
|
||||
{
|
||||
name: "Multiple placeholders SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").
|
||||
Placeholder(42).
|
||||
Write(" AND name = ").
|
||||
Placeholder("John")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?",
|
||||
wantArgs: []interface{}{42, "John"},
|
||||
},
|
||||
{
|
||||
name: "Multiple placeholders Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").
|
||||
Placeholder(42).
|
||||
Write(" AND name = ").
|
||||
Placeholder("John")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2",
|
||||
wantArgs: []interface{}{42, "John"},
|
||||
},
|
||||
{
|
||||
name: "Placeholders for IN SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id IN (").
|
||||
Placeholders(3).
|
||||
Write(")").
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
|
||||
wantArgs: []interface{}{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "Placeholders for IN Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id IN (").
|
||||
Placeholders(3).
|
||||
Write(")").
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)",
|
||||
wantArgs: []interface{}{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhereClauseBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Simple where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ?",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "Where with And",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
And("active = ").Placeholder(true)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?",
|
||||
wantArgs: []interface{}{1, true},
|
||||
},
|
||||
{
|
||||
name: "Where with Or",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?",
|
||||
wantArgs: []interface{}{1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with parentheses",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("(").
|
||||
Write("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2).
|
||||
Write(")")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
|
||||
wantArgs: []interface{}{true, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with StartGroup and EndGroup",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
Write(" AND (").
|
||||
Write("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2).
|
||||
Write(")")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
|
||||
wantArgs: []interface{}{true, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with nested groups",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("(").
|
||||
Write("active = ").Placeholder(true).
|
||||
Or("role = ").Placeholder("admin").
|
||||
Write(")").
|
||||
And("created_at > ").Placeholder("2020-01-01")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?",
|
||||
wantArgs: []interface{}{true, "admin", "2020-01-01"},
|
||||
},
|
||||
{
|
||||
name: "WhereIn",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
WhereIn("id", 3).
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
|
||||
wantArgs: []interface{}{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinClauseBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Inner join",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Left join",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Multiple joins",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name", "s.role").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
|
||||
Join(db.LeftJoin, "settings s", "s.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Join with where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
|
||||
Where("u.active = ").Placeholder(true)
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?",
|
||||
wantArgs: []interface{}{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderLimitOffset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Order by",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").OrderBy("name ASC")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users ORDER BY name ASC",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Order by multiple columns",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Limit",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Limit(10)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users LIMIT 10",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Limit and offset",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Limit(10).Offset(20)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20",
|
||||
wantArgs: []interface{}{},
|
||||
},
|
||||
{
|
||||
name: "Complete query with all clauses",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").
|
||||
From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
OrderBy("name ASC").
|
||||
Limit(10).
|
||||
Offset(20)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20",
|
||||
wantArgs: []interface{}{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertUpdateDelete(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Insert SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John", "john@example.com")
|
||||
},
|
||||
wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)",
|
||||
wantArgs: []interface{}{"John", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "Insert Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John", "john@example.com")
|
||||
},
|
||||
wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)",
|
||||
wantArgs: []interface{}{"John", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "Update SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("John").
|
||||
Set("email").Placeholder("john@example.com").
|
||||
Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?",
|
||||
wantArgs: []interface{}{"John", "john@example.com", 1},
|
||||
},
|
||||
{
|
||||
name: "Update Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("John").
|
||||
Set("email").Placeholder("john@example.com").
|
||||
Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3",
|
||||
wantArgs: []interface{}{"John", "john@example.com", 1},
|
||||
},
|
||||
{
|
||||
name: "Delete SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Delete().From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "DELETE FROM users WHERE id = ?",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
{
|
||||
name: "Delete Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Delete().From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "DELETE FROM users WHERE id = $1",
|
||||
wantArgs: []interface{}{1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHavingClause(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Simple having",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "COUNT(*) as count").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("count > ").Placeholder(5)
|
||||
},
|
||||
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?",
|
||||
wantArgs: []interface{}{5},
|
||||
},
|
||||
{
|
||||
name: "Having with multiple conditions",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "AVG(salary) as avg_salary").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("avg_salary > ").Placeholder(50000).
|
||||
And("COUNT(*) > ").Placeholder(3)
|
||||
},
|
||||
wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?",
|
||||
wantArgs: []interface{}{50000, 3},
|
||||
},
|
||||
{
|
||||
name: "Having with postgres placeholders",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "COUNT(*) as count").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("count > ").Placeholder(5)
|
||||
},
|
||||
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1",
|
||||
wantArgs: []interface{}{5},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexQueries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []interface{}
|
||||
}{
|
||||
{
|
||||
name: "Complex select with join and where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count").
|
||||
From("users u").
|
||||
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id").
|
||||
Where("u.active = ").Placeholder(true).
|
||||
GroupBy("u.id", "u.name").
|
||||
Having("COUNT(w.id) > ").Placeholder(0).
|
||||
OrderBy("workspace_count DESC").
|
||||
Limit(10)
|
||||
},
|
||||
wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10",
|
||||
wantArgs: []interface{}{true, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType)
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Scanner provides methods for scanning rows into structs
|
||||
@@ -23,7 +25,24 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner {
|
||||
// QueryRow executes a query and scans the result into a struct
|
||||
func (s *Scanner) QueryRow(dest interface{}, q *Query) error {
|
||||
row := s.db.QueryRow(q.String(), q.Args()...)
|
||||
return scanStruct(row, dest)
|
||||
|
||||
// Handle primitive types
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
elem := v.Elem()
|
||||
switch elem.Kind() {
|
||||
case reflect.Struct:
|
||||
return scanStruct(row, dest)
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
|
||||
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
|
||||
reflect.Float32, reflect.Float64, reflect.Bool, reflect.String:
|
||||
return row.Scan(dest)
|
||||
default:
|
||||
return fmt.Errorf("unsupported dest type: %T", dest)
|
||||
}
|
||||
}
|
||||
|
||||
// Query executes a query and scans multiple results into a slice of structs
|
||||
@@ -41,7 +60,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error {
|
||||
func scanStruct(row *sql.Row, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer to a struct")
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
@@ -64,8 +83,9 @@ func scanStruct(row *sql.Row, dest interface{}) error {
|
||||
func scanStructs(rows *sql.Rows, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer to a slice")
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
sliceVal := v.Elem()
|
||||
if sliceVal.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be a pointer to a slice")
|
||||
@@ -93,3 +113,176 @@ func scanStructs(rows *sql.Rows, dest interface{}) error {
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// ScannerEx is an extended version of Scanner with more features
|
||||
type ScannerEx struct {
|
||||
db *sql.DB
|
||||
dbType DBType
|
||||
}
|
||||
|
||||
// NewScannerEx creates a new ScannerEx instance
|
||||
func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx {
|
||||
return &ScannerEx{
|
||||
db: db,
|
||||
dbType: dbType,
|
||||
}
|
||||
}
|
||||
|
||||
// QueryRow executes a query and scans the result into a struct
|
||||
func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error {
|
||||
row := s.db.QueryRow(q.String(), q.Args()...)
|
||||
|
||||
// Get column names
|
||||
// Note: This is a workaround since sql.Row doesn't expose column names.
|
||||
// In a real implementation, you'd likely need to execute the query to get columns first.
|
||||
// For simplicity, we'll infer them from the struct tags.
|
||||
|
||||
return scanStructTags(row, dest)
|
||||
}
|
||||
|
||||
// Query executes a query and scans multiple results into a slice of structs
|
||||
func (s *ScannerEx) Query(dest interface{}, q *Query) error {
|
||||
rows, err := s.db.Query(q.String(), q.Args()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
return scanStructsTags(rows, dest)
|
||||
}
|
||||
|
||||
// getFieldMap builds a map of db column names to struct fields using struct tags
|
||||
func getFieldMap(t reflect.Type) map[string]int {
|
||||
fieldMap := make(map[string]int)
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
|
||||
// Check db tag first
|
||||
tag := field.Tag.Get("db")
|
||||
if tag != "" && tag != "-" {
|
||||
fieldMap[tag] = i
|
||||
continue
|
||||
}
|
||||
|
||||
// Check json tag next
|
||||
tag = field.Tag.Get("json")
|
||||
if tag != "" && tag != "-" {
|
||||
// Handle json tag options like omitempty
|
||||
parts := strings.Split(tag, ",")
|
||||
fieldMap[parts[0]] = i
|
||||
continue
|
||||
}
|
||||
|
||||
// Default to field name with snake_case conversion
|
||||
fieldMap[toSnakeCase(field.Name)] = i
|
||||
}
|
||||
|
||||
return fieldMap
|
||||
}
|
||||
|
||||
var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
|
||||
// toSnakeCase converts a camelCase string to snake_case
|
||||
func toSnakeCase(s string) string {
|
||||
return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}"))
|
||||
}
|
||||
|
||||
// scanStructTags scans a single row into a struct using field tags
|
||||
func scanStructTags(row *sql.Row, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
v = v.Elem()
|
||||
if v.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to a struct")
|
||||
}
|
||||
|
||||
fields := make([]interface{}, 0, v.NumField())
|
||||
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
field := v.Field(i)
|
||||
if field.CanSet() {
|
||||
fields = append(fields, field.Addr().Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return row.Scan(fields...)
|
||||
}
|
||||
|
||||
// scanStructsTags scans multiple rows into a slice of structs using field tags
|
||||
func scanStructsTags(rows *sql.Rows, dest interface{}) error {
|
||||
v := reflect.ValueOf(dest)
|
||||
if v.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("dest must be a pointer")
|
||||
}
|
||||
|
||||
sliceVal := v.Elem()
|
||||
if sliceVal.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("dest must be a pointer to a slice")
|
||||
}
|
||||
|
||||
elemType := sliceVal.Type().Elem()
|
||||
isPtr := elemType.Kind() == reflect.Ptr
|
||||
if isPtr {
|
||||
elemType = elemType.Elem()
|
||||
}
|
||||
|
||||
if elemType.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("dest must be a pointer to a slice of structs")
|
||||
}
|
||||
|
||||
// Get column names
|
||||
columns, err := rows.Columns()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build field mapping
|
||||
fieldMap := getFieldMap(elemType)
|
||||
|
||||
// Prepare values slice for each scan
|
||||
values := make([]interface{}, len(columns))
|
||||
scanFields := make([]interface{}, len(columns))
|
||||
for i := range values {
|
||||
scanFields[i] = &values[i]
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
// Create a new struct instance
|
||||
newElem := reflect.New(elemType).Elem()
|
||||
|
||||
// Scan row into values
|
||||
if err := rows.Scan(scanFields...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Map values to struct fields
|
||||
for i, colName := range columns {
|
||||
if fieldIndex, ok := fieldMap[colName]; ok {
|
||||
field := newElem.Field(fieldIndex)
|
||||
if field.CanSet() {
|
||||
val := reflect.ValueOf(values[i])
|
||||
if val.Elem().Kind() == reflect.Interface {
|
||||
val = val.Elem()
|
||||
}
|
||||
if val.Kind() == reflect.Ptr && !val.IsNil() {
|
||||
field.Set(val.Elem())
|
||||
} else if !val.IsNil() {
|
||||
field.Set(val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Append to result slice
|
||||
if isPtr {
|
||||
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
|
||||
} else {
|
||||
sliceVal.Set(reflect.Append(sliceVal, newElem))
|
||||
}
|
||||
}
|
||||
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
298
server/internal/db/scanner_test.go
Normal file
298
server/internal/db/scanner_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lemma/internal/db"
|
||||
)
|
||||
|
||||
func TestScannerQueryRow(t *testing.T) {
|
||||
mockSecrets := &mockSecretsService{}
|
||||
testDB, err := db.NewTestDB(mockSecrets)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
// Create a test table
|
||||
_, err = testDB.TestDB().Exec(`
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
active BOOLEAN NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int
|
||||
Email string
|
||||
CreatedAt time.Time
|
||||
Active bool
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
_, err = testDB.TestDB().Exec(
|
||||
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
|
||||
1, "test@example.com", now, true,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test data: %v", err)
|
||||
}
|
||||
|
||||
// Test query row success
|
||||
t.Run("QueryRow success", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("id", "email", "created_at", "active").
|
||||
From("users").
|
||||
Where("id = ").
|
||||
Placeholder(1)
|
||||
|
||||
var user User
|
||||
err := scanner.QueryRow(&user, q)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if user.ID != 1 {
|
||||
t.Errorf("Expected ID 1, got %d", user.ID)
|
||||
}
|
||||
if user.Email != "test@example.com" {
|
||||
t.Errorf("Expected Email test@example.com, got %s", user.Email)
|
||||
}
|
||||
if !user.CreatedAt.Equal(now) {
|
||||
t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt)
|
||||
}
|
||||
if !user.Active {
|
||||
t.Errorf("Expected Active true, got %v", user.Active)
|
||||
}
|
||||
})
|
||||
|
||||
// Test query row no results
|
||||
t.Run("QueryRow no results", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("id", "email", "created_at", "active").
|
||||
From("users").
|
||||
Where("id = ").
|
||||
Placeholder(999)
|
||||
|
||||
var user User
|
||||
err := scanner.QueryRow(&user, q)
|
||||
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("Expected ErrNoRows, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test scanning a single value
|
||||
t.Run("QueryRow single value", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("COUNT(*)").From("users")
|
||||
|
||||
var count int
|
||||
err := scanner.QueryRow(&count, q)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("Expected count 1, got %d", count)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestScannerQuery(t *testing.T) {
|
||||
mockSecrets := &mockSecretsService{}
|
||||
testDB, err := db.NewTestDB(mockSecrets)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
// Create a test table
|
||||
_, err = testDB.TestDB().Exec(`
|
||||
CREATE TABLE users (
|
||||
id INTEGER PRIMARY KEY,
|
||||
email TEXT NOT NULL,
|
||||
created_at TIMESTAMP NOT NULL,
|
||||
active BOOLEAN NOT NULL
|
||||
)
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
ID int
|
||||
Email string
|
||||
CreatedAt time.Time
|
||||
Active bool
|
||||
}
|
||||
|
||||
// Insert test data
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
testUsers := []User{
|
||||
{ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true},
|
||||
{ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false},
|
||||
{ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true},
|
||||
}
|
||||
|
||||
for _, user := range testUsers {
|
||||
_, err = testDB.TestDB().Exec(
|
||||
"INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)",
|
||||
user.ID, user.Email, user.CreatedAt, user.Active,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to insert test data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Test query multiple rows
|
||||
t.Run("Query multiple rows", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("id", "email", "created_at", "active").
|
||||
From("users").
|
||||
OrderBy("id ASC")
|
||||
|
||||
var users []User
|
||||
err := scanner.Query(&users, q)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(users) != len(testUsers) {
|
||||
t.Errorf("Expected %d users, got %d", len(testUsers), len(users))
|
||||
}
|
||||
|
||||
for i, u := range users {
|
||||
if u.ID != testUsers[i].ID {
|
||||
t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID)
|
||||
}
|
||||
if u.Email != testUsers[i].Email {
|
||||
t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email)
|
||||
}
|
||||
if !u.CreatedAt.Equal(testUsers[i].CreatedAt) {
|
||||
t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt)
|
||||
}
|
||||
if u.Active != testUsers[i].Active {
|
||||
t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test query with filter
|
||||
t.Run("Query with filter", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("id", "email", "created_at", "active").
|
||||
From("users").
|
||||
Where("active = ").
|
||||
Placeholder(true).
|
||||
OrderBy("id ASC")
|
||||
|
||||
var users []User
|
||||
err := scanner.Query(&users, q)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(users) != 2 {
|
||||
t.Errorf("Expected 2 users, got %d", len(users))
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if !u.Active {
|
||||
t.Errorf("Expected only active users, got inactive user: %+v", u)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Test query empty result
|
||||
t.Run("Query empty result", func(t *testing.T) {
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("id", "email", "created_at", "active").
|
||||
From("users").
|
||||
Where("id > ").
|
||||
Placeholder(100)
|
||||
|
||||
var users []User
|
||||
err := scanner.Query(&users, q)
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("Expected no error, got %v", err)
|
||||
}
|
||||
|
||||
if len(users) != 0 {
|
||||
t.Errorf("Expected 0 users, got %d", len(users))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestScanErrors(t *testing.T) {
|
||||
mockSecrets := &mockSecretsService{}
|
||||
testDB, err := db.NewTestDB(mockSecrets)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer testDB.Close()
|
||||
|
||||
scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite)
|
||||
q := db.NewQuery(db.DBTypeSQLite)
|
||||
q.Select("1")
|
||||
|
||||
// Test non-pointer
|
||||
t.Run("QueryRow non-pointer", func(t *testing.T) {
|
||||
var user struct{}
|
||||
err := scanner.QueryRow(user, q) // Passing non-pointer
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-pointer, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test pointer to non-slice for Query
|
||||
t.Run("Query pointer to non-slice", func(t *testing.T) {
|
||||
var user struct{}
|
||||
err := scanner.Query(&user, q) // Passing pointer to struct, not slice
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-slice pointer, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
// Test non-pointer for Query
|
||||
t.Run("Query non-pointer", func(t *testing.T) {
|
||||
var users []struct{}
|
||||
err := scanner.Query(users, q) // Passing non-pointer
|
||||
|
||||
if err == nil {
|
||||
t.Error("Expected error for non-pointer, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Mock secrets service for testing
|
||||
type mockSecretsService struct{}
|
||||
|
||||
func (m *mockSecretsService) Encrypt(plaintext string) (string, error) {
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) {
|
||||
return ciphertext, nil
|
||||
}
|
||||
Reference in New Issue
Block a user