diff --git a/server/internal/config/config.go b/server/internal/config/config.go index 0cfa5f7..41c21a6 100644 --- a/server/internal/config/config.go +++ b/server/internal/config/config.go @@ -3,12 +3,11 @@ package config import ( "fmt" + "novamd/internal/secrets" "os" "strconv" "strings" "time" - - "novamd/internal/crypto" ) // Config holds the configuration for the application @@ -48,7 +47,7 @@ func (c *Config) Validate() error { } // Validate encryption key - if err := crypto.ValidateKey(c.EncryptionKey); err != nil { + if err := secrets.ValidateKey(c.EncryptionKey); err != nil { return fmt.Errorf("invalid NOVAMD_ENCRYPTION_KEY: %w", err) } diff --git a/server/internal/crypto/crypto.go b/server/internal/crypto/crypto.go deleted file mode 100644 index 76cf338..0000000 --- a/server/internal/crypto/crypto.go +++ /dev/null @@ -1,114 +0,0 @@ -package crypto - -import ( - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "encoding/base64" - "fmt" - "io" -) - -var ( - ErrKeyRequired = fmt.Errorf("encryption key is required") - ErrInvalidKeySize = fmt.Errorf("encryption key must be 32 bytes (256 bits) when decoded") -) - -type Crypto struct { - key []byte -} - -// ValidateKey checks if the provided key is suitable for AES-256 -func ValidateKey(key string) error { - if key == "" { - return ErrKeyRequired - } - - // Attempt to decode base64 - keyBytes, err := base64.StdEncoding.DecodeString(key) - if err != nil { - return fmt.Errorf("invalid base64 encoding: %w", err) - } - - if len(keyBytes) != 32 { - return fmt.Errorf("%w: got %d bytes", ErrInvalidKeySize, len(keyBytes)) - } - - // Verify the key can be used for AES - _, err = aes.NewCipher(keyBytes) - if err != nil { - return fmt.Errorf("invalid encryption key: %w", err) - } - - return nil -} - -// New creates a new Crypto instance with the provided base64-encoded key -func New(key string) (*Crypto, error) { - if err := ValidateKey(key); err != nil { - return nil, err - } - - keyBytes, _ := base64.StdEncoding.DecodeString(key) - return &Crypto{key: keyBytes}, nil -} - -// Encrypt encrypts the plaintext using AES-256-GCM -func (c *Crypto) Encrypt(plaintext string) (string, error) { - if plaintext == "" { - return "", nil - } - - block, err := aes.NewCipher(c.key) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err := io.ReadFull(rand.Reader, nonce); err != nil { - return "", err - } - - ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil) - return base64.StdEncoding.EncodeToString(ciphertext), nil -} - -// Decrypt decrypts the ciphertext using AES-256-GCM -func (c *Crypto) Decrypt(ciphertext string) (string, error) { - if ciphertext == "" { - return "", nil - } - - data, err := base64.StdEncoding.DecodeString(ciphertext) - if err != nil { - return "", err - } - - block, err := aes.NewCipher(c.key) - if err != nil { - return "", err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return "", err - } - - nonceSize := gcm.NonceSize() - if len(data) < nonceSize { - return "", fmt.Errorf("ciphertext too short") - } - - nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] - plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil) - if err != nil { - return "", err - } - - return string(plaintext), nil -} diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 8cd04e9..a2624ee 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -5,8 +5,8 @@ import ( "database/sql" "fmt" - "novamd/internal/crypto" "novamd/internal/models" + "novamd/internal/secrets" _ "github.com/mattn/go-sqlite3" // SQLite driver ) @@ -70,7 +70,7 @@ type Database interface { // database represents the database connection type database struct { *sql.DB - crypto *crypto.Crypto + secretsService secrets.Encryptor } // Init initializes the database connection @@ -85,14 +85,14 @@ func Init(dbPath string, encryptionKey string) (Database, error) { } // Initialize crypto service - cryptoService, err := crypto.New(encryptionKey) + secretsService, err := secrets.New(encryptionKey) if err != nil { return nil, fmt.Errorf("failed to initialize encryption: %w", err) } database := &database{ - DB: db, - crypto: cryptoService, + DB: db, + secretsService: secretsService, } if err := database.Migrate(); err != nil { @@ -112,12 +112,12 @@ func (db *database) encryptToken(token string) (string, error) { if token == "" { return "", nil } - return db.crypto.Encrypt(token) + return db.secretsService.Encrypt(token) } func (db *database) decryptToken(token string) (string, error) { if token == "" { return "", nil } - return db.crypto.Decrypt(token) + return db.secretsService.Decrypt(token) } diff --git a/server/internal/secrets/secrets.go b/server/internal/secrets/secrets.go new file mode 100644 index 0000000..abb81d4 --- /dev/null +++ b/server/internal/secrets/secrets.go @@ -0,0 +1,112 @@ +// Package secrets provides an Encryptor interface for encrypting and decrypting strings using AES-256-GCM. +package secrets + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "fmt" + "io" +) + +// Encryptor is an interface for encrypting and decrypting strings +type Encryptor interface { + Encrypt(plaintext string) (string, error) + Decrypt(ciphertext string) (string, error) +} + +type encryptor struct { + gcm cipher.AEAD +} + +// ValidateKey checks if the provided base64-encoded key is suitable for AES-256 +func ValidateKey(key string) error { + _, err := decodeAndValidateKey(key) + return err +} + +// decodeAndValidateKey validates and decodes the base64-encoded key +// Returns the decoded key bytes if valid +func decodeAndValidateKey(key string) ([]byte, error) { + if key == "" { + return nil, fmt.Errorf("encryption key is required") + } + + keyBytes, err := base64.StdEncoding.DecodeString(key) + if err != nil { + return nil, fmt.Errorf("invalid base64 encoding: %w", err) + } + + if len(keyBytes) != 32 { + return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits): got %d bytes", len(keyBytes)) + } + + // Verify the key can be used for AES + _, err = aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("invalid encryption key: %w", err) + } + + return keyBytes, nil +} + +// New creates a new Crypto instance with the provided base64-encoded key +func New(key string) (Encryptor, error) { + keyBytes, err := decodeAndValidateKey(key) + if err != nil { + return nil, err + } + + block, err := aes.NewCipher(keyBytes) + if err != nil { + return nil, fmt.Errorf("failed to create cipher: %w", err) + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, fmt.Errorf("failed to create GCM: %w", err) + } + + return &encryptor{gcm: gcm}, nil +} + +// Encrypt encrypts the plaintext using AES-256-GCM +func (e *encryptor) Encrypt(plaintext string) (string, error) { + if plaintext == "" { + return "", nil + } + + nonce := make([]byte, e.gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.StdEncoding.EncodeToString(ciphertext), nil +} + +// Decrypt decrypts the ciphertext using AES-256-GCM +func (e *encryptor) Decrypt(ciphertext string) (string, error) { + if ciphertext == "" { + return "", nil + } + + data, err := base64.StdEncoding.DecodeString(ciphertext) + if err != nil { + return "", fmt.Errorf("invalid base64 encoding: %w", err) + } + + nonceSize := e.gcm.NonceSize() + if len(data) < nonceSize { + return "", fmt.Errorf("invalid ciphertext: too short") + } + + nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:] + plaintext, err := e.gcm.Open(nil, nonce, ciphertextBytes, nil) + if err != nil { + return "", err + } + + return string(plaintext), nil +} diff --git a/server/internal/secrets/secrets_test.go b/server/internal/secrets/secrets_test.go new file mode 100644 index 0000000..ed09d98 --- /dev/null +++ b/server/internal/secrets/secrets_test.go @@ -0,0 +1,257 @@ +package secrets_test + +import ( + "encoding/base64" + "strings" + "testing" + + "novamd/internal/secrets" +) + +func TestValidateKey(t *testing.T) { + testCases := []struct { + name string + key string + wantErr bool + errContains string + }{ + { + name: "valid 32-byte base64 key", + key: base64.StdEncoding.EncodeToString(make([]byte, 32)), + wantErr: false, + }, + { + name: "empty key", + key: "", + wantErr: true, + errContains: "encryption key is required", + }, + { + name: "invalid base64", + key: "not-base64!@#$", + wantErr: true, + errContains: "invalid base64 encoding", + }, + { + name: "wrong key size (16 bytes)", + key: base64.StdEncoding.EncodeToString(make([]byte, 16)), + wantErr: true, + errContains: "encryption key must be 32 bytes", + }, + { + name: "wrong key size (64 bytes)", + key: base64.StdEncoding.EncodeToString(make([]byte, 64)), + wantErr: true, + errContains: "encryption key must be 32 bytes", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := secrets.ValidateKey(tc.key) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestNew(t *testing.T) { + testCases := []struct { + name string + key string + wantErr bool + errContains string + }{ + { + name: "valid key", + key: base64.StdEncoding.EncodeToString(make([]byte, 32)), + wantErr: false, + }, + { + name: "empty key", + key: "", + wantErr: true, + errContains: "encryption key is required", + }, + { + name: "invalid key", + key: "invalid", + wantErr: true, + errContains: "invalid base64 encoding", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e, err := secrets.New(tc.key) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if e == nil { + t.Error("expected Encryptor instance, got nil") + } + }) + } +} + +func TestEncryptDecrypt(t *testing.T) { + // Generate a valid key for testing + key := base64.StdEncoding.EncodeToString(make([]byte, 32)) + e, err := secrets.New(key) + if err != nil { + t.Fatalf("failed to create Encryptor instance: %v", err) + } + + testCases := []struct { + name string + plaintext string + wantErr bool + }{ + { + name: "normal text", + plaintext: "Hello, World!", + wantErr: false, + }, + { + name: "empty string", + plaintext: "", + wantErr: false, + }, + { + name: "long text", + plaintext: strings.Repeat("Long text with lots of content. ", 100), + wantErr: false, + }, + { + name: "special characters", + plaintext: "!@#$%^&*()_+-=[]{}|;:,.<>?", + wantErr: false, + }, + { + name: "unicode characters", + plaintext: "Hello, 世界! नमस्ते", + wantErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test encryption + ciphertext, err := e.Encrypt(tc.plaintext) + if tc.wantErr { + if err == nil { + t.Error("expected encryption error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected encryption error: %v", err) + } + + // Verify ciphertext is different from plaintext + if tc.plaintext != "" && ciphertext == tc.plaintext { + t.Error("ciphertext matches plaintext") + } + + // Test decryption + decrypted, err := e.Decrypt(ciphertext) + if err != nil { + t.Fatalf("unexpected decryption error: %v", err) + } + + // Verify decrypted text matches original + if decrypted != tc.plaintext { + t.Errorf("decrypted text = %q, want %q", decrypted, tc.plaintext) + } + }) + } +} + +func TestDecryptInvalidCiphertext(t *testing.T) { + key := base64.StdEncoding.EncodeToString(make([]byte, 32)) + e, err := secrets.New(key) + if err != nil { + t.Fatalf("failed to create Encryptor instance: %v", err) + } + + testCases := []struct { + name string + ciphertext string + wantErr bool + errContains string + }{ + { + name: "empty ciphertext", + ciphertext: "", + wantErr: false, + }, + { + name: "invalid base64", + ciphertext: "not-base64!@#$", + wantErr: true, + errContains: "invalid base64 encoding", + }, + { + name: "invalid ciphertext (too short)", + ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 10)), + wantErr: true, + errContains: "invalid ciphertext: too short", + }, + { + name: "tampered ciphertext", + ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 50)), + wantErr: true, + errContains: "message authentication failed", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + decrypted, err := e.Decrypt(tc.ciphertext) + + if tc.wantErr { + if err == nil { + t.Error("expected error, got nil") + return + } + if !strings.Contains(err.Error(), tc.errContains) { + t.Errorf("error = %v, want error containing %q", err, tc.errContains) + } + return + } + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if decrypted != "" { + t.Errorf("expected empty string, got %q", decrypted) + } + }) + } +}