Merge pull request #36 from lordmathis/feat/postgres

Add support for Postgres
This commit is contained in:
2025-03-07 22:42:08 +01:00
committed by GitHub
45 changed files with 3130 additions and 708 deletions

View File

@@ -13,6 +13,21 @@ jobs:
name: Run Tests
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: lemma_test
ports:
- 5432:5432
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
defaults:
run:
working-directory: ./server
@@ -29,6 +44,8 @@ jobs:
- name: Run Tests
run: go test -tags=test,integration ./... -v
env:
LEMMA_TEST_POSTGRES_URL: "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
- name: Run Tests with Race Detector
run: go test -tags=test,integration -race ./... -v

View File

@@ -15,6 +15,9 @@
"go.lintOnSave": "package",
"go.formatTool": "goimports",
"go.testFlags": ["-tags=test,integration"],
"go.testEnvVars": {
"LEMMA_TEST_POSTGRES_URL": "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
},
"[go]": {
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {

View File

@@ -33,8 +33,8 @@ Lemma can be configured using environment variables. Here are the available conf
### Optional Environment Variables
- `LEMMA_ENV`: Set to "development" to enable development mode
- `LEMMA_DB_PATH`: Path to the SQLite database file (default: "./lemma.db")
- `LEMMA_WORKDIR`: Working directory for application data (default: "./data")
- `LEMMA_DB_URL`: URL (Connection string) to the database. Supported databases are sqlite and postgres a (default: "./lemma.db")
- `LEMMA_WORKDIR`: Working directory for application data (default: "sqlite://lemma.db")
- `LEMMA_STATIC_PATH`: Path to static files (default: "../app/dist")
- `LEMMA_PORT`: Port to run the server on (default: "8080")
- `LEMMA_DOMAIN`: Domain name where the application is hosted for cookie authentication

View File

@@ -0,0 +1,35 @@
version: "3.8"
services:
postgres:
image: postgres:17
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: lemma_test
ports:
- "5432:5432"
volumes:
- postgres-data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
pgadmin:
image: dpage/pgadmin4
environment:
PGADMIN_DEFAULT_EMAIL: admin@admin.com
PGADMIN_DEFAULT_PASSWORD: admin
PGADMIN_CONFIG_SERVER_MODE: "False"
ports:
- "8080:80"
volumes:
- pgadmin-data:/var/lib/pgadmin
depends_on:
- postgres
volumes:
postgres-data:
pgadmin-data:

View File

@@ -57,14 +57,20 @@ package app // import "lemma/internal/app"
Package app provides application-level functionality for initializing and
running the server
FUNCTIONS
func ParseDBURL(dbURL string) (db.DBType, string, error)
ParseDBURL parses a database URL and returns the driver name and data source
TYPES
type Config struct {
DBPath string
DBURL string
DBType db.DBType
WorkDir string
StaticPath string
Port string
RootURL string
Domain string
CORSOrigins []string
AdminEmail string
@@ -281,6 +287,24 @@ const (
TYPES
type DBField struct {
Name string
Value any
Type reflect.Type
OriginalName string
// Has unexported fields.
}
func StructTagsToFields(s any) ([]DBField, error)
StructTagsToFields converts a struct to a slice of DBField instances
type DBType string
const (
DBTypeSQLite DBType = "sqlite3"
DBTypePostgres DBType = "postgres"
)
type Database interface {
UserStore
WorkspaceStore
@@ -292,14 +316,115 @@ type Database interface {
}
Database defines the methods for interacting with the database
func Init(dbPath string, secretsService secrets.Service) (Database, error)
func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error)
Init initializes the database connection
type Migration struct {
Version int
SQL string
type JoinType string
const (
InnerJoin JoinType = "INNER JOIN"
LeftJoin JoinType = "LEFT JOIN"
RightJoin JoinType = "RIGHT JOIN"
)
type Query struct {
// Has unexported fields.
}
Migration represents a database migration
Query represents a SQL query with its parameters
func NewQuery(dbType DBType, secretsService secrets.Service) *Query
NewQuery creates a new Query instance
func (q *Query) AddArgs(args ...any) *Query
AddArgs adds arguments to the query
func (q *Query) And(condition string) *Query
And adds an AND condition
func (q *Query) Args() []any
Args returns the query arguments
func (q *Query) Delete() *Query
Delete starts a DELETE statement
func (q *Query) EndGroup() *Query
EndGroup ends a parenthetical group
func (q *Query) From(table string) *Query
From adds a FROM clause
func (q *Query) GroupBy(columns ...string) *Query
GroupBy adds a GROUP BY clause
func (q *Query) Having(condition string) *Query
Having adds a HAVING clause for filtering groups
func (q *Query) Insert(table string, columns ...string) *Query
Insert starts an INSERT statement
func (q *Query) InsertStruct(s any, table string) (*Query, error)
InsertStruct creates an INSERT query from a struct
func (q *Query) Join(joinType JoinType, table, condition string) *Query
Join adds a JOIN clause
func (q *Query) Limit(limit int) *Query
Limit adds a LIMIT clause
func (q *Query) Offset(offset int) *Query
Offset adds an OFFSET clause
func (q *Query) Or(condition string) *Query
Or adds an OR condition
func (q *Query) OrderBy(columns ...string) *Query
OrderBy adds an ORDER BY clause
func (q *Query) Placeholder(arg any) *Query
Placeholder adds a placeholder for a single argument
func (q *Query) Placeholders(n int) *Query
Placeholders adds n placeholders separated by commas
func (q *Query) Returning(columns ...string) *Query
Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+)
func (q *Query) Select(columns ...string) *Query
Select adds a SELECT clause
func (q *Query) SelectStruct(s any, table string) (*Query, error)
SelectStruct creates a SELECT query from a struct
func (q *Query) Set(column string) *Query
Set adds a SET clause for updates
func (q *Query) StartGroup() *Query
StartGroup starts a parenthetical group
func (q *Query) String() string
String returns the formatted query string
func (q *Query) Update(table string) *Query
Update starts an UPDATE statement
func (q *Query) UpdateStruct(s any, table string) (*Query, error)
UpdateStruct creates an UPDATE query from a struct
func (q *Query) Values(count int) *Query
Values adds a VALUES clause
func (q *Query) Where(condition string) *Query
Where adds a WHERE clause
func (q *Query) WhereIn(column string, count int) *Query
WhereIn adds a WHERE IN clause
func (q *Query) Write(s string) *Query
Write adds a string to the query
type Scanner interface {
Scan(dest ...any) error
}
Scanner is an interface that both sql.Row and sql.Rows satisfy
type SessionStore interface {
CreateSession(session *models.Session) error
@@ -897,22 +1022,22 @@ and serialize data in the application.
TYPES
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
ID string `db:"id"` // Unique session identifier
UserID int `db:"user_id"` // ID of the user this session belongs to
RefreshToken string `db:"refresh_token"` // The refresh token associated with this session
ExpiresAt time.Time `db:"expires_at"` // When this session expires
CreatedAt time.Time `db:"created_at,default"` // When this session was created
}
Session represents a user session in the database
type User struct {
ID int `json:"id" validate:"required,min=1"`
Email string `json:"email" validate:"required,email"`
DisplayName string `json:"displayName"`
PasswordHash string `json:"-"`
Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"`
CreatedAt time.Time `json:"createdAt"`
LastWorkspaceID int `json:"lastWorkspaceId"`
ID int `json:"id" db:"id,default" validate:"required,min=1"`
Email string `json:"email" db:"email" validate:"required,email"`
DisplayName string `json:"displayName" db:"display_name"`
PasswordHash string `json:"-" db:"password_hash"`
Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"`
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"`
}
User represents a user in the system
@@ -930,24 +1055,24 @@ const (
User roles
type Workspace struct {
ID int `json:"id" validate:"required,min=1"`
UserID int `json:"userId" validate:"required,min=1"`
Name string `json:"name" validate:"required"`
CreatedAt time.Time `json:"createdAt"`
LastOpenedFilePath string `json:"lastOpenedFilePath"`
ID int `json:"id" db:"id,default" validate:"required,min=1"`
UserID int `json:"userId" db:"user_id" validate:"required,min=1"`
Name string `json:"name" db:"name" validate:"required"`
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"`
// Integrated settings
Theme string `json:"theme" validate:"oneof=light dark"`
AutoSave bool `json:"autoSave"`
ShowHiddenFiles bool `json:"showHiddenFiles"`
GitEnabled bool `json:"gitEnabled"`
GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"`
GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"`
GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"`
GitAutoCommit bool `json:"gitAutoCommit"`
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
GitCommitName string `json:"gitCommitName"`
GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"`
Theme string `json:"theme" db:"theme" validate:"oneof=light dark"`
AutoSave bool `json:"autoSave" db:"auto_save"`
ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"`
GitEnabled bool `json:"gitEnabled" db:"git_enabled"`
GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"`
GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"`
GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"`
GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"`
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"`
GitCommitName string `json:"gitCommitName" db:"git_commit_name"`
GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"`
}
Workspace represents a user's workspace in the system

View File

@@ -9,6 +9,7 @@ require (
github.com/go-git/go-git/v5 v5.13.1
github.com/go-playground/validator/v10 v10.22.1
github.com/golang-jwt/jwt/v5 v5.2.1
github.com/golang-migrate/migrate/v4 v4.18.2
github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.23
github.com/stretchr/testify v1.10.0
@@ -21,7 +22,7 @@ require (
require (
dario.cat/mergo v1.0.0 // indirect
github.com/KyleBanks/depth v1.2.1 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/cloudflare/circl v1.3.7 // indirect
@@ -38,10 +39,13 @@ require (
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/josharian/intern v1.0.0 // indirect
github.com/kevinburke/ssh_config v1.2.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/mailru/easyjson v0.7.6 // indirect
github.com/pjbgf/sha1cd v0.3.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@@ -49,12 +53,11 @@ require (
github.com/skeema/knownhosts v1.3.0 // indirect
github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe // indirect
github.com/xanzy/ssh-agent v0.3.3 // indirect
golang.org/x/mod v0.17.0 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/net v0.33.0 // indirect
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.28.0 // indirect
golang.org/x/text v0.21.0 // indirect
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect
golang.org/x/tools v0.24.0 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect

View File

@@ -1,10 +1,12 @@
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0=
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc=
github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE=
github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY=
github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow=
github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk=
github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
@@ -21,10 +23,22 @@ github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGL
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8=
github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM=
github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk=
github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E=
github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4=
github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk=
github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c=
github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc=
github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4=
github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ=
github.com/elazarl/goproxy v1.2.3/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
@@ -43,6 +57,10 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj
github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII=
github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M=
github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY=
github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg=
@@ -61,14 +79,23 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.22.1 h1:40JcKH+bBNGFczGuoBYgX4I6m/i27HYW8P9FDk5PbgA=
github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=
github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q=
github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk=
github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8=
github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE=
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I=
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
@@ -84,15 +111,27 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc=
github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA=
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0=
github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0=
github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo=
github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0=
github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM=
github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug=
github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM=
github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4=
github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
@@ -123,13 +162,23 @@ github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbW
github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40=
github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM=
github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk=
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8=
go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw=
go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8=
go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc=
go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8=
go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4=
go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ=
go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw=
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8=
golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
@@ -151,8 +200,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24=
golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=

View File

@@ -2,9 +2,11 @@ package app
import (
"fmt"
"lemma/internal/db"
"lemma/internal/logging"
"lemma/internal/secrets"
"os"
"path/filepath"
"strconv"
"strings"
"time"
@@ -12,7 +14,8 @@ import (
// Config holds the configuration for the application
type Config struct {
DBPath string
DBURL string
DBType db.DBType
WorkDir string
StaticPath string
Port string
@@ -31,7 +34,7 @@ type Config struct {
// DefaultConfig returns a new Config instance with default values
func DefaultConfig() *Config {
return &Config{
DBPath: "./lemma.db",
DBURL: "sqlite://lemma.db",
WorkDir: "./data",
StaticPath: "../app/dist",
Port: "8080",
@@ -65,6 +68,31 @@ func (c *Config) Redact() *Config {
return &redacted
}
// ParseDBURL parses a database URL and returns the driver name and data source
func ParseDBURL(dbURL string) (db.DBType, string, error) {
if strings.HasPrefix(dbURL, "sqlite://") || strings.HasPrefix(dbURL, "sqlite3://") {
path := strings.TrimPrefix(dbURL, "sqlite://")
path = strings.TrimPrefix(path, "sqlite3://")
if path == ":memory:" {
return db.DBTypeSQLite, path, nil
}
if !filepath.IsAbs(path) {
path = filepath.Clean(path)
}
return db.DBTypeSQLite, path, nil
}
// Try to parse as postgres URL
if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
return db.DBTypePostgres, dbURL, nil
}
return "", "", fmt.Errorf("unsupported database URL format: %s", dbURL)
}
// LoadConfig creates a new Config instance with values from environment variables
func LoadConfig() (*Config, error) {
config := DefaultConfig()
@@ -73,8 +101,13 @@ func LoadConfig() (*Config, error) {
config.IsDevelopment = env == "development"
}
if dbPath := os.Getenv("LEMMA_DB_PATH"); dbPath != "" {
config.DBPath = dbPath
if dbURL := os.Getenv("LEMMA_DB_URL"); dbURL != "" {
dbType, dataSource, err := ParseDBURL(dbURL)
if err != nil {
return nil, err
}
config.DBURL = dataSource
config.DBType = dbType
}
if workDir := os.Getenv("LEMMA_WORKDIR"); workDir != "" {

View File

@@ -2,6 +2,7 @@ package app_test
import (
"lemma/internal/app"
"lemma/internal/db"
"os"
"testing"
"time"
@@ -14,10 +15,10 @@ func TestDefaultConfig(t *testing.T) {
tests := []struct {
name string
got interface{}
expected interface{}
got any
expected any
}{
{"DBPath", cfg.DBPath, "./lemma.db"},
{"DBPath", cfg.DBURL, "sqlite://lemma.db"},
{"WorkDir", cfg.WorkDir, "./data"},
{"StaticPath", cfg.StaticPath, "../app/dist"},
{"Port", cfg.Port, "8080"},
@@ -47,7 +48,7 @@ func TestLoad(t *testing.T) {
cleanup := func() {
envVars := []string{
"LEMMA_ENV",
"LEMMA_DB_PATH",
"LEMMA_DB_URL",
"LEMMA_WORKDIR",
"LEMMA_STATIC_PATH",
"LEMMA_PORT",
@@ -81,8 +82,8 @@ func TestLoad(t *testing.T) {
t.Fatalf("Load() error = %v", err)
}
if cfg.DBPath != "./lemma.db" {
t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./lemma.db")
if cfg.DBURL != "sqlite://lemma.db" {
t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "sqlite://lemma.db")
}
})
@@ -93,7 +94,7 @@ func TestLoad(t *testing.T) {
// Set all environment variables
envs := map[string]string{
"LEMMA_ENV": "development",
"LEMMA_DB_PATH": "/custom/db/path.db",
"LEMMA_DB_URL": "sqlite:///custom/db/path.db",
"LEMMA_WORKDIR": "/custom/work/dir",
"LEMMA_STATIC_PATH": "/custom/static/path",
"LEMMA_PORT": "3000",
@@ -118,11 +119,12 @@ func TestLoad(t *testing.T) {
tests := []struct {
name string
got interface{}
expected interface{}
got any
expected any
}{
{"IsDevelopment", cfg.IsDevelopment, true},
{"DBPath", cfg.DBPath, "/custom/db/path.db"},
{"DBURL", cfg.DBURL, "/custom/db/path.db"},
{"DBType", cfg.DBType, db.DBTypeSQLite},
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
{"Port", cfg.Port, "3000"},

View File

@@ -28,9 +28,9 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
// initDatabase initializes and migrates the database
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
logging.Debug("initializing database", "path", cfg.DBPath)
logging.Debug("initializing database", "path", cfg.DBURL)
database, err := db.Init(cfg.DBPath, secretsService)
database, err := db.Init(cfg.DBType, cfg.DBURL, secretsService)
if err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err)
}

View File

@@ -112,7 +112,7 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to
func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) {
log := getJWTLogger()
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) {
// Validate the signing method
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])

View File

@@ -9,9 +9,17 @@ import (
"lemma/internal/models"
"lemma/internal/secrets"
_ "github.com/lib/pq" // Postgres driver
_ "github.com/mattn/go-sqlite3" // SQLite driver
)
type DBType string
const (
DBTypeSQLite DBType = "sqlite3"
DBTypePostgres DBType = "postgres"
)
// UserStore defines the methods for interacting with user data in the database
type UserStore interface {
CreateUser(user *models.User) (*models.User, error)
@@ -68,12 +76,18 @@ type SystemStore interface {
SetSystemSetting(key, value string) error
}
type StructScanner interface {
ScanStruct(row *sql.Row, dest interface{}) error
ScanStructs(rows *sql.Rows, dest interface{}) error
}
// Database defines the methods for interacting with the database
type Database interface {
UserStore
WorkspaceStore
SessionStore
SystemStore
StructScanner
Begin() (*sql.Tx, error)
Close() error
Migrate() error
@@ -93,6 +107,7 @@ var (
// Sub-interfaces
_ WorkspaceReader = (*database)(nil)
_ WorkspaceWriter = (*database)(nil)
_ StructScanner = (*database)(nil)
)
var logger logging.Logger
@@ -108,13 +123,45 @@ func getLogger() logging.Logger {
type database struct {
*sql.DB
secretsService secrets.Service
dbType DBType
}
// Init initializes the database connection
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
log := getLogger()
func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) {
db, err := sql.Open("sqlite3", dbPath)
switch dbType {
case DBTypeSQLite:
db, err := initSQLite(dbURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize SQLite database: %w", err)
}
database := &database{
DB: db,
secretsService: secretsService,
dbType: dbType,
}
return database, nil
case DBTypePostgres:
db, err := initPostgres(dbURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize Postgres database: %w", err)
}
database := &database{
DB: db,
secretsService: secretsService,
dbType: dbType,
}
return database, nil
}
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
func initSQLite(dbURL string) (*sql.DB, error) {
log := getLogger()
db, err := sql.Open("sqlite3", dbURL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
@@ -128,13 +175,20 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
}
log.Debug("foreign keys enabled")
database := &database{
DB: db,
secretsService: secretsService,
return db, nil
}
return database, nil
func initPostgres(dbURL string) (*sql.DB, error) {
db, err := sql.Open("postgres", dbURL)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return db, nil
}
// Close closes the database connection
@@ -148,29 +202,6 @@ func (db *database) Close() error {
return nil
}
// Helper methods for token encryption/decryption
func (db *database) encryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
encrypted, err := db.secretsService.Encrypt(token)
if err != nil {
return "", fmt.Errorf("failed to encrypt token: %w", err)
}
return encrypted, nil
}
func (db *database) decryptToken(token string) (string, error) {
if token == "" {
return "", nil
}
decrypted, err := db.secretsService.Decrypt(token)
if err != nil {
return "", fmt.Errorf("failed to decrypt token: %w", err)
}
return decrypted, nil
func (db *database) NewQuery() *Query {
return NewQuery(db.dbType, db.secretsService)
}

View File

@@ -1,141 +1,71 @@
package db
import (
"embed"
"fmt"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/postgres"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
)
// Migration represents a database migration
type Migration struct {
Version int
SQL string
}
var migrations = []Migration{
{
Version: 1,
SQL: `
-- Create users table
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
display_name TEXT,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_workspace_id INTEGER
);
-- Create workspaces table with integrated settings
CREATE TABLE IF NOT EXISTS workspaces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_opened_file_path TEXT,
-- Settings fields
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
auto_save BOOLEAN NOT NULL DEFAULT 0,
git_enabled BOOLEAN NOT NULL DEFAULT 0,
git_url TEXT,
git_user TEXT,
git_token TEXT,
git_auto_commit BOOLEAN NOT NULL DEFAULT 0,
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
git_commit_name TEXT,
git_commit_email TEXT,
show_hidden_files BOOLEAN NOT NULL DEFAULT 0,
created_by INTEGER REFERENCES users(id),
updated_by INTEGER REFERENCES users(id),
updated_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
);
-- Create sessions table for authentication
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
refresh_token TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
);
-- Create system_settings table for application settings
CREATE TABLE IF NOT EXISTS system_settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Create indexes for performance
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
`,
},
}
//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql
var migrationsFS embed.FS
// Migrate applies all database migrations
func (db *database) Migrate() error {
log := getLogger().WithGroup("migrations")
log.Info("starting database migration")
// Create migrations table if it doesn't exist
_, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations (
version INTEGER PRIMARY KEY
)`)
var migrationPath string
switch db.dbType {
case DBTypePostgres:
migrationPath = "migrations/postgres"
case DBTypeSQLite:
migrationPath = "migrations/sqlite"
default:
return fmt.Errorf("unsupported database driver: %s", db.dbType)
}
log.Debug("using migration path", "path", migrationPath)
sourceInstance, err := iofs.New(migrationsFS, migrationPath)
if err != nil {
return fmt.Errorf("failed to create migrations table: %w", err)
return fmt.Errorf("failed to create source instance: %w", err)
}
// Get current version
var currentVersion int
err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(&currentVersion)
var m *migrate.Migrate
switch db.dbType {
case DBTypePostgres:
driver, err := postgres.WithInstance(db.DB, &postgres.Config{})
if err != nil {
return fmt.Errorf("failed to get current migration version: %w", err)
return fmt.Errorf("failed to create postgres driver: %w", err)
}
// Apply new migrations
for _, migration := range migrations {
if migration.Version > currentVersion {
log := log.With("migration_version", migration.Version)
tx, err := db.Begin()
m, err = migrate.NewWithInstance("iofs", sourceInstance, "postgres", driver)
if err != nil {
return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err)
return fmt.Errorf("failed to create migrate instance: %w", err)
}
// Execute migration SQL
_, err = tx.Exec(migration.SQL)
case DBTypeSQLite:
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("migration %d failed: %v, rollback failed: %v",
migration.Version, err, rbErr)
return fmt.Errorf("failed to create sqlite driver: %w", err)
}
return fmt.Errorf("migration %d failed: %w", migration.Version, err)
}
// Update migrations table
_, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version)
m, err = migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver)
if err != nil {
if rbErr := tx.Rollback(); rbErr != nil {
return fmt.Errorf("failed to update migration version: %v, rollback failed: %v",
err, rbErr)
}
return fmt.Errorf("failed to update migration version: %w", err)
return fmt.Errorf("failed to create migrate instance: %w", err)
}
// Commit transaction
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
default:
return fmt.Errorf("unsupported database driver: %s", db.dbType)
}
currentVersion = migration.Version
log.Debug("migration applied", "new_version", currentVersion)
}
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
return fmt.Errorf("failed to run migrations: %w", err)
}
log.Info("database migration completed", "final_version", currentVersion)
log.Info("database migration completed")
return nil
}

View File

@@ -0,0 +1,9 @@
-- 001_initial_schema.down.sql (PostgreSQL version)
DROP INDEX IF EXISTS idx_sessions_refresh_token;
DROP INDEX IF EXISTS idx_sessions_expires_at;
DROP INDEX IF EXISTS idx_sessions_user_id;
DROP INDEX IF EXISTS idx_workspaces_user_id;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS workspaces;
DROP TABLE IF EXISTS system_settings;
DROP TABLE IF EXISTS users;

View File

@@ -0,0 +1,61 @@
-- 001_initial_schema.up.sql (PostgreSQL version)
-- Create users table
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
email TEXT NOT NULL UNIQUE,
display_name TEXT,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_workspace_id INTEGER
);
-- Create workspaces table with integrated settings
CREATE TABLE IF NOT EXISTS workspaces (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL,
name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_opened_file_path TEXT,
-- Settings fields
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
auto_save BOOLEAN NOT NULL DEFAULT FALSE,
git_enabled BOOLEAN NOT NULL DEFAULT FALSE,
git_url TEXT,
git_user TEXT,
git_token TEXT,
git_auto_commit BOOLEAN NOT NULL DEFAULT FALSE,
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
git_commit_name TEXT,
git_commit_email TEXT,
show_hidden_files BOOLEAN NOT NULL DEFAULT FALSE,
created_by INTEGER REFERENCES users(id),
updated_by INTEGER REFERENCES users(id),
updated_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
);
-- Create sessions table for authentication
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
refresh_token TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
);
-- Create system_settings table for application settings
CREATE TABLE IF NOT EXISTS system_settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Create indexes for performance
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
CREATE INDEX idx_workspaces_user_id ON workspaces(user_id);

View File

@@ -0,0 +1,9 @@
-- 001_initial_schema.down.sql
DROP INDEX IF EXISTS idx_sessions_refresh_token;
DROP INDEX IF EXISTS idx_sessions_expires_at;
DROP INDEX IF EXISTS idx_sessions_user_id;
DROP INDEX IF EXISTS idx_workspaces_user_id;
DROP TABLE IF EXISTS sessions;
DROP TABLE IF EXISTS workspaces;
DROP TABLE IF EXISTS system_settings;
DROP TABLE IF EXISTS users;

View File

@@ -0,0 +1,60 @@
-- 001_initial_schema.up.sql
-- Create users table
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE,
display_name TEXT,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')),
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_workspace_id INTEGER
);
-- Create workspaces table with integrated settings
CREATE TABLE IF NOT EXISTS workspaces (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_opened_file_path TEXT,
-- Settings fields
theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')),
auto_save BOOLEAN NOT NULL DEFAULT 0,
git_enabled BOOLEAN NOT NULL DEFAULT 0,
git_url TEXT,
git_user TEXT,
git_token TEXT,
git_auto_commit BOOLEAN NOT NULL DEFAULT 0,
git_commit_msg_template TEXT DEFAULT '${action} ${filename}',
git_commit_name TEXT,
git_commit_email TEXT,
show_hidden_files BOOLEAN NOT NULL DEFAULT 0,
created_by INTEGER REFERENCES users(id),
updated_by INTEGER REFERENCES users(id),
updated_at TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id)
);
-- Create sessions table for authentication
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
refresh_token TEXT NOT NULL,
expires_at TIMESTAMP NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE
);
-- Create system_settings table for application settings
CREATE TABLE IF NOT EXISTS system_settings (
key TEXT PRIMARY KEY,
value TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
-- Create indexes for performance
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token);
CREATE INDEX idx_workspaces_user_id ON workspaces(user_id);

View File

@@ -1,53 +1,35 @@
package db_test
import (
"testing"
"lemma/internal/db"
_ "lemma/internal/testenv"
"testing"
_ "github.com/mattn/go-sqlite3"
)
func TestMigrate(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("failed to initialize database: %v", err)
}
defer database.Close()
t.Run("migrations are applied in order", func(t *testing.T) {
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run initial migrations: %v", err)
}
// Check migration version
var version int
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 1 { // Current number of migrations in production code
t.Errorf("expected migration version 1, got %d", version)
}
// Verify number of migration entries matches versions applied
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration entries, got %d", count)
}
})
t.Run("migrations create expected schema", func(t *testing.T) {
// Run migrations
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
// Verify tables exist
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"}
tables := []string{
"users",
"workspaces",
"sessions",
"system_settings",
"schema_migrations",
}
for _, table := range tables {
if !tableExists(t, database, table) {
t.Errorf("table %q does not exist", table)
@@ -62,92 +44,34 @@ func TestMigrate(t *testing.T) {
{"sessions", "idx_sessions_user_id"},
{"sessions", "idx_sessions_expires_at"},
{"sessions", "idx_sessions_refresh_token"},
{"workspaces", "idx_workspaces_user_id"},
}
for _, idx := range indexes {
if !indexExists(t, database, idx.table, idx.name) {
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
}
}
})
t.Run("migrations are idempotent", func(t *testing.T) {
// Run migrations again
if err := database.Migrate(); err != nil {
t.Fatalf("failed to re-run migrations: %v", err)
}
// Verify migration count hasn't changed
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration entries, got %d", count)
}
})
t.Run("rollback on migration failure", func(t *testing.T) {
// Create a test table that would conflict with a failing migration
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
if err != nil {
t.Fatalf("failed to create test table: %v", err)
}
// Start transaction
tx, err := database.Begin()
if err != nil {
t.Fatalf("failed to start transaction: %v", err)
}
// Try operations that should fail and rollback
_, err = tx.Exec(`
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
INSERT INTO nonexistent_table VALUES (1);
`)
if err == nil {
tx.Rollback()
t.Fatal("expected migration to fail")
}
tx.Rollback()
// Verify the migration version hasn't changed
var version int
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 1 {
t.Errorf("expected migration version to remain at 1, got %d", version)
}
})
}
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
t.Helper()
var name string
err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master
WHERE type='table' AND name=?`,
tableName,
).Scan(&name)
return err == nil
}
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
t.Helper()
var name string
err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master
WHERE type='index' AND tbl_name=? AND name=?`,
tableName, indexName,
).Scan(&name)
return err == nil
}

306
server/internal/db/query.go Normal file
View File

@@ -0,0 +1,306 @@
package db
import (
"fmt"
"lemma/internal/secrets"
"strings"
)
type JoinType string
const (
InnerJoin JoinType = "INNER JOIN"
LeftJoin JoinType = "LEFT JOIN"
RightJoin JoinType = "RIGHT JOIN"
)
// Query represents a SQL query with its parameters
type Query struct {
builder strings.Builder
args []any
dbType DBType
secretsService secrets.Service
pos int // tracks the current placeholder position
hasSelect bool
hasFrom bool
hasWhere bool
hasOrderBy bool
hasGroupBy bool
hasHaving bool
hasLimit bool
hasOffset bool
isInParens bool
parensDepth int
}
// NewQuery creates a new Query instance
func NewQuery(dbType DBType, secretsService secrets.Service) *Query {
return &Query{
dbType: dbType,
secretsService: secretsService,
args: make([]any, 0),
}
}
// Select adds a SELECT clause
func (q *Query) Select(columns ...string) *Query {
if !q.hasSelect {
q.Write("SELECT ")
q.Write(strings.Join(columns, ", "))
q.hasSelect = true
}
return q
}
// From adds a FROM clause
func (q *Query) From(table string) *Query {
if !q.hasFrom {
q.Write(" FROM ")
q.Write(table)
q.hasFrom = true
}
return q
}
// Where adds a WHERE clause
func (q *Query) Where(condition string) *Query {
if !q.hasWhere {
q.Write(" WHERE ")
q.hasWhere = true
} else {
q.Write(" AND ")
}
q.Write(condition)
return q
}
// WhereIn adds a WHERE IN clause
func (q *Query) WhereIn(column string, count int) *Query {
if !q.hasWhere {
q.Write(" WHERE ")
q.hasWhere = true
} else {
q.Write(" AND ")
}
q.Write(column)
q.Write(" IN (")
q.Placeholders(count)
q.Write(")")
return q
}
// And adds an AND condition
func (q *Query) And(condition string) *Query {
q.Write(" AND ")
q.Write(condition)
return q
}
// Or adds an OR condition
func (q *Query) Or(condition string) *Query {
q.Write(" OR ")
q.Write(condition)
return q
}
// Join adds a JOIN clause
func (q *Query) Join(joinType JoinType, table, condition string) *Query {
q.Write(" ")
q.Write(string(joinType))
q.Write(" ")
q.Write(table)
q.Write(" ON ")
q.Write(condition)
return q
}
// OrderBy adds an ORDER BY clause
func (q *Query) OrderBy(columns ...string) *Query {
if !q.hasOrderBy {
q.Write(" ORDER BY ")
q.Write(strings.Join(columns, ", "))
q.hasOrderBy = true
}
return q
}
// GroupBy adds a GROUP BY clause
func (q *Query) GroupBy(columns ...string) *Query {
if !q.hasGroupBy {
q.Write(" GROUP BY ")
q.Write(strings.Join(columns, ", "))
q.hasGroupBy = true
}
return q
}
// Having adds a HAVING clause for filtering groups
func (q *Query) Having(condition string) *Query {
if !q.hasHaving {
q.Write(" HAVING ")
q.hasHaving = true
} else {
q.Write(" AND ")
}
q.Write(condition)
return q
}
// Limit adds a LIMIT clause
func (q *Query) Limit(limit int) *Query {
if !q.hasLimit {
q.Write(" LIMIT ")
q.Write(fmt.Sprintf("%d", limit))
q.hasLimit = true
}
return q
}
// Offset adds an OFFSET clause
func (q *Query) Offset(offset int) *Query {
if !q.hasOffset {
q.Write(" OFFSET ")
q.Write(fmt.Sprintf("%d", offset))
q.hasOffset = true
}
return q
}
// Insert starts an INSERT statement
func (q *Query) Insert(table string, columns ...string) *Query {
q.Write("INSERT INTO ")
q.Write(table)
q.Write(" (")
q.Write(strings.Join(columns, ", "))
q.Write(") VALUES ")
return q
}
// Values adds a VALUES clause
func (q *Query) Values(count int) *Query {
q.Write("(")
q.Placeholders(count)
q.Write(")")
return q
}
// Update starts an UPDATE statement
func (q *Query) Update(table string) *Query {
q.Write("UPDATE ")
q.Write(table)
q.Write(" SET ")
return q
}
// Set adds a SET clause for updates
func (q *Query) Set(column string) *Query {
if strings.Contains(q.builder.String(), "SET ") &&
!strings.HasSuffix(q.builder.String(), "SET ") {
q.Write(", ")
}
q.Write(column)
q.Write(" = ")
return q
}
// Delete starts a DELETE statement
func (q *Query) Delete() *Query {
q.Write("DELETE")
return q
}
// StartGroup starts a parenthetical group
func (q *Query) StartGroup() *Query {
if q.hasWhere {
q.Write(" AND (")
} else {
q.Write(" WHERE (")
q.hasWhere = true
}
q.parensDepth++
return q
}
// EndGroup ends a parenthetical group
func (q *Query) EndGroup() *Query {
if q.parensDepth > 0 {
q.Write(")")
q.parensDepth--
}
return q
}
// Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+)
func (q *Query) Returning(columns ...string) *Query {
q.Write(" RETURNING ")
if len(columns) == 1 && columns[0] == "*" {
q.Write("*")
} else {
q.Write(strings.Join(columns, ", "))
}
return q
}
// Write adds a string to the query
func (q *Query) Write(s string) *Query {
q.builder.WriteString(s)
return q
}
// Placeholder adds a placeholder for a single argument
func (q *Query) Placeholder(arg any) *Query {
q.pos++
q.args = append(q.args, arg)
if q.dbType == DBTypePostgres {
q.builder.WriteString(fmt.Sprintf("$%d", q.pos))
} else {
q.builder.WriteString("?")
}
return q
}
// Placeholders adds n placeholders separated by commas
func (q *Query) Placeholders(n int) *Query {
placeholders := make([]string, n)
for i := range n {
q.pos++
if q.dbType == DBTypePostgres {
placeholders[i] = fmt.Sprintf("$%d", q.pos)
} else {
placeholders[i] = "?"
}
}
q.builder.WriteString(strings.Join(placeholders, ", "))
return q
}
func (q *Query) TimeSince(days int) *Query {
if q.dbType == DBTypePostgres {
q.builder.WriteString(fmt.Sprintf("NOW() - INTERVAL '%d days'", days))
} else {
q.builder.WriteString(fmt.Sprintf("datetime('now', '-%d days')", days))
}
return q
}
// AddArgs adds arguments to the query
func (q *Query) AddArgs(args ...any) *Query {
q.args = append(q.args, args...)
return q
}
// String returns the formatted query string
func (q *Query) String() string {
return q.builder.String()
}
// Args returns the query arguments
func (q *Query) Args() []any {
return q.args
}

View File

@@ -0,0 +1,856 @@
package db_test
import (
"reflect"
"testing"
"lemma/internal/db"
)
func TestNewQuery(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
}{
{
name: "SQLite query",
dbType: db.DBTypeSQLite,
},
{
name: "Postgres query",
dbType: db.DBTypePostgres,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
// Test that a new query is empty
if q.String() != "" {
t.Errorf("NewQuery() should return empty string, got %q", q.String())
}
if len(q.Args()) != 0 {
t.Errorf("NewQuery() should return empty args, got %v", q.Args())
}
// Test placeholder behavior - SQLite uses ? and Postgres uses $1
q.Write("test").Placeholder(1)
expectedPlaceholder := "?"
if tt.dbType == db.DBTypePostgres {
expectedPlaceholder = "$1"
}
if q.String() != "test"+expectedPlaceholder {
t.Errorf("Expected placeholder format %q for %s, got %q",
"test"+expectedPlaceholder, tt.name, q.String())
}
})
}
}
func TestBasicQueryBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Simple select SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users")
},
wantSQL: "SELECT id, name FROM users",
wantArgs: []any{},
},
{
name: "Simple select Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users")
},
wantSQL: "SELECT id, name FROM users",
wantArgs: []any{},
},
{
name: "Select with where SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT id, name FROM users WHERE id = ?",
wantArgs: []any{1},
},
{
name: "Select with where Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("id", "name").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT id, name FROM users WHERE id = $1",
wantArgs: []any{1},
},
{
name: "Multiple where conditions SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("role = ").Placeholder("admin")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?",
wantArgs: []any{true, "admin"},
},
{
name: "Multiple where conditions Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("role = ").Placeholder("admin")
},
wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2",
wantArgs: []any{true, "admin"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestPlaceholders(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Single placeholder SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
},
wantSQL: "SELECT * FROM users WHERE id = ?",
wantArgs: []any{42},
},
{
name: "Single placeholder Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42)
},
wantSQL: "SELECT * FROM users WHERE id = $1",
wantArgs: []any{42},
},
{
name: "Multiple placeholders SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").
Placeholder(42).
Write(" AND name = ").
Placeholder("John")
},
wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?",
wantArgs: []any{42, "John"},
},
{
name: "Multiple placeholders Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id = ").
Placeholder(42).
Write(" AND name = ").
Placeholder("John")
},
wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2",
wantArgs: []any{42, "John"},
},
{
name: "Placeholders for IN SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id IN (").
Placeholders(3).
Write(")").
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
wantArgs: []any{1, 2, 3},
},
{
name: "Placeholders for IN Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Write("SELECT * FROM users WHERE id IN (").
Placeholders(3).
Write(")").
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)",
wantArgs: []any{1, 2, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestWhereClauseBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Simple where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Where("id = ").Placeholder(1)
},
wantSQL: "SELECT * FROM users WHERE id = ?",
wantArgs: []any{1},
},
{
name: "Where with And",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("id = ").Placeholder(1).
And("active = ").Placeholder(true)
},
wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?",
wantArgs: []any{1, true},
},
{
name: "Where with Or",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("id = ").Placeholder(1).
Or("id = ").Placeholder(2)
},
wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?",
wantArgs: []any{1, 2},
},
{
name: "Where with parentheses",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
And("(").
Write("id = ").Placeholder(1).
Or("id = ").Placeholder(2).
Write(")")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
wantArgs: []any{true, 1, 2},
},
{
name: "Where with StartGroup and EndGroup",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("active = ").Placeholder(true).
Write(" AND (").
Write("id = ").Placeholder(1).
Or("id = ").Placeholder(2).
Write(")")
},
wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)",
wantArgs: []any{true, 1, 2},
},
{
name: "Where with nested groups",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
Where("(").
Write("active = ").Placeholder(true).
Or("role = ").Placeholder("admin").
Write(")").
And("created_at > ").Placeholder("2020-01-01")
},
wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?",
wantArgs: []any{true, "admin", "2020-01-01"},
},
{
name: "WhereIn",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").
WhereIn("id", 3).
AddArgs(1, 2, 3)
},
wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)",
wantArgs: []any{1, 2, 3},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestJoinClauseBuilding(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Inner join",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id",
wantArgs: []any{},
},
{
name: "Left join",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id",
wantArgs: []any{},
},
{
name: "Multiple joins",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name", "s.role").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
Join(db.LeftJoin, "settings s", "s.user_id = u.id")
},
wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id",
wantArgs: []any{},
},
{
name: "Join with where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.*", "w.name").
From("users u").
Join(db.InnerJoin, "workspaces w", "w.user_id = u.id").
Where("u.active = ").Placeholder(true)
},
wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?",
wantArgs: []any{true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestOrderLimitOffset(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Order by",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").OrderBy("name ASC")
},
wantSQL: "SELECT * FROM users ORDER BY name ASC",
wantArgs: []any{},
},
{
name: "Order by multiple columns",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC")
},
wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC",
wantArgs: []any{},
},
{
name: "Limit",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Limit(10)
},
wantSQL: "SELECT * FROM users LIMIT 10",
wantArgs: []any{},
},
{
name: "Limit and offset",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").From("users").Limit(10).Offset(20)
},
wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20",
wantArgs: []any{},
},
{
name: "Complete query with all clauses",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("*").
From("users").
Where("active = ").Placeholder(true).
OrderBy("name ASC").
Limit(10).
Offset(20)
},
wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20",
wantArgs: []any{true},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestInsertUpdateDelete(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Insert SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John", "john@example.com")
},
wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)",
wantArgs: []any{"John", "john@example.com"},
},
{
name: "Insert Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John", "john@example.com")
},
wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)",
wantArgs: []any{"John", "john@example.com"},
},
{
name: "Update SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("John").
Set("email").Placeholder("john@example.com").
Where("id = ").Placeholder(1)
},
wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?",
wantArgs: []any{"John", "john@example.com", 1},
},
{
name: "Update Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("John").
Set("email").Placeholder("john@example.com").
Where("id = ").Placeholder(1)
},
wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3",
wantArgs: []any{"John", "john@example.com", 1},
},
{
name: "Delete SQLite",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Delete().From("users").Where("id = ").Placeholder(1)
},
wantSQL: "DELETE FROM users WHERE id = ?",
wantArgs: []any{1},
},
{
name: "Delete Postgres",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Delete().From("users").Where("id = ").Placeholder(1)
},
wantSQL: "DELETE FROM users WHERE id = $1",
wantArgs: []any{1},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestHavingClause(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Simple having",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "COUNT(*) as count").
From("employees").
GroupBy("department").
Having("count > ").Placeholder(5)
},
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?",
wantArgs: []any{5},
},
{
name: "Having with multiple conditions",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "AVG(salary) as avg_salary").
From("employees").
GroupBy("department").
Having("avg_salary > ").Placeholder(50000).
And("COUNT(*) > ").Placeholder(3)
},
wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?",
wantArgs: []any{50000, 3},
},
{
name: "Having with postgres placeholders",
dbType: db.DBTypePostgres,
buildFn: func(q *db.Query) *db.Query {
return q.Select("department", "COUNT(*) as count").
From("employees").
GroupBy("department").
Having("count > ").Placeholder(5)
},
wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1",
wantArgs: []any{5},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}
func TestQueryReturning(t *testing.T) {
testCases := []struct {
name string
dbType db.DBType
buildQuery func(q *db.Query) *db.Query
expectedSQL string
expectedArgs []any
}{
{
name: "SQLite INSERT with RETURNING single column",
dbType: db.DBTypeSQLite,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("id")
},
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id",
expectedArgs: []any{"John Doe", "john@example.com"},
},
{
name: "PostgreSQL INSERT with RETURNING single column",
dbType: db.DBTypePostgres,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("id")
},
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id",
expectedArgs: []any{"John Doe", "john@example.com"},
},
{
name: "SQLite INSERT with RETURNING multiple columns",
dbType: db.DBTypeSQLite,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("id", "created_at")
},
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id, created_at",
expectedArgs: []any{"John Doe", "john@example.com"},
},
{
name: "PostgreSQL INSERT with RETURNING multiple columns",
dbType: db.DBTypePostgres,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("id", "created_at")
},
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id, created_at",
expectedArgs: []any{"John Doe", "john@example.com"},
},
{
name: "SQLite UPDATE with RETURNING",
dbType: db.DBTypeSQLite,
buildQuery: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("Jane Doe").
Where("id = ").Placeholder(1).
Returning("name", "updated_at")
},
expectedSQL: "UPDATE users SET name = ? WHERE id = ? RETURNING name, updated_at",
expectedArgs: []any{"Jane Doe", 1},
},
{
name: "PostgreSQL UPDATE with RETURNING",
dbType: db.DBTypePostgres,
buildQuery: func(q *db.Query) *db.Query {
return q.Update("users").
Set("name").Placeholder("Jane Doe").
Where("id = ").Placeholder(1).
Returning("name", "updated_at")
},
expectedSQL: "UPDATE users SET name = $1 WHERE id = $2 RETURNING name, updated_at",
expectedArgs: []any{"Jane Doe", 1},
},
{
name: "SQLite DELETE with RETURNING",
dbType: db.DBTypeSQLite,
buildQuery: func(q *db.Query) *db.Query {
return q.Delete().
From("users").
Where("id = ").Placeholder(1).
Returning("id", "name", "email")
},
expectedSQL: "DELETE FROM users WHERE id = ? RETURNING id, name, email",
expectedArgs: []any{1},
},
{
name: "PostgreSQL DELETE with RETURNING",
dbType: db.DBTypePostgres,
buildQuery: func(q *db.Query) *db.Query {
return q.Delete().
From("users").
Where("id = ").Placeholder(1).
Returning("id", "name", "email")
},
expectedSQL: "DELETE FROM users WHERE id = $1 RETURNING id, name, email",
expectedArgs: []any{1},
},
{
name: "SQLite INSERT with RETURNING *",
dbType: db.DBTypeSQLite,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("*")
},
expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING *",
expectedArgs: []any{"John Doe", "john@example.com"},
},
{
name: "PostgreSQL INSERT with RETURNING *",
dbType: db.DBTypePostgres,
buildQuery: func(q *db.Query) *db.Query {
return q.Insert("users", "name", "email").
Values(2).
AddArgs("John Doe", "john@example.com").
Returning("*")
},
expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING *",
expectedArgs: []any{"John Doe", "john@example.com"},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
query := db.NewQuery(tc.dbType, &mockSecrets{})
result := tc.buildQuery(query)
if result.String() != tc.expectedSQL {
t.Errorf("Expected SQL: %s, got: %s", tc.expectedSQL, result.String())
}
if len(result.Args()) != len(tc.expectedArgs) {
t.Errorf("Expected %d args, got %d", len(tc.expectedArgs), len(result.Args()))
}
for i, arg := range result.Args() {
if arg != tc.expectedArgs[i] {
t.Errorf("Expected arg %d to be %v, got %v", i, tc.expectedArgs[i], arg)
}
}
})
}
}
func TestComplexQueries(t *testing.T) {
tests := []struct {
name string
dbType db.DBType
buildFn func(*db.Query) *db.Query
wantSQL string
wantArgs []any
}{
{
name: "Complex select with join and where",
dbType: db.DBTypeSQLite,
buildFn: func(q *db.Query) *db.Query {
return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count").
From("users u").
Join(db.LeftJoin, "workspaces w", "w.user_id = u.id").
Where("u.active = ").Placeholder(true).
GroupBy("u.id", "u.name").
Having("COUNT(w.id) > ").Placeholder(0).
OrderBy("workspace_count DESC").
Limit(10)
},
wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10",
wantArgs: []any{true, 0},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
q := db.NewQuery(tt.dbType, &mockSecrets{})
q = tt.buildFn(q)
gotSQL := q.String()
gotArgs := q.Args()
if gotSQL != tt.wantSQL {
t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}

View File

@@ -10,11 +10,12 @@ import (
// CreateSession inserts a new session record into the database
func (db *database) CreateSession(session *models.Session) error {
_, err := db.Exec(`
INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at)
VALUES (?, ?, ?, ?, ?)`,
session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt,
)
query, err := db.NewQuery().
InsertStruct(session, "sessions")
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
_, err = db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to store session: %w", err)
}
@@ -25,13 +26,18 @@ func (db *database) CreateSession(session *models.Session) error {
// GetSessionByRefreshToken retrieves a session by its refresh token
func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE refresh_token = ? AND expires_at > ?`,
refreshToken, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
query := db.NewQuery()
query, err := query.SelectStruct(session, "sessions")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("refresh_token = ").
Placeholder(refreshToken).
And("expires_at >").
Placeholder(time.Now())
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, session)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found or expired")
}
@@ -45,13 +51,18 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi
// GetSessionByID retrieves a session by its ID
func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
session := &models.Session{}
err := db.QueryRow(`
SELECT id, user_id, refresh_token, expires_at, created_at
FROM sessions
WHERE id = ? AND expires_at > ?`,
sessionID, time.Now(),
).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt)
query := db.NewQuery()
query, err := query.SelectStruct(session, "sessions")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("id = ").
Placeholder(sessionID).
And("expires_at >").
Placeholder(time.Now())
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, session)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("session not found")
}
@@ -64,7 +75,13 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) {
// DeleteSession removes a session from the database
func (db *database) DeleteSession(sessionID string) error {
result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
query := db.NewQuery().
Delete().
From("sessions").
Where("id = ").
Placeholder(sessionID)
result, err := db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
@@ -84,7 +101,12 @@ func (db *database) DeleteSession(sessionID string) error {
// CleanExpiredSessions removes all expired sessions from the database
func (db *database) CleanExpiredSessions() error {
log := getLogger().WithGroup("sessions")
result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now())
query := db.NewQuery().
Delete().
From("sessions").
Where("expires_at <=").
Placeholder(time.Now())
result, err := db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to clean expired sessions: %w", err)
}

View File

@@ -13,7 +13,7 @@ import (
)
func TestSessionOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

View File

@@ -0,0 +1,341 @@
package db
import (
"database/sql"
"fmt"
"reflect"
"sort"
"strings"
"unicode"
)
type DBField struct {
Name string
Value any
Type reflect.Type
OriginalName string
useDefault bool
encrypted bool
}
// StructTagsToFields converts a struct to a slice of DBField instances
func StructTagsToFields(s any) ([]DBField, error) {
v := reflect.ValueOf(s)
if v.Kind() == reflect.Ptr {
if v.IsNil() {
return nil, fmt.Errorf("nil pointer provided")
}
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("provided value is %s, expected struct", v.Kind())
}
t := v.Type()
fields := make([]DBField, 0, t.NumField())
for i := range t.NumField() {
f := t.Field(i)
if !f.IsExported() {
continue
}
tag := f.Tag.Get("db")
if tag == "-" {
continue
}
if tag == "" {
tag = toSnakeCase(f.Name)
}
useDefault := false
encrypted := false
ommit := false
if strings.Contains(tag, ",") {
parts := strings.Split(tag, ",")
tag = parts[0]
for _, opt := range parts[1:] {
switch opt {
case "omitempty":
if reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) {
ommit = true
}
case "default":
useDefault = true
case "encrypted":
encrypted = true
}
}
}
if ommit {
continue
}
fields = append(fields, DBField{
Name: tag,
Value: v.Field(i).Interface(),
Type: f.Type,
OriginalName: f.Name,
useDefault: useDefault,
encrypted: encrypted,
})
}
sort.Slice(fields, func(i, j int) bool {
return fields[i].Name < fields[j].Name
})
return fields, nil
}
func toSnakeCase(s string) string {
var res string
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
res += "_"
}
res += string(unicode.ToLower(r))
} else {
res += string(r)
}
}
return res
}
// InsertStruct creates an INSERT query from a struct
func (q *Query) InsertStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
columns := make([]string, 0, len(fields))
values := make([]any, 0, len(fields))
for _, f := range fields {
value := f.Value
if f.useDefault {
continue
}
if f.encrypted {
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
value = encValue
}
columns = append(columns, f.Name)
values = append(values, value)
}
if len(columns) == 0 {
return nil, fmt.Errorf("no columns to insert")
}
q.Insert(table, columns...).Values(len(columns)).AddArgs(values...)
return q, nil
}
// UpdateStruct creates an UPDATE query from a struct
func (q *Query) UpdateStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
q = q.Update(table)
for _, f := range fields {
value := f.Value
if f.useDefault {
continue
}
if f.encrypted {
encValue, err := q.secretsService.Encrypt(value.(string))
if err != nil {
return nil, err
}
value = encValue
}
q = q.Set(f.Name).Placeholder(value)
}
return q, nil
}
// SelectStruct creates a SELECT query from a struct
func (q *Query) SelectStruct(s any, table string) (*Query, error) {
fields, err := StructTagsToFields(s)
if err != nil {
return nil, err
}
columns := make([]string, 0, len(fields))
for _, f := range fields {
columns = append(columns, f.Name)
}
q = q.Select(columns...).From(table)
return q, nil
}
// Scanner is an interface that both sql.Row and sql.Rows satisfy
type Scanner interface {
Scan(dest ...any) error
}
// scanStructInstance is an internal function that handles the scanning logic for a single instance
func (db *database) scanStructInstance(destVal reflect.Value, scanner Scanner) error {
fields, err := StructTagsToFields(destVal.Interface())
if err != nil {
return fmt.Errorf("failed to extract struct fields: %w", err)
}
scanDest := make([]any, len(fields))
var fieldsToDecrypt []string
nullStringIndexes := make(map[int]reflect.Value)
for i, field := range fields {
// Find the field in the struct
structField := destVal.FieldByName(field.OriginalName)
if !structField.IsValid() {
return fmt.Errorf("struct field %s not found", field.OriginalName)
}
if field.encrypted {
fieldsToDecrypt = append(fieldsToDecrypt, field.OriginalName)
}
if structField.Kind() == reflect.String {
// Handle null strings separately
nullStringIndexes[i] = structField
var ns sql.NullString
scanDest[i] = &ns
} else {
scanDest[i] = structField.Addr().Interface()
}
}
// Scan using the scanner interface
if err := scanner.Scan(scanDest...); err != nil {
return err
}
// Set null strings to their values if they are valid
for i, field := range nullStringIndexes {
ns := scanDest[i].(*sql.NullString)
if ns.Valid {
field.SetString(ns.String)
}
}
// Decrypt encrypted fields
for _, fieldName := range fieldsToDecrypt {
field := destVal.FieldByName(fieldName)
if !field.IsZero() {
decValue, err := db.secretsService.Decrypt(field.Interface().(string))
if err != nil {
return err
}
field.SetString(decValue)
}
}
return nil
}
// ScanStruct scans a single row into a struct
func (db *database) ScanStruct(row *sql.Row, dest any) error {
if row == nil {
return fmt.Errorf("row cannot be nil")
}
if row.Err() != nil {
return row.Err()
}
// Get the destination value
destVal := reflect.ValueOf(dest)
if destVal.Kind() != reflect.Ptr || destVal.IsNil() {
return fmt.Errorf("destination must be a non-nil pointer to a struct, got %T", dest)
}
destVal = destVal.Elem()
if destVal.Kind() != reflect.Struct {
return fmt.Errorf("destination must be a pointer to a struct, got pointer to %s", destVal.Kind())
}
return db.scanStructInstance(destVal, row)
}
// ScanStructs scans multiple rows into a slice of structs
func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error {
if rows == nil {
return fmt.Errorf("rows cannot be nil")
}
// Get the slice value and element type
sliceVal := reflect.ValueOf(destSlice)
if sliceVal.Kind() != reflect.Ptr || sliceVal.IsNil() {
return fmt.Errorf("destination must be a non-nil pointer to a slice, got %T", destSlice)
}
sliceVal = sliceVal.Elem()
if sliceVal.Kind() != reflect.Slice {
return fmt.Errorf("destination must be a pointer to a slice, got pointer to %s", sliceVal.Kind())
}
// Get the element type of the slice
elemType := sliceVal.Type().Elem()
// Check if we have a direct struct type or a pointer to struct
isPtr := elemType.Kind() == reflect.Ptr
structType := elemType
if isPtr {
structType = elemType.Elem()
}
if structType.Kind() != reflect.Struct {
return fmt.Errorf("slice element type must be a struct or pointer to struct, got %s", elemType.String())
}
// Process each row
for rows.Next() {
// Create a new instance of the struct for each row
newElem := reflect.New(structType).Elem()
// Scan this row into the new element
if err := db.scanStructInstance(newElem, rows); err != nil {
return err
}
// Add the new element to the result slice
if isPtr {
sliceVal.Set(reflect.Append(sliceVal, newElem.Addr()))
} else {
sliceVal.Set(reflect.Append(sliceVal, newElem))
}
}
// Check for errors from iterating over rows
if err := rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,507 @@
package db_test
import (
"reflect"
"testing"
"time"
"lemma/internal/db"
"lemma/internal/models"
_ "lemma/internal/testenv"
)
// TestStructTagsToFields tests the exported StructTagsToFields function
func TestStructTagsToFields(t *testing.T) {
type testStruct struct {
ID int `db:"id"`
Name string `db:"custom_name"`
CreatedAt time.Time `db:"created_at,default"`
Skip string `db:"-"`
Empty string `db:"empty,omitempty"`
Secret string `db:"secret,encrypted"`
NoTag string
}
tests := []struct {
name string
input interface{}
wantFields int
wantErr bool
}{
{
name: "valid struct",
input: testStruct{
ID: 1,
Name: "Test",
CreatedAt: time.Now(),
Skip: "skip me",
Secret: "secret value",
NoTag: "no tag",
},
wantFields: 5, // ID, Name, CreatedAt, Secret, NoTag (Empty is omitted)
wantErr: false,
},
{
name: "nil pointer",
input: (*testStruct)(nil),
wantFields: 0,
wantErr: true,
},
{
name: "non-struct",
input: "not a struct",
wantFields: 0,
wantErr: true,
},
{
name: "struct pointer",
input: &testStruct{
ID: 2,
Name: "Test Pointer",
},
wantFields: 5, // Same fields as above
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fields, err := db.StructTagsToFields(tt.input)
if tt.wantErr {
if err == nil {
t.Error("Expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if len(fields) != tt.wantFields {
t.Errorf("Expected %d fields, got %d", tt.wantFields, len(fields))
}
// Check specific field handling for valid struct test
if tt.name == "valid struct" {
// Find fields by name
var idField, nameField, createdAtField, secretField, emptyField, noTagField *db.DBField
for i := range fields {
f := &fields[i]
switch f.Name {
case "id":
idField = f
case "custom_name":
nameField = f
case "created_at":
createdAtField = f
case "secret":
secretField = f
case "empty":
emptyField = f
case "no_tag":
noTagField = f
}
}
// Check fields exist
if idField == nil {
t.Error("ID field not found")
}
if nameField == nil {
t.Error("Name field not found")
}
if createdAtField == nil {
t.Error("CreatedAt field not found")
}
if secretField == nil {
t.Error("Secret field not found")
}
if noTagField == nil {
t.Error("NoTag field not found")
}
if emptyField != nil {
t.Error("Empty field should be omitted")
}
// Check original names
if idField != nil && idField.OriginalName != "ID" {
t.Errorf("Expected OriginalName 'ID', got '%s'", idField.OriginalName)
}
if nameField != nil && nameField.OriginalName != "Name" {
t.Errorf("Expected OriginalName 'Name', got '%s'", nameField.OriginalName)
}
}
})
}
}
// TestStructQueries tests the struct-based query methods using the test database
func TestStructQueries(t *testing.T) {
// Setup test database
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
// Define test data
user := &models.User{
Email: "structquery@example.com",
DisplayName: "Struct Query Test",
PasswordHash: "hashed_password",
Role: models.RoleEditor,
}
t.Run("InsertStructQuery", func(t *testing.T) {
// Insert user with struct query
createdUser, err := database.CreateUser(user)
if err != nil {
t.Fatalf("Failed to create user with struct query: %v", err)
}
// Verify user was created with proper values
if createdUser.ID == 0 {
t.Error("Expected non-zero user ID")
}
if createdUser.Email != user.Email {
t.Errorf("Email = %v, want %v", createdUser.Email, user.Email)
}
if createdUser.DisplayName != user.DisplayName {
t.Errorf("DisplayName = %v, want %v", createdUser.DisplayName, user.DisplayName)
}
if createdUser.Role != user.Role {
t.Errorf("Role = %v, want %v", createdUser.Role, user.Role)
}
// We will use this user for the next test cases
user = createdUser
})
t.Run("SelectStructQuery", func(t *testing.T) {
// Get the created user
fetchedUser, err := database.GetUserByID(user.ID)
if err != nil {
t.Fatalf("Failed to get user with struct query: %v", err)
}
// Verify fetched user matches the original
if fetchedUser.ID != user.ID {
t.Errorf("ID = %v, want %v", fetchedUser.ID, user.ID)
}
if fetchedUser.Email != user.Email {
t.Errorf("Email = %v, want %v", fetchedUser.Email, user.Email)
}
if fetchedUser.DisplayName != user.DisplayName {
t.Errorf("DisplayName = %v, want %v", fetchedUser.DisplayName, user.DisplayName)
}
if fetchedUser.Role != user.Role {
t.Errorf("Role = %v, want %v", fetchedUser.Role, user.Role)
}
})
t.Run("UpdateStructQuery", func(t *testing.T) {
// Update the user
user.DisplayName = "Updated Display Name"
user.Role = models.RoleAdmin
err := database.UpdateUser(user)
if err != nil {
t.Fatalf("Failed to update user with struct query: %v", err)
}
// Verify update worked
updatedUser, err := database.GetUserByID(user.ID)
if err != nil {
t.Fatalf("Failed to get updated user: %v", err)
}
if updatedUser.DisplayName != "Updated Display Name" {
t.Errorf("DisplayName = %v, want %v", updatedUser.DisplayName, "Updated Display Name")
}
if updatedUser.Role != models.RoleAdmin {
t.Errorf("Role = %v, want %v", updatedUser.Role, models.RoleAdmin)
}
})
t.Run("ScanStructs", func(t *testing.T) {
// Create another user to test multiple rows
secondUser := &models.User{
Email: "structquery2@example.com",
DisplayName: "Struct Query Test 2",
PasswordHash: "hashed_password2",
Role: models.RoleViewer,
}
createdUser2, err := database.CreateUser(secondUser)
if err != nil {
t.Fatalf("Failed to create second user: %v", err)
}
// Get all users
users, err := database.GetAllUsers()
if err != nil {
t.Fatalf("Failed to get all users: %v", err)
}
// Verify we have at least the two users we created
if len(users) < 2 {
t.Errorf("Expected at least 2 users, got %d", len(users))
}
// Check if both our users are in the result
foundUser1 := false
foundUser2 := false
for _, u := range users {
if u.ID == user.ID {
foundUser1 = true
if u.DisplayName != user.DisplayName {
t.Errorf("DisplayName = %v, want %v", u.DisplayName, user.DisplayName)
}
}
if u.ID == createdUser2.ID {
foundUser2 = true
if u.DisplayName != secondUser.DisplayName {
t.Errorf("DisplayName = %v, want %v", u.DisplayName, secondUser.DisplayName)
}
}
}
if !foundUser1 {
t.Errorf("First user (ID: %d) not found in results", user.ID)
}
if !foundUser2 {
t.Errorf("Second user (ID: %d) not found in results", createdUser2.ID)
}
})
t.Run("ScanStruct with null values", func(t *testing.T) {
// Test handling of NULL values by creating a workspace with null values
workspace := &models.Workspace{
UserID: user.ID,
Name: "Null Test Workspace",
// Leave all optional fields as zero values
}
workspace.SetDefaultSettings() // This will set default values
err := database.CreateWorkspace(workspace)
if err != nil {
t.Fatalf("Failed to create test workspace: %v", err)
}
// Clear the GitToken to test NULL handling
testDB := database.TestDB()
_, err = testDB.Exec("UPDATE workspaces SET git_token = NULL WHERE id = ?", workspace.ID)
if err != nil {
t.Fatalf("Failed to set git_token to NULL: %v", err)
}
// Fetch the workspace with NULL field
fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID)
if err != nil {
t.Fatalf("Failed to get workspace with NULL field: %v", err)
}
// Verify the NULL field is empty
if fetchedWorkspace.GitToken != "" {
t.Errorf("Expected empty GitToken, got '%s'", fetchedWorkspace.GitToken)
}
})
t.Run("ScanStructErrors", func(t *testing.T) {
// Test error handling in ScanStruct
testDB := database.TestDB()
// Attempt to scan too many columns into a struct with fewer fields
row := testDB.QueryRow("SELECT 1, 2, 3")
var singleField struct {
One int `db:"one"`
}
err := database.ScanStruct(row, &singleField)
if err == nil {
t.Error("Expected error when scanning too many columns, got nil")
}
// Test scanning into a non-struct
var notAStruct int
row = testDB.QueryRow("SELECT 1")
err = database.ScanStruct(row, &notAStruct)
if err == nil {
t.Error("Expected error when scanning into non-struct, got nil")
}
// Test scanning into nil
var nilPtr *struct{}
row = testDB.QueryRow("SELECT 1")
err = database.ScanStruct(row, nilPtr)
if err == nil {
t.Error("Expected error when scanning into nil pointer, got nil")
}
})
}
// TestScanStructsErrors tests error handling for ScanStructs
func TestScanStructsErrors(t *testing.T) {
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
testDB := database.TestDB()
t.Run("ScanStructsWithNilRows", func(t *testing.T) {
var users []*models.User
err := database.ScanStructs(nil, &users)
if err == nil {
t.Error("Expected error with nil rows, got nil")
}
})
t.Run("ScanStructsWithNilDest", func(t *testing.T) {
rows, err := testDB.Query("SELECT 1")
if err != nil {
t.Fatalf("Failed to execute query: %v", err)
}
defer rows.Close()
var nilSlice *[]*models.User
err = database.ScanStructs(rows, nilSlice)
if err == nil {
t.Error("Expected error with nil destination, got nil")
}
})
t.Run("ScanStructsWithNonSlice", func(t *testing.T) {
rows, err := testDB.Query("SELECT 1")
if err != nil {
t.Fatalf("Failed to execute query: %v", err)
}
defer rows.Close()
var nonSlice int
err = database.ScanStructs(rows, &nonSlice)
if err == nil {
t.Error("Expected error with non-slice destination, got nil")
}
})
t.Run("ScanStructsWithNonStructSlice", func(t *testing.T) {
rows, err := testDB.Query("SELECT 1")
if err != nil {
t.Fatalf("Failed to execute query: %v", err)
}
defer rows.Close()
var intSlice []int
err = database.ScanStructs(rows, &intSlice)
if err == nil {
t.Error("Expected error with non-struct slice, got nil")
}
})
}
// TestEncryptedFields tests handling of encrypted fields
func TestEncryptedFields(t *testing.T) {
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("Failed to create test database: %v", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
}
// Create user with workspace that has encrypted token
user, err := database.CreateUser(&models.User{
Email: "encrypted@example.com",
DisplayName: "Encryption Test",
PasswordHash: "hash",
Role: models.RoleEditor,
})
if err != nil {
t.Fatalf("Failed to create test user: %v", err)
}
// Create workspace with encrypted field
workspace := &models.Workspace{
UserID: user.ID,
Name: "Encryption Test",
Theme: "dark",
GitEnabled: true,
GitURL: "https://github.com/user/repo",
GitUser: "username",
GitToken: "secret-token", // This field is encrypted
GitCommitName: "Test User",
GitCommitEmail: "test@example.com",
}
if err := database.CreateWorkspace(workspace); err != nil {
t.Fatalf("Failed to create test workspace: %v", err)
}
// Verify our mock secrets service passed the token through unmodified
// In a real application, the token would be encrypted in the database
testDB := database.TestDB()
var rawToken string
err = testDB.QueryRow("SELECT git_token FROM workspaces WHERE id = ?", workspace.ID).Scan(&rawToken)
if err != nil {
t.Fatalf("Failed to query raw token: %v", err)
}
// With the mock secrets service, encryption is a no-op so the token is stored as-is
if rawToken != "secret-token" {
t.Errorf("Expected raw token 'secret-token', got '%s'", rawToken)
}
// Verify the fetched workspace has the correct token
fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID)
if err != nil {
t.Fatalf("Failed to get workspace: %v", err)
}
if fetchedWorkspace.GitToken != "secret-token" {
t.Errorf("Expected GitToken 'secret-token', got '%s'", fetchedWorkspace.GitToken)
}
}
// Helper function to compare slices of DBFields
func compareDBFields(t *testing.T, got, want []db.DBField) {
t.Helper()
if len(got) != len(want) {
t.Errorf("Got %d fields, want %d", len(got), len(want))
return
}
for i := range got {
if got[i].Name != want[i].Name {
t.Errorf("Field %d name: got %s, want %s", i, got[i].Name, want[i].Name)
}
if got[i].OriginalName != want[i].OriginalName {
t.Errorf("Field %d original name: got %s, want %s", i, got[i].OriginalName, want[i].OriginalName)
}
if !reflect.DeepEqual(got[i].Value, want[i].Value) {
t.Errorf("Field %d value: got %v, want %v", i, got[i].Value, want[i].Value)
}
}
}

View File

@@ -49,7 +49,12 @@ func (db *database) EnsureJWTSecret() (string, error) {
// GetSystemSetting retrieves a system setting by key
func (db *database) GetSystemSetting(key string) (string, error) {
var value string
err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value)
query := db.NewQuery().
Select("value").
From("system_settings").
Where("key = ").
Placeholder(key)
err := db.QueryRow(query.String(), query.args...).Scan(&value)
if err != nil {
return "", err
}
@@ -59,11 +64,14 @@ func (db *database) GetSystemSetting(key string) (string, error) {
// SetSystemSetting stores or updates a system setting
func (db *database) SetSystemSetting(key, value string) error {
_, err := db.Exec(`
INSERT INTO system_settings (key, value)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE SET value = ?`,
key, value, value)
query := db.NewQuery().
Insert("system_settings", "key", "value").
Values(2).
AddArgs(key, value).
Write("ON CONFLICT(key) DO UPDATE SET value = ").
Placeholder(value)
_, err := db.Exec(query.String(), query.args...)
if err != nil {
return fmt.Errorf("failed to store system setting: %w", err)
@@ -92,22 +100,30 @@ func (db *database) GetSystemStats() (*UserStats, error) {
stats := &UserStats{}
// Get total users
err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers)
query := db.NewQuery().
Select("COUNT(*)").
From("users")
err := db.QueryRow(query.String()).Scan(&stats.TotalUsers)
if err != nil {
return nil, fmt.Errorf("failed to get total users count: %w", err)
}
// Get total workspaces
err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces)
query = db.NewQuery().
Select("COUNT(*)").
From("workspaces")
err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces)
if err != nil {
return nil, fmt.Errorf("failed to get total workspaces count: %w", err)
}
// Get active users (users with activity in last 30 days)
err = db.QueryRow(`
SELECT COUNT(DISTINCT user_id)
FROM sessions
WHERE created_at > datetime('now', '-30 days')`).
query = db.NewQuery().
Select("COUNT(DISTINCT user_id)").
From("sessions").
Where("created_at >").
TimeSince(30)
err = db.QueryRow(query.String()).
Scan(&stats.ActiveUsers)
if err != nil {
return nil, fmt.Errorf("failed to get active users count: %w", err)

View File

@@ -15,7 +15,7 @@ import (
)
func TestSystemOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

View File

@@ -4,7 +4,11 @@ package db
import (
"database/sql"
"fmt"
"lemma/internal/secrets"
"log"
"strings"
"time"
)
type TestDatabase interface {
@@ -12,19 +16,102 @@ type TestDatabase interface {
TestDB() *sql.DB
}
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) {
db, err := Init(dbPath, secretsService)
func NewTestSQLiteDB(secretsService secrets.Service) (TestDatabase, error) {
db, err := Init(DBTypeSQLite, ":memory:", secretsService)
if err != nil {
return nil, err
}
return &testDatabase{db.(*database)}, nil
return &testSQLiteDatabase{db.(*database)}, nil
}
type testDatabase struct {
type testSQLiteDatabase struct {
*database
}
func (td *testDatabase) TestDB() *sql.DB {
func (td *testSQLiteDatabase) TestDB() *sql.DB {
return td.DB
}
// NewPostgresTestDB creates a test database using PostgreSQL
func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, error) {
if dbURL == "" {
return nil, fmt.Errorf("postgres URL cannot be empty")
}
initialDB, err := sql.Open("postgres", dbURL)
if err != nil {
return nil, fmt.Errorf("failed to open postgres database: %w", err)
}
if err := initialDB.Ping(); err != nil {
initialDB.Close()
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
}
// Create a unique schema name for this test run to avoid conflicts
schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano())
_, err = initialDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName))
if err != nil {
initialDB.Close()
return nil, fmt.Errorf("failed to create schema: %w", err)
}
// Close the initial connection and create a new one with the schema set
initialDB.Close()
var newDBURL string
if strings.Contains(dbURL, "?") {
// URL already has parameters
newDBURL = fmt.Sprintf("%s&search_path=%s", dbURL, schemaName)
} else {
// URL has no parameters yet
newDBURL = fmt.Sprintf("%s?search_path=%s", dbURL, schemaName)
}
db, err := sql.Open("postgres", newDBURL)
if err != nil {
return nil, fmt.Errorf("failed to open postgres database: %w", err)
}
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to ping postgres database: %w", err)
}
// Set search path to use our schema
_, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName))
if err != nil {
db.Close()
return nil, fmt.Errorf("failed to set search path: %w", err)
}
// Create database instance
database := &postgresTestDatabase{
database: &database{DB: db, secretsService: secretsSvc, dbType: DBTypePostgres},
schemaName: schemaName,
}
return database, nil
}
// postgresTestDatabase extends the regular postgres database to add test-specific cleanup
type postgresTestDatabase struct {
*database
schemaName string
}
// Close closes the database connection and drops the test schema
func (db *postgresTestDatabase) Close() error {
_, err := db.TestDB().Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", db.schemaName))
if err != nil {
log.Printf("Failed to drop schema %s: %v", db.schemaName, err)
}
return db.TestDB().Close()
}
// TestDB returns the underlying *sql.DB instance
func (db *postgresTestDatabase) TestDB() *sql.DB {
return db.DB
}

View File

@@ -17,26 +17,21 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
}
defer tx.Rollback()
result, err := tx.Exec(`
INSERT INTO users (email, display_name, password_hash, role)
VALUES (?, ?, ?, ?)`,
user.Email, user.DisplayName, user.PasswordHash, user.Role)
query, err := db.NewQuery().
InsertStruct(user, "users")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query.Returning("id", "created_at")
err = tx.QueryRow(query.String(), query.Args()...).
Scan(&user.ID, &user.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to insert user: %w", err)
}
userID, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("failed to get last insert ID: %w", err)
}
user.ID = int(userID)
// Retrieve the created_at timestamp
err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt)
if err != nil {
return nil, fmt.Errorf("failed to get created timestamp: %w", err)
}
// Create default workspace with default settings
defaultWorkspace := &models.Workspace{
UserID: user.ID,
@@ -51,7 +46,13 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
}
// Update user's last workspace ID
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID)
query = db.NewQuery().
Update("users").
Set("last_workspace_id").
Placeholder(defaultWorkspace.ID).
Where("id = ").
Placeholder(user.ID)
_, err = tx.Exec(query.String(), query.Args()...)
if err != nil {
return nil, fmt.Errorf("failed to update last workspace ID: %w", err)
}
@@ -70,47 +71,39 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) {
// Helper function to create a workspace in a transaction
func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error {
log := getLogger().WithGroup("users")
result, err := tx.Exec(`
INSERT INTO workspaces (
user_id, name,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name,
workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken,
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate,
workspace.GitCommitName, workspace.GitCommitEmail,
)
insertQuery, err := db.NewQuery().
InsertStruct(workspace, "workspaces")
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
insertQuery.Returning("id")
err = tx.QueryRow(insertQuery.String(), insertQuery.Args()...).Scan(&workspace.ID)
if err != nil {
return fmt.Errorf("failed to insert workspace: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get workspace ID: %w", err)
}
workspace.ID = int(id)
log.Debug("created user workspace",
"workspace_id", workspace.ID,
"user_id", workspace.UserID)
return nil
}
// GetUserByID retrieves a user by its ID
func (db *database) GetUserByID(id int) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
id, email, display_name, password_hash, role, created_at,
last_workspace_id
FROM users
WHERE id = ?`, id).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
query := db.NewQuery()
query, err := query.SelectStruct(user, "users")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("id = ").Placeholder(id)
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, user)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
}
@@ -120,16 +113,18 @@ func (db *database) GetUserByID(id int) (*models.User, error) {
return user, nil
}
// GetUserByEmail retrieves a user by its email
func (db *database) GetUserByEmail(email string) (*models.User, error) {
user := &models.User{}
err := db.QueryRow(`
SELECT
id, email, display_name, password_hash, role, created_at,
last_workspace_id
FROM users
WHERE email = ?`, email).
Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash,
&user.Role, &user.CreatedAt, &user.LastWorkspaceID)
query := db.NewQuery()
query, err := query.SelectStruct(user, "users")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("email = ").Placeholder(email)
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, user)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("user not found")
@@ -141,14 +136,16 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) {
return user, nil
}
// UpdateUser updates an existing user record in the database
func (db *database) UpdateUser(user *models.User) error {
result, err := db.Exec(`
UPDATE users
SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ?
WHERE id = ?`,
user.Email, user.DisplayName, user.PasswordHash, user.Role,
user.LastWorkspaceID, user.ID)
query := db.NewQuery()
query, err := query.UpdateStruct(user, "users")
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("id = ").Placeholder(user.ID)
result, err := db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to update user: %w", err)
}
@@ -165,29 +162,25 @@ func (db *database) UpdateUser(user *models.User) error {
return nil
}
// GetAllUsers retrieves all users from the database
func (db *database) GetAllUsers() ([]*models.User, error) {
rows, err := db.Query(`
SELECT
id, email, display_name, role, created_at,
last_workspace_id
FROM users
ORDER BY id ASC`)
query := db.NewQuery()
query, err := query.SelectStruct(&models.User{}, "users")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.OrderBy("id ASC")
rows, err := db.Query(query.String(), query.Args()...)
if err != nil {
return nil, fmt.Errorf("failed to query users: %w", err)
}
defer rows.Close()
var users []*models.User
for rows.Next() {
user := &models.User{}
err := rows.Scan(
&user.ID, &user.Email, &user.DisplayName, &user.Role,
&user.CreatedAt, &user.LastWorkspaceID,
)
users := []*models.User{}
err = db.ScanStructs(rows, &users)
if err != nil {
return nil, fmt.Errorf("failed to scan user row: %w", err)
}
users = append(users, user)
return nil, fmt.Errorf("failed to scan users: %w", err)
}
return users, nil
@@ -200,15 +193,26 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
}
defer tx.Rollback()
// Find workspace ID from name
workspaceQuery := db.NewQuery().
Select("id").
From("workspaces").
Where("user_id = ").Placeholder(userID).
And("name = ").Placeholder(workspaceName)
var workspaceID int
err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?",
userID, workspaceName).Scan(&workspaceID)
err = tx.QueryRow(workspaceQuery.String(), workspaceQuery.Args()...).Scan(&workspaceID)
if err != nil {
return fmt.Errorf("failed to find workspace: %w", err)
}
_, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
workspaceID, userID)
// Update user's last workspace
updateQuery := db.NewQuery().
Update("users").
Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID)
_, err = tx.Exec(updateQuery.String(), updateQuery.Args()...)
if err != nil {
return fmt.Errorf("failed to update last workspace: %w", err)
}
@@ -221,6 +225,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error
return nil
}
// DeleteUser deletes a user and all their workspaces
func (db *database) DeleteUser(id int) error {
log := getLogger().WithGroup("users")
log.Debug("deleting user", "user_id", id)
@@ -233,13 +238,24 @@ func (db *database) DeleteUser(id int) error {
// Delete all user's workspaces first
log.Debug("deleting user workspaces", "user_id", id)
_, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id)
deleteWorkspacesQuery := db.NewQuery().
Delete().
From("workspaces").
Where("user_id = ").Placeholder(id)
_, err = tx.Exec(deleteWorkspacesQuery.String(), deleteWorkspacesQuery.Args()...)
if err != nil {
return fmt.Errorf("failed to delete workspaces: %w", err)
}
// Delete the user
_, err = tx.Exec("DELETE FROM users WHERE id = ?", id)
deleteUserQuery := db.NewQuery().
Delete().
From("users").
Where("id = ").Placeholder(id)
_, err = tx.Exec(deleteUserQuery.String(), deleteUserQuery.Args()...)
if err != nil {
return fmt.Errorf("failed to delete user: %w", err)
}
@@ -253,15 +269,16 @@ func (db *database) DeleteUser(id int) error {
return nil
}
// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user
func (db *database) GetLastWorkspaceName(userID int) (string, error) {
query := db.NewQuery().
Select("w.name").
From("workspaces w").
Join(InnerJoin, "users u", "u.last_workspace_id = w.id").
Where("u.id = ").Placeholder(userID)
var workspaceName string
err := db.QueryRow(`
SELECT
w.name
FROM workspaces w
JOIN users u ON u.last_workspace_id = w.id
WHERE u.id = ?`, userID).
Scan(&workspaceName)
err := db.QueryRow(query.String(), query.Args()...).Scan(&workspaceName)
if err == sql.ErrNoRows {
return "", fmt.Errorf("no last workspace found")
@@ -275,8 +292,13 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) {
// CountAdminUsers returns the number of admin users in the system
func (db *database) CountAdminUsers() (int, error) {
query := db.NewQuery().
Select("COUNT(*)").
From("users").
Where("role = ").Placeholder(models.RoleAdmin)
var count int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count)
err := db.QueryRow(query.String(), query.Args()...).Scan(&count)
if err != nil {
return 0, fmt.Errorf("failed to count admin users: %w", err)
}

View File

@@ -10,7 +10,7 @@ import (
)
func TestUserOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

View File

@@ -19,58 +19,36 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error {
workspace.SetDefaultSettings()
}
// Encrypt token if present
encryptedToken, err := db.encryptToken(workspace.GitToken)
query, err := db.NewQuery().
InsertStruct(workspace, "workspaces")
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
return fmt.Errorf("failed to create query: %w", err)
}
result, err := db.Exec(`
INSERT INTO workspaces (
user_id, name, theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles,
workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken,
workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail,
)
query.Returning("id", "created_at")
err = db.QueryRow(query.String(), query.Args()...).
Scan(&workspace.ID, &workspace.CreatedAt)
if err != nil {
return fmt.Errorf("failed to insert workspace: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("failed to get workspace ID: %w", err)
}
workspace.ID = int(id)
return nil
}
// GetWorkspaceByID retrieves a workspace by its ID
func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
query := db.NewQuery()
query, err := query.SelectStruct(workspace, "workspaces")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("id = ").Placeholder(id)
err := db.QueryRow(`
SELECT
id, user_id, name, created_at,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
FROM workspaces
WHERE id = ?`,
id,
).Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, workspace)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
@@ -79,37 +57,22 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) {
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
}
// Decrypt token
workspace.GitToken, err = db.decryptToken(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return workspace, nil
}
// GetWorkspaceByName retrieves a workspace by its name and user ID
func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) {
workspace := &models.Workspace{}
var encryptedToken string
query := db.NewQuery()
query, err := query.SelectStruct(workspace, "workspaces")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("user_id = ").Placeholder(userID).
And("name = ").Placeholder(workspaceName)
err := db.QueryRow(`
SELECT
id, user_id, name, created_at,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
FROM workspaces
WHERE user_id = ? AND name = ?`,
userID, workspaceName,
).Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
row := db.QueryRow(query.String(), query.Args()...)
err = db.ScanStruct(row, workspace)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("workspace not found")
@@ -118,54 +81,22 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model
return nil, fmt.Errorf("failed to fetch workspace: %w", err)
}
// Decrypt token
workspace.GitToken, err = db.decryptToken(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
return workspace, nil
}
// UpdateWorkspace updates a workspace record in the database
func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// Encrypt token before storing
encryptedToken, err := db.encryptToken(workspace.GitToken)
query := db.NewQuery()
query, err := query.
UpdateStruct(workspace, "workspaces")
query = query.Where("id =").Placeholder(workspace.ID).And("user_id =").Placeholder(workspace.UserID)
if err != nil {
return fmt.Errorf("failed to encrypt token: %w", err)
return fmt.Errorf("failed to create query: %w", err)
}
_, err = db.Exec(`
UPDATE workspaces
SET
name = ?,
theme = ?,
auto_save = ?,
show_hidden_files = ?,
git_enabled = ?,
git_url = ?,
git_user = ?,
git_token = ?,
git_auto_commit = ?,
git_commit_msg_template = ?,
git_commit_name = ?,
git_commit_email = ?
WHERE id = ? AND user_id = ?`,
workspace.Name,
workspace.Theme,
workspace.AutoSave,
workspace.ShowHiddenFiles,
workspace.GitEnabled,
workspace.GitURL,
workspace.GitUser,
encryptedToken,
workspace.GitAutoCommit,
workspace.GitCommitMsgTemplate,
workspace.GitCommitName,
workspace.GitCommitEmail,
workspace.ID,
workspace.UserID,
)
_, err = db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to update workspace: %w", err)
}
@@ -175,48 +106,24 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error {
// GetWorkspacesByUserID retrieves all workspaces for a user
func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
FROM workspaces
WHERE user_id = ?`,
userID,
)
workspace := &models.Workspace{}
query := db.NewQuery()
query, err := query.SelectStruct(workspace, "workspaces")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
query = query.Where("user_id = ").Placeholder(userID)
rows, err := db.Query(query.String(), query.Args()...)
if err != nil {
return nil, fmt.Errorf("failed to query workspaces: %w", err)
}
defer rows.Close()
var workspaces []*models.Workspace
for rows.Next() {
workspace := &models.Workspace{}
var encryptedToken string
err := rows.Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
err = db.ScanStructs(rows, &workspaces)
if err != nil {
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
}
// Decrypt token
workspace.GitToken, err = db.decryptToken(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
workspaces = append(workspaces, workspace)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
return nil, fmt.Errorf("failed to scan workspaces: %w", err)
}
return workspaces, nil
@@ -224,34 +131,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro
// UpdateWorkspaceSettings updates only the settings portion of a workspace
func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
_, err := db.Exec(`
UPDATE workspaces
SET
theme = ?,
auto_save = ?,
show_hidden_files = ?,
git_enabled = ?,
git_url = ?,
git_user = ?,
git_token = ?,
git_auto_commit = ?,
git_commit_msg_template = ?,
git_commit_name = ?,
git_commit_email = ?
WHERE id = ?`,
workspace.Theme,
workspace.AutoSave,
workspace.ShowHiddenFiles,
workspace.GitEnabled,
workspace.GitURL,
workspace.GitUser,
workspace.GitToken,
workspace.GitAutoCommit,
workspace.GitCommitMsgTemplate,
workspace.GitCommitName,
workspace.GitCommitEmail,
workspace.ID,
)
query := db.NewQuery()
query, err := query.
UpdateStruct(workspace, "workspaces")
query = query.Where("id =").Placeholder(workspace.ID)
if err != nil {
return fmt.Errorf("failed to create query: %w", err)
}
_, err = db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to update workspace settings: %w", err)
}
@@ -263,7 +153,12 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error {
func (db *database) DeleteWorkspace(id int) error {
log := getLogger().WithGroup("workspaces")
_, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id)
query := db.NewQuery().
Delete().
From("workspaces").
Where("id = ").Placeholder(id)
_, err := db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to delete workspace: %w", err)
}
@@ -275,7 +170,13 @@ func (db *database) DeleteWorkspace(id int) error {
// DeleteWorkspaceTx removes a workspace record from the database within a transaction
func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
log := getLogger().WithGroup("workspaces")
result, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id)
query := db.NewQuery().
Delete().
From("workspaces").
Where("id = ").Placeholder(id)
result, err := tx.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to delete workspace in transaction: %w", err)
}
@@ -285,15 +186,18 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error {
return fmt.Errorf("failed to get rows affected in transaction: %w", err)
}
log.Debug("workspace deleted",
"workspace_id", id)
log.Debug("workspace deleted", "workspace_id", id)
return nil
}
// UpdateLastWorkspaceTx sets the last workspace for a user in a transaction
func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error {
result, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?",
workspaceID, userID)
query := db.NewQuery().
Update("users").
Set("last_workspace_id").Placeholder(workspaceID).
Where("id = ").Placeholder(userID)
result, err := tx.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to update last workspace in transaction: %w", err)
}
@@ -308,8 +212,12 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e
// UpdateLastOpenedFile updates the last opened file path for a workspace
func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error {
_, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?",
filePath, workspaceID)
query := db.NewQuery().
Update("workspaces").
Set("last_opened_file_path").Placeholder(filePath).
Where("id = ").Placeholder(workspaceID)
_, err := db.Exec(query.String(), query.Args()...)
if err != nil {
return fmt.Errorf("failed to update last opened file: %w", err)
}
@@ -319,9 +227,13 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error
// GetLastOpenedFile retrieves the last opened file path for a workspace
func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
query := db.NewQuery().
Select("last_opened_file_path").
From("workspaces").
Where("id = ").Placeholder(workspaceID)
var filePath sql.NullString
err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?",
workspaceID).Scan(&filePath)
err := db.QueryRow(query.String(), query.Args()...).Scan(&filePath)
if err == sql.ErrNoRows {
return "", fmt.Errorf("workspace not found")
@@ -339,46 +251,22 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) {
// GetAllWorkspaces retrieves all workspaces in the database
func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) {
rows, err := db.Query(`
SELECT
id, user_id, name, created_at,
theme, auto_save, show_hidden_files,
git_enabled, git_url, git_user, git_token,
git_auto_commit, git_commit_msg_template,
git_commit_name, git_commit_email
FROM workspaces`,
)
query := db.NewQuery()
query, err := query.SelectStruct(&models.Workspace{}, "workspaces")
if err != nil {
return nil, fmt.Errorf("failed to create query: %w", err)
}
rows, err := db.Query(query.String(), query.Args()...)
if err != nil {
return nil, fmt.Errorf("failed to query workspaces: %w", err)
}
defer rows.Close()
var workspaces []*models.Workspace
for rows.Next() {
workspace := &models.Workspace{}
var encryptedToken string
err := rows.Scan(
&workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt,
&workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles,
&workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken,
&workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate,
&workspace.GitCommitName, &workspace.GitCommitEmail,
)
err = db.ScanStructs(rows, &workspaces)
if err != nil {
return nil, fmt.Errorf("failed to scan workspace row: %w", err)
}
// Decrypt token
workspace.GitToken, err = db.decryptToken(encryptedToken)
if err != nil {
return nil, fmt.Errorf("failed to decrypt token: %w", err)
}
workspaces = append(workspaces, workspace)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating workspace rows: %w", err)
return nil, fmt.Errorf("failed to scan workspaces: %w", err)
}
return workspaces, nil

View File

@@ -10,7 +10,7 @@ import (
)
func TestWorkspaceOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{})
database, err := db.NewTestSQLiteDB(&mockSecrets{})
if err != nil {
t.Fatalf("failed to create test database: %v", err)
}

