mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 23:44:22 +00:00
Fix tests for db type
This commit is contained in:
@@ -2,6 +2,7 @@ package app_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"lemma/internal/app"
|
"lemma/internal/app"
|
||||||
|
"lemma/internal/db"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -17,7 +18,7 @@ func TestDefaultConfig(t *testing.T) {
|
|||||||
got interface{}
|
got interface{}
|
||||||
expected interface{}
|
expected interface{}
|
||||||
}{
|
}{
|
||||||
{"DBPath", cfg.DBURL, "./lemma.db"},
|
{"DBPath", cfg.DBURL, "sqlite://lemma.db"},
|
||||||
{"WorkDir", cfg.WorkDir, "./data"},
|
{"WorkDir", cfg.WorkDir, "./data"},
|
||||||
{"StaticPath", cfg.StaticPath, "../app/dist"},
|
{"StaticPath", cfg.StaticPath, "../app/dist"},
|
||||||
{"Port", cfg.Port, "8080"},
|
{"Port", cfg.Port, "8080"},
|
||||||
@@ -47,7 +48,7 @@ func TestLoad(t *testing.T) {
|
|||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
envVars := []string{
|
envVars := []string{
|
||||||
"LEMMA_ENV",
|
"LEMMA_ENV",
|
||||||
"LEMMA_DB_PATH",
|
"LEMMA_DB_URL",
|
||||||
"LEMMA_WORKDIR",
|
"LEMMA_WORKDIR",
|
||||||
"LEMMA_STATIC_PATH",
|
"LEMMA_STATIC_PATH",
|
||||||
"LEMMA_PORT",
|
"LEMMA_PORT",
|
||||||
@@ -81,8 +82,8 @@ func TestLoad(t *testing.T) {
|
|||||||
t.Fatalf("Load() error = %v", err)
|
t.Fatalf("Load() error = %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.DBURL != "./lemma.db" {
|
if cfg.DBURL != "sqlite://lemma.db" {
|
||||||
t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "./lemma.db")
|
t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "sqlite://lemma.db")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -93,7 +94,7 @@ func TestLoad(t *testing.T) {
|
|||||||
// Set all environment variables
|
// Set all environment variables
|
||||||
envs := map[string]string{
|
envs := map[string]string{
|
||||||
"LEMMA_ENV": "development",
|
"LEMMA_ENV": "development",
|
||||||
"LEMMA_DB_PATH": "/custom/db/path.db",
|
"LEMMA_DB_URL": "sqlite:///custom/db/path.db",
|
||||||
"LEMMA_WORKDIR": "/custom/work/dir",
|
"LEMMA_WORKDIR": "/custom/work/dir",
|
||||||
"LEMMA_STATIC_PATH": "/custom/static/path",
|
"LEMMA_STATIC_PATH": "/custom/static/path",
|
||||||
"LEMMA_PORT": "3000",
|
"LEMMA_PORT": "3000",
|
||||||
@@ -122,7 +123,8 @@ func TestLoad(t *testing.T) {
|
|||||||
expected interface{}
|
expected interface{}
|
||||||
}{
|
}{
|
||||||
{"IsDevelopment", cfg.IsDevelopment, true},
|
{"IsDevelopment", cfg.IsDevelopment, true},
|
||||||
{"DBPath", cfg.DBURL, "/custom/db/path.db"},
|
{"DBURL", cfg.DBURL, "/custom/db/path.db"},
|
||||||
|
{"DBType", cfg.DBType, db.DBTypeSQLite},
|
||||||
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
|
{"WorkDir", cfg.WorkDir, "/custom/work/dir"},
|
||||||
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
|
{"StaticPath", cfg.StaticPath, "/custom/static/path"},
|
||||||
{"Port", cfg.Port, "3000"},
|
{"Port", cfg.Port, "3000"},
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ func initSecretsService(cfg *Config) (secrets.Service, error) {
|
|||||||
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
|
func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) {
|
||||||
logging.Debug("initializing database", "path", cfg.DBURL)
|
logging.Debug("initializing database", "path", cfg.DBURL)
|
||||||
|
|
||||||
database, err := db.Init(cfg.DBURL, secretsService)
|
database, err := db.Init(cfg.DBType, cfg.DBURL, secretsService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
return nil, fmt.Errorf("failed to initialize database: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -119,10 +119,31 @@ type database struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Init initializes the database connection
|
// Init initializes the database connection
|
||||||
func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) {
|
||||||
log := getLogger()
|
|
||||||
|
|
||||||
db, err := sql.Open("sqlite3", dbPath)
|
switch dbType {
|
||||||
|
case DBTypeSQLite:
|
||||||
|
db, err := initSQLite(dbURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize SQLite database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
database := &database{
|
||||||
|
DB: db,
|
||||||
|
secretsService: secretsService,
|
||||||
|
dbType: dbType,
|
||||||
|
}
|
||||||
|
return database, nil
|
||||||
|
case DBTypePostgres:
|
||||||
|
return nil, fmt.Errorf("postgres database not supported yet")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("unsupported database type: %s", dbType)
|
||||||
|
}
|
||||||
|
|
||||||
|
func initSQLite(dbURL string) (*sql.DB, error) {
|
||||||
|
log := getLogger()
|
||||||
|
db, err := sql.Open("sqlite3", dbURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
@@ -136,13 +157,7 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) {
|
|||||||
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
return nil, fmt.Errorf("failed to enable foreign keys: %w", err)
|
||||||
}
|
}
|
||||||
log.Debug("foreign keys enabled")
|
log.Debug("foreign keys enabled")
|
||||||
|
return db, nil
|
||||||
database := &database{
|
|
||||||
DB: db,
|
|
||||||
secretsService: secretsService,
|
|
||||||
}
|
|
||||||
|
|
||||||
return database, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the database connection
|
// Close closes the database connection
|
||||||
|
|||||||
@@ -25,9 +25,8 @@ func (db *database) Migrate() error {
|
|||||||
|
|
||||||
var m *migrate.Migrate
|
var m *migrate.Migrate
|
||||||
|
|
||||||
driverName := db.dbType
|
switch db.dbType {
|
||||||
switch driverName {
|
case DBTypePostgres:
|
||||||
case "postgres":
|
|
||||||
driver, err := postgres.WithInstance(db.DB, &postgres.Config{})
|
driver, err := postgres.WithInstance(db.DB, &postgres.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create postgres driver: %w", err)
|
return fmt.Errorf("failed to create postgres driver: %w", err)
|
||||||
@@ -37,7 +36,7 @@ func (db *database) Migrate() error {
|
|||||||
return fmt.Errorf("failed to create migrate instance: %w", err)
|
return fmt.Errorf("failed to create migrate instance: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
case "sqlite3":
|
case DBTypeSQLite:
|
||||||
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create sqlite driver: %w", err)
|
return fmt.Errorf("failed to create sqlite driver: %w", err)
|
||||||
@@ -48,7 +47,7 @@ func (db *database) Migrate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported database driver: %s", driverName)
|
return fmt.Errorf("unsupported database driver: %s", db.dbType)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
if err := m.Up(); err != nil && err != migrate.ErrNoChange {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
-- 001_initial_schema.up.sql
|
-- 001_initial_schema.up.sql
|
||||||
-- Create users table
|
-- Create users table
|
||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS users (
|
||||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
email TEXT NOT NULL UNIQUE,
|
email TEXT NOT NULL UNIQUE,
|
||||||
display_name TEXT,
|
display_name TEXT,
|
||||||
password_hash TEXT NOT NULL,
|
password_hash TEXT NOT NULL,
|
||||||
@@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS users (
|
|||||||
|
|
||||||
-- Create workspaces table with integrated settings
|
-- Create workspaces table with integrated settings
|
||||||
CREATE TABLE IF NOT EXISTS workspaces (
|
CREATE TABLE IF NOT EXISTS workspaces (
|
||||||
id INTEGER PRIMARY KEY AUTO_INCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
user_id INTEGER NOT NULL,
|
user_id INTEGER NOT NULL,
|
||||||
name TEXT NOT NULL,
|
name TEXT NOT NULL,
|
||||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
|||||||
@@ -1,53 +1,36 @@
|
|||||||
package db_test
|
package db_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
|
||||||
|
|
||||||
"lemma/internal/db"
|
"lemma/internal/db"
|
||||||
|
|
||||||
_ "lemma/internal/testenv"
|
_ "lemma/internal/testenv"
|
||||||
|
"testing"
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestMigrate(t *testing.T) {
|
func TestMigrate(t *testing.T) {
|
||||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
database, err := db.NewTestDB(&mockSecrets{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to initialize database: %v", err)
|
t.Fatalf("failed to initialize database: %v", err)
|
||||||
}
|
}
|
||||||
defer database.Close()
|
defer database.Close()
|
||||||
|
|
||||||
t.Run("migrations are applied in order", func(t *testing.T) {
|
|
||||||
if err := database.Migrate(); err != nil {
|
|
||||||
t.Fatalf("failed to run initial migrations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check migration version
|
|
||||||
var version int
|
|
||||||
err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get migration version: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if version != 1 { // Current number of migrations in production code
|
|
||||||
t.Errorf("expected migration version 1, got %d", version)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify number of migration entries matches versions applied
|
|
||||||
var count int
|
|
||||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to count migrations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if count != 1 {
|
|
||||||
t.Errorf("expected 1 migration entries, got %d", count)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("migrations create expected schema", func(t *testing.T) {
|
t.Run("migrations create expected schema", func(t *testing.T) {
|
||||||
|
// Run migrations
|
||||||
|
if err := database.Migrate(); err != nil {
|
||||||
|
t.Fatalf("failed to run migrations: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Verify tables exist
|
// Verify tables exist
|
||||||
tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"}
|
tables := []string{
|
||||||
|
"users",
|
||||||
|
"workspaces",
|
||||||
|
"sessions",
|
||||||
|
"system_settings",
|
||||||
|
// Note: golang-migrate uses its own migrations table
|
||||||
|
"schema_migrations",
|
||||||
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
if !tableExists(t, database, table) {
|
if !tableExists(t, database, table) {
|
||||||
t.Errorf("table %q does not exist", table)
|
t.Errorf("table %q does not exist", table)
|
||||||
@@ -63,91 +46,32 @@ func TestMigrate(t *testing.T) {
|
|||||||
{"sessions", "idx_sessions_expires_at"},
|
{"sessions", "idx_sessions_expires_at"},
|
||||||
{"sessions", "idx_sessions_refresh_token"},
|
{"sessions", "idx_sessions_refresh_token"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, idx := range indexes {
|
for _, idx := range indexes {
|
||||||
if !indexExists(t, database, idx.table, idx.name) {
|
if !indexExists(t, database, idx.table, idx.name) {
|
||||||
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
|
t.Errorf("index %q on table %q does not exist", idx.name, idx.table)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("migrations are idempotent", func(t *testing.T) {
|
|
||||||
// Run migrations again
|
|
||||||
if err := database.Migrate(); err != nil {
|
|
||||||
t.Fatalf("failed to re-run migrations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify migration count hasn't changed
|
|
||||||
var count int
|
|
||||||
err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to count migrations: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if count != 1 {
|
|
||||||
t.Errorf("expected 1 migration entries, got %d", count)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("rollback on migration failure", func(t *testing.T) {
|
|
||||||
// Create a test table that would conflict with a failing migration
|
|
||||||
_, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to create test table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start transaction
|
|
||||||
tx, err := database.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to start transaction: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try operations that should fail and rollback
|
|
||||||
_, err = tx.Exec(`
|
|
||||||
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
|
||||||
INSERT INTO nonexistent_table VALUES (1);
|
|
||||||
`)
|
|
||||||
if err == nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatal("expected migration to fail")
|
|
||||||
}
|
|
||||||
tx.Rollback()
|
|
||||||
|
|
||||||
// Verify the migration version hasn't changed
|
|
||||||
var version int
|
|
||||||
err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get migration version: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if version != 1 {
|
|
||||||
t.Errorf("expected migration version to remain at 1, got %d", version)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
|
func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
err := database.TestDB().QueryRow(`
|
err := database.TestDB().QueryRow(`
|
||||||
SELECT name FROM sqlite_master
|
SELECT name FROM sqlite_master
|
||||||
WHERE type='table' AND name=?`,
|
WHERE type='table' AND name=?`,
|
||||||
tableName,
|
tableName,
|
||||||
).Scan(&name)
|
).Scan(&name)
|
||||||
|
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
|
func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
var name string
|
var name string
|
||||||
err := database.TestDB().QueryRow(`
|
err := database.TestDB().QueryRow(`
|
||||||
SELECT name FROM sqlite_master
|
SELECT name FROM sqlite_master
|
||||||
WHERE type='index' AND tbl_name=? AND name=?`,
|
WHERE type='index' AND tbl_name=? AND name=?`,
|
||||||
tableName, indexName,
|
tableName, indexName,
|
||||||
).Scan(&name)
|
).Scan(&name)
|
||||||
|
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSessionOperations(t *testing.T) {
|
func TestSessionOperations(t *testing.T) {
|
||||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
database, err := db.NewTestDB(&mockSecrets{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test database: %v", err)
|
t.Fatalf("failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestSystemOperations(t *testing.T) {
|
func TestSystemOperations(t *testing.T) {
|
||||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
database, err := db.NewTestDB(&mockSecrets{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test database: %v", err)
|
t.Fatalf("failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,8 @@ type TestDatabase interface {
|
|||||||
TestDB() *sql.DB
|
TestDB() *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) {
|
func NewTestDB(secretsService secrets.Service) (TestDatabase, error) {
|
||||||
db, err := Init(dbPath, secretsService)
|
db, err := Init(DBTypeSQLite, ":memory:", secretsService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestUserOperations(t *testing.T) {
|
func TestUserOperations(t *testing.T) {
|
||||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
database, err := db.NewTestDB(&mockSecrets{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test database: %v", err)
|
t.Fatalf("failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestWorkspaceOperations(t *testing.T) {
|
func TestWorkspaceOperations(t *testing.T) {
|
||||||
database, err := db.NewTestDB(":memory:", &mockSecrets{})
|
database, err := db.NewTestDB(&mockSecrets{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test database: %v", err)
|
t.Fatalf("failed to create test database: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func setupTestHarness(t *testing.T) *testHarness {
|
|||||||
t.Fatalf("Failed to initialize secrets service: %v", err)
|
t.Fatalf("Failed to initialize secrets service: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
database, err := db.NewTestDB(":memory:", secretsSvc)
|
database, err := db.NewTestDB(secretsSvc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to initialize test database: %v", err)
|
t.Fatalf("Failed to initialize test database: %v", err)
|
||||||
}
|
}
|
||||||
@@ -99,7 +99,7 @@ func setupTestHarness(t *testing.T) *testHarness {
|
|||||||
|
|
||||||
// Create test config
|
// Create test config
|
||||||
testConfig := &app.Config{
|
testConfig := &app.Config{
|
||||||
DBURL: ":memory:",
|
DBURL: "sqlite://:memory:",
|
||||||
WorkDir: tempDir,
|
WorkDir: tempDir,
|
||||||
StaticPath: "../testdata",
|
StaticPath: "../testdata",
|
||||||
Port: "8081",
|
Port: "8081",
|
||||||
|
|||||||
Reference in New Issue
Block a user