From 9eee42c673fe88a55c7fc276067915e8e09a708c Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 3 Dec 2025 20:31:38 +0100 Subject: [PATCH] Initial api key store implementation --- cmd/server/main.go | 4 +- cmd/server/migrate_json.go | 4 +- go.mod | 9 +- go.sum | 33 +- pkg/auth/hash.go | 73 +++++ pkg/auth/key.go | 51 ++++ pkg/database/apikeys.go | 211 +++++++++++++ pkg/database/database.go | 30 +- .../migrations/001_initial_schema.down.sql | 12 +- .../migrations/001_initial_schema.up.sql | 36 +++ pkg/database/permissions.go | 57 ++++ pkg/instance/instance.go | 10 +- pkg/manager/manager.go | 4 +- pkg/server/handlers.go | 14 +- pkg/server/handlers_auth.go | 284 ++++++++++++++++++ pkg/server/middleware.go | 231 +++++++++++--- pkg/server/middleware_test.go | 20 +- pkg/server/routes.go | 19 +- 18 files changed, 986 insertions(+), 116 deletions(-) create mode 100644 pkg/auth/hash.go create mode 100644 pkg/auth/key.go create mode 100644 pkg/database/apikeys.go create mode 100644 pkg/database/permissions.go create mode 100644 pkg/server/handlers_auth.go diff --git a/cmd/server/main.go b/cmd/server/main.go index 0eeca45..9431b59 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -48,7 +48,7 @@ func main() { cfg.CommitHash = commitHash cfg.BuildTime = buildTime - // Create the data directory if it doesn't exist + // Create data directory if it doesn't exist if cfg.Instances.AutoCreateDirs { // Create the main data directory if err := os.MkdirAll(cfg.DataDir, 0755); err != nil { @@ -91,7 +91,7 @@ func main() { instanceManager := manager.New(&cfg, db) // Create a new handler with the instance manager - handler := server.NewHandler(instanceManager, cfg) + handler := server.NewHandler(instanceManager, cfg, db) // Setup the router with the handler r := server.SetupRouter(handler) diff --git a/cmd/server/migrate_json.go b/cmd/server/migrate_json.go index 7ee6a2b..eb14781 100644 --- a/cmd/server/migrate_json.go +++ b/cmd/server/migrate_json.go @@ -13,7 +13,7 @@ import ( // migrateFromJSON migrates instances from JSON files to SQLite database // This is a one-time migration that runs on first startup with existing JSON files. -func migrateFromJSON(cfg *config.AppConfig, db database.DB) error { +func migrateFromJSON(cfg *config.AppConfig, db database.InstanceStore) error { instancesDir := cfg.Instances.InstancesDir if instancesDir == "" { return nil // No instances directory configured @@ -76,7 +76,7 @@ func migrateFromJSON(cfg *config.AppConfig, db database.DB) error { } // migrateJSONFile migrates a single JSON file to the database -func migrateJSONFile(filename string, db database.DB) error { +func migrateJSONFile(filename string, db database.InstanceStore) error { data, err := os.ReadFile(filename) if err != nil { return fmt.Errorf("failed to read file: %w", err) diff --git a/go.mod b/go.mod index 73b77f8..de69ecf 100644 --- a/go.mod +++ b/go.mod @@ -5,8 +5,11 @@ go 1.24.5 require ( github.com/go-chi/chi/v5 v5.2.2 github.com/go-chi/cors v1.2.2 + github.com/golang-migrate/migrate/v4 v4.19.1 + github.com/mattn/go-sqlite3 v1.14.24 github.com/swaggo/http-swagger v1.3.4 github.com/swaggo/swag v1.16.5 + golang.org/x/crypto v0.45.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -16,16 +19,12 @@ require ( github.com/go-openapi/jsonreference v0.21.0 // indirect github.com/go-openapi/spec v0.21.0 // indirect github.com/go-openapi/swag v0.23.1 // indirect - github.com/golang-migrate/migrate/v4 v4.19.1 // indirect - github.com/hashicorp/errwrap v1.1.0 // indirect - github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/mailru/easyjson v0.9.0 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/swaggo/files v1.0.1 // indirect - go.uber.org/atomic v1.7.0 // indirect golang.org/x/mod v0.29.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/sync v0.18.0 // indirect + golang.org/x/sys v0.38.0 // indirect golang.org/x/tools v0.38.0 // indirect ) diff --git a/go.sum b/go.sum index 8924797..7431b29 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,7 @@ github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= -github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-chi/cors v1.2.2 h1:Jmey33TE+b+rB7fT8MUy1u0I4L+NARQlK6LhzKPSyQE= @@ -16,35 +14,26 @@ github.com/go-openapi/spec v0.21.0 h1:LTVzPc3p/RzRnkQqLRndbAzjY0d0BCL72A6j3CdL9Z github.com/go-openapi/spec v0.21.0/go.mod h1:78u6VdPw81XU44qEWGhtr982gJ5BWg2c0I5XwVMotYk= github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= -github.com/golang-migrate/migrate/v4 v4.18.1 h1:JML/k+t4tpHCpQTCAD62Nu43NUFzHY4CV3uAuvHGC+Y= -github.com/golang-migrate/migrate/v4 v4.18.1/go.mod h1:HAX6m3sQgcdO81tdjn5exv20+3Kb13cmGli1hrD6hks= github.com/golang-migrate/migrate/v4 v4.19.1 h1:OCyb44lFuQfYXYLx1SCxPZQGU7mcaZ7gH9yH4jSFbBA= github.com/golang-migrate/migrate/v4 v4.19.1/go.mod h1:CTcgfjxhaUtsLipnLoQRWCrjYXycRz/g5+RWDuYgPrE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= -github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= -github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= -github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/swaggo/files v1.0.1 h1:J1bVJ4XHZNq0I46UU90611i9/YzdrF7x92oX1ig5IdE= @@ -54,27 +43,21 @@ github.com/swaggo/http-swagger v1.3.4/go.mod h1:9dAh0unqMBAlbp1uE2Uc2mQTxNMU/ha4 github.com/swaggo/swag v1.16.5 h1:nMf2fEV1TetMTJb4XzD0Lz7jFfKJmJKGTygEey8NSxM= github.com/swaggo/swag v1.16.5/go.mod h1:ngP2etMK5a0P3QBizic5MEwpRmluJZPHjXcMoj4Xesg= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= +golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= -golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= -golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I= golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -83,6 +66,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc= +golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= @@ -93,8 +78,6 @@ golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= -golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/pkg/auth/hash.go b/pkg/auth/hash.go new file mode 100644 index 0000000..776b851 --- /dev/null +++ b/pkg/auth/hash.go @@ -0,0 +1,73 @@ +package auth + +import ( + "crypto/rand" + "crypto/subtle" + "encoding/base64" + "fmt" + "strings" + + "golang.org/x/crypto/argon2" +) + +const ( + // Argon2 parameters + time uint32 = 1 + memory uint32 = 64 * 1024 // 64 MB + threads uint8 = 4 + keyLen uint32 = 32 + saltLen uint32 = 16 +) + +// HashKey hashes an API key using Argon2id +func HashKey(plainTextKey string) (string, error) { + // Generate random salt + salt := make([]byte, saltLen) + if _, err := rand.Read(salt); err != nil { + return "", fmt.Errorf("failed to generate salt: %w", err) + } + + // Derive key using Argon2id + hash := argon2.IDKey([]byte(plainTextKey), salt, time, memory, threads, keyLen) + + // Format: $argon2id$v=19$m=65536,t=1,p=4$$ + saltB64 := base64.RawStdEncoding.EncodeToString(salt) + hashB64 := base64.RawStdEncoding.EncodeToString(hash) + + return fmt.Sprintf("$argon2id$v=19$m=%d,t=%d,p=%d$%s$%s", memory, time, threads, saltB64, hashB64), nil +} + +// VerifyKey verifies a plain-text key against an Argon2id hash +func VerifyKey(plainTextKey, hash string) bool { + // Parse the hash format + parts := strings.Split(hash, "$") + if len(parts) != 6 || parts[1] != "argon2id" { + return false + } + + // Extract parameters + var version, time, memory, threads int + if _, err := fmt.Sscanf(parts[2], "v=%d", &version); err != nil || version != 19 { + return false + } + if _, err := fmt.Sscanf(parts[3], "m=%d,t=%d,p=%d", &memory, &time, &threads); err != nil { + return false + } + + // Decode salt and hash + salt, err := base64.RawStdEncoding.DecodeString(parts[4]) + if err != nil { + return false + } + + expectedHash, err := base64.RawStdEncoding.DecodeString(parts[5]) + if err != nil { + return false + } + + // Compute hash of the provided key + computedHash := argon2.IDKey([]byte(plainTextKey), salt, uint32(time), uint32(memory), uint8(threads), uint32(len(expectedHash))) + + // Compare hashes using constant-time comparison + return subtle.ConstantTimeCompare(computedHash, expectedHash) == 1 +} diff --git a/pkg/auth/key.go b/pkg/auth/key.go new file mode 100644 index 0000000..211647c --- /dev/null +++ b/pkg/auth/key.go @@ -0,0 +1,51 @@ +package auth + +import ( + "crypto/rand" + "encoding/hex" + "fmt" +) + +type PermissionMode string + +const ( + PermissionModeAllowAll PermissionMode = "allow_all" + PermissionModePerInstance PermissionMode = "per_instance" +) + +type APIKey struct { + ID int + KeyHash string + Name string + UserID string + PermissionMode PermissionMode + ExpiresAt *int64 + Enabled bool + CreatedAt int64 + UpdatedAt int64 + LastUsedAt *int64 +} + +type KeyPermission struct { + KeyID int + InstanceID int + CanInfer bool + CanViewLogs bool +} + +// GenerateKey generates a cryptographically secure inference API key +// Format: sk-inference-<64-hex-chars> +func GenerateKey() (string, error) { + // Generate 32 random bytes + bytes := make([]byte, 32) + _, err := rand.Read(bytes) + if err != nil { + return "", fmt.Errorf("failed to generate random bytes: %w", err) + } + + // Convert to hex (64 characters) + hexStr := hex.EncodeToString(bytes) + + // Prefix with "sk-inference-" + return fmt.Sprintf("sk-inference-%s", hexStr), nil +} diff --git a/pkg/database/apikeys.go b/pkg/database/apikeys.go new file mode 100644 index 0000000..940bc62 --- /dev/null +++ b/pkg/database/apikeys.go @@ -0,0 +1,211 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "llamactl/pkg/auth" + "time" +) + +// CreateKey inserts a new API key with permissions (transactional) +func (db *sqliteDB) CreateKey(ctx context.Context, key *auth.APIKey, permissions []auth.KeyPermission) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer tx.Rollback() + + // Insert the API key + query := ` + INSERT INTO api_keys (key_hash, name, user_id, permission_mode, expires_at, enabled, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ` + + var expiresAt sql.NullInt64 + if key.ExpiresAt != nil { + expiresAt = sql.NullInt64{Int64: *key.ExpiresAt, Valid: true} + } + + result, err := tx.ExecContext(ctx, query, + key.KeyHash, key.Name, key.UserID, key.PermissionMode, + expiresAt, key.Enabled, key.CreatedAt, key.UpdatedAt, + ) + if err != nil { + return fmt.Errorf("failed to insert API key: %w", err) + } + + keyID, err := result.LastInsertId() + if err != nil { + return fmt.Errorf("failed to get last insert ID: %w", err) + } + key.ID = int(keyID) + + // Insert permissions if per-instance mode + if key.PermissionMode == auth.PermissionModePerInstance { + for _, perm := range permissions { + query := ` + INSERT INTO key_permissions (key_id, instance_id, can_infer, can_view_logs) + VALUES (?, ?, ?, ?) + ` + _, err := tx.ExecContext(ctx, query, perm.KeyID, perm.InstanceID, perm.CanInfer, perm.CanViewLogs) + if err != nil { + return fmt.Errorf("failed to insert permission for instance %d: %w", perm.InstanceID, err) + } + } + } + + return tx.Commit() +} + +// GetKeyByID retrieves an API key by ID +func (db *sqliteDB) GetKeyByID(ctx context.Context, id int) (*auth.APIKey, error) { + query := ` + SELECT id, key_hash, name, user_id, permission_mode, expires_at, enabled, created_at, updated_at, last_used_at + FROM api_keys + WHERE id = ? + ` + + var key auth.APIKey + var expiresAt sql.NullInt64 + var lastUsedAt sql.NullInt64 + + err := db.QueryRowContext(ctx, query, id).Scan( + &key.ID, &key.KeyHash, &key.Name, &key.UserID, &key.PermissionMode, + &expiresAt, &key.Enabled, &key.CreatedAt, &key.UpdatedAt, &lastUsedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("API key not found") + } + return nil, fmt.Errorf("failed to query API key: %w", err) + } + + if expiresAt.Valid { + key.ExpiresAt = &expiresAt.Int64 + } + if lastUsedAt.Valid { + key.LastUsedAt = &lastUsedAt.Int64 + } + + return &key, nil +} + +// GetUserKeys retrieves all API keys for a user +func (db *sqliteDB) GetUserKeys(ctx context.Context, userID string) ([]*auth.APIKey, error) { + query := ` + SELECT id, key_hash, name, user_id, permission_mode, expires_at, enabled, created_at, updated_at, last_used_at + FROM api_keys + WHERE user_id = ? + ORDER BY created_at DESC + ` + + rows, err := db.QueryContext(ctx, query, userID) + if err != nil { + return nil, fmt.Errorf("failed to query API keys: %w", err) + } + defer rows.Close() + + var keys []*auth.APIKey + for rows.Next() { + var key auth.APIKey + var expiresAt sql.NullInt64 + var lastUsedAt sql.NullInt64 + + err := rows.Scan( + &key.ID, &key.KeyHash, &key.Name, &key.UserID, &key.PermissionMode, + &expiresAt, &key.Enabled, &key.CreatedAt, &key.UpdatedAt, &lastUsedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan API key: %w", err) + } + + if expiresAt.Valid { + key.ExpiresAt = &expiresAt.Int64 + } + if lastUsedAt.Valid { + key.LastUsedAt = &lastUsedAt.Int64 + } + + keys = append(keys, &key) + } + + return keys, nil +} + +// GetActiveKeys retrieves all enabled, non-expired API keys +func (db *sqliteDB) GetActiveKeys(ctx context.Context) ([]*auth.APIKey, error) { + query := ` + SELECT id, key_hash, name, user_id, permission_mode, expires_at, enabled, created_at, updated_at, last_used_at + FROM api_keys + WHERE enabled = 1 AND (expires_at IS NULL OR expires_at > ?) + ORDER BY created_at DESC + ` + + now := time.Now().Unix() + rows, err := db.QueryContext(ctx, query, now) + if err != nil { + return nil, fmt.Errorf("failed to query active API keys: %w", err) + } + defer rows.Close() + + var keys []*auth.APIKey + for rows.Next() { + var key auth.APIKey + var expiresAt sql.NullInt64 + var lastUsedAt sql.NullInt64 + + err := rows.Scan( + &key.ID, &key.KeyHash, &key.Name, &key.UserID, &key.PermissionMode, + &expiresAt, &key.Enabled, &key.CreatedAt, &key.UpdatedAt, &lastUsedAt, + ) + if err != nil { + return nil, fmt.Errorf("failed to scan API key: %w", err) + } + + if expiresAt.Valid { + key.ExpiresAt = &expiresAt.Int64 + } + if lastUsedAt.Valid { + key.LastUsedAt = &lastUsedAt.Int64 + } + + keys = append(keys, &key) + } + + return keys, nil +} + +// DeleteKey removes an API key (cascades to permissions) +func (db *sqliteDB) DeleteKey(ctx context.Context, id int) error { + query := `DELETE FROM api_keys WHERE id = ?` + + result, err := db.ExecContext(ctx, query, id) + if err != nil { + return fmt.Errorf("failed to delete API key: %w", err) + } + + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + + if rowsAffected == 0 { + return fmt.Errorf("API key not found") + } + + return nil +} + +// TouchKey updates the last_used_at timestamp +func (db *sqliteDB) TouchKey(ctx context.Context, id int) error { + query := `UPDATE api_keys SET last_used_at = ?, updated_at = ? WHERE id = ?` + + now := time.Now().Unix() + _, err := db.ExecContext(ctx, query, now, now, id) + if err != nil { + return fmt.Errorf("failed to update last used timestamp: %w", err) + } + + return nil +} diff --git a/pkg/database/database.go b/pkg/database/database.go index 957fac7..98793ff 100644 --- a/pkg/database/database.go +++ b/pkg/database/database.go @@ -1,8 +1,10 @@ package database import ( + "context" "database/sql" "fmt" + "llamactl/pkg/auth" "llamactl/pkg/instance" "log" "path/filepath" @@ -11,14 +13,26 @@ import ( _ "github.com/mattn/go-sqlite3" ) -// DB defines the interface for instance persistence operations -type DB interface { +// InstanceStore defines interface for instance persistence operations +type InstanceStore interface { Save(inst *instance.Instance) error Delete(name string) error LoadAll() ([]*instance.Instance, error) Close() error } +// AuthStore defines the interface for authentication operations +type AuthStore interface { + CreateKey(ctx context.Context, key *auth.APIKey, permissions []auth.KeyPermission) error + GetUserKeys(ctx context.Context, userID string) ([]*auth.APIKey, error) + GetActiveKeys(ctx context.Context) ([]*auth.APIKey, error) + GetKeyByID(ctx context.Context, id int) (*auth.APIKey, error) + DeleteKey(ctx context.Context, id int) error + TouchKey(ctx context.Context, id int) error + GetPermissions(ctx context.Context, keyID int) ([]auth.KeyPermission, error) + HasPermission(ctx context.Context, keyID, instanceID int) (bool, error) +} + // Config contains database configuration settings type Config struct { // Database file path (relative to data_dir or absolute) @@ -30,13 +44,13 @@ type Config struct { ConnMaxLifetime time.Duration } -// sqliteDB wraps the database connection with configuration +// sqliteDB wraps database connection with configuration type sqliteDB struct { *sql.DB config *Config } -// Open creates a new database connection with the provided configuration +// Open creates a new database connection with provided configuration func Open(config *Config) (*sqliteDB, error) { if config == nil { return nil, fmt.Errorf("database config cannot be nil") @@ -46,10 +60,10 @@ func Open(config *Config) (*sqliteDB, error) { return nil, fmt.Errorf("database path cannot be empty") } - // Ensure the database directory exists + // Ensure that database directory exists dbDir := filepath.Dir(config.Path) if dbDir != "." && dbDir != "/" { - // Directory will be created by the manager if auto_create_dirs is enabled + // Directory will be created by manager if auto_create_dirs is enabled log.Printf("Database will be created at: %s", config.Path) } @@ -89,7 +103,7 @@ func Open(config *Config) (*sqliteDB, error) { }, nil } -// Close closes the database connection +// Close closes database connection func (db *sqliteDB) Close() error { if db.DB != nil { log.Println("Closing database connection") @@ -98,7 +112,7 @@ func (db *sqliteDB) Close() error { return nil } -// HealthCheck verifies the database is accessible +// HealthCheck verifies that database is accessible func (db *sqliteDB) HealthCheck() error { if db.DB == nil { return fmt.Errorf("database connection is nil") diff --git a/pkg/database/migrations/001_initial_schema.down.sql b/pkg/database/migrations/001_initial_schema.down.sql index 08b26e0..633814b 100644 --- a/pkg/database/migrations/001_initial_schema.down.sql +++ b/pkg/database/migrations/001_initial_schema.down.sql @@ -1,7 +1,11 @@ --- Drop indexes first -DROP INDEX IF EXISTS idx_instances_backend_type; +-- Drop API key related indexes and tables first +DROP INDEX IF EXISTS idx_key_permissions_instance_id; +DROP INDEX IF EXISTS idx_api_keys_expires_at; +DROP INDEX IF EXISTS idx_api_keys_user_id; +DROP TABLE IF EXISTS key_permissions; +DROP TABLE IF EXISTS api_keys; + +-- Drop instance related indexes and tables DROP INDEX IF EXISTS idx_instances_status; DROP INDEX IF EXISTS idx_instances_name; - --- Drop tables DROP TABLE IF EXISTS instances; diff --git a/pkg/database/migrations/001_initial_schema.up.sql b/pkg/database/migrations/001_initial_schema.up.sql index 89eac83..2338a82 100644 --- a/pkg/database/migrations/001_initial_schema.up.sql +++ b/pkg/database/migrations/001_initial_schema.up.sql @@ -25,3 +25,39 @@ CREATE TABLE IF NOT EXISTS instances ( -- ----------------------------------------------------------------------------- CREATE UNIQUE INDEX IF NOT EXISTS idx_instances_name ON instances(name); CREATE INDEX IF NOT EXISTS idx_instances_status ON instances(status); + +-- ----------------------------------------------------------------------------- +-- API Keys Table: Database-backed inference API keys +-- ----------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS api_keys ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + key_hash TEXT NOT NULL, + name TEXT NOT NULL, + user_id TEXT NOT NULL, + permission_mode TEXT NOT NULL CHECK(permission_mode IN ('allow_all', 'per_instance')) DEFAULT 'per_instance', + expires_at INTEGER NULL, + enabled INTEGER NOT NULL DEFAULT 1, + created_at INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + last_used_at INTEGER NULL +); + +-- ----------------------------------------------------------------------------- +-- Key Permissions Table: Per-instance permissions for API keys +-- ----------------------------------------------------------------------------- +CREATE TABLE IF NOT EXISTS key_permissions ( + key_id INTEGER NOT NULL, + instance_id INTEGER NOT NULL, + can_infer INTEGER NOT NULL DEFAULT 0, + can_view_logs INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY (key_id, instance_id), + FOREIGN KEY (key_id) REFERENCES api_keys (id) ON DELETE CASCADE, + FOREIGN KEY (instance_id) REFERENCES instances (id) ON DELETE CASCADE +); + +-- ----------------------------------------------------------------------------- +-- Indexes for API keys and permissions +-- ----------------------------------------------------------------------------- +CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON api_keys(user_id); +CREATE INDEX IF NOT EXISTS idx_api_keys_expires_at ON api_keys(expires_at); +CREATE INDEX IF NOT EXISTS idx_key_permissions_instance_id ON key_permissions(instance_id); diff --git a/pkg/database/permissions.go b/pkg/database/permissions.go new file mode 100644 index 0000000..afd746b --- /dev/null +++ b/pkg/database/permissions.go @@ -0,0 +1,57 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "llamactl/pkg/auth" +) + +// GetPermissions retrieves all permissions for a key +func (db *sqliteDB) GetPermissions(ctx context.Context, keyID int) ([]auth.KeyPermission, error) { + query := ` + SELECT key_id, instance_id, can_infer, can_view_logs + FROM key_permissions + WHERE key_id = ? + ORDER BY instance_id + ` + + rows, err := db.QueryContext(ctx, query, keyID) + if err != nil { + return nil, fmt.Errorf("failed to query key permissions: %w", err) + } + defer rows.Close() + + var permissions []auth.KeyPermission + for rows.Next() { + var perm auth.KeyPermission + err := rows.Scan(&perm.KeyID, &perm.InstanceID, &perm.CanInfer, &perm.CanViewLogs) + if err != nil { + return nil, fmt.Errorf("failed to scan key permission: %w", err) + } + permissions = append(permissions, perm) + } + + return permissions, nil +} + +// HasPermission checks if key has inference permission for instance +func (db *sqliteDB) HasPermission(ctx context.Context, keyID, instanceID int) (bool, error) { + query := ` + SELECT can_infer + FROM key_permissions + WHERE key_id = ? AND instance_id = ? + ` + + var canInfer bool + err := db.QueryRowContext(ctx, query, keyID, instanceID).Scan(&canInfer) + if err != nil { + if err == sql.ErrNoRows { + // No permission record found, deny access + return false, nil + } + return false, fmt.Errorf("failed to check key permission: %w", err) + } + + return canInfer, nil +} diff --git a/pkg/instance/instance.go b/pkg/instance/instance.go index 376cc0c..465cd5e 100644 --- a/pkg/instance/instance.go +++ b/pkg/instance/instance.go @@ -9,10 +9,11 @@ import ( "time" ) -// Instance represents a running instance of the llama server +// Instance represents a running instance of llama server type Instance struct { + ID int `json:"id"` Name string `json:"name"` - Created int64 `json:"created,omitempty"` // Unix timestamp when the instance was created + Created int64 `json:"created,omitempty"` // Unix timestamp when instance was created // Global configuration globalInstanceSettings *config.InstancesConfig @@ -48,6 +49,7 @@ func New(name string, globalConfig *config.AppConfig, opts *Options, onStatusCha options := newOptions(opts) instance := &Instance{ + ID: 0, // Will be set by database Name: name, options: options, globalInstanceSettings: globalInstanceSettings, @@ -279,11 +281,13 @@ func (i *Instance) buildEnvironment() map[string]string { // MarshalJSON implements json.Marshaler for Instance func (i *Instance) MarshalJSON() ([]byte, error) { return json.Marshal(&struct { + ID int `json:"id"` Name string `json:"name"` Status *status `json:"status"` Created int64 `json:"created,omitempty"` Options *options `json:"options,omitempty"` }{ + ID: i.ID, Name: i.Name, Status: i.status, Created: i.Created, @@ -295,6 +299,7 @@ func (i *Instance) MarshalJSON() ([]byte, error) { func (i *Instance) UnmarshalJSON(data []byte) error { // Explicitly deserialize to match MarshalJSON format aux := &struct { + ID int `json:"id"` Name string `json:"name"` Status *status `json:"status"` Created int64 `json:"created,omitempty"` @@ -306,6 +311,7 @@ func (i *Instance) UnmarshalJSON(data []byte) error { } // Set the fields + i.ID = aux.ID i.Name = aux.Name i.Created = aux.Created i.status = aux.Status diff --git a/pkg/manager/manager.go b/pkg/manager/manager.go index 5aca037..fc23a2b 100644 --- a/pkg/manager/manager.go +++ b/pkg/manager/manager.go @@ -31,7 +31,7 @@ type instanceManager struct { // Components (each with own synchronization) registry *instanceRegistry ports *portAllocator - db database.DB + db database.InstanceStore remote *remoteManager lifecycle *lifecycleManager @@ -44,7 +44,7 @@ type instanceManager struct { } // New creates a new instance of InstanceManager with dependency injection. -func New(globalConfig *config.AppConfig, db database.DB) InstanceManager { +func New(globalConfig *config.AppConfig, db database.InstanceStore) InstanceManager { if globalConfig.Instances.TimeoutCheckInterval <= 0 { globalConfig.Instances.TimeoutCheckInterval = 5 // Default to 5 minutes if not set diff --git a/pkg/server/handlers.go b/pkg/server/handlers.go index 78b83c5..3e232ee 100644 --- a/pkg/server/handlers.go +++ b/pkg/server/handlers.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "llamactl/pkg/config" + "llamactl/pkg/database" "llamactl/pkg/instance" "llamactl/pkg/manager" "llamactl/pkg/validation" @@ -52,20 +53,25 @@ type Handler struct { InstanceManager manager.InstanceManager cfg config.AppConfig httpClient *http.Client + authStore database.AuthStore + authMiddleware *APIAuthMiddleware } // NewHandler creates a new Handler instance with the provided instance manager and configuration -func NewHandler(im manager.InstanceManager, cfg config.AppConfig) *Handler { - return &Handler{ +func NewHandler(im manager.InstanceManager, cfg config.AppConfig, authStore database.AuthStore) *Handler { + handler := &Handler{ InstanceManager: im, cfg: cfg, httpClient: &http.Client{ Timeout: 30 * time.Second, }, + authStore: authStore, } + handler.authMiddleware = NewAPIAuthMiddleware(cfg.Auth, authStore) + return handler } -// getInstance retrieves an instance by name from the request query parameters +// getInstance retrieves an instance by name from request query parameters func (h *Handler) getInstance(r *http.Request) (*instance.Instance, error) { name := chi.URLParam(r, "name") validatedName, err := validation.ValidateInstanceName(name) @@ -81,7 +87,7 @@ func (h *Handler) getInstance(r *http.Request) (*instance.Instance, error) { return inst, nil } -// ensureInstanceRunning ensures the instance is running by starting it if on-demand start is enabled +// ensureInstanceRunning ensures that an instance is running by starting it if on-demand start is enabled // It handles LRU eviction when the maximum number of running instances is reached func (h *Handler) ensureInstanceRunning(inst *instance.Instance) error { options := inst.GetOptions() diff --git a/pkg/server/handlers_auth.go b/pkg/server/handlers_auth.go new file mode 100644 index 0000000..3971711 --- /dev/null +++ b/pkg/server/handlers_auth.go @@ -0,0 +1,284 @@ +package server + +import ( + "encoding/json" + "fmt" + "llamactl/pkg/auth" + "net/http" + "strconv" + "time" + + "github.com/go-chi/chi/v5" +) + +type InstancePermission struct { + InstanceID int `json:"instance_id"` + CanInfer bool `json:"can_infer"` + CanViewLogs bool `json:"can_view_logs"` +} + +type CreateKeyRequest struct { + Name string + PermissionMode auth.PermissionMode + ExpiresAt *int64 + InstancePermissions []InstancePermission +} + +// CreateInferenceKey handles POST /api/v1/keys +func (h *Handler) CreateInferenceKey() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req CreateKeyRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid_json", "Invalid JSON in request body") + return + } + + // Validate request + if req.Name == "" { + writeError(w, http.StatusBadRequest, "invalid_name", "Name is required") + return + } + if len(req.Name) > 100 { + writeError(w, http.StatusBadRequest, "invalid_name", "Name must be 100 characters or less") + return + } + if req.PermissionMode != auth.PermissionModeAllowAll && req.PermissionMode != auth.PermissionModePerInstance { + writeError(w, http.StatusBadRequest, "invalid_permission_mode", "Permission mode must be 'allow_all' or 'per_instance'") + return + } + if req.PermissionMode == auth.PermissionModePerInstance && len(req.InstancePermissions) == 0 { + writeError(w, http.StatusBadRequest, "missing_permissions", "Instance permissions required when permission mode is 'per_instance'") + return + } + if req.ExpiresAt != nil && *req.ExpiresAt <= time.Now().Unix() { + writeError(w, http.StatusBadRequest, "invalid_expires_at", "Expiration time must be in future") + return + } + + // Validate instance IDs exist + if req.PermissionMode == auth.PermissionModePerInstance { + instances, err := h.InstanceManager.ListInstances() + if err != nil { + writeError(w, http.StatusInternalServerError, "fetch_instances_failed", fmt.Sprintf("Failed to fetch instances: %v", err)) + return + } + instanceIDMap := make(map[int]bool) + for _, inst := range instances { + instanceIDMap[inst.ID] = true + } + + for _, perm := range req.InstancePermissions { + if !instanceIDMap[perm.InstanceID] { + writeError(w, http.StatusBadRequest, "invalid_instance_id", fmt.Sprintf("Instance ID %d does not exist", perm.InstanceID)) + return + } + } + } + + // Generate plain-text key + plainTextKey, err := auth.GenerateKey() + if err != nil { + writeError(w, http.StatusInternalServerError, "key_generation_failed", "Failed to generate API key") + return + } + + // Hash key + keyHash, err := auth.HashKey(plainTextKey) + if err != nil { + writeError(w, http.StatusInternalServerError, "key_hashing_failed", "Failed to hash API key") + return + } + + // Create APIKey struct + now := time.Now().Unix() + apiKey := &auth.APIKey{ + KeyHash: keyHash, + Name: req.Name, + UserID: "system", + PermissionMode: req.PermissionMode, + ExpiresAt: req.ExpiresAt, + Enabled: true, + CreatedAt: now, + UpdatedAt: now, + } + + // Convert InstancePermissions to KeyPermissions + var keyPermissions []auth.KeyPermission + for _, perm := range req.InstancePermissions { + keyPermissions = append(keyPermissions, auth.KeyPermission{ + KeyID: 0, // Will be set by database after key creation + InstanceID: perm.InstanceID, + CanInfer: perm.CanInfer, + CanViewLogs: perm.CanViewLogs, + }) + } + + // Create in database + err = h.authStore.CreateKey(r.Context(), apiKey, keyPermissions) + if err != nil { + writeError(w, http.StatusInternalServerError, "creation_failed", fmt.Sprintf("Failed to create API key: %v", err)) + return + } // Return response with plain-text key (only shown once) + response := map[string]interface{}{ + "id": apiKey.ID, + "name": apiKey.Name, + "user_id": apiKey.UserID, + "permission_mode": apiKey.PermissionMode, + "expires_at": apiKey.ExpiresAt, + "enabled": apiKey.Enabled, + "created_at": apiKey.CreatedAt, + "updated_at": apiKey.UpdatedAt, + "last_used_at": apiKey.LastUsedAt, + "key": plainTextKey, // Only returned on creation + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(response) + } +} + +// ListInferenceKeys handles GET /api/v1/keys +func (h *Handler) ListInferenceKeys() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + keys, err := h.authStore.GetUserKeys(r.Context(), "system") + if err != nil { + writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API keys: %v", err)) + return + } + + // Remove key_hash from all keys + var response []map[string]interface{} + for _, key := range keys { + response = append(response, map[string]interface{}{ + "id": key.ID, + "name": key.Name, + "user_id": key.UserID, + "permission_mode": key.PermissionMode, + "expires_at": key.ExpiresAt, + "enabled": key.Enabled, + "created_at": key.CreatedAt, + "updated_at": key.UpdatedAt, + "last_used_at": key.LastUsedAt, + }) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } +} + +// GetInferenceKey handles GET /api/v1/keys/{id} +func (h *Handler) GetInferenceKey() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.Atoi(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID") + return + } + + key, err := h.authStore.GetKeyByID(r.Context(), id) + if err != nil { + if err.Error() == "API key not found" { + writeError(w, http.StatusNotFound, "not_found", "API key not found") + return + } + writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API key: %v", err)) + return + } + + // Remove key_hash from response + response := map[string]interface{}{ + "id": key.ID, + "name": key.Name, + "user_id": key.UserID, + "permission_mode": key.PermissionMode, + "expires_at": key.ExpiresAt, + "enabled": key.Enabled, + "created_at": key.CreatedAt, + "updated_at": key.UpdatedAt, + "last_used_at": key.LastUsedAt, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } +} + +// DeleteInferenceKey handles DELETE /api/v1/keys/{id} +func (h *Handler) DeleteInferenceKey() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.Atoi(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID") + return + } + + err = h.authStore.DeleteKey(r.Context(), id) + if err != nil { + if err.Error() == "API key not found" { + writeError(w, http.StatusNotFound, "not_found", "API key not found") + return + } + writeError(w, http.StatusInternalServerError, "deletion_failed", fmt.Sprintf("Failed to delete API key: %v", err)) + return + } + + w.WriteHeader(http.StatusNoContent) + } +} + +// GetInferenceKeyPermissions handles GET /api/v1/keys/{id}/permissions +func (h *Handler) GetInferenceKeyPermissions() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + idStr := chi.URLParam(r, "id") + id, err := strconv.Atoi(idStr) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid_id", "Invalid key ID") + return + } + + // Verify key exists + _, err = h.authStore.GetKeyByID(r.Context(), id) + if err != nil { + if err.Error() == "API key not found" { + writeError(w, http.StatusNotFound, "not_found", "API key not found") + return + } + writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch API key: %v", err)) + return + } + + permissions, err := h.authStore.GetPermissions(r.Context(), id) + if err != nil { + writeError(w, http.StatusInternalServerError, "fetch_failed", fmt.Sprintf("Failed to fetch permissions: %v", err)) + return + } + + // Get instance names for the permissions + instances, err := h.InstanceManager.ListInstances() + if err != nil { + writeError(w, http.StatusInternalServerError, "fetch_instances_failed", fmt.Sprintf("Failed to fetch instances: %v", err)) + return + } + instanceNameMap := make(map[int]string) + for _, inst := range instances { + instanceNameMap[inst.ID] = inst.Name + } + + var response []map[string]interface{} + for _, perm := range permissions { + response = append(response, map[string]interface{}{ + "instance_id": perm.InstanceID, + "instance_name": instanceNameMap[perm.InstanceID], + "can_infer": perm.CanInfer, + "can_view_logs": perm.CanViewLogs, + }) + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(response) + } +} diff --git a/pkg/server/middleware.go b/pkg/server/middleware.go index 0654be6..0c1797c 100644 --- a/pkg/server/middleware.go +++ b/pkg/server/middleware.go @@ -1,15 +1,19 @@ package server import ( + "context" "crypto/rand" "crypto/subtle" "encoding/hex" "fmt" + "llamactl/pkg/auth" "llamactl/pkg/config" + "llamactl/pkg/database" "log" "net/http" "os" "strings" + "time" ) type KeyType int @@ -19,58 +23,59 @@ const ( KeyTypeManagement ) +// contextKey is a custom type for context keys to avoid collisions +type contextKey string + +const ( + apiKeyContextKey contextKey = "apiKey" +) + type APIAuthMiddleware struct { + authStore database.AuthStore requireInferenceAuth bool - inferenceKeys map[string]bool requireManagementAuth bool - managementKeys map[string]bool + managementKeys map[string]bool // Config-based management keys } // NewAPIAuthMiddleware creates a new APIAuthMiddleware with the given configuration -func NewAPIAuthMiddleware(authCfg config.AuthConfig) *APIAuthMiddleware { +func NewAPIAuthMiddleware(authCfg config.AuthConfig, authStore database.AuthStore) *APIAuthMiddleware { + // Load management keys from config into managementKeys map + managementKeys := make(map[string]bool) + for _, key := range authCfg.ManagementKeys { + managementKeys[key] = true + } + // If len(authCfg.InferenceKeys) > 0, log warning + if len(authCfg.InferenceKeys) > 0 { + log.Println("⚠️ Config-based inference keys are no longer supported and will be ignored.") + log.Println(" Please create inference keys in web UI or via management API.") + } + + // Handle legacy auto-generation for management keys if none provided and auth is required var generated bool = false - - inferenceAPIKeys := make(map[string]bool) - managementAPIKeys := make(map[string]bool) - const banner = "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━" if authCfg.RequireManagementAuth && len(authCfg.ManagementKeys) == 0 { key := generateAPIKey(KeyTypeManagement) - managementAPIKeys[key] = true + managementKeys[key] = true generated = true fmt.Printf("%s\n⚠️ MANAGEMENT AUTHENTICATION REQUIRED\n%s\n", banner, banner) fmt.Printf("🔑 Generated Management API Key:\n\n %s\n\n", key) } - for _, key := range authCfg.ManagementKeys { - managementAPIKeys[key] = true - } - - if authCfg.RequireInferenceAuth && len(authCfg.InferenceKeys) == 0 { - key := generateAPIKey(KeyTypeInference) - inferenceAPIKeys[key] = true - generated = true - fmt.Printf("%s\n⚠️ INFERENCE AUTHENTICATION REQUIRED\n%s\n", banner, banner) - fmt.Printf("🔑 Generated Inference API Key:\n\n %s\n\n", key) - } - for _, key := range authCfg.InferenceKeys { - inferenceAPIKeys[key] = true - } if generated { fmt.Printf("%s\n⚠️ IMPORTANT\n%s\n", banner, banner) - fmt.Println("• These keys are auto-generated and will change on restart") + fmt.Println("• This key is auto-generated and will change on restart") fmt.Println("• For production, add explicit keys to your configuration") - fmt.Println("• Copy these keys before they disappear from the terminal") + fmt.Println("• Copy this key before it disappears from the terminal") fmt.Println(banner) } return &APIAuthMiddleware{ + authStore: authStore, requireInferenceAuth: authCfg.RequireInferenceAuth, - inferenceKeys: inferenceAPIKeys, requireManagementAuth: authCfg.RequireManagementAuth, - managementKeys: managementAPIKeys, + managementKeys: managementKeys, } } @@ -100,7 +105,120 @@ func generateAPIKey(keyType KeyType) string { return fmt.Sprintf("%s-%s", prefix, hex.EncodeToString(randomBytes)) } -// AuthMiddleware returns a middleware that checks API keys for the given key type +// InferenceAuthMiddleware returns middleware for inference endpoints +func (a *APIAuthMiddleware) InferenceAuthMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + // Extract API key from request + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Try database authentication first + var foundKey *auth.APIKey + if a.requireInferenceAuth { + activeKeys, err := a.authStore.GetActiveKeys(r.Context()) + if err != nil { + log.Printf("Failed to get active inference keys: %v", err) + // Continue to management key fallback + } else { + for _, key := range activeKeys { + if auth.VerifyKey(apiKey, key.KeyHash) { + foundKey = key + // Async update last_used_at + go func(keyID int) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := a.authStore.TouchKey(ctx, keyID); err != nil { + log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err) + } + }(key.ID) + break + } + } + } + } + + // If no database key found, try management key authentication (config-based) + if foundKey == nil { + if !a.isValidManagementKey(apiKey) { + a.unauthorized(w, "Invalid API key") + return + } + // Management key was used, continue without adding APIKey to context + } else { + // Add APIKey to context for permission checking + ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey) + r = r.WithContext(ctx) + } + + next.ServeHTTP(w, r) + }) + } +} + +// ManagementAuthMiddleware returns middleware for management endpoints +func (a *APIAuthMiddleware) ManagementAuthMiddleware() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "OPTIONS" { + next.ServeHTTP(w, r) + return + } + + // Extract API key from request + apiKey := a.extractAPIKey(r) + if apiKey == "" { + a.unauthorized(w, "Missing API key") + return + } + + // Check if key exists in managementKeys map using constant-time comparison + if !a.isValidManagementKey(apiKey) { + a.unauthorized(w, "Invalid API key") + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// CheckInstancePermission checks if the authenticated key has permission for the instance +func (a *APIAuthMiddleware) CheckInstancePermission(ctx context.Context, instanceID int) error { + // Extract APIKey from context + apiKey, ok := ctx.Value(apiKeyContextKey).(*auth.APIKey) + if !ok { + // APIKey is nil, management key was used, allow all + return nil + } + + // If permission_mode == "allow_all", allow all + if apiKey.PermissionMode == auth.PermissionModeAllowAll { + return nil + } + + // Check per-instance permissions + canInfer, err := a.authStore.HasPermission(ctx, apiKey.ID, instanceID) + if err != nil { + return err + } + + if !canInfer { + return http.ErrBodyNotAllowed // Use this as a generic error to indicate permission denied + } + + return nil +} + +// AuthMiddleware returns a middleware that checks API keys for the given key type (legacy support) func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -118,10 +236,38 @@ func (a *APIAuthMiddleware) AuthMiddleware(keyType KeyType) func(http.Handler) h var isValid bool switch keyType { case KeyTypeInference: - // Management keys also work for OpenAI endpoints (higher privilege) - isValid = a.isValidKey(apiKey, KeyTypeInference) || a.isValidKey(apiKey, KeyTypeManagement) + // Try database authentication first + if a.requireInferenceAuth { + activeKeys, err := a.authStore.GetActiveKeys(r.Context()) + if err == nil { + for _, key := range activeKeys { + if auth.VerifyKey(apiKey, key.KeyHash) { + foundKey := key + // Async update last_used_at + go func(keyID int) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := a.authStore.TouchKey(ctx, keyID); err != nil { + log.Printf("Failed to update last used timestamp for key %d: %v", keyID, err) + } + }(key.ID) + + // Add APIKey to context for permission checking + ctx := context.WithValue(r.Context(), apiKeyContextKey, foundKey) + r = r.WithContext(ctx) + isValid = true + break + } + } + } + } + + // If no database key found, try management key (higher privilege) + if !isValid { + isValid = a.isValidManagementKey(apiKey) + } case KeyTypeManagement: - isValid = a.isValidKey(apiKey, KeyTypeManagement) + isValid = a.isValidManagementKey(apiKey) default: isValid = false } @@ -158,20 +304,9 @@ func (a *APIAuthMiddleware) extractAPIKey(r *http.Request) string { return "" } -// isValidKey checks if the provided API key is valid for the given key type -func (a *APIAuthMiddleware) isValidKey(providedKey string, keyType KeyType) bool { - var validKeys map[string]bool - - switch keyType { - case KeyTypeInference: - validKeys = a.inferenceKeys - case KeyTypeManagement: - validKeys = a.managementKeys - default: - return false - } - - for validKey := range validKeys { +// isValidManagementKey checks if the provided API key is a valid management key +func (a *APIAuthMiddleware) isValidManagementKey(providedKey string) bool { + for validKey := range a.managementKeys { if len(providedKey) == len(validKey) && subtle.ConstantTimeCompare([]byte(providedKey), []byte(validKey)) == 1 { return true @@ -187,3 +322,11 @@ func (a *APIAuthMiddleware) unauthorized(w http.ResponseWriter, message string) response := fmt.Sprintf(`{"error": {"message": "%s", "type": "authentication_error"}}`, message) w.Write([]byte(response)) } + +// forbidden sends a forbidden response +func (a *APIAuthMiddleware) forbidden(w http.ResponseWriter, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + response := fmt.Sprintf(`{"error": {"message": "%s", "type": "permission_denied"}}`, message) + w.Write([]byte(response)) +} diff --git a/pkg/server/middleware_test.go b/pkg/server/middleware_test.go index 8a1e7fc..720362f 100644 --- a/pkg/server/middleware_test.go +++ b/pkg/server/middleware_test.go @@ -19,15 +19,7 @@ func TestAuthMiddleware(t *testing.T) { method string expectedStatus int }{ - // Valid key tests - { - name: "valid inference key for inference", - keyType: server.KeyTypeInference, - inferenceKeys: []string{"sk-inference-valid123"}, - requestKey: "sk-inference-valid123", - method: "GET", - expectedStatus: http.StatusOK, - }, + // Valid key tests - using management keys only since config-based inference keys are deprecated { name: "valid management key for inference", // Management keys work for inference keyType: server.KeyTypeInference, @@ -123,7 +115,7 @@ func TestAuthMiddleware(t *testing.T) { InferenceKeys: tt.inferenceKeys, ManagementKeys: tt.managementKeys, } - middleware := server.NewAPIAuthMiddleware(cfg) + middleware := server.NewAPIAuthMiddleware(cfg, nil) // Create test request req := httptest.NewRequest(tt.method, "/test", nil) @@ -131,7 +123,7 @@ func TestAuthMiddleware(t *testing.T) { req.Header.Set("Authorization", "Bearer "+tt.requestKey) } - // Create test handler using the appropriate middleware + // Create test handler using appropriate middleware var handler http.Handler if tt.keyType == server.KeyTypeInference { handler = middleware.AuthMiddleware(server.KeyTypeInference)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -189,7 +181,7 @@ func TestGenerateAPIKey(t *testing.T) { } // Create middleware - this should trigger key generation - middleware := server.NewAPIAuthMiddleware(config) + middleware := server.NewAPIAuthMiddleware(config, nil) // Test that auth is required (meaning a key was generated) req := httptest.NewRequest("GET", "/", nil) @@ -214,7 +206,7 @@ func TestGenerateAPIKey(t *testing.T) { } // Test uniqueness by creating another middleware instance - middleware2 := server.NewAPIAuthMiddleware(config) + middleware2 := server.NewAPIAuthMiddleware(config, nil) req2 := httptest.NewRequest("GET", "/", nil) recorder2 := httptest.NewRecorder() @@ -314,7 +306,7 @@ func TestAutoGeneration(t *testing.T) { ManagementKeys: tt.providedManagement, } - middleware := server.NewAPIAuthMiddleware(cfg) + middleware := server.NewAPIAuthMiddleware(cfg, nil) // Test inference behavior if inference auth is required if tt.requireInference { diff --git a/pkg/server/routes.go b/pkg/server/routes.go index b159968..36a6081 100644 --- a/pkg/server/routes.go +++ b/pkg/server/routes.go @@ -27,7 +27,7 @@ func SetupRouter(handler *Handler) *chi.Mux { })) // Add API authentication middleware - authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth) + authMiddleware := NewAPIAuthMiddleware(handler.cfg.Auth, handler.authStore) if handler.cfg.Server.EnableSwagger { r.Get("/swagger/*", httpSwagger.Handler( @@ -46,6 +46,17 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Get("/config", handler.ConfigHandler()) + // API key management endpoints + r.Route("/auth", func(r chi.Router) { + r.Route("/keys", func(r chi.Router) { + r.Post("/", handler.CreateInferenceKey()) // Create API key + r.Get("/", handler.ListInferenceKeys()) // List API keys + r.Get("/{id}", handler.GetInferenceKey()) // Get API key details + r.Delete("/{id}", handler.DeleteInferenceKey()) // Delete API key + r.Get("/{id}/permissions", handler.GetInferenceKeyPermissions()) // Get key permissions + }) + }) + // Backend-specific endpoints r.Route("/backends", func(r chi.Router) { r.Route("/llama-cpp", func(r chi.Router) { @@ -94,13 +105,13 @@ func SetupRouter(handler *Handler) *chi.Mux { }) }) - r.Route(("/v1"), func(r chi.Router) { + r.Route("/v1", func(r chi.Router) { if authMiddleware != nil && handler.cfg.Auth.RequireInferenceAuth { r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) } - r.Get(("/models"), handler.OpenAIListInstances()) // List instances in OpenAI-compatible format + r.Get("/models", handler.OpenAIListInstances()) // List instances in OpenAI-compatible format // OpenAI-compatible proxy endpoint // Handles all POST requests to /v1/*, including: @@ -128,7 +139,7 @@ func SetupRouter(handler *Handler) *chi.Mux { r.Use(authMiddleware.AuthMiddleware(KeyTypeInference)) } - // This handler auto start the server if it's not running + // This handler auto starts the server if it's not running llamaCppHandler := handler.LlamaCppProxy() // llama.cpp server specific proxy endpoints