mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Merge pull request #36 from lordmathis/feat/postgres
Add support for Postgres
This commit is contained in:
17
.github/workflows/go-test.yml
vendored
17
.github/workflows/go-test.yml
vendored
@@ -13,6 +13,21 @@ jobs:
|
||||
name: Run Tests
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
env:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: lemma_test
|
||||
ports:
|
||||
- 5432:5432
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
defaults:
|
||||
run:
|
||||
working-directory: ./server
|
||||
@@ -29,6 +44,8 @@ jobs:
|
||||
|
||||
- name: Run Tests
|
||||
run: go test -tags=test,integration ./... -v
|
||||
env:
|
||||
LEMMA_TEST_POSTGRES_URL: "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
|
||||
|
||||
- name: Run Tests with Race Detector
|
||||
run: go test -tags=test,integration -race ./... -v
|
||||
|
||||
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@@ -15,6 +15,9 @@
|
||||
"go.lintOnSave": "package",
|
||||
"go.formatTool": "goimports",
|
||||
"go.testFlags": ["-tags=test,integration"],
|
||||
"go.testEnvVars": {
|
||||
"LEMMA_TEST_POSTGRES_URL": "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
|
||||
},
|
||||
"[go]": {
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
|
||||
@@ -33,8 +33,8 @@ Lemma can be configured using environment variables. Here are the available conf
|
||||
### Optional Environment Variables
|
||||
|
||||
- `LEMMA_ENV`: Set to "development" to enable development mode
|
||||
- `LEMMA_DB_PATH`: Path to the SQLite database file (default: "./lemma.db")
|
||||
- `LEMMA_WORKDIR`: Working directory for application data (default: "./data")
|
||||
- `LEMMA_DB_URL`: URL (Connection string) to the database. Supported databases are sqlite and postgres a (default: "./lemma.db")
|
||||
- `LEMMA_WORKDIR`: Working directory for application data (default: "sqlite://lemma.db")
|
||||
- `LEMMA_STATIC_PATH`: Path to static files (default: "../app/dist")
|
||||
- `LEMMA_PORT`: Port to run the server on (default: "8080")
|
||||
- `LEMMA_DOMAIN`: Domain name where the application is hosted for cookie authentication
|
||||
|
||||
35
server/docker-compose.test.yaml
Normal file
35
server/docker-compose.test.yaml
Normal file
@@ -0,0 +1,35 @@
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
postgres:
|
||||
image: postgres:17
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: lemma_test
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- postgres-data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
pgadmin:
|
||||
image: dpage/pgadmin4
|
||||
environment:
|
||||
PGADMIN_DEFAULT_EMAIL: admin@admin.com
|
||||
PGADMIN_DEFAULT_PASSWORD: admin
|
||||
PGADMIN_CONFIG_SERVER_MODE: "False"
|
||||
ports:
|
||||
- "8080:80"
|
||||
volumes:
|
||||
- pgadmin-data:/var/lib/pgadmin
|
||||
depends_on:
|
||||
- postgres
|
||||
|
||||
volumes:
|
||||
postgres-data:
|
||||
pgadmin-data:
|
||||
@@ -57,14 +57,20 @@ package app // import "lemma/internal/app"
|
||||
Package app provides application-level functionality for initializing and
|
||||
running the server
|
||||
|
||||
FUNCTIONS
|
||||
|
||||
func ParseDBURL(dbURL string) (db.DBType, string, error)
|
||||
ParseDBURL parses a database URL and returns the driver name and data source
|
||||
|
||||
|
||||
TYPES
|
||||
|
||||
type Config struct {
|
||||
DBPath string
|
||||
DBURL string
|
||||
DBType db.DBType
|
||||
WorkDir string
|
||||
StaticPath string
|
||||
Port string
|
||||
RootURL string
|
||||
Domain string
|
||||
CORSOrigins []string
|
||||
AdminEmail string
|
||||
@@ -281,6 +287,24 @@ const (
|
||||
|
||||
TYPES
|
||||
|
||||
type DBField struct {
|
||||
Name string
|
||||
Value any
|
||||
Type reflect.Type
|
||||
OriginalName string
|
||||
|
||||
// Has unexported fields.
|
||||
}
|
||||
|
||||
func StructTagsToFields(s any) ([]DBField, error)
|
||||
StructTagsToFields converts a struct to a slice of DBField instances
|
||||
|
||||
type DBType string
|
||||
|
||||
const (
|
||||
DBTypeSQLite DBType = "sqlite3"
|
||||
DBTypePostgres DBType = "postgres"
|
||||
)
|
||||
type Database interface {
|
||||
UserStore
|
||||
WorkspaceStore
|
||||
@@ -292,14 +316,115 @@ type Database interface {
|
||||
}
|
||||
Database defines the methods for interacting with the database
|
||||
|
||||
func Init(dbPath string, secretsService secrets.Service) (Database, error)
|
||||
func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error)
|
||||
Init initializes the database connection
|
||||
|
||||
type Migration struct {
|
||||
Version int
|
||||
SQL string
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
InnerJoin JoinType = "INNER JOIN"
|
||||
LeftJoin JoinType = "LEFT JOIN"
|
||||
RightJoin JoinType = "RIGHT JOIN"
|
||||
)
|
||||
type Query struct {
|
||||
// Has unexported fields.
|
||||
}
|
||||
Migration represents a database migration
|
||||
Query represents a SQL query with its parameters
|
||||
|
||||
func NewQuery(dbType DBType, secretsService secrets.Service) *Query
|
||||
NewQuery creates a new Query instance
|
||||
|
||||
func (q *Query) AddArgs(args ...any) *Query
|
||||
AddArgs adds arguments to the query
|
||||
|
||||
func (q *Query) And(condition string) *Query
|
||||
And adds an AND condition
|
||||
|
||||
func (q *Query) Args() []any
|
||||
Args returns the query arguments
|
||||
|
||||
func (q *Query) Delete() *Query
|
||||
Delete starts a DELETE statement
|
||||
|
||||
func (q *Query) EndGroup() *Query
|
||||
EndGroup ends a parenthetical group
|
||||
|
||||
func (q *Query) From(table string) *Query
|
||||
From adds a FROM clause
|
||||
|
||||
func (q *Query) GroupBy(columns ...string) *Query
|
||||
GroupBy adds a GROUP BY clause
|
||||
|
||||
func (q *Query) Having(condition string) *Query
|
||||
Having adds a HAVING clause for filtering groups
|
||||
|
||||
func (q *Query) Insert(table string, columns ...string) *Query
|
||||
Insert starts an INSERT statement
|
||||
|
||||
func (q *Query) InsertStruct(s any, table string) (*Query, error)
|
||||
InsertStruct creates an INSERT query from a struct
|
||||
|
||||
func (q *Query) Join(joinType JoinType, table, condition string) *Query
|
||||
Join adds a JOIN clause
|
||||
|
||||
func (q *Query) Limit(limit int) *Query
|
||||
Limit adds a LIMIT clause
|
||||
|
||||
func (q *Query) Offset(offset int) *Query
|
||||
Offset adds an OFFSET clause
|
||||
|
||||
func (q *Query) Or(condition string) *Query
|
||||
Or adds an OR condition
|
||||
|
||||
func (q *Query) OrderBy(columns ...string) *Query
|
||||
OrderBy adds an ORDER BY clause
|
||||
|
||||
func (q *Query) Placeholder(arg any) *Query
|
||||
Placeholder adds a placeholder for a single argument
|
||||
|
||||
func (q *Query) Placeholders(n int) *Query
|
||||
Placeholders adds n placeholders separated by commas
|
||||
|
||||
func (q *Query) Returning(columns ...string) *Query
|
||||
Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+)
|
||||
|
||||
func (q *Query) Select(columns ...string) *Query
|
||||
Select adds a SELECT clause
|
||||
|
||||
func (q *Query) SelectStruct(s any, table string) (*Query, error)
|
||||
SelectStruct creates a SELECT query from a struct
|
||||
|
||||
func (q *Query) Set(column string) *Query
|
||||
Set adds a SET clause for updates
|
||||
|
||||
func (q *Query) StartGroup() *Query
|
||||
StartGroup starts a parenthetical group
|
||||
|
||||
func (q *Query) String() string
|
||||
String returns the formatted query string
|
||||
|
||||
func (q *Query) Update(table string) *Query
|
||||
Update starts an UPDATE statement
|
||||
|
||||
func (q *Query) UpdateStruct(s any, table string) (*Query, error)
|
||||
UpdateStruct creates an UPDATE query from a struct
|
||||
|
||||
func (q *Query) Values(count int) *Query
|
||||
Values adds a VALUES clause
|
||||
|
||||
func (q *Query) Where(condition string) *Query
|
||||
Where adds a WHERE clause
|
||||
|
||||
func (q *Query) WhereIn(column string, count int) *Query
|
||||
WhereIn adds a WHERE IN clause
|
||||
|
||||
func (q *Query) Write(s string) *Query
|
||||
Write adds a string to the query
|
||||
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
Scanner is an interface that both sql.Row and sql.Rows satisfy
|
||||
|
||||
type SessionStore interface {
|
||||
CreateSession(session *models.Session) error
|
||||
@@ -897,22 +1022,22 @@ and serialize data in the application.
|
||||
TYPES
|
||||
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
ID string `db:"id"` // Unique session identifier
|
||||
UserID int `db:"user_id"` // ID of the user this session belongs to
|
||||
RefreshToken string `db:"refresh_token"` // The refresh token associated with this session
|
||||
ExpiresAt time.Time `db:"expires_at"` // When this session expires
|
||||
CreatedAt time.Time `db:"created_at,default"` // When this session was created
|
||||
}
|
||||
Session represents a user session in the database
|
||||
|
||||
type User struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
DisplayName string `json:"displayName"`
|
||||
PasswordHash string `json:"-"`
|
||||
Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
LastWorkspaceID int `json:"lastWorkspaceId"`
|
||||
ID int `json:"id" db:"id,default" validate:"required,min=1"`
|
||||
Email string `json:"email" db:"email" validate:"required,email"`
|
||||
DisplayName string `json:"displayName" db:"display_name"`
|
||||
PasswordHash string `json:"-" db:"password_hash"`
|
||||
Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"`
|
||||
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
|
||||
LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"`
|
||||
}
|
||||
User represents a user in the system
|
||||
|
||||
@@ -930,24 +1055,24 @@ const (
|
||||
User roles
|
||||
|
||||
type Workspace struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
UserID int `json:"userId" validate:"required,min=1"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
LastOpenedFilePath string `json:"lastOpenedFilePath"`
|
||||
ID int `json:"id" db:"id,default" validate:"required,min=1"`
|
||||
UserID int `json:"userId" db:"user_id" validate:"required,min=1"`
|
||||
Name string `json:"name" db:"name" validate:"required"`
|
||||
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
|
||||
LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"`
|
||||
|
||||
// Integrated settings
|
||||
Theme string `json:"theme" validate:"oneof=light dark"`
|
||||
AutoSave bool `json:"autoSave"`
|
||||
ShowHiddenFiles bool `json:"showHiddenFiles"`
|
||||
GitEnabled bool `json:"gitEnabled"`
|
||||
GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"`
|
||||
GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"`
|
||||
GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"`
|
||||
GitAutoCommit bool `json:"gitAutoCommit"`
|
||||
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
|
||||
GitCommitName string `json:"gitCommitName"`
|
||||
GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"`
|
||||
Theme string `json:"theme" db:"theme" validate:"oneof=light dark"`
|
||||
AutoSave bool `json:"autoSave" db:"auto_save"`
|
||||
ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"`
|
||||
GitEnabled bool `json:"gitEnabled" db:"git_enabled"`
|
||||
GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"`
|
||||
GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"`
|
||||
GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"`
|
||||
GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"`
|
||||
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"`
|
||||
GitCommitName string `json:"gitCommitName" db:"git_commit_name"`
|
||||
GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"`
|
||||
}
|
||||
Workspace represents a user's workspace in the system
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ require (
|
||||
github.com/go-git/go-git/v5 v5.13.1
|
||||
github.com/go-playground/validator/v10 v10.22.1
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mattn/go-sqlite3 v1.14.23
|
||||
github.com/stretchr/testify v1.10.0
|
||||
@@ -21,7 +22,7 @@ require (
|
||||
require (
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
github.com/KyleBanks/depth v1.2.1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.1 // indirect
|
||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||
github.com/ProtonMail/go-crypto v1.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudflare/circl v1.3.7 // indirect
|
||||
@@ -38,10 +39,13 @@ require (
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
|
||||
github.com/hashicorp/errwrap v1.1.0 // indirect
|
||||
github.com/hashicorp/go-multierror v1.1.1 // indirect
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/kevinburke/ssh_config v1.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/mailru/easyjson v0.7.6 // indirect
|
||||
github.com/pjbgf/sha1cd v0.3.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
@@ -49,12 +53,11 @@ require (
|
||||
github.com/skeema/knownhosts v1.3.0 // indirect
|
||||
github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe // indirect
|
||||
github.com/xanzy/ssh-agent v0.3.3 // indirect
|
||||
golang.org/x/mod v0.17.0 // indirect
|
||||
go.uber.org/atomic v1.7.0 // indirect
|
||||
golang.org/x/net v0.33.0 // indirect
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.28.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
|
||||
golang.org/x/tools v0.24.0 // indirect
|
||||
gopkg.in/warnings.v0 v0.1.2 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
|
||||
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
|
||||
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
|
||||
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
|
||||
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
|
||||
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk=
|
||||
github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
@@ -21,10 +23,22 @@ github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGL
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
|
||||
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
|
||||
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
|
||||
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
|
||||
github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4=
|
||||
github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
|
||||
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
|
||||
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
|
||||
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
|
||||
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
|
||||
github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ=
|
||||
github.com/elazarl/goproxy v1.2.3/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
|
||||
@@ -43,6 +57,10 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj
|
||||
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
|
||||
github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M=
|
||||
github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc=
|
||||
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
|
||||
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
|
||||
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
|
||||
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
|
||||
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
|
||||
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
|
||||
@@ -61,14 +79,23 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.22.1 h1:40JcKH+bBNGFczGuoBYgX4I6m/i27HYW8P9FDk5PbgA=
|
||||
github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
|
||||
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
|
||||
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
|
||||
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
|
||||
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
|
||||
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
|
||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
|
||||
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
|
||||
@@ -84,15 +111,27 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
|
||||
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
|
||||
github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA=
|
||||
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
|
||||
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
||||
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
|
||||
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
|
||||
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
|
||||
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
|
||||
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
|
||||
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
|
||||
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
|
||||
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
|
||||
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
|
||||
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
|
||||
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
|
||||
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
|
||||
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
|
||||
github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4=
|
||||
github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
@@ -123,13 +162,23 @@ github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbW
|
||||
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
|
||||
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
|
||||
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
|
||||
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
|
||||
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
|
||||
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
|
||||
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
|
||||
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
|
||||
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
|
||||
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
|
||||
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
|
||||
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
||||
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
|
||||
@@ -151,8 +200,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
|
||||
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
|
||||
@@ -2,9 +2,11 @@ package app
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"lemma/internal/db"
|
||||
"lemma/internal/logging"
|
||||
"lemma/internal/secrets"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -12,7 +14,8 @@ import (
|
||||
|
||||
// Config holds the configuration for the application
|
||||
type Config struct {
|
||||
DBPath string
|
||||
DBURL string
|
||||
DBType db.DBType
|
||||
WorkDir string
|
||||
StaticPath string
|
||||
Port string
|
||||
@@ -31,7 +34,7 @@ type Config struct {
|
||||
// DefaultConfig returns a new Config instance with default values
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
DBPath: "./lemma.db",
|
||||
DBURL: "sqlite://lemma.db",
|
||||
WorkDir: "./data",
|
||||
StaticPath: "../app/dist",
|
||||
Port: "8080",
|
||||
@@ -65,6 +68,31 @@ func (c *Config) Redact() *Config {
|
||||
return &redacted
|
||||
}
|
||||
|
||||
// ParseDBURL parses a database URL and returns the driver name and data source
|
||||
func ParseDBURL(dbURL string) (db.DBType, string, error) {
|
||||
if strings.HasPrefix(dbURL, "sqlite://") || strings.HasPrefix(dbURL, "sqlite3://") {
|
||||
|
||||
path := strings.TrimPrefix(dbURL, "sqlite://")
|
||||
path = strings.TrimPrefix(path, "sqlite3://")
|
||||
|
||||
if path == ":memory:" {
|
||||
return db.DBTypeSQLite, path, nil
|
||||
}
|
||||
|
||||
if !filepath.IsAbs(path) {
|
||||
path = filepath.Clean(path)
|
||||
}
|
||||
return db.DBTypeSQLite, path, nil
|
||||
}
|
||||
|
||||
// Try to parse as postgres URL
|
||||
if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
|
||||
return db.DBTypePostgres, dbURL, nil
|
||||
}
|
||||
|
||||
return "", "", fmt.Errorf("unsupported database URL format: %s", dbURL)
|
||||
}
|
||||
|
||||
// LoadConfig creates a new Config instance with values from environment variables
|
||||
func LoadConfig() (*Config, error) {
|
||||
config := DefaultConfig()
|
||||
@@ -73,8 +101,13 @@ func LoadConfig() (*Config, error) {
|
||||
config.IsDevelopment = env == "development"
|
||||
}
|
||||
|
||||
if dbPath := os.Getenv("LEMMA_DB_PATH"); dbPath != "" {
|
||||
config.DBPath = dbPath
|
||||
if dbURL := os.Getenv("LEMMA_DB_URL"); dbURL != "" {
|
||||
dbType, dataSource, err := ParseDBURL(dbURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.DBURL = dataSource
|
||||
config.DBType = dbType
|
||||
}
|
||||
|
||||
if workDir := os.Getenv("LEMMA_WORKDIR"); workDir != "" {
|
||||
|
||||
@@ -2,6 +2,7 @@ package app_test
|
||||
|
||||
import (
|
||||
"lemma/internal/app"
|
||||
"lemma/internal/db"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -14,10 +15,10 @@ func TestDefaultConfig(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
got any
|
||||
expected any
|
||||
}{
|
||||
{"DBPath", cfg.DBPath, "./lemma.db"},
|
||||
{"DBPath", cfg.DBURL, "sqlite://lemma.db"},
|
||||
{"WorkDir", cfg.WorkDir, "./data"},
|
||||
{"StaticPath", cfg.StaticPath, "../app/dist"},
|
||||
{"Port", cfg.Port, "8080"},
|
||||
@@ -47,7 +48,7 @@ func TestLoad(t *testing.T) {
|
||||
cleanup := func() {
|
||||
envVars := []string{
|
||||
"LEMMA_ENV",
|
||||
"LEMMA_DB_PATH",
|
||||
"LEMMA_DB_URL",
|
||||
"LEMMA_WORKDIR",
|
||||
"LEMMA_STATIC_PATH",
|
||||
"LEMMA_PORT",
|
||||
@@ -81,8 +82,8 @@ func TestLoad(t *testing.T) {
|
||||
t.Fatalf("Load() error = %v", err)
|
||||
}
|
||||
|
||||
if cfg.DBPath != "./lemma.db" {
|
||||
t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./lemma.db")
|
||||
if cfg.DBURL != "sqlite://lemma.db" {
|
||||
t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "sqlite://lemma.db")
|
||||
}
|
||||
})
|
||||
|
||||
@@ -93,7 +94,7 @@ func TestLoad(t *testing.T) {
|
||||
// Set all environment variables
|
||||
envs := map[string]string{
|
||||
"LEMMA_ENV": "development",
|
||||
"LEMMA_DB_PATH": "/custom/db/path.db",
|
||||
"LEMMA_DB_URL": "sqlite:///custom/db/path.db",
|
||||
"LEMMA_WORKDIR": "/custom/work/dir",
|
||||
"LEMMA_STATIC_PATH": "/custom/static/path",
|
||||
"LEMMA_PORT": "3000",
|
||||
@@ -118,11 +119,12 @@ func TestLoad(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
got any
|
||||
expected any
|
||||
}{
|
||||
{"IsDevelopment", cfg.IsDevelopment, true},
|
||||
{"DBPath", cfg.DBPath, "/custom/db/path.db"},
|
||||
{"DBURL", cfg.DBURL, "/custom/db/path.db"},
|
||||
{"DBType", cfg.DBType, db.DBTypeSQLite},
|
||||
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
|
||||
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
|
||||
{"Port", cfg.Port, "3000"},
|
||||
|
||||
@@ -28,9 +28,9 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
|
||||
|
||||
// initDatabase initializes and migrates the database
|
||||
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
|
||||
logging.Debug("initializing database", "path", cfg.DBPath)
|
||||
logging.Debug("initializing database", "path", cfg.DBURL)
|
||||
|
||||
database, err := db.Init(cfg.DBPath, secretsService)
|
||||
database, err := db.Init(cfg.DBType, cfg.DBURL, secretsService)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to
|
||||
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
|
||||
log := getJWTLogger()
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) {
|
||||
// Validate the signing method
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
|
||||
@@ -9,9 +9,17 @@ import (
|
||||
"lemma/internal/models"
|
||||
"lemma/internal/secrets"
|
||||
|
||||
_ "github.com/lib/pq" // Postgres driver
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
|
||||
type DBType string
|
||||
|
||||
const (
|
||||
DBTypeSQLite DBType = "sqlite3"
|
||||
DBTypePostgres DBType = "postgres"
|
||||
)
|
||||
|
||||
// UserStore defines the methods for interacting with user data in the database
|
||||
type UserStore interface {
|
||||
CreateUser(user *models.User) (*models.User, error)
|
||||
@@ -68,12 +76,18 @@ type SystemStore interface {
|
||||
SetSystemSetting(key, value string) error
|
||||
}
|
||||
|
||||
type StructScanner interface {
|
||||
ScanStruct(row *sql.Row, dest interface{}) error
|
||||
ScanStructs(rows *sql.Rows, dest interface{}) error
|
||||
}
|
||||
|
||||
// Database defines the methods for interacting with the database
|
||||
type Database interface {
|
||||
UserStore
|
||||
WorkspaceStore
|
||||
SessionStore
|
||||
SystemStore
|
||||
StructScanner
|
||||
Begin() (*sql.Tx, error)
|
||||
Close() error
|
||||
Migrate() error
|
||||
@@ -93,6 +107,7 @@ var (
|
||||
// Sub-interfaces
|
||||
_ WorkspaceReader = (*database)(nil)
|
||||
_ WorkspaceWriter = (*database)(nil)
|
||||
_ StructScanner = (*database)(nil)
|
||||
)
|
||||
|
||||
var logger logging.Logger
|
||||
@@ -108,13 +123,45 @@ func getLogger() logging.Logger {
|
||||
type database struct {
|
||||
*sql.DB
|
||||
secretsService secrets.Service
|
||||
dbType DBType
|
||||
}
|
||||
|
||||
// Init initializes the database connection
|
||||
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
||||
log := getLogger()
|
||||
func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) {
|
||||
|
||||
db, err := sql.Open("sqlite3", dbPath)
|
||||
switch dbType {
|
||||
case DBTypeSQLite:
|
||||
db, err := initSQLite(dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize SQLite database: %w", err)
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
dbType: dbType,
|
||||
}
|
||||
return database, nil
|
||||
case DBTypePostgres:
|
||||
db, err := initPostgres(dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize Postgres database: %w", err)
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
dbType: dbType,
|
||||
}
|
||||
return database, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||
}
|
||||
|
||||
func initSQLite(dbURL string) (*sql.DB, error) {
|
||||
log := getLogger()
|
||||
db, err := sql.Open("sqlite3", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
@@ -128,13 +175,20 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||
}
|
||||
log.Debug("foreign keys enabled")
|
||||
return db, nil
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
func initPostgres(dbURL string) (*sql.DB, error) {
|
||||
db, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
return database, nil
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
@@ -148,29 +202,6 @@ func (db *database) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Helper methods for token encryption/decryption
|
||||
func (db *database) encryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
encrypted, err := db.secretsService.Encrypt(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to encrypt token: %w", err)
|
||||
}
|
||||
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (db *database) decryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
decrypted, err := db.secretsService.Decrypt(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
return decrypted, nil
|
||||
func (db *database) NewQuery() *Query {
|
||||
return NewQuery(db.dbType, db.secretsService)
|
||||
}
|
||||
|
||||
@@ -1,141 +1,71 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
)
|
||||
|
||||
// Migration represents a database migration
|
||||
type Migration struct {
|
||||
Version int
|
||||
SQL string
|
||||
}
|
||||
|
||||
var migrations = []Migration{
|
||||
{
|
||||
Version: 1,
|
||||
SQL: `
|
||||
-- Create users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_workspace_id INTEGER
|
||||
);
|
||||
|
||||
-- Create workspaces table with integrated settings
|
||||
CREATE TABLE IF NOT EXISTS workspaces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_opened_file_path TEXT,
|
||||
-- Settings fields
|
||||
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
|
||||
auto_save BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_url TEXT,
|
||||
git_user TEXT,
|
||||
git_token TEXT,
|
||||
git_auto_commit BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
|
||||
git_commit_name TEXT,
|
||||
git_commit_email TEXT,
|
||||
show_hidden_files BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_by INTEGER REFERENCES users(id),
|
||||
updated_by INTEGER REFERENCES users(id),
|
||||
updated_at TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id)
|
||||
);
|
||||
|
||||
-- Create sessions table for authentication
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create system_settings table for application settings
|
||||
CREATE TABLE IF NOT EXISTS system_settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
|
||||
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
|
||||
`,
|
||||
},
|
||||
}
|
||||
//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql
|
||||
var migrationsFS embed.FS
|
||||
|
||||
// Migrate applies all database migrations
|
||||
func (db *database) Migrate() error {
|
||||
log := getLogger().WithGroup("migrations")
|
||||
log.Info("starting database migration")
|
||||
|
||||
// Create migrations table if it doesn't exist
|
||||
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
|
||||
version INTEGER PRIMARY KEY
|
||||
)`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrations table: %w", err)
|
||||
var migrationPath string
|
||||
switch db.dbType {
|
||||
case DBTypePostgres:
|
||||
migrationPath = "migrations/postgres"
|
||||
case DBTypeSQLite:
|
||||
migrationPath = "migrations/sqlite"
|
||||
default:
|
||||
return fmt.Errorf("unsupported database driver: %s", db.dbType)
|
||||
}
|
||||
|
||||
// Get current version
|
||||
var currentVersion int
|
||||
err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(¤tVersion)
|
||||
log.Debug("using migration path", "path", migrationPath)
|
||||
|
||||
sourceInstance, err := iofs.New(migrationsFS, migrationPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get current migration version: %w", err)
|
||||
return fmt.Errorf("failed to create source instance: %w", err)
|
||||
}
|
||||
|
||||
// Apply new migrations
|
||||
for _, migration := range migrations {
|
||||
if migration.Version > currentVersion {
|
||||
log := log.With("migration_version", migration.Version)
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
|
||||
}
|
||||
var m *migrate.Migrate
|
||||
|
||||
// Execute migration SQL
|
||||
_, err = tx.Exec(migration.SQL)
|
||||
if err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("migration %d failed: %v, rollback failed: %v",
|
||||
migration.Version, err, rbErr)
|
||||
}
|
||||
return fmt.Errorf("migration %d failed: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
// Update migrations table
|
||||
_, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version)
|
||||
if err != nil {
|
||||
if rbErr := tx.Rollback(); rbErr != nil {
|
||||
return fmt.Errorf("failed to update migration version: %v, rollback failed: %v",
|
||||
err, rbErr)
|
||||
}
|
||||
return fmt.Errorf("failed to update migration version: %w", err)
|
||||
}
|
||||
|
||||
// Commit transaction
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
|
||||
}
|
||||
|
||||
currentVersion = migration.Version
|
||||
log.Debug("migration applied", "new_version", currentVersion)
|
||||
switch db.dbType {
|
||||
case DBTypePostgres:
|
||||
driver, err := postgres.WithInstance(db.DB, &postgres.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create postgres driver: %w", err)
|
||||
}
|
||||
m, err = migrate.NewWithInstance("iofs", sourceInstance, "postgres", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||
}
|
||||
|
||||
case DBTypeSQLite:
|
||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create sqlite driver: %w", err)
|
||||
}
|
||||
m, err = migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported database driver: %s", db.dbType)
|
||||
}
|
||||
|
||||
log.Info("database migration completed", "final_version", currentVersion)
|
||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return fmt.Errorf("failed to run migrations: %w", err)
|
||||
}
|
||||
|
||||
log.Info("database migration completed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
-- 001_initial_schema.down.sql (PostgreSQL version)
|
||||
DROP INDEX IF EXISTS idx_sessions_refresh_token;
|
||||
DROP INDEX IF EXISTS idx_sessions_expires_at;
|
||||
DROP INDEX IF EXISTS idx_sessions_user_id;
|
||||
DROP INDEX IF EXISTS idx_workspaces_user_id;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS workspaces;
|
||||
DROP TABLE IF EXISTS system_settings;
|
||||
DROP TABLE IF EXISTS users;
|
||||
@@ -0,0 +1,61 @@
|
||||
-- 001_initial_schema.up.sql (PostgreSQL version)
|
||||
|
||||
-- Create users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id SERIAL PRIMARY KEY,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_workspace_id INTEGER
|
||||
);
|
||||
|
||||
-- Create workspaces table with integrated settings
|
||||
CREATE TABLE IF NOT EXISTS workspaces (
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_opened_file_path TEXT,
|
||||
-- Settings fields
|
||||
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
|
||||
auto_save BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
git_enabled BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
git_url TEXT,
|
||||
git_user TEXT,
|
||||
git_token TEXT,
|
||||
git_auto_commit BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
|
||||
git_commit_name TEXT,
|
||||
git_commit_email TEXT,
|
||||
show_hidden_files BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
created_by INTEGER REFERENCES users(id),
|
||||
updated_by INTEGER REFERENCES users(id),
|
||||
updated_at TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create sessions table for authentication
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create system_settings table for application settings
|
||||
CREATE TABLE IF NOT EXISTS system_settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
|
||||
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
|
||||
CREATE INDEX idx_workspaces_user_id ON workspaces(user_id);
|
||||
@@ -0,0 +1,9 @@
|
||||
-- 001_initial_schema.down.sql
|
||||
DROP INDEX IF EXISTS idx_sessions_refresh_token;
|
||||
DROP INDEX IF EXISTS idx_sessions_expires_at;
|
||||
DROP INDEX IF EXISTS idx_sessions_user_id;
|
||||
DROP INDEX IF EXISTS idx_workspaces_user_id;
|
||||
DROP TABLE IF EXISTS sessions;
|
||||
DROP TABLE IF EXISTS workspaces;
|
||||
DROP TABLE IF EXISTS system_settings;
|
||||
DROP TABLE IF EXISTS users;
|
||||
@@ -0,0 +1,60 @@
|
||||
-- 001_initial_schema.up.sql
|
||||
-- Create users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
display_name TEXT,
|
||||
password_hash TEXT NOT NULL,
|
||||
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_workspace_id INTEGER
|
||||
);
|
||||
|
||||
-- Create workspaces table with integrated settings
|
||||
CREATE TABLE IF NOT EXISTS workspaces (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id INTEGER NOT NULL,
|
||||
name TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
last_opened_file_path TEXT,
|
||||
-- Settings fields
|
||||
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
|
||||
auto_save BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_enabled BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_url TEXT,
|
||||
git_user TEXT,
|
||||
git_token TEXT,
|
||||
git_auto_commit BOOLEAN NOT NULL DEFAULT 0,
|
||||
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
|
||||
git_commit_name TEXT,
|
||||
git_commit_email TEXT,
|
||||
show_hidden_files BOOLEAN NOT NULL DEFAULT 0,
|
||||
created_by INTEGER REFERENCES users(id),
|
||||
updated_by INTEGER REFERENCES users(id),
|
||||
updated_at TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id)
|
||||
);
|
||||
|
||||
-- Create sessions table for authentication
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id TEXT PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL,
|
||||
refresh_token TEXT NOT NULL,
|
||||
expires_at TIMESTAMP NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Create system_settings table for application settings
|
||||
CREATE TABLE IF NOT EXISTS system_settings (
|
||||
key TEXT PRIMARY KEY,
|
||||
value TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
-- Create indexes for performance
|
||||
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
|
||||
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
|
||||
CREATE INDEX idx_workspaces_user_id ON workspaces(user_id);
|
||||
@@ -1,53 +1,35 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"lemma/internal/db"
|
||||
|
||||
_ "lemma/internal/testenv"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to initialize database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
t.Run("migrations are applied in order", func(t *testing.T) {
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run initial migrations: %v", err)
|
||||
}
|
||||
|
||||
// Check migration version
|
||||
var version int
|
||||
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get migration version: %v", err)
|
||||
}
|
||||
|
||||
if version != 1 { // Current number of migrations in production code
|
||||
t.Errorf("expected migration version 1, got %d", version)
|
||||
}
|
||||
|
||||
// Verify number of migration entries matches versions applied
|
||||
var count int
|
||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to count migrations: %v", err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 migration entries, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrations create expected schema", func(t *testing.T) {
|
||||
// Run migrations
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Verify tables exist
|
||||
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"}
|
||||
tables := []string{
|
||||
"users",
|
||||
"workspaces",
|
||||
"sessions",
|
||||
"system_settings",
|
||||
"schema_migrations",
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if !tableExists(t, database, table) {
|
||||
t.Errorf("table %q does not exist", table)
|
||||
@@ -62,92 +44,34 @@ func TestMigrate(t *testing.T) {
|
||||
{"sessions", "idx_sessions_user_id"},
|
||||
{"sessions", "idx_sessions_expires_at"},
|
||||
{"sessions", "idx_sessions_refresh_token"},
|
||||
{"workspaces", "idx_workspaces_user_id"},
|
||||
}
|
||||
|
||||
for _, idx := range indexes {
|
||||
if !indexExists(t, database, idx.table, idx.name) {
|
||||
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("migrations are idempotent", func(t *testing.T) {
|
||||
// Run migrations again
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("failed to re-run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Verify migration count hasn't changed
|
||||
var count int
|
||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to count migrations: %v", err)
|
||||
}
|
||||
|
||||
if count != 1 {
|
||||
t.Errorf("expected 1 migration entries, got %d", count)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rollback on migration failure", func(t *testing.T) {
|
||||
// Create a test table that would conflict with a failing migration
|
||||
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test table: %v", err)
|
||||
}
|
||||
|
||||
// Start transaction
|
||||
tx, err := database.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to start transaction: %v", err)
|
||||
}
|
||||
|
||||
// Try operations that should fail and rollback
|
||||
_, err = tx.Exec(`
|
||||
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
||||
INSERT INTO nonexistent_table VALUES (1);
|
||||
`)
|
||||
if err == nil {
|
||||
tx.Rollback()
|
||||
t.Fatal("expected migration to fail")
|
||||
}
|
||||
tx.Rollback()
|
||||
|
||||
// Verify the migration version hasn't changed
|
||||
var version int
|
||||
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get migration version: %v", err)
|
||||
}
|
||||
|
||||
if version != 1 {
|
||||
t.Errorf("expected migration version to remain at 1, got %d", version)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
|
||||
t.Helper()
|
||||
|
||||
var name string
|
||||
err := database.TestDB().QueryRow(`
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?`,
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name=?`,
|
||||
tableName,
|
||||
).Scan(&name)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
|
||||
t.Helper()
|
||||
|
||||
var name string
|
||||
err := database.TestDB().QueryRow(`
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='index' AND tbl_name=? AND name=?`,
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='index' AND tbl_name=? AND name=?`,
|
||||
tableName, indexName,
|
||||
).Scan(&name)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
306
server/internal/db/query.go
Normal file
306
server/internal/db/query.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"lemma/internal/secrets"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type JoinType string
|
||||
|
||||
const (
|
||||
InnerJoin JoinType = "INNER JOIN"
|
||||
LeftJoin JoinType = "LEFT JOIN"
|
||||
RightJoin JoinType = "RIGHT JOIN"
|
||||
)
|
||||
|
||||
// Query represents a SQL query with its parameters
|
||||
type Query struct {
|
||||
builder strings.Builder
|
||||
args []any
|
||||
dbType DBType
|
||||
secretsService secrets.Service
|
||||
pos int // tracks the current placeholder position
|
||||
hasSelect bool
|
||||
hasFrom bool
|
||||
hasWhere bool
|
||||
hasOrderBy bool
|
||||
hasGroupBy bool
|
||||
hasHaving bool
|
||||
hasLimit bool
|
||||
hasOffset bool
|
||||
isInParens bool
|
||||
parensDepth int
|
||||
}
|
||||
|
||||
// NewQuery creates a new Query instance
|
||||
func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
|
||||
return &Query{
|
||||
dbType: dbType,
|
||||
secretsService: secretsService,
|
||||
args: make([]any, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Select adds a SELECT clause
|
||||
func (q *Query) Select(columns ...string) *Query {
|
||||
if !q.hasSelect {
|
||||
q.Write("SELECT ")
|
||||
q.Write(strings.Join(columns, ", "))
|
||||
q.hasSelect = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// From adds a FROM clause
|
||||
func (q *Query) From(table string) *Query {
|
||||
if !q.hasFrom {
|
||||
q.Write(" FROM ")
|
||||
q.Write(table)
|
||||
q.hasFrom = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Where adds a WHERE clause
|
||||
func (q *Query) Where(condition string) *Query {
|
||||
if !q.hasWhere {
|
||||
q.Write(" WHERE ")
|
||||
q.hasWhere = true
|
||||
} else {
|
||||
q.Write(" AND ")
|
||||
}
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// WhereIn adds a WHERE IN clause
|
||||
func (q *Query) WhereIn(column string, count int) *Query {
|
||||
if !q.hasWhere {
|
||||
q.Write(" WHERE ")
|
||||
q.hasWhere = true
|
||||
} else {
|
||||
q.Write(" AND ")
|
||||
}
|
||||
q.Write(column)
|
||||
q.Write(" IN (")
|
||||
q.Placeholders(count)
|
||||
q.Write(")")
|
||||
return q
|
||||
}
|
||||
|
||||
// And adds an AND condition
|
||||
func (q *Query) And(condition string) *Query {
|
||||
q.Write(" AND ")
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// Or adds an OR condition
|
||||
func (q *Query) Or(condition string) *Query {
|
||||
q.Write(" OR ")
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// Join adds a JOIN clause
|
||||
func (q *Query) Join(joinType JoinType, table, condition string) *Query {
|
||||
q.Write(" ")
|
||||
q.Write(string(joinType))
|
||||
q.Write(" ")
|
||||
q.Write(table)
|
||||
q.Write(" ON ")
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// OrderBy adds an ORDER BY clause
|
||||
func (q *Query) OrderBy(columns ...string) *Query {
|
||||
if !q.hasOrderBy {
|
||||
q.Write(" ORDER BY ")
|
||||
q.Write(strings.Join(columns, ", "))
|
||||
q.hasOrderBy = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// GroupBy adds a GROUP BY clause
|
||||
func (q *Query) GroupBy(columns ...string) *Query {
|
||||
if !q.hasGroupBy {
|
||||
q.Write(" GROUP BY ")
|
||||
q.Write(strings.Join(columns, ", "))
|
||||
q.hasGroupBy = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Having adds a HAVING clause for filtering groups
|
||||
func (q *Query) Having(condition string) *Query {
|
||||
if !q.hasHaving {
|
||||
q.Write(" HAVING ")
|
||||
q.hasHaving = true
|
||||
} else {
|
||||
q.Write(" AND ")
|
||||
}
|
||||
q.Write(condition)
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit adds a LIMIT clause
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
if !q.hasLimit {
|
||||
q.Write(" LIMIT ")
|
||||
q.Write(fmt.Sprintf("%d", limit))
|
||||
q.hasLimit = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset adds an OFFSET clause
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
if !q.hasOffset {
|
||||
q.Write(" OFFSET ")
|
||||
q.Write(fmt.Sprintf("%d", offset))
|
||||
q.hasOffset = true
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Insert starts an INSERT statement
|
||||
func (q *Query) Insert(table string, columns ...string) *Query {
|
||||
q.Write("INSERT INTO ")
|
||||
q.Write(table)
|
||||
q.Write(" (")
|
||||
q.Write(strings.Join(columns, ", "))
|
||||
q.Write(") VALUES ")
|
||||
return q
|
||||
}
|
||||
|
||||
// Values adds a VALUES clause
|
||||
func (q *Query) Values(count int) *Query {
|
||||
q.Write("(")
|
||||
q.Placeholders(count)
|
||||
q.Write(")")
|
||||
return q
|
||||
}
|
||||
|
||||
// Update starts an UPDATE statement
|
||||
func (q *Query) Update(table string) *Query {
|
||||
q.Write("UPDATE ")
|
||||
q.Write(table)
|
||||
q.Write(" SET ")
|
||||
return q
|
||||
}
|
||||
|
||||
// Set adds a SET clause for updates
|
||||
func (q *Query) Set(column string) *Query {
|
||||
if strings.Contains(q.builder.String(), "SET ") &&
|
||||
!strings.HasSuffix(q.builder.String(), "SET ") {
|
||||
q.Write(", ")
|
||||
}
|
||||
q.Write(column)
|
||||
q.Write(" = ")
|
||||
return q
|
||||
}
|
||||
|
||||
// Delete starts a DELETE statement
|
||||
func (q *Query) Delete() *Query {
|
||||
q.Write("DELETE")
|
||||
return q
|
||||
}
|
||||
|
||||
// StartGroup starts a parenthetical group
|
||||
func (q *Query) StartGroup() *Query {
|
||||
if q.hasWhere {
|
||||
q.Write(" AND (")
|
||||
} else {
|
||||
q.Write(" WHERE (")
|
||||
q.hasWhere = true
|
||||
}
|
||||
q.parensDepth++
|
||||
return q
|
||||
}
|
||||
|
||||
// EndGroup ends a parenthetical group
|
||||
func (q *Query) EndGroup() *Query {
|
||||
if q.parensDepth > 0 {
|
||||
q.Write(")")
|
||||
q.parensDepth--
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
// Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+)
|
||||
func (q *Query) Returning(columns ...string) *Query {
|
||||
q.Write(" RETURNING ")
|
||||
if len(columns) == 1 && columns[0] == "*" {
|
||||
q.Write("*")
|
||||
} else {
|
||||
q.Write(strings.Join(columns, ", "))
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// Write adds a string to the query
|
||||
func (q *Query) Write(s string) *Query {
|
||||
q.builder.WriteString(s)
|
||||
return q
|
||||
}
|
||||
|
||||
// Placeholder adds a placeholder for a single argument
|
||||
func (q *Query) Placeholder(arg any) *Query {
|
||||
q.pos++
|
||||
q.args = append(q.args, arg)
|
||||
|
||||
if q.dbType == DBTypePostgres {
|
||||
q.builder.WriteString(fmt.Sprintf("$%d", q.pos))
|
||||
} else {
|
||||
q.builder.WriteString("?")
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// Placeholders adds n placeholders separated by commas
|
||||
func (q *Query) Placeholders(n int) *Query {
|
||||
placeholders := make([]string, n)
|
||||
|
||||
for i := range n {
|
||||
q.pos++
|
||||
if q.dbType == DBTypePostgres {
|
||||
placeholders[i] = fmt.Sprintf("$%d", q.pos)
|
||||
} else {
|
||||
placeholders[i] = "?"
|
||||
}
|
||||
}
|
||||
|
||||
q.builder.WriteString(strings.Join(placeholders, ", "))
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) TimeSince(days int) *Query {
|
||||
if q.dbType == DBTypePostgres {
|
||||
q.builder.WriteString(fmt.Sprintf("NOW() - INTERVAL '%d days'", days))
|
||||
} else {
|
||||
q.builder.WriteString(fmt.Sprintf("datetime('now', '-%d days')", days))
|
||||
}
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// AddArgs adds arguments to the query
|
||||
func (q *Query) AddArgs(args ...any) *Query {
|
||||
q.args = append(q.args, args...)
|
||||
return q
|
||||
}
|
||||
|
||||
// String returns the formatted query string
|
||||
func (q *Query) String() string {
|
||||
return q.builder.String()
|
||||
}
|
||||
|
||||
// Args returns the query arguments
|
||||
func (q *Query) Args() []any {
|
||||
return q.args
|
||||
}
|
||||
856
server/internal/db/query_test.go
Normal file
856
server/internal/db/query_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"lemma/internal/db"
|
||||
)
|
||||
|
||||
func TestNewQuery(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
}{
|
||||
{
|
||||
name: "SQLite query",
|
||||
dbType: db.DBTypeSQLite,
|
||||
},
|
||||
{
|
||||
name: "Postgres query",
|
||||
dbType: db.DBTypePostgres,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
|
||||
// Test that a new query is empty
|
||||
if q.String() != "" {
|
||||
t.Errorf("NewQuery() should return empty string, got %q", q.String())
|
||||
}
|
||||
if len(q.Args()) != 0 {
|
||||
t.Errorf("NewQuery() should return empty args, got %v", q.Args())
|
||||
}
|
||||
|
||||
// Test placeholder behavior - SQLite uses ? and Postgres uses $1
|
||||
q.Write("test").Placeholder(1)
|
||||
|
||||
expectedPlaceholder := "?"
|
||||
if tt.dbType == db.DBTypePostgres {
|
||||
expectedPlaceholder = "$1"
|
||||
}
|
||||
|
||||
if q.String() != "test"+expectedPlaceholder {
|
||||
t.Errorf("Expected placeholder format %q for %s, got %q",
|
||||
"test"+expectedPlaceholder, tt.name, q.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicQueryBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Simple select SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users")
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Simple select Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users")
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Select with where SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users WHERE id = ?",
|
||||
wantArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "Select with where Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT id, name FROM users WHERE id = $1",
|
||||
wantArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "Multiple where conditions SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("role = ").Placeholder("admin")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?",
|
||||
wantArgs: []any{true, "admin"},
|
||||
},
|
||||
{
|
||||
name: "Multiple where conditions Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("role = ").Placeholder("admin")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2",
|
||||
wantArgs: []any{true, "admin"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPlaceholders(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Single placeholder SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ?",
|
||||
wantArgs: []any{42},
|
||||
},
|
||||
{
|
||||
name: "Single placeholder Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = $1",
|
||||
wantArgs: []any{42},
|
||||
},
|
||||
{
|
||||
name: "Multiple placeholders SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").
|
||||
Placeholder(42).
|
||||
Write(" AND name = ").
|
||||
Placeholder("John")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?",
|
||||
wantArgs: []any{42, "John"},
|
||||
},
|
||||
{
|
||||
name: "Multiple placeholders Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id = ").
|
||||
Placeholder(42).
|
||||
Write(" AND name = ").
|
||||
Placeholder("John")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2",
|
||||
wantArgs: []any{42, "John"},
|
||||
},
|
||||
{
|
||||
name: "Placeholders for IN SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id IN (").
|
||||
Placeholders(3).
|
||||
Write(")").
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
|
||||
wantArgs: []any{1, 2, 3},
|
||||
},
|
||||
{
|
||||
name: "Placeholders for IN Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Write("SELECT * FROM users WHERE id IN (").
|
||||
Placeholders(3).
|
||||
Write(")").
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)",
|
||||
wantArgs: []any{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhereClauseBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Simple where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ?",
|
||||
wantArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "Where with And",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
And("active = ").Placeholder(true)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?",
|
||||
wantArgs: []any{1, true},
|
||||
},
|
||||
{
|
||||
name: "Where with Or",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?",
|
||||
wantArgs: []any{1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with parentheses",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
And("(").
|
||||
Write("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2).
|
||||
Write(")")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
|
||||
wantArgs: []any{true, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with StartGroup and EndGroup",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
Write(" AND (").
|
||||
Write("id = ").Placeholder(1).
|
||||
Or("id = ").Placeholder(2).
|
||||
Write(")")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
|
||||
wantArgs: []any{true, 1, 2},
|
||||
},
|
||||
{
|
||||
name: "Where with nested groups",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
Where("(").
|
||||
Write("active = ").Placeholder(true).
|
||||
Or("role = ").Placeholder("admin").
|
||||
Write(")").
|
||||
And("created_at > ").Placeholder("2020-01-01")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?",
|
||||
wantArgs: []any{true, "admin", "2020-01-01"},
|
||||
},
|
||||
{
|
||||
name: "WhereIn",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").
|
||||
WhereIn("id", 3).
|
||||
AddArgs(1, 2, 3)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
|
||||
wantArgs: []any{1, 2, 3},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJoinClauseBuilding(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Inner join",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Left join",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Multiple joins",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name", "s.role").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
|
||||
Join(db.LeftJoin, "settings s", "s.user_id = u.id")
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Join with where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.*", "w.name").
|
||||
From("users u").
|
||||
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
|
||||
Where("u.active = ").Placeholder(true)
|
||||
},
|
||||
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?",
|
||||
wantArgs: []any{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrderLimitOffset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Order by",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").OrderBy("name ASC")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users ORDER BY name ASC",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Order by multiple columns",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC")
|
||||
},
|
||||
wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Limit",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Limit(10)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users LIMIT 10",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Limit and offset",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").From("users").Limit(10).Offset(20)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20",
|
||||
wantArgs: []any{},
|
||||
},
|
||||
{
|
||||
name: "Complete query with all clauses",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("*").
|
||||
From("users").
|
||||
Where("active = ").Placeholder(true).
|
||||
OrderBy("name ASC").
|
||||
Limit(10).
|
||||
Offset(20)
|
||||
},
|
||||
wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20",
|
||||
wantArgs: []any{true},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertUpdateDelete(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Insert SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John", "john@example.com")
|
||||
},
|
||||
wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)",
|
||||
wantArgs: []any{"John", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "Insert Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John", "john@example.com")
|
||||
},
|
||||
wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)",
|
||||
wantArgs: []any{"John", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "Update SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("John").
|
||||
Set("email").Placeholder("john@example.com").
|
||||
Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?",
|
||||
wantArgs: []any{"John", "john@example.com", 1},
|
||||
},
|
||||
{
|
||||
name: "Update Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("John").
|
||||
Set("email").Placeholder("john@example.com").
|
||||
Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3",
|
||||
wantArgs: []any{"John", "john@example.com", 1},
|
||||
},
|
||||
{
|
||||
name: "Delete SQLite",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Delete().From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "DELETE FROM users WHERE id = ?",
|
||||
wantArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "Delete Postgres",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Delete().From("users").Where("id = ").Placeholder(1)
|
||||
},
|
||||
wantSQL: "DELETE FROM users WHERE id = $1",
|
||||
wantArgs: []any{1},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHavingClause(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Simple having",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "COUNT(*) as count").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("count > ").Placeholder(5)
|
||||
},
|
||||
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?",
|
||||
wantArgs: []any{5},
|
||||
},
|
||||
{
|
||||
name: "Having with multiple conditions",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "AVG(salary) as avg_salary").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("avg_salary > ").Placeholder(50000).
|
||||
And("COUNT(*) > ").Placeholder(3)
|
||||
},
|
||||
wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?",
|
||||
wantArgs: []any{50000, 3},
|
||||
},
|
||||
{
|
||||
name: "Having with postgres placeholders",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("department", "COUNT(*) as count").
|
||||
From("employees").
|
||||
GroupBy("department").
|
||||
Having("count > ").Placeholder(5)
|
||||
},
|
||||
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1",
|
||||
wantArgs: []any{5},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryReturning(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildQuery func(q *db.Query) *db.Query
|
||||
expectedSQL string
|
||||
expectedArgs []any
|
||||
}{
|
||||
{
|
||||
name: "SQLite INSERT with RETURNING single column",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("id")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL INSERT with RETURNING single column",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("id")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "SQLite INSERT with RETURNING multiple columns",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("id", "created_at")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id, created_at",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL INSERT with RETURNING multiple columns",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("id", "created_at")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id, created_at",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "SQLite UPDATE with RETURNING",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("Jane Doe").
|
||||
Where("id = ").Placeholder(1).
|
||||
Returning("name", "updated_at")
|
||||
},
|
||||
expectedSQL: "UPDATE users SET name = ? WHERE id = ? RETURNING name, updated_at",
|
||||
expectedArgs: []any{"Jane Doe", 1},
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL UPDATE with RETURNING",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Update("users").
|
||||
Set("name").Placeholder("Jane Doe").
|
||||
Where("id = ").Placeholder(1).
|
||||
Returning("name", "updated_at")
|
||||
},
|
||||
expectedSQL: "UPDATE users SET name = $1 WHERE id = $2 RETURNING name, updated_at",
|
||||
expectedArgs: []any{"Jane Doe", 1},
|
||||
},
|
||||
{
|
||||
name: "SQLite DELETE with RETURNING",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Delete().
|
||||
From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
Returning("id", "name", "email")
|
||||
},
|
||||
expectedSQL: "DELETE FROM users WHERE id = ? RETURNING id, name, email",
|
||||
expectedArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL DELETE with RETURNING",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Delete().
|
||||
From("users").
|
||||
Where("id = ").Placeholder(1).
|
||||
Returning("id", "name", "email")
|
||||
},
|
||||
expectedSQL: "DELETE FROM users WHERE id = $1 RETURNING id, name, email",
|
||||
expectedArgs: []any{1},
|
||||
},
|
||||
{
|
||||
name: "SQLite INSERT with RETURNING *",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("*")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING *",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
{
|
||||
name: "PostgreSQL INSERT with RETURNING *",
|
||||
dbType: db.DBTypePostgres,
|
||||
buildQuery: func(q *db.Query) *db.Query {
|
||||
return q.Insert("users", "name", "email").
|
||||
Values(2).
|
||||
AddArgs("John Doe", "john@example.com").
|
||||
Returning("*")
|
||||
},
|
||||
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING *",
|
||||
expectedArgs: []any{"John Doe", "john@example.com"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
query := db.NewQuery(tc.dbType, &mockSecrets{})
|
||||
result := tc.buildQuery(query)
|
||||
|
||||
if result.String() != tc.expectedSQL {
|
||||
t.Errorf("Expected SQL: %s, got: %s", tc.expectedSQL, result.String())
|
||||
}
|
||||
|
||||
if len(result.Args()) != len(tc.expectedArgs) {
|
||||
t.Errorf("Expected %d args, got %d", len(tc.expectedArgs), len(result.Args()))
|
||||
}
|
||||
|
||||
for i, arg := range result.Args() {
|
||||
if arg != tc.expectedArgs[i] {
|
||||
t.Errorf("Expected arg %d to be %v, got %v", i, tc.expectedArgs[i], arg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexQueries(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType db.DBType
|
||||
buildFn func(*db.Query) *db.Query
|
||||
wantSQL string
|
||||
wantArgs []any
|
||||
}{
|
||||
{
|
||||
name: "Complex select with join and where",
|
||||
dbType: db.DBTypeSQLite,
|
||||
buildFn: func(q *db.Query) *db.Query {
|
||||
return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count").
|
||||
From("users u").
|
||||
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id").
|
||||
Where("u.active = ").Placeholder(true).
|
||||
GroupBy("u.id", "u.name").
|
||||
Having("COUNT(w.id) > ").Placeholder(0).
|
||||
OrderBy("workspace_count DESC").
|
||||
Limit(10)
|
||||
},
|
||||
wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10",
|
||||
wantArgs: []any{true, 0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
q := db.NewQuery(tt.dbType, &mockSecrets{})
|
||||
q = tt.buildFn(q)
|
||||
|
||||
gotSQL := q.String()
|
||||
gotArgs := q.Args()
|
||||
|
||||
if gotSQL != tt.wantSQL {
|
||||
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
|
||||
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -10,11 +10,12 @@ import (
|
||||
|
||||
// CreateSession inserts a new session record into the database
|
||||
func (db *database) CreateSession(session *models.Session) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)`,
|
||||
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
|
||||
)
|
||||
query, err := db.NewQuery().
|
||||
InsertStruct(session, "sessions")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
_, err = db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store session: %w", err)
|
||||
}
|
||||
@@ -25,13 +26,18 @@ func (db *database) CreateSession(session *models.Session) error {
|
||||
// GetSessionByRefreshToken retrieves a session by its refresh token
|
||||
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
|
||||
session := &models.Session{}
|
||||
err := db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE refresh_token = ? AND expires_at > ?`,
|
||||
refreshToken, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(session, "sessions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("refresh_token = ").
|
||||
Placeholder(refreshToken).
|
||||
And("expires_at >").
|
||||
Placeholder(time.Now())
|
||||
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, session)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
@@ -45,13 +51,18 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
|
||||
// GetSessionByID retrieves a session by its ID
|
||||
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
|
||||
session := &models.Session{}
|
||||
err := db.QueryRow(`
|
||||
SELECT id, user_id, refresh_token, expires_at, created_at
|
||||
FROM sessions
|
||||
WHERE id = ? AND expires_at > ?`,
|
||||
sessionID, time.Now(),
|
||||
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(session, "sessions")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("id = ").
|
||||
Placeholder(sessionID).
|
||||
And("expires_at >").
|
||||
Placeholder(time.Now())
|
||||
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, session)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("session not found")
|
||||
}
|
||||
@@ -64,7 +75,13 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
|
||||
|
||||
// DeleteSession removes a session from the database
|
||||
func (db *database) DeleteSession(sessionID string) error {
|
||||
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
||||
query := db.NewQuery().
|
||||
Delete().
|
||||
From("sessions").
|
||||
Where("id = ").
|
||||
Placeholder(sessionID)
|
||||
|
||||
result, err := db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
}
|
||||
@@ -84,7 +101,12 @@ func (db *database) DeleteSession(sessionID string) error {
|
||||
// CleanExpiredSessions removes all expired sessions from the database
|
||||
func (db *database) CleanExpiredSessions() error {
|
||||
log := getLogger().WithGroup("sessions")
|
||||
result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
|
||||
query := db.NewQuery().
|
||||
Delete().
|
||||
From("sessions").
|
||||
Where("expires_at <=").
|
||||
Placeholder(time.Now())
|
||||
result, err := db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to clean expired sessions: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSessionOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
|
||||
341
server/internal/db/struct_query.go
Normal file
341
server/internal/db/struct_query.go
Normal file
@@ -0,0 +1,341 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type DBField struct {
|
||||
Name string
|
||||
Value any
|
||||
Type reflect.Type
|
||||
OriginalName string
|
||||
useDefault bool
|
||||
encrypted bool
|
||||
}
|
||||
|
||||
// StructTagsToFields converts a struct to a slice of DBField instances
|
||||
func StructTagsToFields(s any) ([]DBField, error) {
|
||||
v := reflect.ValueOf(s)
|
||||
|
||||
if v.Kind() == reflect.Ptr {
|
||||
if v.IsNil() {
|
||||
return nil, fmt.Errorf("nil pointer provided")
|
||||
}
|
||||
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
if v.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("provided value is %s, expected struct", v.Kind())
|
||||
}
|
||||
|
||||
t := v.Type()
|
||||
fields := make([]DBField, 0, t.NumField())
|
||||
|
||||
for i := range t.NumField() {
|
||||
f := t.Field(i)
|
||||
|
||||
if !f.IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Tag.Get("db")
|
||||
if tag == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
if tag == "" {
|
||||
tag = toSnakeCase(f.Name)
|
||||
}
|
||||
|
||||
useDefault := false
|
||||
encrypted := false
|
||||
ommit := false
|
||||
|
||||
if strings.Contains(tag, ",") {
|
||||
parts := strings.Split(tag, ",")
|
||||
tag = parts[0]
|
||||
|
||||
for _, opt := range parts[1:] {
|
||||
switch opt {
|
||||
case "omitempty":
|
||||
if reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) {
|
||||
ommit = true
|
||||
}
|
||||
case "default":
|
||||
useDefault = true
|
||||
case "encrypted":
|
||||
encrypted = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ommit {
|
||||
continue
|
||||
}
|
||||
|
||||
fields = append(fields, DBField{
|
||||
Name: tag,
|
||||
Value: v.Field(i).Interface(),
|
||||
Type: f.Type,
|
||||
OriginalName: f.Name,
|
||||
useDefault: useDefault,
|
||||
encrypted: encrypted,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(fields, func(i, j int) bool {
|
||||
return fields[i].Name < fields[j].Name
|
||||
})
|
||||
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
func toSnakeCase(s string) string {
|
||||
var res string
|
||||
|
||||
for i, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
if i > 0 {
|
||||
res += "_"
|
||||
}
|
||||
res += string(unicode.ToLower(r))
|
||||
} else {
|
||||
res += string(r)
|
||||
}
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// InsertStruct creates an INSERT query from a struct
|
||||
func (q *Query) InsertStruct(s any, table string) (*Query, error) {
|
||||
fields, err := StructTagsToFields(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(fields))
|
||||
values := make([]any, 0, len(fields))
|
||||
|
||||
for _, f := range fields {
|
||||
value := f.Value
|
||||
|
||||
if f.useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.encrypted {
|
||||
encValue, err := q.secretsService.Encrypt(value.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value = encValue
|
||||
}
|
||||
|
||||
columns = append(columns, f.Name)
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if len(columns) == 0 {
|
||||
return nil, fmt.Errorf("no columns to insert")
|
||||
}
|
||||
|
||||
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// UpdateStruct creates an UPDATE query from a struct
|
||||
func (q *Query) UpdateStruct(s any, table string) (*Query, error) {
|
||||
fields, err := StructTagsToFields(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q = q.Update(table)
|
||||
|
||||
for _, f := range fields {
|
||||
value := f.Value
|
||||
|
||||
if f.useDefault {
|
||||
continue
|
||||
}
|
||||
|
||||
if f.encrypted {
|
||||
encValue, err := q.secretsService.Encrypt(value.(string))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value = encValue
|
||||
}
|
||||
|
||||
q = q.Set(f.Name).Placeholder(value)
|
||||
}
|
||||
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// SelectStruct creates a SELECT query from a struct
|
||||
func (q *Query) SelectStruct(s any, table string) (*Query, error) {
|
||||
fields, err := StructTagsToFields(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(fields))
|
||||
for _, f := range fields {
|
||||
columns = append(columns, f.Name)
|
||||
}
|
||||
|
||||
q = q.Select(columns...).From(table)
|
||||
return q, nil
|
||||
}
|
||||
|
||||
// Scanner is an interface that both sql.Row and sql.Rows satisfy
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// scanStructInstance is an internal function that handles the scanning logic for a single instance
|
||||
func (db *database) scanStructInstance(destVal reflect.Value, scanner Scanner) error {
|
||||
fields, err := StructTagsToFields(destVal.Interface())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to extract struct fields: %w", err)
|
||||
}
|
||||
|
||||
scanDest := make([]any, len(fields))
|
||||
var fieldsToDecrypt []string
|
||||
nullStringIndexes := make(map[int]reflect.Value)
|
||||
|
||||
for i, field := range fields {
|
||||
// Find the field in the struct
|
||||
structField := destVal.FieldByName(field.OriginalName)
|
||||
if !structField.IsValid() {
|
||||
return fmt.Errorf("struct field %s not found", field.OriginalName)
|
||||
}
|
||||
|
||||
if field.encrypted {
|
||||
fieldsToDecrypt = append(fieldsToDecrypt, field.OriginalName)
|
||||
}
|
||||
|
||||
if structField.Kind() == reflect.String {
|
||||
// Handle null strings separately
|
||||
nullStringIndexes[i] = structField
|
||||
var ns sql.NullString
|
||||
scanDest[i] = &ns
|
||||
} else {
|
||||
scanDest[i] = structField.Addr().Interface()
|
||||
}
|
||||
}
|
||||
|
||||
// Scan using the scanner interface
|
||||
if err := scanner.Scan(scanDest...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set null strings to their values if they are valid
|
||||
for i, field := range nullStringIndexes {
|
||||
ns := scanDest[i].(*sql.NullString)
|
||||
if ns.Valid {
|
||||
field.SetString(ns.String)
|
||||
}
|
||||
}
|
||||
|
||||
// Decrypt encrypted fields
|
||||
for _, fieldName := range fieldsToDecrypt {
|
||||
field := destVal.FieldByName(fieldName)
|
||||
if !field.IsZero() {
|
||||
decValue, err := db.secretsService.Decrypt(field.Interface().(string))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
field.SetString(decValue)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ScanStruct scans a single row into a struct
|
||||
func (db *database) ScanStruct(row *sql.Row, dest any) error {
|
||||
if row == nil {
|
||||
return fmt.Errorf("row cannot be nil")
|
||||
}
|
||||
|
||||
if row.Err() != nil {
|
||||
return row.Err()
|
||||
}
|
||||
|
||||
// Get the destination value
|
||||
destVal := reflect.ValueOf(dest)
|
||||
if destVal.Kind() != reflect.Ptr || destVal.IsNil() {
|
||||
return fmt.Errorf("destination must be a non-nil pointer to a struct, got %T", dest)
|
||||
}
|
||||
|
||||
destVal = destVal.Elem()
|
||||
if destVal.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("destination must be a pointer to a struct, got pointer to %s", destVal.Kind())
|
||||
}
|
||||
|
||||
return db.scanStructInstance(destVal, row)
|
||||
}
|
||||
|
||||
// ScanStructs scans multiple rows into a slice of structs
|
||||
func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error {
|
||||
if rows == nil {
|
||||
return fmt.Errorf("rows cannot be nil")
|
||||
}
|
||||
|
||||
// Get the slice value and element type
|
||||
sliceVal := reflect.ValueOf(destSlice)
|
||||
if sliceVal.Kind() != reflect.Ptr || sliceVal.IsNil() {
|
||||
return fmt.Errorf("destination must be a non-nil pointer to a slice, got %T", destSlice)
|
||||
}
|
||||
|
||||
sliceVal = sliceVal.Elem()
|
||||
if sliceVal.Kind() != reflect.Slice {
|
||||
return fmt.Errorf("destination must be a pointer to a slice, got pointer to %s", sliceVal.Kind())
|
||||
}
|
||||
|
||||
// Get the element type of the slice
|
||||
elemType := sliceVal.Type().Elem()
|
||||
|
||||
// Check if we have a direct struct type or a pointer to struct
|
||||
isPtr := elemType.Kind() == reflect.Ptr
|
||||
structType := elemType
|
||||
if isPtr {
|
||||
structType = elemType.Elem()
|
||||
}
|
||||
if structType.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("slice element type must be a struct or pointer to struct, got %s", elemType.String())
|
||||
}
|
||||
|
||||
// Process each row
|
||||
for rows.Next() {
|
||||
// Create a new instance of the struct for each row
|
||||
newElem := reflect.New(structType).Elem()
|
||||
|
||||
// Scan this row into the new element
|
||||
if err := db.scanStructInstance(newElem, rows); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add the new element to the result slice
|
||||
if isPtr {
|
||||
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
|
||||
} else {
|
||||
sliceVal.Set(reflect.Append(sliceVal, newElem))
|
||||
}
|
||||
}
|
||||
|
||||
// Check for errors from iterating over rows
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
507
server/internal/db/struct_query_test.go
Normal file
507
server/internal/db/struct_query_test.go
Normal file
@@ -0,0 +1,507 @@
|
||||
package db_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lemma/internal/db"
|
||||
"lemma/internal/models"
|
||||
_ "lemma/internal/testenv"
|
||||
)
|
||||
|
||||
// TestStructTagsToFields tests the exported StructTagsToFields function
|
||||
func TestStructTagsToFields(t *testing.T) {
|
||||
type testStruct struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"custom_name"`
|
||||
CreatedAt time.Time `db:"created_at,default"`
|
||||
Skip string `db:"-"`
|
||||
Empty string `db:"empty,omitempty"`
|
||||
Secret string `db:"secret,encrypted"`
|
||||
NoTag string
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input interface{}
|
||||
wantFields int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid struct",
|
||||
input: testStruct{
|
||||
ID: 1,
|
||||
Name: "Test",
|
||||
CreatedAt: time.Now(),
|
||||
Skip: "skip me",
|
||||
Secret: "secret value",
|
||||
NoTag: "no tag",
|
||||
},
|
||||
wantFields: 5, // ID, Name, CreatedAt, Secret, NoTag (Empty is omitted)
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "nil pointer",
|
||||
input: (*testStruct)(nil),
|
||||
wantFields: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "non-struct",
|
||||
input: "not a struct",
|
||||
wantFields: 0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "struct pointer",
|
||||
input: &testStruct{
|
||||
ID: 2,
|
||||
Name: "Test Pointer",
|
||||
},
|
||||
wantFields: 5, // Same fields as above
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
fields, err := db.StructTagsToFields(tt.input)
|
||||
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Error("Expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(fields) != tt.wantFields {
|
||||
t.Errorf("Expected %d fields, got %d", tt.wantFields, len(fields))
|
||||
}
|
||||
|
||||
// Check specific field handling for valid struct test
|
||||
if tt.name == "valid struct" {
|
||||
// Find fields by name
|
||||
var idField, nameField, createdAtField, secretField, emptyField, noTagField *db.DBField
|
||||
for i := range fields {
|
||||
f := &fields[i]
|
||||
switch f.Name {
|
||||
case "id":
|
||||
idField = f
|
||||
case "custom_name":
|
||||
nameField = f
|
||||
case "created_at":
|
||||
createdAtField = f
|
||||
case "secret":
|
||||
secretField = f
|
||||
case "empty":
|
||||
emptyField = f
|
||||
case "no_tag":
|
||||
noTagField = f
|
||||
}
|
||||
}
|
||||
|
||||
// Check fields exist
|
||||
if idField == nil {
|
||||
t.Error("ID field not found")
|
||||
}
|
||||
if nameField == nil {
|
||||
t.Error("Name field not found")
|
||||
}
|
||||
if createdAtField == nil {
|
||||
t.Error("CreatedAt field not found")
|
||||
}
|
||||
if secretField == nil {
|
||||
t.Error("Secret field not found")
|
||||
}
|
||||
if noTagField == nil {
|
||||
t.Error("NoTag field not found")
|
||||
}
|
||||
if emptyField != nil {
|
||||
t.Error("Empty field should be omitted")
|
||||
}
|
||||
|
||||
// Check original names
|
||||
if idField != nil && idField.OriginalName != "ID" {
|
||||
t.Errorf("Expected OriginalName 'ID', got '%s'", idField.OriginalName)
|
||||
}
|
||||
if nameField != nil && nameField.OriginalName != "Name" {
|
||||
t.Errorf("Expected OriginalName 'Name', got '%s'", nameField.OriginalName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestStructQueries tests the struct-based query methods using the test database
|
||||
func TestStructQueries(t *testing.T) {
|
||||
// Setup test database
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Define test data
|
||||
user := &models.User{
|
||||
Email: "structquery@example.com",
|
||||
DisplayName: "Struct Query Test",
|
||||
PasswordHash: "hashed_password",
|
||||
Role: models.RoleEditor,
|
||||
}
|
||||
|
||||
t.Run("InsertStructQuery", func(t *testing.T) {
|
||||
// Insert user with struct query
|
||||
createdUser, err := database.CreateUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create user with struct query: %v", err)
|
||||
}
|
||||
|
||||
// Verify user was created with proper values
|
||||
if createdUser.ID == 0 {
|
||||
t.Error("Expected non-zero user ID")
|
||||
}
|
||||
|
||||
if createdUser.Email != user.Email {
|
||||
t.Errorf("Email = %v, want %v", createdUser.Email, user.Email)
|
||||
}
|
||||
|
||||
if createdUser.DisplayName != user.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", createdUser.DisplayName, user.DisplayName)
|
||||
}
|
||||
|
||||
if createdUser.Role != user.Role {
|
||||
t.Errorf("Role = %v, want %v", createdUser.Role, user.Role)
|
||||
}
|
||||
|
||||
// We will use this user for the next test cases
|
||||
user = createdUser
|
||||
})
|
||||
|
||||
t.Run("SelectStructQuery", func(t *testing.T) {
|
||||
// Get the created user
|
||||
fetchedUser, err := database.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get user with struct query: %v", err)
|
||||
}
|
||||
|
||||
// Verify fetched user matches the original
|
||||
if fetchedUser.ID != user.ID {
|
||||
t.Errorf("ID = %v, want %v", fetchedUser.ID, user.ID)
|
||||
}
|
||||
|
||||
if fetchedUser.Email != user.Email {
|
||||
t.Errorf("Email = %v, want %v", fetchedUser.Email, user.Email)
|
||||
}
|
||||
|
||||
if fetchedUser.DisplayName != user.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", fetchedUser.DisplayName, user.DisplayName)
|
||||
}
|
||||
|
||||
if fetchedUser.Role != user.Role {
|
||||
t.Errorf("Role = %v, want %v", fetchedUser.Role, user.Role)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UpdateStructQuery", func(t *testing.T) {
|
||||
// Update the user
|
||||
user.DisplayName = "Updated Display Name"
|
||||
user.Role = models.RoleAdmin
|
||||
|
||||
err := database.UpdateUser(user)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update user with struct query: %v", err)
|
||||
}
|
||||
|
||||
// Verify update worked
|
||||
updatedUser, err := database.GetUserByID(user.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get updated user: %v", err)
|
||||
}
|
||||
|
||||
if updatedUser.DisplayName != "Updated Display Name" {
|
||||
t.Errorf("DisplayName = %v, want %v", updatedUser.DisplayName, "Updated Display Name")
|
||||
}
|
||||
|
||||
if updatedUser.Role != models.RoleAdmin {
|
||||
t.Errorf("Role = %v, want %v", updatedUser.Role, models.RoleAdmin)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStructs", func(t *testing.T) {
|
||||
// Create another user to test multiple rows
|
||||
secondUser := &models.User{
|
||||
Email: "structquery2@example.com",
|
||||
DisplayName: "Struct Query Test 2",
|
||||
PasswordHash: "hashed_password2",
|
||||
Role: models.RoleViewer,
|
||||
}
|
||||
|
||||
createdUser2, err := database.CreateUser(secondUser)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create second user: %v", err)
|
||||
}
|
||||
|
||||
// Get all users
|
||||
users, err := database.GetAllUsers()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get all users: %v", err)
|
||||
}
|
||||
|
||||
// Verify we have at least the two users we created
|
||||
if len(users) < 2 {
|
||||
t.Errorf("Expected at least 2 users, got %d", len(users))
|
||||
}
|
||||
|
||||
// Check if both our users are in the result
|
||||
foundUser1 := false
|
||||
foundUser2 := false
|
||||
|
||||
for _, u := range users {
|
||||
if u.ID == user.ID {
|
||||
foundUser1 = true
|
||||
if u.DisplayName != user.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", u.DisplayName, user.DisplayName)
|
||||
}
|
||||
}
|
||||
if u.ID == createdUser2.ID {
|
||||
foundUser2 = true
|
||||
if u.DisplayName != secondUser.DisplayName {
|
||||
t.Errorf("DisplayName = %v, want %v", u.DisplayName, secondUser.DisplayName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !foundUser1 {
|
||||
t.Errorf("First user (ID: %d) not found in results", user.ID)
|
||||
}
|
||||
if !foundUser2 {
|
||||
t.Errorf("Second user (ID: %d) not found in results", createdUser2.ID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStruct with null values", func(t *testing.T) {
|
||||
// Test handling of NULL values by creating a workspace with null values
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Null Test Workspace",
|
||||
// Leave all optional fields as zero values
|
||||
}
|
||||
workspace.SetDefaultSettings() // This will set default values
|
||||
|
||||
err := database.CreateWorkspace(workspace)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
// Clear the GitToken to test NULL handling
|
||||
testDB := database.TestDB()
|
||||
_, err = testDB.Exec("UPDATE workspaces SET git_token = NULL WHERE id = ?", workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set git_token to NULL: %v", err)
|
||||
}
|
||||
|
||||
// Fetch the workspace with NULL field
|
||||
fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace with NULL field: %v", err)
|
||||
}
|
||||
|
||||
// Verify the NULL field is empty
|
||||
if fetchedWorkspace.GitToken != "" {
|
||||
t.Errorf("Expected empty GitToken, got '%s'", fetchedWorkspace.GitToken)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStructErrors", func(t *testing.T) {
|
||||
// Test error handling in ScanStruct
|
||||
testDB := database.TestDB()
|
||||
|
||||
// Attempt to scan too many columns into a struct with fewer fields
|
||||
row := testDB.QueryRow("SELECT 1, 2, 3")
|
||||
var singleField struct {
|
||||
One int `db:"one"`
|
||||
}
|
||||
|
||||
err := database.ScanStruct(row, &singleField)
|
||||
if err == nil {
|
||||
t.Error("Expected error when scanning too many columns, got nil")
|
||||
}
|
||||
|
||||
// Test scanning into a non-struct
|
||||
var notAStruct int
|
||||
row = testDB.QueryRow("SELECT 1")
|
||||
err = database.ScanStruct(row, ¬AStruct)
|
||||
if err == nil {
|
||||
t.Error("Expected error when scanning into non-struct, got nil")
|
||||
}
|
||||
|
||||
// Test scanning into nil
|
||||
var nilPtr *struct{}
|
||||
row = testDB.QueryRow("SELECT 1")
|
||||
err = database.ScanStruct(row, nilPtr)
|
||||
if err == nil {
|
||||
t.Error("Expected error when scanning into nil pointer, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestScanStructsErrors tests error handling for ScanStructs
|
||||
func TestScanStructsErrors(t *testing.T) {
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
testDB := database.TestDB()
|
||||
|
||||
t.Run("ScanStructsWithNilRows", func(t *testing.T) {
|
||||
var users []*models.User
|
||||
err := database.ScanStructs(nil, &users)
|
||||
if err == nil {
|
||||
t.Error("Expected error with nil rows, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStructsWithNilDest", func(t *testing.T) {
|
||||
rows, err := testDB.Query("SELECT 1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var nilSlice *[]*models.User
|
||||
err = database.ScanStructs(rows, nilSlice)
|
||||
if err == nil {
|
||||
t.Error("Expected error with nil destination, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStructsWithNonSlice", func(t *testing.T) {
|
||||
rows, err := testDB.Query("SELECT 1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var nonSlice int
|
||||
err = database.ScanStructs(rows, &nonSlice)
|
||||
if err == nil {
|
||||
t.Error("Expected error with non-slice destination, got nil")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("ScanStructsWithNonStructSlice", func(t *testing.T) {
|
||||
rows, err := testDB.Query("SELECT 1")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to execute query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var intSlice []int
|
||||
err = database.ScanStructs(rows, &intSlice)
|
||||
if err == nil {
|
||||
t.Error("Expected error with non-struct slice, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestEncryptedFields tests handling of encrypted fields
|
||||
func TestEncryptedFields(t *testing.T) {
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test database: %v", err)
|
||||
}
|
||||
defer database.Close()
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("Failed to run migrations: %v", err)
|
||||
}
|
||||
|
||||
// Create user with workspace that has encrypted token
|
||||
user, err := database.CreateUser(&models.User{
|
||||
Email: "encrypted@example.com",
|
||||
DisplayName: "Encryption Test",
|
||||
PasswordHash: "hash",
|
||||
Role: models.RoleEditor,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
// Create workspace with encrypted field
|
||||
workspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
Name: "Encryption Test",
|
||||
Theme: "dark",
|
||||
GitEnabled: true,
|
||||
GitURL: "https://github.com/user/repo",
|
||||
GitUser: "username",
|
||||
GitToken: "secret-token", // This field is encrypted
|
||||
GitCommitName: "Test User",
|
||||
GitCommitEmail: "test@example.com",
|
||||
}
|
||||
|
||||
if err := database.CreateWorkspace(workspace); err != nil {
|
||||
t.Fatalf("Failed to create test workspace: %v", err)
|
||||
}
|
||||
|
||||
// Verify our mock secrets service passed the token through unmodified
|
||||
// In a real application, the token would be encrypted in the database
|
||||
testDB := database.TestDB()
|
||||
var rawToken string
|
||||
err = testDB.QueryRow("SELECT git_token FROM workspaces WHERE id = ?", workspace.ID).Scan(&rawToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query raw token: %v", err)
|
||||
}
|
||||
|
||||
// With the mock secrets service, encryption is a no-op so the token is stored as-is
|
||||
if rawToken != "secret-token" {
|
||||
t.Errorf("Expected raw token 'secret-token', got '%s'", rawToken)
|
||||
}
|
||||
|
||||
// Verify the fetched workspace has the correct token
|
||||
fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get workspace: %v", err)
|
||||
}
|
||||
|
||||
if fetchedWorkspace.GitToken != "secret-token" {
|
||||
t.Errorf("Expected GitToken 'secret-token', got '%s'", fetchedWorkspace.GitToken)
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to compare slices of DBFields
|
||||
func compareDBFields(t *testing.T, got, want []db.DBField) {
|
||||
t.Helper()
|
||||
|
||||
if len(got) != len(want) {
|
||||
t.Errorf("Got %d fields, want %d", len(got), len(want))
|
||||
return
|
||||
}
|
||||
|
||||
for i := range got {
|
||||
if got[i].Name != want[i].Name {
|
||||
t.Errorf("Field %d name: got %s, want %s", i, got[i].Name, want[i].Name)
|
||||
}
|
||||
if got[i].OriginalName != want[i].OriginalName {
|
||||
t.Errorf("Field %d original name: got %s, want %s", i, got[i].OriginalName, want[i].OriginalName)
|
||||
}
|
||||
if !reflect.DeepEqual(got[i].Value, want[i].Value) {
|
||||
t.Errorf("Field %d value: got %v, want %v", i, got[i].Value, want[i].Value)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -49,7 +49,12 @@ func (db *database) EnsureJWTSecret() (string, error) {
|
||||
// GetSystemSetting retrieves a system setting by key
|
||||
func (db *database) GetSystemSetting(key string) (string, error) {
|
||||
var value string
|
||||
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
|
||||
query := db.NewQuery().
|
||||
Select("value").
|
||||
From("system_settings").
|
||||
Where("key = ").
|
||||
Placeholder(key)
|
||||
err := db.QueryRow(query.String(), query.args...).Scan(&value)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -59,11 +64,14 @@ func (db *database) GetSystemSetting(key string) (string, error) {
|
||||
|
||||
// SetSystemSetting stores or updates a system setting
|
||||
func (db *database) SetSystemSetting(key, value string) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO system_settings (key, value)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE SET value = ?`,
|
||||
key, value, value)
|
||||
query := db.NewQuery().
|
||||
Insert("system_settings", "key", "value").
|
||||
Values(2).
|
||||
AddArgs(key, value).
|
||||
Write("ON CONFLICT(key) DO UPDATE SET value = ").
|
||||
Placeholder(value)
|
||||
|
||||
_, err := db.Exec(query.String(), query.args...)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store system setting: %w", err)
|
||||
@@ -92,22 +100,30 @@ func (db *database) GetSystemStats() (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// Get total users
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
|
||||
query := db.NewQuery().
|
||||
Select("COUNT(*)").
|
||||
From("users")
|
||||
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get total users count: %w", err)
|
||||
}
|
||||
|
||||
// Get total workspaces
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
|
||||
query = db.NewQuery().
|
||||
Select("COUNT(*)").
|
||||
From("workspaces")
|
||||
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get total workspaces count: %w", err)
|
||||
}
|
||||
|
||||
// Get active users (users with activity in last 30 days)
|
||||
err = db.QueryRow(`
|
||||
SELECT COUNT(DISTINCT user_id)
|
||||
FROM sessions
|
||||
WHERE created_at > datetime('now', '-30 days')`).
|
||||
query = db.NewQuery().
|
||||
Select("COUNT(DISTINCT user_id)").
|
||||
From("sessions").
|
||||
Where("created_at >").
|
||||
TimeSince(30)
|
||||
err = db.QueryRow(query.String()).
|
||||
Scan(&stats.ActiveUsers)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active users count: %w", err)
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func TestSystemOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,11 @@ package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"lemma/internal/secrets"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type TestDatabase interface {
|
||||
@@ -12,19 +16,102 @@ type TestDatabase interface {
|
||||
TestDB() *sql.DB
|
||||
}
|
||||
|
||||
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) {
|
||||
db, err := Init(dbPath, secretsService)
|
||||
func NewTestSQLiteDB(secretsService secrets.Service) (TestDatabase, error) {
|
||||
db, err := Init(DBTypeSQLite, ":memory:", secretsService)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &testDatabase{db.(*database)}, nil
|
||||
return &testSQLiteDatabase{db.(*database)}, nil
|
||||
}
|
||||
|
||||
type testDatabase struct {
|
||||
type testSQLiteDatabase struct {
|
||||
*database
|
||||
}
|
||||
|
||||
func (td *testDatabase) TestDB() *sql.DB {
|
||||
func (td *testSQLiteDatabase) TestDB() *sql.DB {
|
||||
return td.DB
|
||||
}
|
||||
|
||||
// NewPostgresTestDB creates a test database using PostgreSQL
|
||||
func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, error) {
|
||||
if dbURL == "" {
|
||||
return nil, fmt.Errorf("postgres URL cannot be empty")
|
||||
}
|
||||
|
||||
initialDB, err := sql.Open("postgres", dbURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open postgres database: %w", err)
|
||||
}
|
||||
|
||||
if err := initialDB.Ping(); err != nil {
|
||||
initialDB.Close()
|
||||
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
|
||||
}
|
||||
|
||||
// Create a unique schema name for this test run to avoid conflicts
|
||||
schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano())
|
||||
_, err = initialDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName))
|
||||
if err != nil {
|
||||
initialDB.Close()
|
||||
return nil, fmt.Errorf("failed to create schema: %w", err)
|
||||
}
|
||||
|
||||
// Close the initial connection and create a new one with the schema set
|
||||
initialDB.Close()
|
||||
|
||||
var newDBURL string
|
||||
if strings.Contains(dbURL, "?") {
|
||||
// URL already has parameters
|
||||
newDBURL = fmt.Sprintf("%s&search_path=%s", dbURL, schemaName)
|
||||
} else {
|
||||
// URL has no parameters yet
|
||||
newDBURL = fmt.Sprintf("%s?search_path=%s", dbURL, schemaName)
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", newDBURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open postgres database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
|
||||
}
|
||||
|
||||
// Set search path to use our schema
|
||||
_, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName))
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to set search path: %w", err)
|
||||
}
|
||||
|
||||
// Create database instance
|
||||
database := &postgresTestDatabase{
|
||||
database: &database{DB: db, secretsService: secretsSvc, dbType: DBTypePostgres},
|
||||
schemaName: schemaName,
|
||||
}
|
||||
|
||||
return database, nil
|
||||
}
|
||||
|
||||
// postgresTestDatabase extends the regular postgres database to add test-specific cleanup
|
||||
type postgresTestDatabase struct {
|
||||
*database
|
||||
schemaName string
|
||||
}
|
||||
|
||||
// Close closes the database connection and drops the test schema
|
||||
func (db *postgresTestDatabase) Close() error {
|
||||
_, err := db.TestDB().Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", db.schemaName))
|
||||
if err != nil {
|
||||
log.Printf("Failed to drop schema %s: %v", db.schemaName, err)
|
||||
}
|
||||
|
||||
return db.TestDB().Close()
|
||||
}
|
||||
|
||||
// TestDB returns the underlying *sql.DB instance
|
||||
func (db *postgresTestDatabase) TestDB() *sql.DB {
|
||||
return db.DB
|
||||
}
|
||||
|
||||
@@ -17,26 +17,21 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO users (email, display_name, password_hash, role)
|
||||
VALUES (?, ?, ?, ?)`,
|
||||
user.Email, user.DisplayName, user.PasswordHash, user.Role)
|
||||
query, err := db.NewQuery().
|
||||
InsertStruct(user, "users")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
query.Returning("id", "created_at")
|
||||
|
||||
err = tx.QueryRow(query.String(), query.Args()...).
|
||||
Scan(&user.ID, &user.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert user: %w", err)
|
||||
}
|
||||
|
||||
userID, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
|
||||
}
|
||||
user.ID = int(userID)
|
||||
|
||||
// Retrieve the created_at timestamp
|
||||
err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get created timestamp: %w", err)
|
||||
}
|
||||
|
||||
// Create default workspace with default settings
|
||||
defaultWorkspace := &models.Workspace{
|
||||
UserID: user.ID,
|
||||
@@ -51,7 +46,13 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||
}
|
||||
|
||||
// Update user's last workspace ID
|
||||
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID)
|
||||
query = db.NewQuery().
|
||||
Update("users").
|
||||
Set("last_workspace_id").
|
||||
Placeholder(defaultWorkspace.ID).
|
||||
Where("id = ").
|
||||
Placeholder(user.ID)
|
||||
_, err = tx.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update last workspace ID: %w", err)
|
||||
}
|
||||
@@ -70,47 +71,39 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
|
||||
// Helper function to create a workspace in a transaction
|
||||
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
|
||||
log := getLogger().WithGroup("users")
|
||||
result, err := tx.Exec(`
|
||||
INSERT INTO workspaces (
|
||||
user_id, name,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name,
|
||||
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
|
||||
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,
|
||||
workspace.GitCommitName, workspace.GitCommitEmail,
|
||||
)
|
||||
|
||||
insertQuery, err := db.NewQuery().
|
||||
InsertStruct(workspace, "workspaces")
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
insertQuery.Returning("id")
|
||||
|
||||
err = tx.QueryRow(insertQuery.String(), insertQuery.Args()...).Scan(&workspace.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert workspace: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get workspace ID: %w", err)
|
||||
}
|
||||
workspace.ID = int(id)
|
||||
|
||||
log.Debug("created user workspace",
|
||||
"workspace_id", workspace.ID,
|
||||
"user_id", workspace.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves a user by its ID
|
||||
func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, email, display_name, password_hash, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
WHERE id = ?`, id).
|
||||
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
|
||||
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(user, "users")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
query = query.Where("id = ").Placeholder(id)
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, user)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
@@ -120,16 +113,18 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves a user by its email
|
||||
func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||
user := &models.User{}
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, email, display_name, password_hash, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
WHERE email = ?`, email).
|
||||
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
|
||||
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(user, "users")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
query = query.Where("email = ").Placeholder(email)
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, user)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
@@ -141,14 +136,16 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// UpdateUser updates an existing user record in the database
|
||||
func (db *database) UpdateUser(user *models.User) error {
|
||||
result, err := db.Exec(`
|
||||
UPDATE users
|
||||
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
|
||||
WHERE id = ?`,
|
||||
user.Email, user.DisplayName, user.PasswordHash, user.Role,
|
||||
user.LastWorkspaceID, user.ID)
|
||||
query := db.NewQuery()
|
||||
query, err := query.UpdateStruct(user, "users")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("id = ").Placeholder(user.ID)
|
||||
|
||||
result, err := db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
@@ -165,29 +162,25 @@ func (db *database) UpdateUser(user *models.User) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllUsers retrieves all users from the database
|
||||
func (db *database) GetAllUsers() ([]*models.User, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, email, display_name, role, created_at,
|
||||
last_workspace_id
|
||||
FROM users
|
||||
ORDER BY id ASC`)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(&models.User{}, "users")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.OrderBy("id ASC")
|
||||
|
||||
rows, err := db.Query(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query users: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []*models.User
|
||||
for rows.Next() {
|
||||
user := &models.User{}
|
||||
err := rows.Scan(
|
||||
&user.ID, &user.Email, &user.DisplayName, &user.Role,
|
||||
&user.CreatedAt, &user.LastWorkspaceID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan user row: %w", err)
|
||||
}
|
||||
users = append(users, user)
|
||||
users := []*models.User{}
|
||||
err = db.ScanStructs(rows, &users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan users: %w", err)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
@@ -200,15 +193,26 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Find workspace ID from name
|
||||
workspaceQuery := db.NewQuery().
|
||||
Select("id").
|
||||
From("workspaces").
|
||||
Where("user_id = ").Placeholder(userID).
|
||||
And("name = ").Placeholder(workspaceName)
|
||||
|
||||
var workspaceID int
|
||||
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?",
|
||||
userID, workspaceName).Scan(&workspaceID)
|
||||
err = tx.QueryRow(workspaceQuery.String(), workspaceQuery.Args()...).Scan(&workspaceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find workspace: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
|
||||
workspaceID, userID)
|
||||
// Update user's last workspace
|
||||
updateQuery := db.NewQuery().
|
||||
Update("users").
|
||||
Set("last_workspace_id").Placeholder(workspaceID).
|
||||
Where("id = ").Placeholder(userID)
|
||||
|
||||
_, err = tx.Exec(updateQuery.String(), updateQuery.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last workspace: %w", err)
|
||||
}
|
||||
@@ -221,6 +225,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser deletes a user and all their workspaces
|
||||
func (db *database) DeleteUser(id int) error {
|
||||
log := getLogger().WithGroup("users")
|
||||
log.Debug("deleting user", "user_id", id)
|
||||
@@ -233,13 +238,24 @@ func (db *database) DeleteUser(id int) error {
|
||||
|
||||
// Delete all user's workspaces first
|
||||
log.Debug("deleting user workspaces", "user_id", id)
|
||||
_, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id)
|
||||
|
||||
deleteWorkspacesQuery := db.NewQuery().
|
||||
Delete().
|
||||
From("workspaces").
|
||||
Where("user_id = ").Placeholder(id)
|
||||
|
||||
_, err = tx.Exec(deleteWorkspacesQuery.String(), deleteWorkspacesQuery.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete workspaces: %w", err)
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
_, err = tx.Exec("DELETE FROM users WHERE id = ?", id)
|
||||
deleteUserQuery := db.NewQuery().
|
||||
Delete().
|
||||
From("users").
|
||||
Where("id = ").Placeholder(id)
|
||||
|
||||
_, err = tx.Exec(deleteUserQuery.String(), deleteUserQuery.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user: %w", err)
|
||||
}
|
||||
@@ -253,15 +269,16 @@ func (db *database) DeleteUser(id int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
|
||||
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||
query := db.NewQuery().
|
||||
Select("w.name").
|
||||
From("workspaces w").
|
||||
Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
|
||||
Where("u.id = ").Placeholder(userID)
|
||||
|
||||
var workspaceName string
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
w.name
|
||||
FROM workspaces w
|
||||
JOIN users u ON u.last_workspace_id = w.id
|
||||
WHERE u.id = ?`, userID).
|
||||
Scan(&workspaceName)
|
||||
err := db.QueryRow(query.String(), query.Args()...).Scan(&workspaceName)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("no last workspace found")
|
||||
@@ -275,8 +292,13 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
|
||||
|
||||
// CountAdminUsers returns the number of admin users in the system
|
||||
func (db *database) CountAdminUsers() (int, error) {
|
||||
query := db.NewQuery().
|
||||
Select("COUNT(*)").
|
||||
From("users").
|
||||
Where("role = ").Placeholder(models.RoleAdmin)
|
||||
|
||||
var count int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
|
||||
err := db.QueryRow(query.String(), query.Args()...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to count admin users: %w", err)
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestUserOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
|
||||
@@ -19,58 +19,36 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
|
||||
workspace.SetDefaultSettings()
|
||||
}
|
||||
|
||||
// Encrypt token if present
|
||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
||||
query, err := db.NewQuery().
|
||||
InsertStruct(workspace, "workspaces")
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
result, err := db.Exec(`
|
||||
INSERT INTO workspaces (
|
||||
user_id, name, theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
|
||||
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail,
|
||||
)
|
||||
query.Returning("id", "created_at")
|
||||
|
||||
err = db.QueryRow(query.String(), query.Args()...).
|
||||
Scan(&workspace.ID, &workspace.CreatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert workspace: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get workspace ID: %w", err)
|
||||
}
|
||||
workspace.ID = int(id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetWorkspaceByID retrieves a workspace by its ID
|
||||
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(workspace, "workspaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("id = ").Placeholder(id)
|
||||
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
FROM workspaces
|
||||
WHERE id = ?`,
|
||||
id,
|
||||
).Scan(
|
||||
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
|
||||
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
|
||||
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
|
||||
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
|
||||
&workspace.GitCommitName, &workspace.GitCommitEmail,
|
||||
)
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, workspace)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("workspace not found")
|
||||
@@ -79,37 +57,22 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
|
||||
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt token
|
||||
workspace.GitToken, err = db.decryptToken(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
return workspace, nil
|
||||
}
|
||||
|
||||
// GetWorkspaceByName retrieves a workspace by its name and user ID
|
||||
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(workspace, "workspaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("user_id = ").Placeholder(userID).
|
||||
And("name = ").Placeholder(workspaceName)
|
||||
|
||||
err := db.QueryRow(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
FROM workspaces
|
||||
WHERE user_id = ? AND name = ?`,
|
||||
userID, workspaceName,
|
||||
).Scan(
|
||||
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
|
||||
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
|
||||
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
|
||||
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
|
||||
&workspace.GitCommitName, &workspace.GitCommitEmail,
|
||||
)
|
||||
row := db.QueryRow(query.String(), query.Args()...)
|
||||
err = db.ScanStruct(row, workspace)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("workspace not found")
|
||||
@@ -118,54 +81,22 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
|
||||
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt token
|
||||
workspace.GitToken, err = db.decryptToken(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
return workspace, nil
|
||||
}
|
||||
|
||||
// UpdateWorkspace updates a workspace record in the database
|
||||
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
// Encrypt token before storing
|
||||
encryptedToken, err := db.encryptToken(workspace.GitToken)
|
||||
|
||||
query := db.NewQuery()
|
||||
query, err := query.
|
||||
UpdateStruct(workspace, "workspaces")
|
||||
query = query.Where("id =").Placeholder(workspace.ID).And("user_id =").Placeholder(workspace.UserID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt token: %w", err)
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`
|
||||
UPDATE workspaces
|
||||
SET
|
||||
name = ?,
|
||||
theme = ?,
|
||||
auto_save = ?,
|
||||
show_hidden_files = ?,
|
||||
git_enabled = ?,
|
||||
git_url = ?,
|
||||
git_user = ?,
|
||||
git_token = ?,
|
||||
git_auto_commit = ?,
|
||||
git_commit_msg_template = ?,
|
||||
git_commit_name = ?,
|
||||
git_commit_email = ?
|
||||
WHERE id = ? AND user_id = ?`,
|
||||
workspace.Name,
|
||||
workspace.Theme,
|
||||
workspace.AutoSave,
|
||||
workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled,
|
||||
workspace.GitURL,
|
||||
workspace.GitUser,
|
||||
encryptedToken,
|
||||
workspace.GitAutoCommit,
|
||||
workspace.GitCommitMsgTemplate,
|
||||
workspace.GitCommitName,
|
||||
workspace.GitCommitEmail,
|
||||
workspace.ID,
|
||||
workspace.UserID,
|
||||
)
|
||||
_, err = db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update workspace: %w", err)
|
||||
}
|
||||
@@ -175,48 +106,24 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
|
||||
|
||||
// GetWorkspacesByUserID retrieves all workspaces for a user
|
||||
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
FROM workspaces
|
||||
WHERE user_id = ?`,
|
||||
userID,
|
||||
)
|
||||
workspace := &models.Workspace{}
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(workspace, "workspaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
query = query.Where("user_id = ").Placeholder(userID)
|
||||
|
||||
rows, err := db.Query(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query workspaces: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var workspaces []*models.Workspace
|
||||
for rows.Next() {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
err := rows.Scan(
|
||||
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
|
||||
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
|
||||
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
|
||||
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
|
||||
&workspace.GitCommitName, &workspace.GitCommitEmail,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt token
|
||||
workspace.GitToken, err = db.decryptToken(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
|
||||
err = db.ScanStructs(rows, &workspaces)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan workspaces: %w", err)
|
||||
}
|
||||
|
||||
return workspaces, nil
|
||||
@@ -224,34 +131,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
|
||||
|
||||
// UpdateWorkspaceSettings updates only the settings portion of a workspace
|
||||
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
_, err := db.Exec(`
|
||||
UPDATE workspaces
|
||||
SET
|
||||
theme = ?,
|
||||
auto_save = ?,
|
||||
show_hidden_files = ?,
|
||||
git_enabled = ?,
|
||||
git_url = ?,
|
||||
git_user = ?,
|
||||
git_token = ?,
|
||||
git_auto_commit = ?,
|
||||
git_commit_msg_template = ?,
|
||||
git_commit_name = ?,
|
||||
git_commit_email = ?
|
||||
WHERE id = ?`,
|
||||
workspace.Theme,
|
||||
workspace.AutoSave,
|
||||
workspace.ShowHiddenFiles,
|
||||
workspace.GitEnabled,
|
||||
workspace.GitURL,
|
||||
workspace.GitUser,
|
||||
workspace.GitToken,
|
||||
workspace.GitAutoCommit,
|
||||
workspace.GitCommitMsgTemplate,
|
||||
workspace.GitCommitName,
|
||||
workspace.GitCommitEmail,
|
||||
workspace.ID,
|
||||
)
|
||||
|
||||
query := db.NewQuery()
|
||||
query, err := query.
|
||||
UpdateStruct(workspace, "workspaces")
|
||||
query = query.Where("id =").Placeholder(workspace.ID)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update workspace settings: %w", err)
|
||||
}
|
||||
@@ -263,7 +153,12 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
|
||||
func (db *database) DeleteWorkspace(id int) error {
|
||||
log := getLogger().WithGroup("workspaces")
|
||||
|
||||
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
query := db.NewQuery().
|
||||
Delete().
|
||||
From("workspaces").
|
||||
Where("id = ").Placeholder(id)
|
||||
|
||||
_, err := db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete workspace: %w", err)
|
||||
}
|
||||
@@ -275,7 +170,13 @@ func (db *database) DeleteWorkspace(id int) error {
|
||||
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
|
||||
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
log := getLogger().WithGroup("workspaces")
|
||||
result, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
|
||||
|
||||
query := db.NewQuery().
|
||||
Delete().
|
||||
From("workspaces").
|
||||
Where("id = ").Placeholder(id)
|
||||
|
||||
result, err := tx.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete workspace in transaction: %w", err)
|
||||
}
|
||||
@@ -285,15 +186,18 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
|
||||
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("workspace deleted",
|
||||
"workspace_id", id)
|
||||
log.Debug("workspace deleted", "workspace_id", id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
|
||||
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
|
||||
result, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
|
||||
workspaceID, userID)
|
||||
query := db.NewQuery().
|
||||
Update("users").
|
||||
Set("last_workspace_id").Placeholder(workspaceID).
|
||||
Where("id = ").Placeholder(userID)
|
||||
|
||||
result, err := tx.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last workspace in transaction: %w", err)
|
||||
}
|
||||
@@ -308,8 +212,12 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
|
||||
|
||||
// UpdateLastOpenedFile updates the last opened file path for a workspace
|
||||
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
|
||||
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?",
|
||||
filePath, workspaceID)
|
||||
query := db.NewQuery().
|
||||
Update("workspaces").
|
||||
Set("last_opened_file_path").Placeholder(filePath).
|
||||
Where("id = ").Placeholder(workspaceID)
|
||||
|
||||
_, err := db.Exec(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update last opened file: %w", err)
|
||||
}
|
||||
@@ -319,9 +227,13 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
|
||||
|
||||
// GetLastOpenedFile retrieves the last opened file path for a workspace
|
||||
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
query := db.NewQuery().
|
||||
Select("last_opened_file_path").
|
||||
From("workspaces").
|
||||
Where("id = ").Placeholder(workspaceID)
|
||||
|
||||
var filePath sql.NullString
|
||||
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?",
|
||||
workspaceID).Scan(&filePath)
|
||||
err := db.QueryRow(query.String(), query.Args()...).Scan(&filePath)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", fmt.Errorf("workspace not found")
|
||||
@@ -339,46 +251,22 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
|
||||
|
||||
// GetAllWorkspaces retrieves all workspaces in the database
|
||||
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
|
||||
rows, err := db.Query(`
|
||||
SELECT
|
||||
id, user_id, name, created_at,
|
||||
theme, auto_save, show_hidden_files,
|
||||
git_enabled, git_url, git_user, git_token,
|
||||
git_auto_commit, git_commit_msg_template,
|
||||
git_commit_name, git_commit_email
|
||||
FROM workspaces`,
|
||||
)
|
||||
query := db.NewQuery()
|
||||
query, err := query.SelectStruct(&models.Workspace{}, "workspaces")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create query: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.Query(query.String(), query.Args()...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query workspaces: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var workspaces []*models.Workspace
|
||||
for rows.Next() {
|
||||
workspace := &models.Workspace{}
|
||||
var encryptedToken string
|
||||
err := rows.Scan(
|
||||
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
|
||||
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
|
||||
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
|
||||
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
|
||||
&workspace.GitCommitName, &workspace.GitCommitEmail,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
|
||||
}
|
||||
|
||||
// Decrypt token
|
||||
workspace.GitToken, err = db.decryptToken(encryptedToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt token: %w", err)
|
||||
}
|
||||
|
||||
workspaces = append(workspaces, workspace)
|
||||
}
|
||||
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
|
||||
err = db.ScanStructs(rows, &workspaces)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan workspaces: %w", err)
|
||||
}
|
||||
|
||||
return workspaces, nil
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWorkspaceOperations(t *testing.T) {
|
||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
||||
database, err := db.NewTestSQLiteDB(&mockSecrets{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create test database: %v", err)
|
||||
}
|
||||
|
||||
@@ -308,7 +308,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
|
||||
}
|
||||
|
||||
// Track what's being updated for logging
|
||||
updates := make(map[string]interface{})
|
||||
updates := make(map[string]any)
|
||||
|
||||
if req.Email != "" {
|
||||
user.Email = req.Email
|
||||
|
||||
@@ -15,21 +15,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// Helper function to check if a user exists in a slice of users
|
||||
func containsUser(users []*models.User, searchUser *models.User) bool {
|
||||
for _, u := range users {
|
||||
if u.ID == searchUser.ID &&
|
||||
u.Email == searchUser.Email &&
|
||||
u.DisplayName == searchUser.DisplayName &&
|
||||
u.Role == searchUser.Role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
func TestAdminHandlers_Integration(t *testing.T) {
|
||||
runWithDatabases(t, testAdminHandlers)
|
||||
}
|
||||
|
||||
func TestAdminHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
func testAdminHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("user management", func(t *testing.T) {
|
||||
@@ -241,3 +232,16 @@ func TestAdminHandlers_Integration(t *testing.T) {
|
||||
assert.Equal(t, http.StatusForbidden, rr.Code)
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to check if a user exists in a slice of users
|
||||
func containsUser(users []*models.User, searchUser *models.User) bool {
|
||||
for _, u := range users {
|
||||
if u.ID == searchUser.ID &&
|
||||
u.Email == searchUser.Email &&
|
||||
u.DisplayName == searchUser.DisplayName &&
|
||||
u.Role == searchUser.Role {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -19,7 +19,11 @@ import (
|
||||
)
|
||||
|
||||
func TestAuthHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
runWithDatabases(t, testAuthHandlers)
|
||||
}
|
||||
|
||||
func testAuthHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("login", func(t *testing.T) {
|
||||
|
||||
@@ -18,7 +18,11 @@ import (
|
||||
)
|
||||
|
||||
func TestFileHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
runWithDatabases(t, testFileHandlers)
|
||||
}
|
||||
|
||||
func testFileHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("file operations", func(t *testing.T) {
|
||||
@@ -192,7 +196,7 @@ func TestFileHandlers_Integration(t *testing.T) {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body interface{}
|
||||
body any
|
||||
}{
|
||||
{"list files", http.MethodGet, baseURL, nil},
|
||||
{"get file", http.MethodGet, baseURL + "/test.md", nil},
|
||||
|
||||
@@ -16,7 +16,11 @@ import (
|
||||
)
|
||||
|
||||
func TestGitHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
runWithDatabases(t, testGitHandlers)
|
||||
}
|
||||
|
||||
func testGitHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("git operations", func(t *testing.T) {
|
||||
@@ -123,7 +127,7 @@ func TestGitHandlers_Integration(t *testing.T) {
|
||||
name string
|
||||
method string
|
||||
path string
|
||||
body interface{}
|
||||
body any
|
||||
}{
|
||||
{
|
||||
name: "commit without token",
|
||||
|
||||
@@ -37,7 +37,7 @@ func NewHandler(db db.Database, s storage.Manager) *Handler {
|
||||
}
|
||||
|
||||
// respondJSON is a helper to send JSON responses
|
||||
func respondJSON(w http.ResponseWriter, data interface{}) {
|
||||
func respondJSON(w http.ResponseWriter, data any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(data); err != nil {
|
||||
respondError(w, "Failed to encode response", http.StatusInternalServerError)
|
||||
|
||||
@@ -45,8 +45,13 @@ type testUser struct {
|
||||
session *models.Session
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type db.DBType
|
||||
URL string
|
||||
}
|
||||
|
||||
// setupTestHarness creates a new test environment
|
||||
func setupTestHarness(t *testing.T) *testHarness {
|
||||
func setupTestHarness(t *testing.T, dbConfig DatabaseConfig) *testHarness {
|
||||
t.Helper()
|
||||
|
||||
// Create temporary directory for test files
|
||||
@@ -61,9 +66,20 @@ func setupTestHarness(t *testing.T) *testHarness {
|
||||
t.Fatalf("Failed to initialize secrets service: %v", err)
|
||||
}
|
||||
|
||||
database, err := db.NewTestDB(":memory:", secretsSvc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize test database: %v", err)
|
||||
var database db.TestDatabase
|
||||
switch dbConfig.Type {
|
||||
case db.DBTypeSQLite:
|
||||
database, err = db.NewTestSQLiteDB(secretsSvc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize test database: %v", err)
|
||||
}
|
||||
case db.DBTypePostgres:
|
||||
database, err = db.NewPostgresTestDB(dbConfig.URL, secretsSvc)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to initialize test database: %v", err)
|
||||
}
|
||||
default:
|
||||
t.Fatalf("Unsupported database type: %s", dbConfig.Type)
|
||||
}
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
@@ -99,7 +115,7 @@ func setupTestHarness(t *testing.T) *testHarness {
|
||||
|
||||
// Create test config
|
||||
testConfig := &app.Config{
|
||||
DBPath: ":memory:",
|
||||
DBURL: "sqlite://:memory:",
|
||||
WorkDir: tempDir,
|
||||
StaticPath: "../testdata",
|
||||
Port: "8081",
|
||||
@@ -156,6 +172,32 @@ func (h *testHarness) teardown(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// runWithDatabases runs a test function with both SQLite and PostgreSQL databases
|
||||
func runWithDatabases(t *testing.T, testFn func(*testing.T, DatabaseConfig)) {
|
||||
// Get PostgreSQL connection URL from environment variable
|
||||
postgresURL := os.Getenv("LEMMA_TEST_POSTGRES_URL")
|
||||
|
||||
// Always run with SQLite in-memory
|
||||
t.Run("SQLite", func(t *testing.T) {
|
||||
testFn(t, DatabaseConfig{
|
||||
Type: db.DBTypeSQLite,
|
||||
URL: "sqlite://:memory:",
|
||||
})
|
||||
})
|
||||
|
||||
// Run with PostgreSQL if connection URL is provided
|
||||
if postgresURL != "" {
|
||||
t.Run("PostgreSQL", func(t *testing.T) {
|
||||
testFn(t, DatabaseConfig{
|
||||
Type: db.DBTypePostgres,
|
||||
URL: postgresURL,
|
||||
})
|
||||
})
|
||||
} else {
|
||||
t.Log("Skipping PostgreSQL tests, LEMMA_TEST_POSTGRES_URL environment variable not set")
|
||||
}
|
||||
}
|
||||
|
||||
// createTestUser creates a test user and returns the user and access token
|
||||
func (h *testHarness) createTestUser(t *testing.T, email, password string, role models.UserRole) *testUser {
|
||||
t.Helper()
|
||||
@@ -195,7 +237,7 @@ func (h *testHarness) createTestUser(t *testing.T, email, password string, role
|
||||
}
|
||||
}
|
||||
|
||||
func (h *testHarness) newRequest(t *testing.T, method, path string, body interface{}) *http.Request {
|
||||
func (h *testHarness) newRequest(t *testing.T, method, path string, body any) *http.Request {
|
||||
t.Helper()
|
||||
|
||||
var reqBody []byte
|
||||
@@ -246,7 +288,7 @@ func (h *testHarness) addCSRFCookie(t *testing.T, req *http.Request) string {
|
||||
}
|
||||
|
||||
// makeRequest is the main helper for making JSON requests
|
||||
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, testUser *testUser) *httptest.ResponseRecorder {
|
||||
func (h *testHarness) makeRequest(t *testing.T, method, path string, body any, testUser *testUser) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
req := h.newRequest(t, method, path, body)
|
||||
|
||||
@@ -15,7 +15,11 @@ import (
|
||||
)
|
||||
|
||||
func TestUserHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
runWithDatabases(t, testUserHandlers)
|
||||
}
|
||||
|
||||
func testUserHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
currentEmail := h.RegularTestUser.userModel.Email
|
||||
|
||||
@@ -15,7 +15,11 @@ import (
|
||||
)
|
||||
|
||||
func TestWorkspaceHandlers_Integration(t *testing.T) {
|
||||
h := setupTestHarness(t)
|
||||
runWithDatabases(t, testWorkspaceHandlers)
|
||||
}
|
||||
|
||||
func testWorkspaceHandlers(t *testing.T, dbConfig DatabaseConfig) {
|
||||
h := setupTestHarness(t, dbConfig)
|
||||
defer h.teardown(t)
|
||||
|
||||
t.Run("list workspaces", func(t *testing.T) {
|
||||
|
||||
@@ -5,9 +5,9 @@ import "time"
|
||||
|
||||
// Session represents a user session in the database
|
||||
type Session struct {
|
||||
ID string // Unique session identifier
|
||||
UserID int // ID of the user this session belongs to
|
||||
RefreshToken string // The refresh token associated with this session
|
||||
ExpiresAt time.Time // When this session expires
|
||||
CreatedAt time.Time // When this session was created
|
||||
ID string `db:"id"` // Unique session identifier
|
||||
UserID int `db:"user_id"` // ID of the user this session belongs to
|
||||
RefreshToken string `db:"refresh_token"` // The refresh token associated with this session
|
||||
ExpiresAt time.Time `db:"expires_at"` // When this session expires
|
||||
CreatedAt time.Time `db:"created_at,default"` // When this session was created
|
||||
}
|
||||
|
||||
@@ -20,13 +20,13 @@ const (
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
Email string `json:"email" validate:"required,email"`
|
||||
DisplayName string `json:"displayName"`
|
||||
PasswordHash string `json:"-"`
|
||||
Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
LastWorkspaceID int `json:"lastWorkspaceId"`
|
||||
ID int `json:"id" db:"id,default" validate:"required,min=1"`
|
||||
Email string `json:"email" db:"email" validate:"required,email"`
|
||||
DisplayName string `json:"displayName" db:"display_name"`
|
||||
PasswordHash string `json:"-" db:"password_hash"`
|
||||
Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"`
|
||||
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
|
||||
LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"`
|
||||
}
|
||||
|
||||
// Validate validates the user struct
|
||||
|
||||
@@ -6,24 +6,24 @@ import (
|
||||
|
||||
// Workspace represents a user's workspace in the system
|
||||
type Workspace struct {
|
||||
ID int `json:"id" validate:"required,min=1"`
|
||||
UserID int `json:"userId" validate:"required,min=1"`
|
||||
Name string `json:"name" validate:"required"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
LastOpenedFilePath string `json:"lastOpenedFilePath"`
|
||||
ID int `json:"id" db:"id,default" validate:"required,min=1"`
|
||||
UserID int `json:"userId" db:"user_id" validate:"required,min=1"`
|
||||
Name string `json:"name" db:"name" validate:"required"`
|
||||
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
|
||||
LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"`
|
||||
|
||||
// Integrated settings
|
||||
Theme string `json:"theme" validate:"oneof=light dark"`
|
||||
AutoSave bool `json:"autoSave"`
|
||||
ShowHiddenFiles bool `json:"showHiddenFiles"`
|
||||
GitEnabled bool `json:"gitEnabled"`
|
||||
GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"`
|
||||
GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"`
|
||||
GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"`
|
||||
GitAutoCommit bool `json:"gitAutoCommit"`
|
||||
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
|
||||
GitCommitName string `json:"gitCommitName"`
|
||||
GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"`
|
||||
Theme string `json:"theme" db:"theme" validate:"oneof=light dark"`
|
||||
AutoSave bool `json:"autoSave" db:"auto_save"`
|
||||
ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"`
|
||||
GitEnabled bool `json:"gitEnabled" db:"git_enabled"`
|
||||
GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"`
|
||||
GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"`
|
||||
GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"`
|
||||
GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"`
|
||||
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"`
|
||||
GitCommitName string `json:"gitCommitName" db:"git_commit_name"`
|
||||
GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"`
|
||||
}
|
||||
|
||||
// Validate validates the workspace struct
|
||||
|
||||
@@ -37,7 +37,7 @@ func (m MockDirInfo) Size() int64 { return m.size }
|
||||
func (m MockDirInfo) Mode() fs.FileMode { return m.mode }
|
||||
func (m MockDirInfo) ModTime() time.Time { return m.modTime }
|
||||
func (m MockDirInfo) IsDir() bool { return m.isDir }
|
||||
func (m MockDirInfo) Sys() interface{} { return nil }
|
||||
func (m MockDirInfo) Sys() any { return nil }
|
||||
|
||||
type mockFS struct {
|
||||
// Record operations for verification
|
||||
|
||||
20
server/run_integration_tests.sh
Executable file
20
server/run_integration_tests.sh
Executable file
@@ -0,0 +1,20 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
COMPOSE_FILE=docker-compose.test.yaml
|
||||
|
||||
if ! docker compose -f $COMPOSE_FILE ps postgres | grep -q "running"; then
|
||||
docker compose -f $COMPOSE_FILE up -d
|
||||
until docker compose -f $COMPOSE_FILE exec postgres pg_isready -U postgres; do
|
||||
sleep 1
|
||||
done
|
||||
echo "PostgreSQL is ready!"
|
||||
fi
|
||||
|
||||
export LEMMA_TEST_POSTGRES_URL="postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
|
||||
|
||||
echo "Running integration tests..."
|
||||
go test -v -tags=test,integration ./...
|
||||
|
||||
docker compose -f $COMPOSE_FILE down
|
||||
Reference in New Issue
Block a user