diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index 80c9661..4998db2 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -9,7 +9,7 @@ import ( ) func TestMigrate(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to initialize database: %v", err) } diff --git a/server/internal/db/sessions_test.go b/server/internal/db/sessions_test.go index 67b2526..208d1ae 100644 --- a/server/internal/db/sessions_test.go +++ b/server/internal/db/sessions_test.go @@ -13,7 +13,7 @@ import ( ) func TestSessionOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/struct_query_test.go b/server/internal/db/struct_query_test.go index 7c28799..71e2c64 100644 --- a/server/internal/db/struct_query_test.go +++ b/server/internal/db/struct_query_test.go @@ -140,7 +140,7 @@ func TestStructTagsToFields(t *testing.T) { // TestStructQueries tests the struct-based query methods using the test database func TestStructQueries(t *testing.T) { // Setup test database - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } @@ -356,7 +356,7 @@ func TestStructQueries(t *testing.T) { // TestScanStructsErrors tests error handling for ScanStructs func TestScanStructsErrors(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } @@ -421,7 +421,7 @@ func TestScanStructsErrors(t *testing.T) { // TestEncryptedFields tests handling of encrypted fields func TestEncryptedFields(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } diff --git a/server/internal/db/system_test.go b/server/internal/db/system_test.go index 8e65023..9c54dd9 100644 --- a/server/internal/db/system_test.go +++ b/server/internal/db/system_test.go @@ -15,7 +15,7 @@ import ( ) func TestSystemOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/testdb.go b/server/internal/db/testdb.go index 14e13e1..0cdd2b8 100644 --- a/server/internal/db/testdb.go +++ b/server/internal/db/testdb.go @@ -4,7 +4,10 @@ package db import ( "database/sql" + "fmt" "lemma/internal/secrets" + "log" + "time" ) type TestDatabase interface { @@ -12,19 +15,80 @@ type TestDatabase interface { TestDB() *sql.DB } -func NewTestDB(secretsService secrets.Service) (TestDatabase, error) { +func NewTestSQLiteDB(secretsService secrets.Service) (TestDatabase, error) { db, err := Init(DBTypeSQLite, ":memory:", secretsService) if err != nil { return nil, err } - return &testDatabase{db.(*database)}, nil + return &testSQLiteDatabase{db.(*database)}, nil } -type testDatabase struct { +type testSQLiteDatabase struct { *database } -func (td *testDatabase) TestDB() *sql.DB { +func (td *testSQLiteDatabase) TestDB() *sql.DB { return td.DB } + +// NewPostgresTestDB creates a test database using PostgreSQL +func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, error) { + if dbURL == "" { + return nil, fmt.Errorf("postgres URL cannot be empty") + } + + db, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to open postgres database: %w", err) + } + + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to ping postgres database: %w", err) + } + + // Create a unique schema name for this test run to avoid conflicts + schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) + _, err = db.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create schema: %w", err) + } + + // Set search path to use our schema + _, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName)) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to set search path: %w", err) + } + + // Create database instance + database := &postgresTestDatabase{ + database: &database{DB: db, secretsService: secretsSvc, dbType: DBTypePostgres}, + schemaName: schemaName, + } + + return database, nil +} + +// postgresTestDatabase extends the regular postgres database to add test-specific cleanup +type postgresTestDatabase struct { + *database + schemaName string +} + +// Close closes the database connection and drops the test schema +func (db *postgresTestDatabase) Close() error { + _, err := db.TestDB().Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", db.schemaName)) + if err != nil { + log.Printf("Failed to drop schema %s: %v", db.schemaName, err) + } + + return db.TestDB().Close() +} + +// TestDB returns the underlying *sql.DB instance +func (db *postgresTestDatabase) TestDB() *sql.DB { + return db.DB +} diff --git a/server/internal/db/users_test.go b/server/internal/db/users_test.go index 5709ab4..8dfe4f0 100644 --- a/server/internal/db/users_test.go +++ b/server/internal/db/users_test.go @@ -10,7 +10,7 @@ import ( ) func TestUserOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/workspaces_test.go b/server/internal/db/workspaces_test.go index 3d6fd9f..fda1c98 100644 --- a/server/internal/db/workspaces_test.go +++ b/server/internal/db/workspaces_test.go @@ -10,7 +10,7 @@ import ( ) func TestWorkspaceOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 0bb8bd3..43e2f64 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -61,7 +61,7 @@ func setupTestHarness(t *testing.T) *testHarness { t.Fatalf("Failed to initialize secrets service: %v", err) } - database, err := db.NewTestDB(secretsSvc) + database, err := db.NewTestSQLiteDB(secretsSvc) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) }