mirror of
https://github.com/lordmathis/lemma.git
synced 2025-11-05 15:44:21 +00:00
Test secrets package
This commit is contained in:
@@ -3,12 +3,11 @@ package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"novamd/internal/secrets"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"novamd/internal/crypto"
|
||||
)
|
||||
|
||||
// Config holds the configuration for the application
|
||||
@@ -48,7 +47,7 @@ func (c *Config) Validate() error {
|
||||
}
|
||||
|
||||
// Validate encryption key
|
||||
if err := crypto.ValidateKey(c.EncryptionKey); err != nil {
|
||||
if err := secrets.ValidateKey(c.EncryptionKey); err != nil {
|
||||
return fmt.Errorf("invalid NOVAMD_ENCRYPTION_KEY: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrKeyRequired = fmt.Errorf("encryption key is required")
|
||||
ErrInvalidKeySize = fmt.Errorf("encryption key must be 32 bytes (256 bits) when decoded")
|
||||
)
|
||||
|
||||
type Crypto struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// ValidateKey checks if the provided key is suitable for AES-256
|
||||
func ValidateKey(key string) error {
|
||||
if key == "" {
|
||||
return ErrKeyRequired
|
||||
}
|
||||
|
||||
// Attempt to decode base64
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return fmt.Errorf("%w: got %d bytes", ErrInvalidKeySize, len(keyBytes))
|
||||
}
|
||||
|
||||
// Verify the key can be used for AES
|
||||
_, err = aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid encryption key: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// New creates a new Crypto instance with the provided base64-encoded key
|
||||
func New(key string) (*Crypto, error) {
|
||||
if err := ValidateKey(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyBytes, _ := base64.StdEncoding.DecodeString(key)
|
||||
return &Crypto{key: keyBytes}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext using AES-256-GCM
|
||||
func (c *Crypto) Encrypt(plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext using AES-256-GCM
|
||||
func (c *Crypto) Decrypt(ciphertext string) (string, error) {
|
||||
if ciphertext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(c.key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"novamd/internal/crypto"
|
||||
"novamd/internal/models"
|
||||
"novamd/internal/secrets"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
)
|
||||
@@ -70,7 +70,7 @@ type Database interface {
|
||||
// database represents the database connection
|
||||
type database struct {
|
||||
*sql.DB
|
||||
crypto *crypto.Crypto
|
||||
secretsService secrets.Encryptor
|
||||
}
|
||||
|
||||
// Init initializes the database connection
|
||||
@@ -85,14 +85,14 @@ func Init(dbPath string, encryptionKey string) (Database, error) {
|
||||
}
|
||||
|
||||
// Initialize crypto service
|
||||
cryptoService, err := crypto.New(encryptionKey)
|
||||
secretsService, err := secrets.New(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize encryption: %w", err)
|
||||
}
|
||||
|
||||
database := &database{
|
||||
DB: db,
|
||||
crypto: cryptoService,
|
||||
DB: db,
|
||||
secretsService: secretsService,
|
||||
}
|
||||
|
||||
if err := database.Migrate(); err != nil {
|
||||
@@ -112,12 +112,12 @@ func (db *database) encryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
return db.crypto.Encrypt(token)
|
||||
return db.secretsService.Encrypt(token)
|
||||
}
|
||||
|
||||
func (db *database) decryptToken(token string) (string, error) {
|
||||
if token == "" {
|
||||
return "", nil
|
||||
}
|
||||
return db.crypto.Decrypt(token)
|
||||
return db.secretsService.Decrypt(token)
|
||||
}
|
||||
|
||||
112
server/internal/secrets/secrets.go
Normal file
112
server/internal/secrets/secrets.go
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package secrets provides an Encryptor interface for encrypting and decrypting strings using AES-256-GCM.
|
||||
package secrets
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// Encryptor is an interface for encrypting and decrypting strings
|
||||
type Encryptor interface {
|
||||
Encrypt(plaintext string) (string, error)
|
||||
Decrypt(ciphertext string) (string, error)
|
||||
}
|
||||
|
||||
type encryptor struct {
|
||||
gcm cipher.AEAD
|
||||
}
|
||||
|
||||
// ValidateKey checks if the provided base64-encoded key is suitable for AES-256
|
||||
func ValidateKey(key string) error {
|
||||
_, err := decodeAndValidateKey(key)
|
||||
return err
|
||||
}
|
||||
|
||||
// decodeAndValidateKey validates and decodes the base64-encoded key
|
||||
// Returns the decoded key bytes if valid
|
||||
func decodeAndValidateKey(key string) ([]byte, error) {
|
||||
if key == "" {
|
||||
return nil, fmt.Errorf("encryption key is required")
|
||||
}
|
||||
|
||||
keyBytes, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
if len(keyBytes) != 32 {
|
||||
return nil, fmt.Errorf("encryption key must be 32 bytes (256 bits): got %d bytes", len(keyBytes))
|
||||
}
|
||||
|
||||
// Verify the key can be used for AES
|
||||
_, err = aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid encryption key: %w", err)
|
||||
}
|
||||
|
||||
return keyBytes, nil
|
||||
}
|
||||
|
||||
// New creates a new Crypto instance with the provided base64-encoded key
|
||||
func New(key string) (Encryptor, error) {
|
||||
keyBytes, err := decodeAndValidateKey(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(keyBytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
||||
}
|
||||
|
||||
return &encryptor{gcm: gcm}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext using AES-256-GCM
|
||||
func (e *encryptor) Encrypt(plaintext string) (string, error) {
|
||||
if plaintext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
nonce := make([]byte, e.gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ciphertext := e.gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext using AES-256-GCM
|
||||
func (e *encryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if ciphertext == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid base64 encoding: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := e.gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("invalid ciphertext: too short")
|
||||
}
|
||||
|
||||
nonce, ciphertextBytes := data[:nonceSize], data[nonceSize:]
|
||||
plaintext, err := e.gcm.Open(nil, nonce, ciphertextBytes, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
257
server/internal/secrets/secrets_test.go
Normal file
257
server/internal/secrets/secrets_test.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package secrets_test
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"novamd/internal/secrets"
|
||||
)
|
||||
|
||||
func TestValidateKey(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid 32-byte base64 key",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
key: "",
|
||||
wantErr: true,
|
||||
errContains: "encryption key is required",
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
key: "not-base64!@#$",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "wrong key size (16 bytes)",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 16)),
|
||||
wantErr: true,
|
||||
errContains: "encryption key must be 32 bytes",
|
||||
},
|
||||
{
|
||||
name: "wrong key size (64 bytes)",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 64)),
|
||||
wantErr: true,
|
||||
errContains: "encryption key must be 32 bytes",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err := secrets.ValidateKey(tc.key)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
key string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "valid key",
|
||||
key: base64.StdEncoding.EncodeToString(make([]byte, 32)),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty key",
|
||||
key: "",
|
||||
wantErr: true,
|
||||
errContains: "encryption key is required",
|
||||
},
|
||||
{
|
||||
name: "invalid key",
|
||||
key: "invalid",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
e, err := secrets.New(tc.key)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
if e == nil {
|
||||
t.Error("expected Encryptor instance, got nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecrypt(t *testing.T) {
|
||||
// Generate a valid key for testing
|
||||
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
|
||||
e, err := secrets.New(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Encryptor instance: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
plaintext string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "normal text",
|
||||
plaintext: "Hello, World!",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
plaintext: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "long text",
|
||||
plaintext: strings.Repeat("Long text with lots of content. ", 100),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special characters",
|
||||
plaintext: "!@#$%^&*()_+-=[]{}|;:,.<>?",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
plaintext: "Hello, 世界! नमस्ते",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test encryption
|
||||
ciphertext, err := e.Encrypt(tc.plaintext)
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected encryption error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected encryption error: %v", err)
|
||||
}
|
||||
|
||||
// Verify ciphertext is different from plaintext
|
||||
if tc.plaintext != "" && ciphertext == tc.plaintext {
|
||||
t.Error("ciphertext matches plaintext")
|
||||
}
|
||||
|
||||
// Test decryption
|
||||
decrypted, err := e.Decrypt(ciphertext)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected decryption error: %v", err)
|
||||
}
|
||||
|
||||
// Verify decrypted text matches original
|
||||
if decrypted != tc.plaintext {
|
||||
t.Errorf("decrypted text = %q, want %q", decrypted, tc.plaintext)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptInvalidCiphertext(t *testing.T) {
|
||||
key := base64.StdEncoding.EncodeToString(make([]byte, 32))
|
||||
e, err := secrets.New(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create Encryptor instance: %v", err)
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
ciphertext string
|
||||
wantErr bool
|
||||
errContains string
|
||||
}{
|
||||
{
|
||||
name: "empty ciphertext",
|
||||
ciphertext: "",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid base64",
|
||||
ciphertext: "not-base64!@#$",
|
||||
wantErr: true,
|
||||
errContains: "invalid base64 encoding",
|
||||
},
|
||||
{
|
||||
name: "invalid ciphertext (too short)",
|
||||
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 10)),
|
||||
wantErr: true,
|
||||
errContains: "invalid ciphertext: too short",
|
||||
},
|
||||
{
|
||||
name: "tampered ciphertext",
|
||||
ciphertext: base64.StdEncoding.EncodeToString(make([]byte, 50)),
|
||||
wantErr: true,
|
||||
errContains: "message authentication failed",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
decrypted, err := e.Decrypt(tc.ciphertext)
|
||||
|
||||
if tc.wantErr {
|
||||
if err == nil {
|
||||
t.Error("expected error, got nil")
|
||||
return
|
||||
}
|
||||
if !strings.Contains(err.Error(), tc.errContains) {
|
||||
t.Errorf("error = %v, want error containing %q", err, tc.errContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if decrypted != "" {
|
||||
t.Errorf("expected empty string, got %q", decrypted)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user