View File

@@ -308,7 +308,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc {
}
// Track what's being updated for logging
updates := make(map[string]interface{})
updates := make(map[string]any)
if req.Email != "" {
user.Email = req.Email

View File

@@ -15,21 +15,12 @@ import (
"github.com/stretchr/testify/require"
)
// Helper function to check if a user exists in a slice of users
func containsUser(users []*models.User, searchUser *models.User) bool {
for _, u := range users {
if u.ID == searchUser.ID &&
u.Email == searchUser.Email &&
u.DisplayName == searchUser.DisplayName &&
u.Role == searchUser.Role {
return true
}
}
return false
func TestAdminHandlers_Integration(t *testing.T) {
runWithDatabases(t, testAdminHandlers)
}
func TestAdminHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
func testAdminHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
t.Run("user management", func(t *testing.T) {
@@ -241,3 +232,16 @@ func TestAdminHandlers_Integration(t *testing.T) {
assert.Equal(t, http.StatusForbidden, rr.Code)
})
}
// Helper function to check if a user exists in a slice of users
func containsUser(users []*models.User, searchUser *models.User) bool {
for _, u := range users {
if u.ID == searchUser.ID &&
u.Email == searchUser.Email &&
u.DisplayName == searchUser.DisplayName &&
u.Role == searchUser.Role {
return true
}
}
return false
}

View File

@@ -19,7 +19,11 @@ import (
)
func TestAuthHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
runWithDatabases(t, testAuthHandlers)
}
func testAuthHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
t.Run("login", func(t *testing.T) {

View File

@@ -18,7 +18,11 @@ import (
)
func TestFileHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
runWithDatabases(t, testFileHandlers)
}
func testFileHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
t.Run("file operations", func(t *testing.T) {
@@ -192,7 +196,7 @@ func TestFileHandlers_Integration(t *testing.T) {
name string
method string
path string
body interface{}
body any
}{
{"list files", http.MethodGet, baseURL, nil},
{"get file", http.MethodGet, baseURL + "/test.md", nil},

View File

@@ -16,7 +16,11 @@ import (
)
func TestGitHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
runWithDatabases(t, testGitHandlers)
}
func testGitHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
t.Run("git operations", func(t *testing.T) {
@@ -123,7 +127,7 @@ func TestGitHandlers_Integration(t *testing.T) {
name string
method string
path string
body interface{}
body any
}{
{
name: "commit without token",

View File

@@ -37,7 +37,7 @@ func NewHandler(db db.Database, s storage.Manager) *Handler {
}
// respondJSON is a helper to send JSON responses
func respondJSON(w http.ResponseWriter, data interface{}) {
func respondJSON(w http.ResponseWriter, data any) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(data); err != nil {
respondError(w, "Failed to encode response", http.StatusInternalServerError)

View File

@@ -45,8 +45,13 @@ type testUser struct {
session *models.Session
}
type DatabaseConfig struct {
Type db.DBType
URL string
}
// setupTestHarness creates a new test environment
func setupTestHarness(t *testing.T) *testHarness {
func setupTestHarness(t *testing.T, dbConfig DatabaseConfig) *testHarness {
t.Helper()
// Create temporary directory for test files
@@ -61,10 +66,21 @@ func setupTestHarness(t *testing.T) *testHarness {
t.Fatalf("Failed to initialize secrets service: %v", err)
}
database, err := db.NewTestDB(":memory:", secretsSvc)
var database db.TestDatabase
switch dbConfig.Type {
case db.DBTypeSQLite:
database, err = db.NewTestSQLiteDB(secretsSvc)
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
case db.DBTypePostgres:
database, err = db.NewPostgresTestDB(dbConfig.URL, secretsSvc)
if err != nil {
t.Fatalf("Failed to initialize test database: %v", err)
}
default:
t.Fatalf("Unsupported database type: %s", dbConfig.Type)
}
if err := database.Migrate(); err != nil {
t.Fatalf("Failed to run migrations: %v", err)
@@ -99,7 +115,7 @@ func setupTestHarness(t *testing.T) *testHarness {
// Create test config
testConfig := &app.Config{
DBPath: ":memory:",
DBURL: "sqlite://:memory:",
WorkDir: tempDir,
StaticPath: "../testdata",
Port: "8081",
@@ -156,6 +172,32 @@ func (h *testHarness) teardown(t *testing.T) {
}
}
// runWithDatabases runs a test function with both SQLite and PostgreSQL databases
func runWithDatabases(t *testing.T, testFn func(*testing.T, DatabaseConfig)) {
// Get PostgreSQL connection URL from environment variable
postgresURL := os.Getenv("LEMMA_TEST_POSTGRES_URL")
// Always run with SQLite in-memory
t.Run("SQLite", func(t *testing.T) {
testFn(t, DatabaseConfig{
Type: db.DBTypeSQLite,
URL: "sqlite://:memory:",
})
})
// Run with PostgreSQL if connection URL is provided
if postgresURL != "" {
t.Run("PostgreSQL", func(t *testing.T) {
testFn(t, DatabaseConfig{
Type: db.DBTypePostgres,
URL: postgresURL,
})
})
} else {
t.Log("Skipping PostgreSQL tests, LEMMA_TEST_POSTGRES_URL environment variable not set")
}
}
// createTestUser creates a test user and returns the user and access token
func (h *testHarness) createTestUser(t *testing.T, email, password string, role models.UserRole) *testUser {
t.Helper()
@@ -195,7 +237,7 @@ func (h *testHarness) createTestUser(t *testing.T, email, password string, role
}
}
func (h *testHarness) newRequest(t *testing.T, method, path string, body interface{}) *http.Request {
func (h *testHarness) newRequest(t *testing.T, method, path string, body any) *http.Request {
t.Helper()
var reqBody []byte
@@ -246,7 +288,7 @@ func (h *testHarness) addCSRFCookie(t *testing.T, req *http.Request) string {
}
// makeRequest is the main helper for making JSON requests
func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, testUser *testUser) *httptest.ResponseRecorder {
func (h *testHarness) makeRequest(t *testing.T, method, path string, body any, testUser *testUser) *httptest.ResponseRecorder {
t.Helper()
req := h.newRequest(t, method, path, body)

View File

@@ -15,7 +15,11 @@ import (
)
func TestUserHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
runWithDatabases(t, testUserHandlers)
}
func testUserHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
currentEmail := h.RegularTestUser.userModel.Email

View File

@@ -15,7 +15,11 @@ import (
)
func TestWorkspaceHandlers_Integration(t *testing.T) {
h := setupTestHarness(t)
runWithDatabases(t, testWorkspaceHandlers)
}
func testWorkspaceHandlers(t *testing.T, dbConfig DatabaseConfig) {
h := setupTestHarness(t, dbConfig)
defer h.teardown(t)
t.Run("list workspaces", func(t *testing.T) {

View File

@@ -5,9 +5,9 @@ import "time"
// Session represents a user session in the database
type Session struct {
ID string // Unique session identifier
UserID int // ID of the user this session belongs to
RefreshToken string // The refresh token associated with this session
ExpiresAt time.Time // When this session expires
CreatedAt time.Time // When this session was created
ID string `db:"id"` // Unique session identifier
UserID int `db:"user_id"` // ID of the user this session belongs to
RefreshToken string `db:"refresh_token"` // The refresh token associated with this session
ExpiresAt time.Time `db:"expires_at"` // When this session expires
CreatedAt time.Time `db:"created_at,default"` // When this session was created
}

View File

@@ -20,13 +20,13 @@ const (
// User represents a user in the system
type User struct {
ID int `json:"id" validate:"required,min=1"`
Email string `json:"email" validate:"required,email"`
DisplayName string `json:"displayName"`
PasswordHash string `json:"-"`
Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"`
CreatedAt time.Time `json:"createdAt"`
LastWorkspaceID int `json:"lastWorkspaceId"`
ID int `json:"id" db:"id,default" validate:"required,min=1"`
Email string `json:"email" db:"email" validate:"required,email"`
DisplayName string `json:"displayName" db:"display_name"`
PasswordHash string `json:"-" db:"password_hash"`
Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"`
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"`
}
// Validate validates the user struct

View File

@@ -6,24 +6,24 @@ import (
// Workspace represents a user's workspace in the system
type Workspace struct {
ID int `json:"id" validate:"required,min=1"`
UserID int `json:"userId" validate:"required,min=1"`
Name string `json:"name" validate:"required"`
CreatedAt time.Time `json:"createdAt"`
LastOpenedFilePath string `json:"lastOpenedFilePath"`
ID int `json:"id" db:"id,default" validate:"required,min=1"`
UserID int `json:"userId" db:"user_id" validate:"required,min=1"`
Name string `json:"name" db:"name" validate:"required"`
CreatedAt time.Time `json:"createdAt" db:"created_at,default"`
LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"`
// Integrated settings
Theme string `json:"theme" validate:"oneof=light dark"`
AutoSave bool `json:"autoSave"`
ShowHiddenFiles bool `json:"showHiddenFiles"`
GitEnabled bool `json:"gitEnabled"`
GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"`
GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"`
GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"`
GitAutoCommit bool `json:"gitAutoCommit"`
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"`
GitCommitName string `json:"gitCommitName"`
GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"`
Theme string `json:"theme" db:"theme" validate:"oneof=light dark"`
AutoSave bool `json:"autoSave" db:"auto_save"`
ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"`
GitEnabled bool `json:"gitEnabled" db:"git_enabled"`
GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"`
GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"`
GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"`
GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"`
GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"`
GitCommitName string `json:"gitCommitName" db:"git_commit_name"`
GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"`
}
// Validate validates the workspace struct

View File

@@ -37,7 +37,7 @@ func (m MockDirInfo) Size() int64 { return m.size }
func (m MockDirInfo) Mode() fs.FileMode { return m.mode }
func (m MockDirInfo) ModTime() time.Time { return m.modTime }
func (m MockDirInfo) IsDir() bool { return m.isDir }
func (m MockDirInfo) Sys() interface{} { return nil }
func (m MockDirInfo) Sys() any { return nil }
type mockFS struct {
// Record operations for verification

20
server/run_integration_tests.sh Executable file
View File

@@ -0,0 +1,20 @@
#!/bin/bash
set -e
COMPOSE_FILE=docker-compose.test.yaml
if ! docker compose -f $COMPOSE_FILE ps postgres | grep -q "running"; then
docker compose -f $COMPOSE_FILE up -d
until docker compose -f $COMPOSE_FILE exec postgres pg_isready -U postgres; do
sleep 1
done
echo "PostgreSQL is ready!"
fi
export LEMMA_TEST_POSTGRES_URL="postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable"
echo "Running integration tests..."
go test -v -tags=test,integration ./...
docker compose -f $COMPOSE_FILE down