From 9d81b1036d678dc738481bb9f4cd5d6406f67732 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 23 Nov 2024 22:33:55 +0100 Subject: [PATCH] Refactor db init --- server/cmd/server/main.go | 19 +++++++++++++++++-- server/internal/db/db.go | 15 ++------------- server/internal/secrets/secrets.go | 8 ++++---- server/internal/secrets/secrets_test.go | 6 +++--- 4 files changed, 26 insertions(+), 22 deletions(-) diff --git a/server/cmd/server/main.go b/server/cmd/server/main.go index 7fbf006..39758cf 100644 --- a/server/cmd/server/main.go +++ b/server/cmd/server/main.go @@ -19,6 +19,7 @@ import ( "novamd/internal/config" "novamd/internal/db" "novamd/internal/handlers" + "novamd/internal/secrets" "novamd/internal/storage" ) @@ -29,12 +30,26 @@ func main() { log.Fatal("Failed to load configuration:", err) } + // Initialize secrets service + secretsService, err := secrets.NewService(cfg.EncryptionKey) + if err != nil { + log.Fatal("Failed to initialize secrets service:", err) + } + // Initialize database - database, err := db.Init(cfg.DBPath, cfg.EncryptionKey) + database, err := db.Init(cfg.DBPath, secretsService) if err != nil { log.Fatal(err) } - defer database.Close() + err = database.Migrate() + if err != nil { + log.Fatal("Failed to apply database migrations:", err) + } + defer func() { + if err := database.Close(); err != nil { + log.Printf("Error closing database: %v", err) + } + }() // Get or generate JWT signing key signingKey := cfg.JWTSigningKey diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 4c45ddc..04df282 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -3,7 +3,6 @@ package db import ( "database/sql" - "fmt" "novamd/internal/models" "novamd/internal/secrets" @@ -95,11 +94,11 @@ var ( // database represents the database connection type database struct { *sql.DB - secretsService secrets.Encryptor + secretsService secrets.Service } // Init initializes the database connection -func Init(dbPath string, encryptionKey string) (Database, error) { +func Init(dbPath string, secretsService secrets.Service) (Database, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, err @@ -109,21 +108,11 @@ func Init(dbPath string, encryptionKey string) (Database, error) { return nil, err } - // Initialize crypto service - secretsService, err := secrets.New(encryptionKey) - if err != nil { - return nil, fmt.Errorf("failed to initialize encryption: %w", err) - } - database := &database{ DB: db, secretsService: secretsService, } - if err := database.Migrate(); err != nil { - return nil, err - } - return database, nil } diff --git a/server/internal/secrets/secrets.go b/server/internal/secrets/secrets.go index abb81d4..2eab8d1 100644 --- a/server/internal/secrets/secrets.go +++ b/server/internal/secrets/secrets.go @@ -10,8 +10,8 @@ import ( "io" ) -// Encryptor is an interface for encrypting and decrypting strings -type Encryptor interface { +// Service is an interface for encrypting and decrypting strings +type Service interface { Encrypt(plaintext string) (string, error) Decrypt(ciphertext string) (string, error) } @@ -51,8 +51,8 @@ func decodeAndValidateKey(key string) ([]byte, error) { return keyBytes, nil } -// New creates a new Crypto instance with the provided base64-encoded key -func New(key string) (Encryptor, error) { +// NewService creates a new Encryptor instance with the provided base64-encoded key +func NewService(key string) (Service, error) { keyBytes, err := decodeAndValidateKey(key) if err != nil { return nil, err diff --git a/server/internal/secrets/secrets_test.go b/server/internal/secrets/secrets_test.go index ed09d98..f1db818 100644 --- a/server/internal/secrets/secrets_test.go +++ b/server/internal/secrets/secrets_test.go @@ -96,7 +96,7 @@ func TestNew(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - e, err := secrets.New(tc.key) + e, err := secrets.NewService(tc.key) if tc.wantErr { if err == nil { @@ -122,7 +122,7 @@ func TestNew(t *testing.T) { func TestEncryptDecrypt(t *testing.T) { // Generate a valid key for testing key := base64.StdEncoding.EncodeToString(make([]byte, 32)) - e, err := secrets.New(key) + e, err := secrets.NewService(key) if err != nil { t.Fatalf("failed to create Encryptor instance: %v", err) } @@ -194,7 +194,7 @@ func TestEncryptDecrypt(t *testing.T) { func TestDecryptInvalidCiphertext(t *testing.T) { key := base64.StdEncoding.EncodeToString(make([]byte, 32)) - e, err := secrets.New(key) + e, err := secrets.NewService(key) if err != nil { t.Fatalf("failed to create Encryptor instance: %v", err) }