diff --git a/server/internal/db/testdb.go b/server/internal/db/testdb.go index 0cdd2b8..201d3fe 100644 --- a/server/internal/db/testdb.go +++ b/server/internal/db/testdb.go @@ -7,6 +7,7 @@ import ( "fmt" "lemma/internal/secrets" "log" + "strings" "time" ) @@ -38,7 +39,37 @@ func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, return nil, fmt.Errorf("postgres URL cannot be empty") } - db, err := sql.Open("postgres", dbURL) + initialDB, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to open postgres database: %w", err) + } + + if err := initialDB.Ping(); err != nil { + initialDB.Close() + return nil, fmt.Errorf("failed to ping postgres database: %w", err) + } + + // Create a unique schema name for this test run to avoid conflicts + schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) + _, err = initialDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) + if err != nil { + initialDB.Close() + return nil, fmt.Errorf("failed to create schema: %w", err) + } + + // Close the initial connection and create a new one with the schema set + initialDB.Close() + + var newDBURL string + if strings.Contains(dbURL, "?") { + // URL already has parameters + newDBURL = fmt.Sprintf("%s&search_path=%s", dbURL, schemaName) + } else { + // URL has no parameters yet + newDBURL = fmt.Sprintf("%s?search_path=%s", dbURL, schemaName) + } + + db, err := sql.Open("postgres", newDBURL) if err != nil { return nil, fmt.Errorf("failed to open postgres database: %w", err) } @@ -48,14 +79,6 @@ func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, return nil, fmt.Errorf("failed to ping postgres database: %w", err) } - // Create a unique schema name for this test run to avoid conflicts - schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) - _, err = db.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) - if err != nil { - db.Close() - return nil, fmt.Errorf("failed to create schema: %w", err) - } - // Set search path to use our schema _, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName)) if err != nil {