Fix tests for db type

This commit is contained in:
2025-02-22 22:32:38 +01:00
parent d47f7d7fb0
commit 25defa5b65
12 changed files with 68 additions and 128 deletions

View File

@@ -2,6 +2,7 @@ package app_test
import ( import (
"lemma/internal/app" "lemma/internal/app"
"lemma/internal/db"
"os" "os"
"testing" "testing"
"time" "time"
@@ -17,7 +18,7 @@ func TestDefaultConfig(t *testing.T) {
got interface{} got interface{}
expected interface{} expected interface{}
}{ }{
{"DBPath", cfg.DBURL, "./lemma.db"}, {"DBPath", cfg.DBURL, "sqlite://lemma.db"},
{"WorkDir", cfg.WorkDir, "./data"}, {"WorkDir", cfg.WorkDir, "./data"},
{"StaticPath", cfg.StaticPath, "../app/dist"}, {"StaticPath", cfg.StaticPath, "../app/dist"},
{"Port", cfg.Port, "8080"}, {"Port", cfg.Port, "8080"},
@@ -47,7 +48,7 @@ func TestLoad(t *testing.T) {
cleanup := func() { cleanup := func() {
envVars := []string{ envVars := []string{
"LEMMA_ENV", "LEMMA_ENV",
"LEMMA_DB_PATH", "LEMMA_DB_URL",
"LEMMA_WORKDIR", "LEMMA_WORKDIR",
"LEMMA_STATIC_PATH", "LEMMA_STATIC_PATH",
"LEMMA_PORT", "LEMMA_PORT",
@@ -81,8 +82,8 @@ func TestLoad(t *testing.T) {
t.Fatalf("Load() error = %v", err) t.Fatalf("Load() error = %v", err)
} }
if cfg.DBURL != "./lemma.db" { if cfg.DBURL != "sqlite://lemma.db" {
t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "./lemma.db") t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "sqlite://lemma.db")
} }
}) })
@@ -93,7 +94,7 @@ func TestLoad(t *testing.T) {
// Set all environment variables // Set all environment variables
envs := map[string]string{ envs := map[string]string{
"LEMMA_ENV": "development", "LEMMA_ENV": "development",
"LEMMA_DB_PATH": "/custom/db/path.db", "LEMMA_DB_URL": "sqlite:///custom/db/path.db",
"LEMMA_WORKDIR": "/custom/work/dir", "LEMMA_WORKDIR": "/custom/work/dir",
"LEMMA_STATIC_PATH": "/custom/static/path", "LEMMA_STATIC_PATH": "/custom/static/path",
"LEMMA_PORT": "3000", "LEMMA_PORT": "3000",
@@ -122,7 +123,8 @@ func TestLoad(t *testing.T) {
expected interface{} expected interface{}
}{ }{
{"IsDevelopment", cfg.IsDevelopment, true}, {"IsDevelopment", cfg.IsDevelopment, true},
{"DBPath", cfg.DBURL, "/custom/db/path.db"}, {"DBURL", cfg.DBURL, "/custom/db/path.db"},
{"DBType", cfg.DBType, db.DBTypeSQLite},
{"WorkDir", cfg.WorkDir, "/custom/work/dir"}, {"WorkDir", cfg.WorkDir, "/custom/work/dir"},
{"StaticPath", cfg.StaticPath, "/custom/static/path"}, {"StaticPath", cfg.StaticPath, "/custom/static/path"},
{"Port", cfg.Port, "3000"}, {"Port", cfg.Port, "3000"},

View File

@@ -30,7 +30,7 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) { func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
logging.Debug("initializing database", "path", cfg.DBURL) logging.Debug("initializing database", "path", cfg.DBURL)
database, err := db.Init(cfg.DBURL, secretsService) database, err := db.Init(cfg.DBType, cfg.DBURL, secretsService)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize database: %w", err) return nil, fmt.Errorf("failed to initialize database: %w", err)
} }

View File

@@ -119,10 +119,31 @@ type database struct {
} }
// Init initializes the database connection // Init initializes the database connection
func Init(dbPath string, secretsService secrets.Service) (Database, error) { func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) {
log := getLogger()
db, err := sql.Open("sqlite3", dbPath) switch dbType {
case DBTypeSQLite:
db, err := initSQLite(dbURL)
if err != nil {
return nil, fmt.Errorf("failed to initialize SQLite database: %w", err)
}
database := &database{
DB: db,
secretsService: secretsService,
dbType: dbType,
}
return database, nil
case DBTypePostgres:
return nil, fmt.Errorf("postgres database not supported yet")
}
return nil, fmt.Errorf("unsupported database type: %s", dbType)
}
func initSQLite(dbURL string) (*sql.DB, error) {
log := getLogger()
db, err := sql.Open("sqlite3", dbURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err) return nil, fmt.Errorf("failed to open database: %w", err)
} }
@@ -136,13 +157,7 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
return nil, fmt.Errorf("failed to enable foreign keys: %w", err) return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
} }
log.Debug("foreign keys enabled") log.Debug("foreign keys enabled")
return db, nil
database := &database{
DB: db,
secretsService: secretsService,
}
return database, nil
} }
// Close closes the database connection // Close closes the database connection

