diff --git a/server/internal/db/query.go b/server/internal/db/query.go index 2f380cd..af8fd7c 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -24,6 +24,7 @@ type Query struct { hasWhere bool hasOrderBy bool hasGroupBy bool + hasHaving bool hasLimit bool hasOffset bool isInParens bool @@ -130,6 +131,18 @@ func (q *Query) GroupBy(columns ...string) *Query { return q } +// Having adds a HAVING clause for filtering groups +func (q *Query) Having(condition string) *Query { + if !q.hasHaving { + q.Write(" HAVING ") + q.hasHaving = true + } else { + q.Write(" AND ") + } + q.Write(condition) + return q +} + // Limit adds a LIMIT clause func (q *Query) Limit(limit int) *Query { if !q.hasLimit { @@ -195,7 +208,12 @@ func (q *Query) Delete() *Query { // StartGroup starts a parenthetical group func (q *Query) StartGroup() *Query { - q.Write(" (") + if q.hasWhere { + q.Write(" AND (") + } else { + q.Write(" WHERE (") + q.hasWhere = true + } q.parensDepth++ return q } diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go new file mode 100644 index 0000000..6664936 --- /dev/null +++ b/server/internal/db/query_test.go @@ -0,0 +1,704 @@ +package db_test + +import ( + "reflect" + "testing" + + "lemma/internal/db" +) + +func TestNewQuery(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + }{ + { + name: "SQLite query", + dbType: db.DBTypeSQLite, + }, + { + name: "Postgres query", + dbType: db.DBTypePostgres, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + + // Test that a new query is empty + if q.String() != "" { + t.Errorf("NewQuery() should return empty string, got %q", q.String()) + } + if len(q.Args()) != 0 { + t.Errorf("NewQuery() should return empty args, got %v", q.Args()) + } + + // Test placeholder behavior - SQLite uses ? and Postgres uses $1 + q.Write("test").Placeholder(1) + + expectedPlaceholder := "?" + if tt.dbType == db.DBTypePostgres { + expectedPlaceholder = "$1" + } + + if q.String() != "test"+expectedPlaceholder { + t.Errorf("Expected placeholder format %q for %s, got %q", + "test"+expectedPlaceholder, tt.name, q.String()) + } + }) + } +} + +func TestBasicQueryBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple select SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users") + }, + wantSQL: "SELECT id, name FROM users", + wantArgs: []interface{}{}, + }, + { + name: "Simple select Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users") + }, + wantSQL: "SELECT id, name FROM users", + wantArgs: []interface{}{}, + }, + { + name: "Select with where SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT id, name FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Select with where Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT id, name FROM users WHERE id = $1", + wantArgs: []interface{}{1}, + }, + { + name: "Multiple where conditions SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("role = ").Placeholder("admin") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?", + wantArgs: []interface{}{true, "admin"}, + }, + { + name: "Multiple where conditions Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("role = ").Placeholder("admin") + }, + wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2", + wantArgs: []interface{}{true, "admin"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestPlaceholders(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Single placeholder SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) + }, + wantSQL: "SELECT * FROM users WHERE id = ?", + wantArgs: []interface{}{42}, + }, + { + name: "Single placeholder Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) + }, + wantSQL: "SELECT * FROM users WHERE id = $1", + wantArgs: []interface{}{42}, + }, + { + name: "Multiple placeholders SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = "). + Placeholder(42). + Write(" AND name = "). + Placeholder("John") + }, + wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?", + wantArgs: []interface{}{42, "John"}, + }, + { + name: "Multiple placeholders Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = "). + Placeholder(42). + Write(" AND name = "). + Placeholder("John") + }, + wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2", + wantArgs: []interface{}{42, "John"}, + }, + { + name: "Placeholders for IN SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id IN ("). + Placeholders(3). + Write(")"). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", + wantArgs: []interface{}{1, 2, 3}, + }, + { + name: "Placeholders for IN Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id IN ("). + Placeholders(3). + Write(")"). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)", + wantArgs: []interface{}{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestWhereClauseBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT * FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Where with And", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("id = ").Placeholder(1). + And("active = ").Placeholder(true) + }, + wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?", + wantArgs: []interface{}{1, true}, + }, + { + name: "Where with Or", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("id = ").Placeholder(1). + Or("id = ").Placeholder(2) + }, + wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?", + wantArgs: []interface{}{1, 2}, + }, + { + name: "Where with parentheses", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("("). + Write("id = ").Placeholder(1). + Or("id = ").Placeholder(2). + Write(")") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", + wantArgs: []interface{}{true, 1, 2}, + }, + { + name: "Where with StartGroup and EndGroup", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + Write(" AND ("). + Write("id = ").Placeholder(1). + Or("id = ").Placeholder(2). + Write(")") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", + wantArgs: []interface{}{true, 1, 2}, + }, + { + name: "Where with nested groups", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("("). + Write("active = ").Placeholder(true). + Or("role = ").Placeholder("admin"). + Write(")"). + And("created_at > ").Placeholder("2020-01-01") + }, + wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?", + wantArgs: []interface{}{true, "admin", "2020-01-01"}, + }, + { + name: "WhereIn", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + WhereIn("id", 3). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", + wantArgs: []interface{}{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestJoinClauseBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Inner join", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Left join", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.LeftJoin, "workspaces w", "w.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Multiple joins", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name", "s.role"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id"). + Join(db.LeftJoin, "settings s", "s.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Join with where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id"). + Where("u.active = ").Placeholder(true) + }, + wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?", + wantArgs: []interface{}{true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestOrderLimitOffset(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Order by", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").OrderBy("name ASC") + }, + wantSQL: "SELECT * FROM users ORDER BY name ASC", + wantArgs: []interface{}{}, + }, + { + name: "Order by multiple columns", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC") + }, + wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC", + wantArgs: []interface{}{}, + }, + { + name: "Limit", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Limit(10) + }, + wantSQL: "SELECT * FROM users LIMIT 10", + wantArgs: []interface{}{}, + }, + { + name: "Limit and offset", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Limit(10).Offset(20) + }, + wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20", + wantArgs: []interface{}{}, + }, + { + name: "Complete query with all clauses", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*"). + From("users"). + Where("active = ").Placeholder(true). + OrderBy("name ASC"). + Limit(10). + Offset(20) + }, + wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20", + wantArgs: []interface{}{true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestInsertUpdateDelete(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Insert SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John", "john@example.com") + }, + wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)", + wantArgs: []interface{}{"John", "john@example.com"}, + }, + { + name: "Insert Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John", "john@example.com") + }, + wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)", + wantArgs: []interface{}{"John", "john@example.com"}, + }, + { + name: "Update SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("John"). + Set("email").Placeholder("john@example.com"). + Where("id = ").Placeholder(1) + }, + wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?", + wantArgs: []interface{}{"John", "john@example.com", 1}, + }, + { + name: "Update Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("John"). + Set("email").Placeholder("john@example.com"). + Where("id = ").Placeholder(1) + }, + wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3", + wantArgs: []interface{}{"John", "john@example.com", 1}, + }, + { + name: "Delete SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Delete().From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "DELETE FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Delete Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Delete().From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "DELETE FROM users WHERE id = $1", + wantArgs: []interface{}{1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestHavingClause(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple having", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "COUNT(*) as count"). + From("employees"). + GroupBy("department"). + Having("count > ").Placeholder(5) + }, + wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?", + wantArgs: []interface{}{5}, + }, + { + name: "Having with multiple conditions", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "AVG(salary) as avg_salary"). + From("employees"). + GroupBy("department"). + Having("avg_salary > ").Placeholder(50000). + And("COUNT(*) > ").Placeholder(3) + }, + wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?", + wantArgs: []interface{}{50000, 3}, + }, + { + name: "Having with postgres placeholders", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "COUNT(*) as count"). + From("employees"). + GroupBy("department"). + Having("count > ").Placeholder(5) + }, + wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1", + wantArgs: []interface{}{5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestComplexQueries(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Complex select with join and where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count"). + From("users u"). + Join(db.LeftJoin, "workspaces w", "w.user_id = u.id"). + Where("u.active = ").Placeholder(true). + GroupBy("u.id", "u.name"). + Having("COUNT(w.id) > ").Placeholder(0). + OrderBy("workspace_count DESC"). + Limit(10) + }, + wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10", + wantArgs: []interface{}{true, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go index 5141d8b..97e0422 100644 --- a/server/internal/db/scanner.go +++ b/server/internal/db/scanner.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "reflect" + "regexp" + "strings" ) // Scanner provides methods for scanning rows into structs @@ -23,7 +25,24 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner { // QueryRow executes a query and scans the result into a struct func (s *Scanner) QueryRow(dest interface{}, q *Query) error { row := s.db.QueryRow(q.String(), q.Args()...) - return scanStruct(row, dest) + + // Handle primitive types + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + elem := v.Elem() + switch elem.Kind() { + case reflect.Struct: + return scanStruct(row, dest) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: + return row.Scan(dest) + default: + return fmt.Errorf("unsupported dest type: %T", dest) + } } // Query executes a query and scans multiple results into a slice of structs @@ -41,7 +60,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error { func scanStruct(row *sql.Row, dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer to a struct") + return fmt.Errorf("dest must be a pointer") } v = v.Elem() if v.Kind() != reflect.Struct { @@ -64,8 +83,9 @@ func scanStruct(row *sql.Row, dest interface{}) error { func scanStructs(rows *sql.Rows, dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer to a slice") + return fmt.Errorf("dest must be a pointer") } + sliceVal := v.Elem() if sliceVal.Kind() != reflect.Slice { return fmt.Errorf("dest must be a pointer to a slice") @@ -93,3 +113,176 @@ func scanStructs(rows *sql.Rows, dest interface{}) error { return rows.Err() } + +// ScannerEx is an extended version of Scanner with more features +type ScannerEx struct { + db *sql.DB + dbType DBType +} + +// NewScannerEx creates a new ScannerEx instance +func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx { + return &ScannerEx{ + db: db, + dbType: dbType, + } +} + +// QueryRow executes a query and scans the result into a struct +func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error { + row := s.db.QueryRow(q.String(), q.Args()...) + + // Get column names + // Note: This is a workaround since sql.Row doesn't expose column names. + // In a real implementation, you'd likely need to execute the query to get columns first. + // For simplicity, we'll infer them from the struct tags. + + return scanStructTags(row, dest) +} + +// Query executes a query and scans multiple results into a slice of structs +func (s *ScannerEx) Query(dest interface{}, q *Query) error { + rows, err := s.db.Query(q.String(), q.Args()...) + if err != nil { + return err + } + defer rows.Close() + + return scanStructsTags(rows, dest) +} + +// getFieldMap builds a map of db column names to struct fields using struct tags +func getFieldMap(t reflect.Type) map[string]int { + fieldMap := make(map[string]int) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Check db tag first + tag := field.Tag.Get("db") + if tag != "" && tag != "-" { + fieldMap[tag] = i + continue + } + + // Check json tag next + tag = field.Tag.Get("json") + if tag != "" && tag != "-" { + // Handle json tag options like omitempty + parts := strings.Split(tag, ",") + fieldMap[parts[0]] = i + continue + } + + // Default to field name with snake_case conversion + fieldMap[toSnakeCase(field.Name)] = i + } + + return fieldMap +} + +var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`) + +// toSnakeCase converts a camelCase string to snake_case +func toSnakeCase(s string) string { + return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}")) +} + +// scanStructTags scans a single row into a struct using field tags +func scanStructTags(row *sql.Row, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to a struct") + } + + fields := make([]interface{}, 0, v.NumField()) + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanSet() { + fields = append(fields, field.Addr().Interface()) + } + } + + return row.Scan(fields...) +} + +// scanStructsTags scans multiple rows into a slice of structs using field tags +func scanStructsTags(rows *sql.Rows, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + sliceVal := v.Elem() + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("dest must be a pointer to a slice") + } + + elemType := sliceVal.Type().Elem() + isPtr := elemType.Kind() == reflect.Ptr + if isPtr { + elemType = elemType.Elem() + } + + if elemType.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to a slice of structs") + } + + // Get column names + columns, err := rows.Columns() + if err != nil { + return err + } + + // Build field mapping + fieldMap := getFieldMap(elemType) + + // Prepare values slice for each scan + values := make([]interface{}, len(columns)) + scanFields := make([]interface{}, len(columns)) + for i := range values { + scanFields[i] = &values[i] + } + + for rows.Next() { + // Create a new struct instance + newElem := reflect.New(elemType).Elem() + + // Scan row into values + if err := rows.Scan(scanFields...); err != nil { + return err + } + + // Map values to struct fields + for i, colName := range columns { + if fieldIndex, ok := fieldMap[colName]; ok { + field := newElem.Field(fieldIndex) + if field.CanSet() { + val := reflect.ValueOf(values[i]) + if val.Elem().Kind() == reflect.Interface { + val = val.Elem() + } + if val.Kind() == reflect.Ptr && !val.IsNil() { + field.Set(val.Elem()) + } else if !val.IsNil() { + field.Set(val) + } + } + } + } + + // Append to result slice + if isPtr { + sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) + } else { + sliceVal.Set(reflect.Append(sliceVal, newElem)) + } + } + + return rows.Err() +} diff --git a/server/internal/db/scanner_test.go b/server/internal/db/scanner_test.go new file mode 100644 index 0000000..a1fbc28 --- /dev/null +++ b/server/internal/db/scanner_test.go @@ -0,0 +1,298 @@ +package db_test + +import ( + "database/sql" + "testing" + "time" + + "lemma/internal/db" +) + +func TestScannerQueryRow(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + // Create a test table + _, err = testDB.TestDB().Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + active BOOLEAN NOT NULL + ) + `) + if err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + type User struct { + ID int + Email string + CreatedAt time.Time + Active bool + } + + // Insert test data + now := time.Now().UTC().Truncate(time.Second) + _, err = testDB.TestDB().Exec( + "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", + 1, "test@example.com", now, true, + ) + if err != nil { + t.Fatalf("Failed to insert test data: %v", err) + } + + // Test query row success + t.Run("QueryRow success", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id = "). + Placeholder(1) + + var user User + err := scanner.QueryRow(&user, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if user.ID != 1 { + t.Errorf("Expected ID 1, got %d", user.ID) + } + if user.Email != "test@example.com" { + t.Errorf("Expected Email test@example.com, got %s", user.Email) + } + if !user.CreatedAt.Equal(now) { + t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt) + } + if !user.Active { + t.Errorf("Expected Active true, got %v", user.Active) + } + }) + + // Test query row no results + t.Run("QueryRow no results", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id = "). + Placeholder(999) + + var user User + err := scanner.QueryRow(&user, q) + + if err != sql.ErrNoRows { + t.Errorf("Expected ErrNoRows, got %v", err) + } + }) + + // Test scanning a single value + t.Run("QueryRow single value", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("COUNT(*)").From("users") + + var count int + err := scanner.QueryRow(&count, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if count != 1 { + t.Errorf("Expected count 1, got %d", count) + } + }) +} + +func TestScannerQuery(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + // Create a test table + _, err = testDB.TestDB().Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + active BOOLEAN NOT NULL + ) + `) + if err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + type User struct { + ID int + Email string + CreatedAt time.Time + Active bool + } + + // Insert test data + now := time.Now().UTC().Truncate(time.Second) + testUsers := []User{ + {ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true}, + {ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false}, + {ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true}, + } + + for _, user := range testUsers { + _, err = testDB.TestDB().Exec( + "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", + user.ID, user.Email, user.CreatedAt, user.Active, + ) + if err != nil { + t.Fatalf("Failed to insert test data: %v", err) + } + } + + // Test query multiple rows + t.Run("Query multiple rows", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + OrderBy("id ASC") + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != len(testUsers) { + t.Errorf("Expected %d users, got %d", len(testUsers), len(users)) + } + + for i, u := range users { + if u.ID != testUsers[i].ID { + t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID) + } + if u.Email != testUsers[i].Email { + t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email) + } + if !u.CreatedAt.Equal(testUsers[i].CreatedAt) { + t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt) + } + if u.Active != testUsers[i].Active { + t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active) + } + } + }) + + // Test query with filter + t.Run("Query with filter", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("active = "). + Placeholder(true). + OrderBy("id ASC") + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != 2 { + t.Errorf("Expected 2 users, got %d", len(users)) + } + + for _, u := range users { + if !u.Active { + t.Errorf("Expected only active users, got inactive user: %+v", u) + } + } + }) + + // Test query empty result + t.Run("Query empty result", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id > "). + Placeholder(100) + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != 0 { + t.Errorf("Expected 0 users, got %d", len(users)) + } + }) +} + +func TestScanErrors(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("1") + + // Test non-pointer + t.Run("QueryRow non-pointer", func(t *testing.T) { + var user struct{} + err := scanner.QueryRow(user, q) // Passing non-pointer + + if err == nil { + t.Error("Expected error for non-pointer, got nil") + } + }) + + // Test pointer to non-slice for Query + t.Run("Query pointer to non-slice", func(t *testing.T) { + var user struct{} + err := scanner.Query(&user, q) // Passing pointer to struct, not slice + + if err == nil { + t.Error("Expected error for non-slice pointer, got nil") + } + }) + + // Test non-pointer for Query + t.Run("Query non-pointer", func(t *testing.T) { + var users []struct{} + err := scanner.Query(users, q) // Passing non-pointer + + if err == nil { + t.Error("Expected error for non-pointer, got nil") + } + }) +} + +// Mock secrets service for testing +type mockSecretsService struct{} + +func (m *mockSecretsService) Encrypt(plaintext string) (string, error) { + return plaintext, nil +} + +func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) { + return ciphertext, nil +}