View File

@@ -25,9 +25,8 @@ func (db *database) Migrate() error {
var m *migrate.Migrate var m *migrate.Migrate
driverName := db.dbType switch db.dbType {
switch driverName { case DBTypePostgres:
case "postgres":
driver, err := postgres.WithInstance(db.DB, &postgres.Config{}) driver, err := postgres.WithInstance(db.DB, &postgres.Config{})
if err != nil { if err != nil {
return fmt.Errorf("failed to create postgres driver: %w", err) return fmt.Errorf("failed to create postgres driver: %w", err)
@@ -37,7 +36,7 @@ func (db *database) Migrate() error {
return fmt.Errorf("failed to create migrate instance: %w", err) return fmt.Errorf("failed to create migrate instance: %w", err)
} }
case "sqlite3": case DBTypeSQLite:
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
if err != nil { if err != nil {
return fmt.Errorf("failed to create sqlite driver: %w", err) return fmt.Errorf("failed to create sqlite driver: %w", err)
@@ -48,7 +47,7 @@ func (db *database) Migrate() error {
} }
default: default:
return fmt.Errorf("unsupported database driver: %s", driverName) return fmt.Errorf("unsupported database driver: %s", db.dbType)
} }
if err := m.Up(); err != nil && err != migrate.ErrNoChange { if err := m.Up(); err != nil && err != migrate.ErrNoChange {

View File

@@ -1,7 +1,7 @@
-- 001_initial_schema.up.sql -- 001_initial_schema.up.sql
-- Create users table -- Create users table
CREATE TABLE IF NOT EXISTS users ( CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTO_INCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
email TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE,
display_name TEXT, display_name TEXT,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
@@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS users (
-- Create workspaces table with integrated settings -- Create workspaces table with integrated settings
CREATE TABLE IF NOT EXISTS workspaces ( CREATE TABLE IF NOT EXISTS workspaces (
id INTEGER PRIMARY KEY AUTO_INCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL, user_id INTEGER NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,

View File

@@ -1,53 +1,36 @@
package db_test package db_test
import ( import (
"testing"
"lemma/internal/db" "lemma/internal/db"
_ "lemma/internal/testenv" _ "lemma/internal/testenv"
"testing"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
func TestMigrate(t *testing.T) { func TestMigrate(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{}) database, err := db.NewTestDB(&mockSecrets{})
if err != nil { if err != nil {
t.Fatalf("failed to initialize database: %v", err) t.Fatalf("failed to initialize database: %v", err)
} }
defer database.Close() defer database.Close()
t.Run("migrations are applied in order", func(t *testing.T) {
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run initial migrations: %v", err)
}
// Check migration version
var version int
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 1 { // Current number of migrations in production code
t.Errorf("expected migration version 1, got %d", version)
}
// Verify number of migration entries matches versions applied
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration entries, got %d", count)
}
})
t.Run("migrations create expected schema", func(t *testing.T) { t.Run("migrations create expected schema", func(t *testing.T) {
// Run migrations
if err := database.Migrate(); err != nil {
t.Fatalf("failed to run migrations: %v", err)
}
// Verify tables exist // Verify tables exist
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"} tables := []string{
"users",
"workspaces",
"sessions",
"system_settings",
// Note: golang-migrate uses its own migrations table
"schema_migrations",
}
for _, table := range tables { for _, table := range tables {
if !tableExists(t, database, table) { if !tableExists(t, database, table) {
t.Errorf("table %q does not exist", table) t.Errorf("table %q does not exist", table)
@@ -63,91 +46,32 @@ func TestMigrate(t *testing.T) {
{"sessions", "idx_sessions_expires_at"}, {"sessions", "idx_sessions_expires_at"},
{"sessions", "idx_sessions_refresh_token"}, {"sessions", "idx_sessions_refresh_token"},
} }
for _, idx := range indexes { for _, idx := range indexes {
if !indexExists(t, database, idx.table, idx.name) { if !indexExists(t, database, idx.table, idx.name) {
t.Errorf("index %q on table %q does not exist", idx.name, idx.table) t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
} }
} }
}) })
t.Run("migrations are idempotent", func(t *testing.T) {
// Run migrations again
if err := database.Migrate(); err != nil {
t.Fatalf("failed to re-run migrations: %v", err)
}
// Verify migration count hasn't changed
var count int
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
if err != nil {
t.Fatalf("failed to count migrations: %v", err)
}
if count != 1 {
t.Errorf("expected 1 migration entries, got %d", count)
}
})
t.Run("rollback on migration failure", func(t *testing.T) {
// Create a test table that would conflict with a failing migration
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
if err != nil {
t.Fatalf("failed to create test table: %v", err)
}
// Start transaction
tx, err := database.Begin()
if err != nil {
t.Fatalf("failed to start transaction: %v", err)
}
// Try operations that should fail and rollback
_, err = tx.Exec(`
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
INSERT INTO nonexistent_table VALUES (1);
`)
if err == nil {
tx.Rollback()
t.Fatal("expected migration to fail")
}
tx.Rollback()
// Verify the migration version hasn't changed
var version int
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
if err != nil {
t.Fatalf("failed to get migration version: %v", err)
}
if version != 1 {
t.Errorf("expected migration version to remain at 1, got %d", version)
}
})
} }
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool { func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
t.Helper() t.Helper()
var name string var name string
err := database.TestDB().QueryRow(` err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master SELECT name FROM sqlite_master
WHERE type='table' AND name=?`, WHERE type='table' AND name=?`,
tableName, tableName,
).Scan(&name) ).Scan(&name)
return err == nil return err == nil
} }
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool { func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
t.Helper() t.Helper()
var name string var name string
err := database.TestDB().QueryRow(` err := database.TestDB().QueryRow(`
SELECT name FROM sqlite_master SELECT name FROM sqlite_master
WHERE type='index' AND tbl_name=? AND name=?`, WHERE type='index' AND tbl_name=? AND name=?`,
tableName, indexName, tableName, indexName,
).Scan(&name) ).Scan(&name)
return err == nil return err == nil
} }

View File

@@ -13,7 +13,7 @@ import (
) )
func TestSessionOperations(t *testing.T) { func TestSessionOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{}) database, err := db.NewTestDB(&mockSecrets{})
if err != nil { if err != nil {
t.Fatalf("failed to create test database: %v", err) t.Fatalf("failed to create test database: %v", err)
} }

View File

@@ -15,7 +15,7 @@ import (
) )
func TestSystemOperations(t *testing.T) { func TestSystemOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{}) database, err := db.NewTestDB(&mockSecrets{})
if err != nil { if err != nil {
t.Fatalf("failed to create test database: %v", err) t.Fatalf("failed to create test database: %v", err)
} }

View File

@@ -12,8 +12,8 @@ type TestDatabase interface {
TestDB() *sql.DB TestDB() *sql.DB
} }
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) { func NewTestDB(secretsService secrets.Service) (TestDatabase, error) {
db, err := Init(dbPath, secretsService) db, err := Init(DBTypeSQLite, ":memory:", secretsService)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -10,7 +10,7 @@ import (
) )
func TestUserOperations(t *testing.T) { func TestUserOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{}) database, err := db.NewTestDB(&mockSecrets{})
if err != nil { if err != nil {
t.Fatalf("failed to create test database: %v", err) t.Fatalf("failed to create test database: %v", err)
} }

View File

@@ -10,7 +10,7 @@ import (
) )
func TestWorkspaceOperations(t *testing.T) { func TestWorkspaceOperations(t *testing.T) {
database, err := db.NewTestDB(":memory:", &mockSecrets{}) database, err := db.NewTestDB(&mockSecrets{})
if err != nil { if err != nil {
t.Fatalf("failed to create test database: %v", err) t.Fatalf("failed to create test database: %v", err)
} }

View File

@@ -61,7 +61,7 @@ func setupTestHarness(t *testing.T) *testHarness {
t.Fatalf("Failed to initialize secrets service: %v", err) t.Fatalf("Failed to initialize secrets service: %v", err)
} }
database, err := db.NewTestDB(":memory:", secretsSvc) database, err := db.NewTestDB(secretsSvc)
if err != nil { if err != nil {
t.Fatalf("Failed to initialize test database: %v", err) t.Fatalf("Failed to initialize test database: %v", err)
} }
@@ -99,7 +99,7 @@ func setupTestHarness(t *testing.T) *testHarness {
// Create test config // Create test config
testConfig := &app.Config{ testConfig := &app.Config{
DBURL: ":memory:", DBURL: "sqlite://:memory:",
WorkDir: tempDir, WorkDir: tempDir,
StaticPath: "../testdata", StaticPath: "../testdata",
Port: "8081", Port: "8081",