From d47f7d7fb0115867c7a3d3bea22e8e18cc6a46af Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 22 Feb 2025 21:53:12 +0100 Subject: [PATCH 01/39] Use golang migrate for migrations --- server/go.mod | 11 +- server/go.sum | 61 ++++++- server/internal/app/config.go | 41 ++++- server/internal/app/config_test.go | 8 +- server/internal/app/init.go | 4 +- server/internal/db/db.go | 8 + server/internal/db/migrations.go | 159 +++++------------- .../db/migrations/001_initial_schema.down.sql | 8 + .../db/migrations/001_initial_schema.up.sql | 59 +++++++ server/internal/handlers/integration_test.go | 2 +- 10 files changed, 220 insertions(+), 141 deletions(-) create mode 100644 server/internal/db/migrations/001_initial_schema.down.sql create mode 100644 server/internal/db/migrations/001_initial_schema.up.sql diff --git a/server/go.mod b/server/go.mod index 8fed5b2..a923917 100644 --- a/server/go.mod +++ b/server/go.mod @@ -9,6 +9,7 @@ require ( github.com/go-git/go-git/v5 v5.13.1 github.com/go-playground/validator/v10 v10.22.1 github.com/golang-jwt/jwt/v5 v5.2.1 + github.com/golang-migrate/migrate/v4 v4.18.2 github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.23 github.com/stretchr/testify v1.10.0 @@ -21,7 +22,7 @@ require ( require ( dario.cat/mergo v1.0.0 // indirect github.com/KyleBanks/depth v1.2.1 // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/ProtonMail/go-crypto v1.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect @@ -38,10 +39,13 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect + github.com/hashicorp/errwrap v1.1.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/kevinburke/ssh_config v1.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/mailru/easyjson v0.7.6 // indirect github.com/pjbgf/sha1cd v0.3.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -49,12 +53,11 @@ require ( github.com/skeema/knownhosts v1.3.0 // indirect github.com/swaggo/files v0.0.0-20220610200504-28940afbdbfe // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect - golang.org/x/mod v0.17.0 // indirect + go.uber.org/atomic v1.7.0 // indirect golang.org/x/net v0.33.0 // indirect - golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.24.0 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/server/go.sum b/server/go.sum index 803c82b..912d559 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1,10 +1,12 @@ dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/ProtonMail/go-crypto v1.1.3 h1:nRBOetoydLeUb4nHajyO2bKqMLfWQ/ZPwkXqXxPxCFk= github.com/ProtonMail/go-crypto v1.1.3/go.mod h1:rA3QumHc/FZ8pAHreoekgiAbzpNsfQAosU5td4SnOrE= github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= @@ -21,10 +23,22 @@ github.com/cyphar/filepath-securejoin v0.3.6/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGL github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dhui/dktest v0.4.4 h1:+I4s6JRE1yGuqflzwqG+aIaMdgXIorCf5P98JnaAWa8= +github.com/dhui/dktest v0.4.4/go.mod h1:4+22R4lgsdAXrDyaH4Nqx2JEz2hLp49MqQmm9HLCQhM= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/docker/docker v27.2.0+incompatible h1:Rk9nIVdfH3+Vz4cyI/uhbINhEZ/oLmc+CBXmH6fbNk4= +github.com/docker/docker v27.2.0+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= +github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/elazarl/goproxy v1.2.3 h1:xwIyKHbaP5yfT6O9KIeYJR5549MXRQkoQMRXGztz8YQ= github.com/elazarl/goproxy v1.2.3/go.mod h1:YfEbZtqP4AetfO6d40vWchF3znWX7C7Vd6ZMfdL8z64= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c= @@ -43,6 +57,10 @@ github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMj github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= github.com/go-git/go-git/v5 v5.13.1 h1:DAQ9APonnlvSWpvolXWIuV6Q6zXy2wHbN4cVlNR5Q+M= github.com/go-git/go-git/v5 v5.13.1/go.mod h1:qryJB4cSBoq3FRoBRf5A77joojuBcmPJ0qu3XXXVixc= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-openapi/jsonpointer v0.19.3/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= github.com/go-openapi/jsonpointer v0.19.5/go.mod h1:Pl9vOtqEWErmShwVjC8pYs9cog34VGT37dQOVbmoatg= @@ -61,14 +79,23 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.22.1 h1:40JcKH+bBNGFczGuoBYgX4I6m/i27HYW8P9FDk5PbgA= github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-migrate/migrate/v4 v4.18.2 h1:2VSCMz7x7mjyTXx3m2zPokOY82LTRgxK1yQYKo6wWQ8= +github.com/golang-migrate/migrate/v4 v4.18.2/go.mod h1:2CM6tJvn2kqPXwnXO/d3rAQYiyoIm180VsO8PRX6Rpk= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY2I= +github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= @@ -84,15 +111,27 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mailru/easyjson v0.0.0-20190614124828-94de47d64c63/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= github.com/mailru/easyjson v0.7.6 h1:8yTIVnZgCoiM1TgqoeTl+LfU5Jg6/xL3QhGQnimLYnA= github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-sqlite3 v1.14.23 h1:gbShiuAP1W5j9UOksQ06aiiqPMxYecovVGwmTxWtuw0= github.com/mattn/go-sqlite3 v1.14.23/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k= github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.1.0 h1:8SG7/vwALn54lVB/0yZ/MMwhFrPYtpEHQb2IpWsCzug= +github.com/opencontainers/image-spec v1.1.0/go.mod h1:W4s4sFTMaBeK1BQLXbG4AdM2szdn85PY75RI83NrTrM= github.com/pjbgf/sha1cd v0.3.0 h1:4D5XXmUUBUl/xQ6IjCkEAbqXskkq/4O7LmGn0AqMDs4= github.com/pjbgf/sha1cd v0.3.0/go.mod h1:nZ1rrWOcGJ5uZgEEVL1VUM9iRQiZvWdbZjkKyFzPPsI= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -123,13 +162,23 @@ github.com/unrolled/secure v1.17.0 h1:Io7ifFgo99Bnh0J7+Q+qcMzWM6kaDPCA5FroFZEdbW github.com/unrolled/secure v1.17.0/go.mod h1:BmF5hyM6tXczk3MpQkFf1hpKSRqCyhqcbiQtiAF7+40= github.com/xanzy/ssh-agent v0.3.3 h1:+/15pJfg/RsTxqYcX6fHqOXZwwMP+2VyYWJeWM2qQFM= github.com/xanzy/ssh-agent v0.3.3/go.mod h1:6dzNDKs0J9rVPHPhaGCukekBHKqfl+L3KghI1Bc68Uw= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= +go.uber.org/atomic v1.7.0 h1:ADUqmZGgLDDfbSL9ZmPxKTybcoEYHgpYfELNoN+7hsw= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0= +golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= @@ -151,8 +200,8 @@ golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/server/internal/app/config.go b/server/internal/app/config.go index 3e0f49c..218b923 100644 --- a/server/internal/app/config.go +++ b/server/internal/app/config.go @@ -2,9 +2,11 @@ package app import ( "fmt" + "lemma/internal/db" "lemma/internal/logging" "lemma/internal/secrets" "os" + "path/filepath" "strconv" "strings" "time" @@ -12,7 +14,8 @@ import ( // Config holds the configuration for the application type Config struct { - DBPath string + DBURL string + DBType db.DBType WorkDir string StaticPath string Port string @@ -31,7 +34,7 @@ type Config struct { // DefaultConfig returns a new Config instance with default values func DefaultConfig() *Config { return &Config{ - DBPath: "./lemma.db", + DBURL: "sqlite://lemma.db", WorkDir: "./data", StaticPath: "../app/dist", Port: "8080", @@ -65,6 +68,31 @@ func (c *Config) Redact() *Config { return &redacted } +// ParseDBURL parses a database URL and returns the driver name and data source +func ParseDBURL(dbURL string) (db.DBType, string, error) { + if strings.HasPrefix(dbURL, "sqlite://") || strings.HasPrefix(dbURL, "sqlite3://") { + + path := strings.TrimPrefix(dbURL, "sqlite://") + path = strings.TrimPrefix(path, "sqlite3://") + + if path == ":memory:" { + return db.DBTypeSQLite, path, nil + } + + if !filepath.IsAbs(path) { + path = filepath.Clean(path) + } + return db.DBTypeSQLite, path, nil + } + + // Try to parse as postgres URL + if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") { + return db.DBTypePostgres, dbURL, nil + } + + return "", "", fmt.Errorf("unsupported database URL format: %s", dbURL) +} + // LoadConfig creates a new Config instance with values from environment variables func LoadConfig() (*Config, error) { config := DefaultConfig() @@ -73,8 +101,13 @@ func LoadConfig() (*Config, error) { config.IsDevelopment = env == "development" } - if dbPath := os.Getenv("LEMMA_DB_PATH"); dbPath != "" { - config.DBPath = dbPath + if dbURL := os.Getenv("LEMMA_DB_URL"); dbURL != "" { + dbType, dataSource, err := ParseDBURL(dbURL) + if err != nil { + return nil, err + } + config.DBURL = dataSource + config.DBType = dbType } if workDir := os.Getenv("LEMMA_WORKDIR"); workDir != "" { diff --git a/server/internal/app/config_test.go b/server/internal/app/config_test.go index c15ca9b..ebf62ef 100644 --- a/server/internal/app/config_test.go +++ b/server/internal/app/config_test.go @@ -17,7 +17,7 @@ func TestDefaultConfig(t *testing.T) { got interface{} expected interface{} }{ - {"DBPath", cfg.DBPath, "./lemma.db"}, + {"DBPath", cfg.DBURL, "./lemma.db"}, {"WorkDir", cfg.WorkDir, "./data"}, {"StaticPath", cfg.StaticPath, "../app/dist"}, {"Port", cfg.Port, "8080"}, @@ -81,8 +81,8 @@ func TestLoad(t *testing.T) { t.Fatalf("Load() error = %v", err) } - if cfg.DBPath != "./lemma.db" { - t.Errorf("default DBPath = %v, want %v", cfg.DBPath, "./lemma.db") + if cfg.DBURL != "./lemma.db" { + t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "./lemma.db") } }) @@ -122,7 +122,7 @@ func TestLoad(t *testing.T) { expected interface{} }{ {"IsDevelopment", cfg.IsDevelopment, true}, - {"DBPath", cfg.DBPath, "/custom/db/path.db"}, + {"DBPath", cfg.DBURL, "/custom/db/path.db"}, {"WorkDir", cfg.WorkDir, "/custom/work/dir"}, {"StaticPath", cfg.StaticPath, "/custom/static/path"}, {"Port", cfg.Port, "3000"}, diff --git a/server/internal/app/init.go b/server/internal/app/init.go index 5bb31b7..ebed735 100644 --- a/server/internal/app/init.go +++ b/server/internal/app/init.go @@ -28,9 +28,9 @@ func initSecretsService(cfg *Config) (secrets.Service, error) { // initDatabase initializes and migrates the database func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) { - logging.Debug("initializing database", "path", cfg.DBPath) + logging.Debug("initializing database", "path", cfg.DBURL) - database, err := db.Init(cfg.DBPath, secretsService) + database, err := db.Init(cfg.DBURL, secretsService) if err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) } diff --git a/server/internal/db/db.go b/server/internal/db/db.go index d08f8b8..fca938c 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -12,6 +12,13 @@ import ( _ "github.com/mattn/go-sqlite3" // SQLite driver ) +type DBType string + +const ( + DBTypeSQLite DBType = "sqlite3" + DBTypePostgres DBType = "postgres" +) + // UserStore defines the methods for interacting with user data in the database type UserStore interface { CreateUser(user *models.User) (*models.User, error) @@ -108,6 +115,7 @@ func getLogger() logging.Logger { type database struct { *sql.DB secretsService secrets.Service + dbType DBType } // Init initializes the database connection diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 5182844..3803f7f 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -1,141 +1,60 @@ package db import ( + "embed" "fmt" + + "github.com/golang-migrate/migrate/v4" + "github.com/golang-migrate/migrate/v4/database/postgres" + "github.com/golang-migrate/migrate/v4/database/sqlite3" + "github.com/golang-migrate/migrate/v4/source/iofs" ) -// Migration represents a database migration -type Migration struct { - Version int - SQL string -} - -var migrations = []Migration{ - { - Version: 1, - SQL: ` - -- Create users table - CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - email TEXT NOT NULL UNIQUE, - display_name TEXT, - password_hash TEXT NOT NULL, - role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')), - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - last_workspace_id INTEGER - ); - - -- Create workspaces table with integrated settings - CREATE TABLE IF NOT EXISTS workspaces ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - user_id INTEGER NOT NULL, - name TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - last_opened_file_path TEXT, - -- Settings fields - theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')), - auto_save BOOLEAN NOT NULL DEFAULT 0, - git_enabled BOOLEAN NOT NULL DEFAULT 0, - git_url TEXT, - git_user TEXT, - git_token TEXT, - git_auto_commit BOOLEAN NOT NULL DEFAULT 0, - git_commit_msg_template TEXT DEFAULT '${action} ${filename}', - git_commit_name TEXT, - git_commit_email TEXT, - show_hidden_files BOOLEAN NOT NULL DEFAULT 0, - created_by INTEGER REFERENCES users(id), - updated_by INTEGER REFERENCES users(id), - updated_at TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users (id) - ); - - -- Create sessions table for authentication - CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - user_id INTEGER NOT NULL, - refresh_token TEXT NOT NULL, - expires_at TIMESTAMP NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE - ); - - -- Create system_settings table for application settings - CREATE TABLE IF NOT EXISTS system_settings ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP - ); - - -- Create indexes for performance - CREATE INDEX idx_sessions_user_id ON sessions(user_id); - CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); - CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token); - `, - }, -} +//go:embed migrations/*.sql +var migrationsFS embed.FS // Migrate applies all database migrations func (db *database) Migrate() error { log := getLogger().WithGroup("migrations") log.Info("starting database migration") - // Create migrations table if it doesn't exist - _, err := db.Exec(`CREATE TABLE IF NOT EXISTS migrations ( - version INTEGER PRIMARY KEY - )`) + sourceInstance, err := iofs.New(migrationsFS, "migrations") if err != nil { - return fmt.Errorf("failed to create migrations table: %w", err) + return fmt.Errorf("failed to create source instance: %w", err) } - // Get current version - var currentVersion int - err = db.QueryRow("SELECT COALESCE(MAX(version), 0) FROM migrations").Scan(¤tVersion) - if err != nil { - return fmt.Errorf("failed to get current migration version: %w", err) - } + var m *migrate.Migrate - // Apply new migrations - for _, migration := range migrations { - if migration.Version > currentVersion { - log := log.With("migration_version", migration.Version) - tx, err := db.Begin() - if err != nil { - return fmt.Errorf("failed to begin transaction for migration %d: %w", migration.Version, err) - } - - // Execute migration SQL - _, err = tx.Exec(migration.SQL) - if err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - return fmt.Errorf("migration %d failed: %v, rollback failed: %v", - migration.Version, err, rbErr) - } - return fmt.Errorf("migration %d failed: %w", migration.Version, err) - } - - // Update migrations table - _, err = tx.Exec("INSERT INTO migrations (version) VALUES (?)", migration.Version) - if err != nil { - if rbErr := tx.Rollback(); rbErr != nil { - return fmt.Errorf("failed to update migration version: %v, rollback failed: %v", - err, rbErr) - } - return fmt.Errorf("failed to update migration version: %w", err) - } - - // Commit transaction - err = tx.Commit() - if err != nil { - return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err) - } - - currentVersion = migration.Version - log.Debug("migration applied", "new_version", currentVersion) + driverName := db.dbType + switch driverName { + case "postgres": + driver, err := postgres.WithInstance(db.DB, &postgres.Config{}) + if err != nil { + return fmt.Errorf("failed to create postgres driver: %w", err) } + m, err = migrate.NewWithInstance("iofs", sourceInstance, "postgres", driver) + if err != nil { + return fmt.Errorf("failed to create migrate instance: %w", err) + } + + case "sqlite3": + driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) + if err != nil { + return fmt.Errorf("failed to create sqlite driver: %w", err) + } + m, err = migrate.NewWithInstance("iofs", sourceInstance, "sqlite3", driver) + if err != nil { + return fmt.Errorf("failed to create migrate instance: %w", err) + } + + default: + return fmt.Errorf("unsupported database driver: %s", driverName) } - log.Info("database migration completed", "final_version", currentVersion) + if err := m.Up(); err != nil && err != migrate.ErrNoChange { + return fmt.Errorf("failed to run migrations: %w", err) + } + + log.Info("database migration completed") return nil } diff --git a/server/internal/db/migrations/001_initial_schema.down.sql b/server/internal/db/migrations/001_initial_schema.down.sql new file mode 100644 index 0000000..f32272a --- /dev/null +++ b/server/internal/db/migrations/001_initial_schema.down.sql @@ -0,0 +1,8 @@ +-- 001_initial_schema.down.sql +DROP INDEX IF EXISTS idx_sessions_refresh_token; +DROP INDEX IF EXISTS idx_sessions_expires_at; +DROP INDEX IF EXISTS idx_sessions_user_id; +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS workspaces; +DROP TABLE IF EXISTS system_settings; +DROP TABLE IF EXISTS users; \ No newline at end of file diff --git a/server/internal/db/migrations/001_initial_schema.up.sql b/server/internal/db/migrations/001_initial_schema.up.sql new file mode 100644 index 0000000..03f1722 --- /dev/null +++ b/server/internal/db/migrations/001_initial_schema.up.sql @@ -0,0 +1,59 @@ +-- 001_initial_schema.up.sql +-- Create users table +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTO_INCREMENT, + email TEXT NOT NULL UNIQUE, + display_name TEXT, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_workspace_id INTEGER +); + +-- Create workspaces table with integrated settings +CREATE TABLE IF NOT EXISTS workspaces ( + id INTEGER PRIMARY KEY AUTO_INCREMENT, + user_id INTEGER NOT NULL, + name TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_opened_file_path TEXT, + -- Settings fields + theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')), + auto_save BOOLEAN NOT NULL DEFAULT 0, + git_enabled BOOLEAN NOT NULL DEFAULT 0, + git_url TEXT, + git_user TEXT, + git_token TEXT, + git_auto_commit BOOLEAN NOT NULL DEFAULT 0, + git_commit_msg_template TEXT DEFAULT '${action} ${filename}', + git_commit_name TEXT, + git_commit_email TEXT, + show_hidden_files BOOLEAN NOT NULL DEFAULT 0, + created_by INTEGER REFERENCES users(id), + updated_by INTEGER REFERENCES users(id), + updated_at TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) +); + +-- Create sessions table for authentication +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + user_id INTEGER NOT NULL, + refresh_token TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +); + +-- Create system_settings table for application settings +CREATE TABLE IF NOT EXISTS system_settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create indexes for performance +CREATE INDEX idx_sessions_user_id ON sessions(user_id); +CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); +CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token); \ No newline at end of file diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index b726d95..30ae7c6 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -99,7 +99,7 @@ func setupTestHarness(t *testing.T) *testHarness { // Create test config testConfig := &app.Config{ - DBPath: ":memory:", + DBURL: ":memory:", WorkDir: tempDir, StaticPath: "../testdata", Port: "8081", From 25defa5b658d8f9af440851150492c64096268dc Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 22 Feb 2025 22:32:38 +0100 Subject: [PATCH 02/39] Fix tests for db type --- server/internal/app/config_test.go | 14 ++- server/internal/app/init.go | 2 +- server/internal/db/db.go | 35 ++++-- server/internal/db/migrations.go | 9 +- .../db/migrations/001_initial_schema.up.sql | 4 +- server/internal/db/migrations_test.go | 116 +++--------------- server/internal/db/sessions_test.go | 2 +- server/internal/db/system_test.go | 2 +- server/internal/db/testdb.go | 4 +- server/internal/db/users_test.go | 2 +- server/internal/db/workspaces_test.go | 2 +- server/internal/handlers/integration_test.go | 4 +- 12 files changed, 68 insertions(+), 128 deletions(-) diff --git a/server/internal/app/config_test.go b/server/internal/app/config_test.go index ebf62ef..17bef80 100644 --- a/server/internal/app/config_test.go +++ b/server/internal/app/config_test.go @@ -2,6 +2,7 @@ package app_test import ( "lemma/internal/app" + "lemma/internal/db" "os" "testing" "time" @@ -17,7 +18,7 @@ func TestDefaultConfig(t *testing.T) { got interface{} expected interface{} }{ - {"DBPath", cfg.DBURL, "./lemma.db"}, + {"DBPath", cfg.DBURL, "sqlite://lemma.db"}, {"WorkDir", cfg.WorkDir, "./data"}, {"StaticPath", cfg.StaticPath, "../app/dist"}, {"Port", cfg.Port, "8080"}, @@ -47,7 +48,7 @@ func TestLoad(t *testing.T) { cleanup := func() { envVars := []string{ "LEMMA_ENV", - "LEMMA_DB_PATH", + "LEMMA_DB_URL", "LEMMA_WORKDIR", "LEMMA_STATIC_PATH", "LEMMA_PORT", @@ -81,8 +82,8 @@ func TestLoad(t *testing.T) { t.Fatalf("Load() error = %v", err) } - if cfg.DBURL != "./lemma.db" { - t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "./lemma.db") + if cfg.DBURL != "sqlite://lemma.db" { + t.Errorf("default DBPath = %v, want %v", cfg.DBURL, "sqlite://lemma.db") } }) @@ -93,7 +94,7 @@ func TestLoad(t *testing.T) { // Set all environment variables envs := map[string]string{ "LEMMA_ENV": "development", - "LEMMA_DB_PATH": "/custom/db/path.db", + "LEMMA_DB_URL": "sqlite:///custom/db/path.db", "LEMMA_WORKDIR": "/custom/work/dir", "LEMMA_STATIC_PATH": "/custom/static/path", "LEMMA_PORT": "3000", @@ -122,7 +123,8 @@ func TestLoad(t *testing.T) { expected interface{} }{ {"IsDevelopment", cfg.IsDevelopment, true}, - {"DBPath", cfg.DBURL, "/custom/db/path.db"}, + {"DBURL", cfg.DBURL, "/custom/db/path.db"}, + {"DBType", cfg.DBType, db.DBTypeSQLite}, {"WorkDir", cfg.WorkDir, "/custom/work/dir"}, {"StaticPath", cfg.StaticPath, "/custom/static/path"}, {"Port", cfg.Port, "3000"}, diff --git a/server/internal/app/init.go b/server/internal/app/init.go index ebed735..13f6030 100644 --- a/server/internal/app/init.go +++ b/server/internal/app/init.go @@ -30,7 +30,7 @@ func initSecretsService(cfg *Config) (secrets.Service, error) { func initDatabase(cfg *Config, secretsService secrets.Service) (db.Database, error) { logging.Debug("initializing database", "path", cfg.DBURL) - database, err := db.Init(cfg.DBURL, secretsService) + database, err := db.Init(cfg.DBType, cfg.DBURL, secretsService) if err != nil { return nil, fmt.Errorf("failed to initialize database: %w", err) } diff --git a/server/internal/db/db.go b/server/internal/db/db.go index fca938c..53af48b 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -119,10 +119,31 @@ type database struct { } // Init initializes the database connection -func Init(dbPath string, secretsService secrets.Service) (Database, error) { - log := getLogger() +func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) { - db, err := sql.Open("sqlite3", dbPath) + switch dbType { + case DBTypeSQLite: + db, err := initSQLite(dbURL) + if err != nil { + return nil, fmt.Errorf("failed to initialize SQLite database: %w", err) + } + + database := &database{ + DB: db, + secretsService: secretsService, + dbType: dbType, + } + return database, nil + case DBTypePostgres: + return nil, fmt.Errorf("postgres database not supported yet") + } + + return nil, fmt.Errorf("unsupported database type: %s", dbType) +} + +func initSQLite(dbURL string) (*sql.DB, error) { + log := getLogger() + db, err := sql.Open("sqlite3", dbURL) if err != nil { return nil, fmt.Errorf("failed to open database: %w", err) } @@ -136,13 +157,7 @@ func Init(dbPath string, secretsService secrets.Service) (Database, error) { return nil, fmt.Errorf("failed to enable foreign keys: %w", err) } log.Debug("foreign keys enabled") - - database := &database{ - DB: db, - secretsService: secretsService, - } - - return database, nil + return db, nil } // Close closes the database connection diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index 3803f7f..efa3769 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -25,9 +25,8 @@ func (db *database) Migrate() error { var m *migrate.Migrate - driverName := db.dbType - switch driverName { - case "postgres": + switch db.dbType { + case DBTypePostgres: driver, err := postgres.WithInstance(db.DB, &postgres.Config{}) if err != nil { return fmt.Errorf("failed to create postgres driver: %w", err) @@ -37,7 +36,7 @@ func (db *database) Migrate() error { return fmt.Errorf("failed to create migrate instance: %w", err) } - case "sqlite3": + case DBTypeSQLite: driver, err := sqlite3.WithInstance(db.DB, &sqlite3.Config{}) if err != nil { return fmt.Errorf("failed to create sqlite driver: %w", err) @@ -48,7 +47,7 @@ func (db *database) Migrate() error { } default: - return fmt.Errorf("unsupported database driver: %s", driverName) + return fmt.Errorf("unsupported database driver: %s", db.dbType) } if err := m.Up(); err != nil && err != migrate.ErrNoChange { diff --git a/server/internal/db/migrations/001_initial_schema.up.sql b/server/internal/db/migrations/001_initial_schema.up.sql index 03f1722..8c13e9b 100644 --- a/server/internal/db/migrations/001_initial_schema.up.sql +++ b/server/internal/db/migrations/001_initial_schema.up.sql @@ -1,7 +1,7 @@ -- 001_initial_schema.up.sql -- Create users table CREATE TABLE IF NOT EXISTS users ( - id INTEGER PRIMARY KEY AUTO_INCREMENT, + id INTEGER PRIMARY KEY AUTOINCREMENT, email TEXT NOT NULL UNIQUE, display_name TEXT, password_hash TEXT NOT NULL, @@ -12,7 +12,7 @@ CREATE TABLE IF NOT EXISTS users ( -- Create workspaces table with integrated settings CREATE TABLE IF NOT EXISTS workspaces ( - id INTEGER PRIMARY KEY AUTO_INCREMENT, + id INTEGER PRIMARY KEY AUTOINCREMENT, user_id INTEGER NOT NULL, name TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index bb8f655..ecce24c 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -1,53 +1,36 @@ package db_test import ( - "testing" - "lemma/internal/db" - _ "lemma/internal/testenv" + "testing" _ "github.com/mattn/go-sqlite3" ) func TestMigrate(t *testing.T) { - database, err := db.NewTestDB(":memory:", &mockSecrets{}) + database, err := db.NewTestDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to initialize database: %v", err) } defer database.Close() - t.Run("migrations are applied in order", func(t *testing.T) { - if err := database.Migrate(); err != nil { - t.Fatalf("failed to run initial migrations: %v", err) - } - - // Check migration version - var version int - err := database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version) - if err != nil { - t.Fatalf("failed to get migration version: %v", err) - } - - if version != 1 { // Current number of migrations in production code - t.Errorf("expected migration version 1, got %d", version) - } - - // Verify number of migration entries matches versions applied - var count int - err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count) - if err != nil { - t.Fatalf("failed to count migrations: %v", err) - } - - if count != 1 { - t.Errorf("expected 1 migration entries, got %d", count) - } - }) - t.Run("migrations create expected schema", func(t *testing.T) { + // Run migrations + if err := database.Migrate(); err != nil { + t.Fatalf("failed to run migrations: %v", err) + } + // Verify tables exist - tables := []string{"users", "workspaces", "sessions", "system_settings", "migrations"} + tables := []string{ + "users", + "workspaces", + "sessions", + "system_settings", + // Note: golang-migrate uses its own migrations table + "schema_migrations", + } + for _, table := range tables { if !tableExists(t, database, table) { t.Errorf("table %q does not exist", table) @@ -63,91 +46,32 @@ func TestMigrate(t *testing.T) { {"sessions", "idx_sessions_expires_at"}, {"sessions", "idx_sessions_refresh_token"}, } - for _, idx := range indexes { if !indexExists(t, database, idx.table, idx.name) { t.Errorf("index %q on table %q does not exist", idx.name, idx.table) } } }) - - t.Run("migrations are idempotent", func(t *testing.T) { - // Run migrations again - if err := database.Migrate(); err != nil { - t.Fatalf("failed to re-run migrations: %v", err) - } - - // Verify migration count hasn't changed - var count int - err = database.TestDB().QueryRow("SELECT COUNT(*) FROM migrations").Scan(&count) - if err != nil { - t.Fatalf("failed to count migrations: %v", err) - } - - if count != 1 { - t.Errorf("expected 1 migration entries, got %d", count) - } - }) - - t.Run("rollback on migration failure", func(t *testing.T) { - // Create a test table that would conflict with a failing migration - _, err := database.TestDB().Exec("CREATE TABLE test_rollback (id INTEGER PRIMARY KEY)") - if err != nil { - t.Fatalf("failed to create test table: %v", err) - } - - // Start transaction - tx, err := database.Begin() - if err != nil { - t.Fatalf("failed to start transaction: %v", err) - } - - // Try operations that should fail and rollback - _, err = tx.Exec(` - CREATE TABLE test_rollback (id INTEGER PRIMARY KEY); - INSERT INTO nonexistent_table VALUES (1); - `) - if err == nil { - tx.Rollback() - t.Fatal("expected migration to fail") - } - tx.Rollback() - - // Verify the migration version hasn't changed - var version int - err = database.TestDB().QueryRow("SELECT MAX(version) FROM migrations").Scan(&version) - if err != nil { - t.Fatalf("failed to get migration version: %v", err) - } - - if version != 1 { - t.Errorf("expected migration version to remain at 1, got %d", version) - } - }) } func tableExists(t *testing.T, database db.TestDatabase, tableName string) bool { t.Helper() - var name string err := database.TestDB().QueryRow(` - SELECT name FROM sqlite_master - WHERE type='table' AND name=?`, + SELECT name FROM sqlite_master + WHERE type='table' AND name=?`, tableName, ).Scan(&name) - return err == nil } func indexExists(t *testing.T, database db.TestDatabase, tableName, indexName string) bool { t.Helper() - var name string err := database.TestDB().QueryRow(` - SELECT name FROM sqlite_master - WHERE type='index' AND tbl_name=? AND name=?`, + SELECT name FROM sqlite_master + WHERE type='index' AND tbl_name=? AND name=?`, tableName, indexName, ).Scan(&name) - return err == nil } diff --git a/server/internal/db/sessions_test.go b/server/internal/db/sessions_test.go index 21f7765..67b2526 100644 --- a/server/internal/db/sessions_test.go +++ b/server/internal/db/sessions_test.go @@ -13,7 +13,7 @@ import ( ) func TestSessionOperations(t *testing.T) { - database, err := db.NewTestDB(":memory:", &mockSecrets{}) + database, err := db.NewTestDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/system_test.go b/server/internal/db/system_test.go index b86a667..8e65023 100644 --- a/server/internal/db/system_test.go +++ b/server/internal/db/system_test.go @@ -15,7 +15,7 @@ import ( ) func TestSystemOperations(t *testing.T) { - database, err := db.NewTestDB(":memory:", &mockSecrets{}) + database, err := db.NewTestDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/testdb.go b/server/internal/db/testdb.go index 203f561..14e13e1 100644 --- a/server/internal/db/testdb.go +++ b/server/internal/db/testdb.go @@ -12,8 +12,8 @@ type TestDatabase interface { TestDB() *sql.DB } -func NewTestDB(dbPath string, secretsService secrets.Service) (TestDatabase, error) { - db, err := Init(dbPath, secretsService) +func NewTestDB(secretsService secrets.Service) (TestDatabase, error) { + db, err := Init(DBTypeSQLite, ":memory:", secretsService) if err != nil { return nil, err } diff --git a/server/internal/db/users_test.go b/server/internal/db/users_test.go index f8ad7db..5709ab4 100644 --- a/server/internal/db/users_test.go +++ b/server/internal/db/users_test.go @@ -10,7 +10,7 @@ import ( ) func TestUserOperations(t *testing.T) { - database, err := db.NewTestDB(":memory:", &mockSecrets{}) + database, err := db.NewTestDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/workspaces_test.go b/server/internal/db/workspaces_test.go index 009f4a9..3d6fd9f 100644 --- a/server/internal/db/workspaces_test.go +++ b/server/internal/db/workspaces_test.go @@ -10,7 +10,7 @@ import ( ) func TestWorkspaceOperations(t *testing.T) { - database, err := db.NewTestDB(":memory:", &mockSecrets{}) + database, err := db.NewTestDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 30ae7c6..fb66585 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -61,7 +61,7 @@ func setupTestHarness(t *testing.T) *testHarness { t.Fatalf("Failed to initialize secrets service: %v", err) } - database, err := db.NewTestDB(":memory:", secretsSvc) + database, err := db.NewTestDB(secretsSvc) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) } @@ -99,7 +99,7 @@ func setupTestHarness(t *testing.T) *testHarness { // Create test config testConfig := &app.Config{ - DBURL: ":memory:", + DBURL: "sqlite://:memory:", WorkDir: tempDir, StaticPath: "../testdata", Port: "8081", From c76057d605ca4556afd4ddff523dac63dd260a9b Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 23 Feb 2025 14:58:30 +0100 Subject: [PATCH 03/39] Implement sql query builder --- server/internal/db/query.go | 263 ++++++++++++++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 server/internal/db/query.go diff --git a/server/internal/db/query.go b/server/internal/db/query.go new file mode 100644 index 0000000..2f380cd --- /dev/null +++ b/server/internal/db/query.go @@ -0,0 +1,263 @@ +package db + +import ( + "fmt" + "strings" +) + +type JoinType string + +const ( + InnerJoin JoinType = "INNER JOIN" + LeftJoin JoinType = "LEFT JOIN" + RightJoin JoinType = "RIGHT JOIN" +) + +// Query represents a SQL query with its parameters +type Query struct { + builder strings.Builder + args []interface{} + dbType DBType + pos int // tracks the current placeholder position + hasSelect bool + hasFrom bool + hasWhere bool + hasOrderBy bool + hasGroupBy bool + hasLimit bool + hasOffset bool + isInParens bool + parensDepth int +} + +// NewQuery creates a new Query instance +func NewQuery(dbType DBType) *Query { + return &Query{ + dbType: dbType, + args: make([]interface{}, 0), + } +} + +// Select adds a SELECT clause +func (q *Query) Select(columns ...string) *Query { + if !q.hasSelect { + q.Write("SELECT ") + q.Write(strings.Join(columns, ", ")) + q.hasSelect = true + } + return q +} + +// From adds a FROM clause +func (q *Query) From(table string) *Query { + if !q.hasFrom { + q.Write(" FROM ") + q.Write(table) + q.hasFrom = true + } + return q +} + +// Where adds a WHERE clause +func (q *Query) Where(condition string) *Query { + if !q.hasWhere { + q.Write(" WHERE ") + q.hasWhere = true + } else { + q.Write(" AND ") + } + q.Write(condition) + return q +} + +// WhereIn adds a WHERE IN clause +func (q *Query) WhereIn(column string, count int) *Query { + if !q.hasWhere { + q.Write(" WHERE ") + q.hasWhere = true + } else { + q.Write(" AND ") + } + q.Write(column) + q.Write(" IN (") + q.Placeholders(count) + q.Write(")") + return q +} + +// And adds an AND condition +func (q *Query) And(condition string) *Query { + q.Write(" AND ") + q.Write(condition) + return q +} + +// Or adds an OR condition +func (q *Query) Or(condition string) *Query { + q.Write(" OR ") + q.Write(condition) + return q +} + +// Join adds a JOIN clause +func (q *Query) Join(joinType JoinType, table, condition string) *Query { + q.Write(" ") + q.Write(string(joinType)) + q.Write(" ") + q.Write(table) + q.Write(" ON ") + q.Write(condition) + return q +} + +// OrderBy adds an ORDER BY clause +func (q *Query) OrderBy(columns ...string) *Query { + if !q.hasOrderBy { + q.Write(" ORDER BY ") + q.Write(strings.Join(columns, ", ")) + q.hasOrderBy = true + } + return q +} + +// GroupBy adds a GROUP BY clause +func (q *Query) GroupBy(columns ...string) *Query { + if !q.hasGroupBy { + q.Write(" GROUP BY ") + q.Write(strings.Join(columns, ", ")) + q.hasGroupBy = true + } + return q +} + +// Limit adds a LIMIT clause +func (q *Query) Limit(limit int) *Query { + if !q.hasLimit { + q.Write(" LIMIT ") + q.Write(fmt.Sprintf("%d", limit)) + q.hasLimit = true + } + return q +} + +// Offset adds an OFFSET clause +func (q *Query) Offset(offset int) *Query { + if !q.hasOffset { + q.Write(" OFFSET ") + q.Write(fmt.Sprintf("%d", offset)) + q.hasOffset = true + } + return q +} + +// Insert starts an INSERT statement +func (q *Query) Insert(table string, columns ...string) *Query { + q.Write("INSERT INTO ") + q.Write(table) + q.Write(" (") + q.Write(strings.Join(columns, ", ")) + q.Write(") VALUES ") + return q +} + +// Values adds a VALUES clause +func (q *Query) Values(count int) *Query { + q.Write("(") + q.Placeholders(count) + q.Write(")") + return q +} + +// Update starts an UPDATE statement +func (q *Query) Update(table string) *Query { + q.Write("UPDATE ") + q.Write(table) + q.Write(" SET ") + return q +} + +// Set adds a SET clause for updates +func (q *Query) Set(column string) *Query { + if strings.Contains(q.builder.String(), "SET ") && + !strings.HasSuffix(q.builder.String(), "SET ") { + q.Write(", ") + } + q.Write(column) + q.Write(" = ") + return q +} + +// Delete starts a DELETE statement +func (q *Query) Delete() *Query { + q.Write("DELETE") + return q +} + +// StartGroup starts a parenthetical group +func (q *Query) StartGroup() *Query { + q.Write(" (") + q.parensDepth++ + return q +} + +// EndGroup ends a parenthetical group +func (q *Query) EndGroup() *Query { + if q.parensDepth > 0 { + q.Write(")") + q.parensDepth-- + } + return q +} + +// Write adds a string to the query +func (q *Query) Write(s string) *Query { + q.builder.WriteString(s) + return q +} + +// Placeholder adds a placeholder for a single argument +func (q *Query) Placeholder(arg interface{}) *Query { + q.pos++ + q.args = append(q.args, arg) + + if q.dbType == DBTypePostgres { + q.builder.WriteString(fmt.Sprintf("$%d", q.pos)) + } else { + q.builder.WriteString("?") + } + + return q +} + +// Placeholders adds n placeholders separated by commas +func (q *Query) Placeholders(n int) *Query { + placeholders := make([]string, n) + + for i := 0; i < n; i++ { + q.pos++ + if q.dbType == DBTypePostgres { + placeholders[i] = fmt.Sprintf("$%d", q.pos) + } else { + placeholders[i] = "?" + } + } + + q.builder.WriteString(strings.Join(placeholders, ", ")) + return q +} + +// AddArgs adds arguments to the query +func (q *Query) AddArgs(args ...interface{}) *Query { + q.args = append(q.args, args...) + return q +} + +// String returns the formatted query string +func (q *Query) String() string { + return q.builder.String() +} + +// Args returns the query arguments +func (q *Query) Args() []interface{} { + return q.args +} From a946b8ae7673ec1037a919342a88c4eeffbd1db7 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 23 Feb 2025 14:59:00 +0100 Subject: [PATCH 04/39] Implement db struct scanner --- server/internal/db/scanner.go | 95 +++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 server/internal/db/scanner.go diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go new file mode 100644 index 0000000..5141d8b --- /dev/null +++ b/server/internal/db/scanner.go @@ -0,0 +1,95 @@ +package db + +import ( + "database/sql" + "fmt" + "reflect" +) + +// Scanner provides methods for scanning rows into structs +type Scanner struct { + db *sql.DB + dbType DBType +} + +// NewScanner creates a new Scanner instance +func NewScanner(db *sql.DB, dbType DBType) *Scanner { + return &Scanner{ + db: db, + dbType: dbType, + } +} + +// QueryRow executes a query and scans the result into a struct +func (s *Scanner) QueryRow(dest interface{}, q *Query) error { + row := s.db.QueryRow(q.String(), q.Args()...) + return scanStruct(row, dest) +} + +// Query executes a query and scans multiple results into a slice of structs +func (s *Scanner) Query(dest interface{}, q *Query) error { + rows, err := s.db.Query(q.String(), q.Args()...) + if err != nil { + return err + } + defer rows.Close() + + return scanStructs(rows, dest) +} + +// scanStruct scans a single row into a struct +func scanStruct(row *sql.Row, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer to a struct") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to a struct") + } + + fields := make([]interface{}, 0, v.NumField()) + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanSet() { + fields = append(fields, field.Addr().Interface()) + } + } + + return row.Scan(fields...) +} + +// scanStructs scans multiple rows into a slice of structs +func scanStructs(rows *sql.Rows, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer to a slice") + } + sliceVal := v.Elem() + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("dest must be a pointer to a slice") + } + + elemType := sliceVal.Type().Elem() + + for rows.Next() { + newElem := reflect.New(elemType).Elem() + fields := make([]interface{}, 0, newElem.NumField()) + + for i := 0; i < newElem.NumField(); i++ { + field := newElem.Field(i) + if field.CanSet() { + fields = append(fields, field.Addr().Interface()) + } + } + + if err := rows.Scan(fields...); err != nil { + return err + } + + sliceVal.Set(reflect.Append(sliceVal, newElem)) + } + + return rows.Err() +} From 7cbe6fd272dc7abbc16c93c551d0a3d03c6849a5 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 23 Feb 2025 15:19:34 +0100 Subject: [PATCH 05/39] Add postgres init --- server/internal/db/db.go | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/server/internal/db/db.go b/server/internal/db/db.go index 53af48b..ff80725 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -9,6 +9,7 @@ import ( "lemma/internal/models" "lemma/internal/secrets" + _ "github.com/lib/pq" // Postgres driver _ "github.com/mattn/go-sqlite3" // SQLite driver ) @@ -135,7 +136,17 @@ func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database } return database, nil case DBTypePostgres: - return nil, fmt.Errorf("postgres database not supported yet") + db, err := initPostgres(dbURL) + if err != nil { + return nil, fmt.Errorf("failed to initialize Postgres database: %w", err) + } + + database := &database{ + DB: db, + secretsService: secretsService, + dbType: dbType, + } + return database, nil } return nil, fmt.Errorf("unsupported database type: %s", dbType) @@ -160,6 +171,19 @@ func initSQLite(dbURL string) (*sql.DB, error) { return db, nil } +func initPostgres(dbURL string) (*sql.DB, error) { + db, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := db.Ping(); err != nil { + return nil, fmt.Errorf("failed to ping database: %w", err) + } + + return db, nil +} + // Close closes the database connection func (db *database) Close() error { log := getLogger() From 96284c3dbd460d823be19cdc669e3bf2a21e5413 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 24 Feb 2025 21:38:52 +0100 Subject: [PATCH 06/39] Add query and scanner tests --- server/internal/db/query.go | 20 +- server/internal/db/query_test.go | 704 +++++++++++++++++++++++++++++ server/internal/db/scanner.go | 199 +++++++- server/internal/db/scanner_test.go | 298 ++++++++++++ 4 files changed, 1217 insertions(+), 4 deletions(-) create mode 100644 server/internal/db/query_test.go create mode 100644 server/internal/db/scanner_test.go diff --git a/server/internal/db/query.go b/server/internal/db/query.go index 2f380cd..af8fd7c 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -24,6 +24,7 @@ type Query struct { hasWhere bool hasOrderBy bool hasGroupBy bool + hasHaving bool hasLimit bool hasOffset bool isInParens bool @@ -130,6 +131,18 @@ func (q *Query) GroupBy(columns ...string) *Query { return q } +// Having adds a HAVING clause for filtering groups +func (q *Query) Having(condition string) *Query { + if !q.hasHaving { + q.Write(" HAVING ") + q.hasHaving = true + } else { + q.Write(" AND ") + } + q.Write(condition) + return q +} + // Limit adds a LIMIT clause func (q *Query) Limit(limit int) *Query { if !q.hasLimit { @@ -195,7 +208,12 @@ func (q *Query) Delete() *Query { // StartGroup starts a parenthetical group func (q *Query) StartGroup() *Query { - q.Write(" (") + if q.hasWhere { + q.Write(" AND (") + } else { + q.Write(" WHERE (") + q.hasWhere = true + } q.parensDepth++ return q } diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go new file mode 100644 index 0000000..6664936 --- /dev/null +++ b/server/internal/db/query_test.go @@ -0,0 +1,704 @@ +package db_test + +import ( + "reflect" + "testing" + + "lemma/internal/db" +) + +func TestNewQuery(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + }{ + { + name: "SQLite query", + dbType: db.DBTypeSQLite, + }, + { + name: "Postgres query", + dbType: db.DBTypePostgres, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + + // Test that a new query is empty + if q.String() != "" { + t.Errorf("NewQuery() should return empty string, got %q", q.String()) + } + if len(q.Args()) != 0 { + t.Errorf("NewQuery() should return empty args, got %v", q.Args()) + } + + // Test placeholder behavior - SQLite uses ? and Postgres uses $1 + q.Write("test").Placeholder(1) + + expectedPlaceholder := "?" + if tt.dbType == db.DBTypePostgres { + expectedPlaceholder = "$1" + } + + if q.String() != "test"+expectedPlaceholder { + t.Errorf("Expected placeholder format %q for %s, got %q", + "test"+expectedPlaceholder, tt.name, q.String()) + } + }) + } +} + +func TestBasicQueryBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple select SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users") + }, + wantSQL: "SELECT id, name FROM users", + wantArgs: []interface{}{}, + }, + { + name: "Simple select Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users") + }, + wantSQL: "SELECT id, name FROM users", + wantArgs: []interface{}{}, + }, + { + name: "Select with where SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT id, name FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Select with where Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT id, name FROM users WHERE id = $1", + wantArgs: []interface{}{1}, + }, + { + name: "Multiple where conditions SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("role = ").Placeholder("admin") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?", + wantArgs: []interface{}{true, "admin"}, + }, + { + name: "Multiple where conditions Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("role = ").Placeholder("admin") + }, + wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2", + wantArgs: []interface{}{true, "admin"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestPlaceholders(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Single placeholder SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) + }, + wantSQL: "SELECT * FROM users WHERE id = ?", + wantArgs: []interface{}{42}, + }, + { + name: "Single placeholder Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) + }, + wantSQL: "SELECT * FROM users WHERE id = $1", + wantArgs: []interface{}{42}, + }, + { + name: "Multiple placeholders SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = "). + Placeholder(42). + Write(" AND name = "). + Placeholder("John") + }, + wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?", + wantArgs: []interface{}{42, "John"}, + }, + { + name: "Multiple placeholders Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id = "). + Placeholder(42). + Write(" AND name = "). + Placeholder("John") + }, + wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2", + wantArgs: []interface{}{42, "John"}, + }, + { + name: "Placeholders for IN SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id IN ("). + Placeholders(3). + Write(")"). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", + wantArgs: []interface{}{1, 2, 3}, + }, + { + name: "Placeholders for IN Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Write("SELECT * FROM users WHERE id IN ("). + Placeholders(3). + Write(")"). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)", + wantArgs: []interface{}{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestWhereClauseBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "SELECT * FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Where with And", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("id = ").Placeholder(1). + And("active = ").Placeholder(true) + }, + wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?", + wantArgs: []interface{}{1, true}, + }, + { + name: "Where with Or", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("id = ").Placeholder(1). + Or("id = ").Placeholder(2) + }, + wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?", + wantArgs: []interface{}{1, 2}, + }, + { + name: "Where with parentheses", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + And("("). + Write("id = ").Placeholder(1). + Or("id = ").Placeholder(2). + Write(")") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", + wantArgs: []interface{}{true, 1, 2}, + }, + { + name: "Where with StartGroup and EndGroup", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("active = ").Placeholder(true). + Write(" AND ("). + Write("id = ").Placeholder(1). + Or("id = ").Placeholder(2). + Write(")") + }, + wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", + wantArgs: []interface{}{true, 1, 2}, + }, + { + name: "Where with nested groups", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + Where("("). + Write("active = ").Placeholder(true). + Or("role = ").Placeholder("admin"). + Write(")"). + And("created_at > ").Placeholder("2020-01-01") + }, + wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?", + wantArgs: []interface{}{true, "admin", "2020-01-01"}, + }, + { + name: "WhereIn", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users"). + WhereIn("id", 3). + AddArgs(1, 2, 3) + }, + wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", + wantArgs: []interface{}{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestJoinClauseBuilding(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Inner join", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Left join", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.LeftJoin, "workspaces w", "w.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Multiple joins", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name", "s.role"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id"). + Join(db.LeftJoin, "settings s", "s.user_id = u.id") + }, + wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id", + wantArgs: []interface{}{}, + }, + { + name: "Join with where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.*", "w.name"). + From("users u"). + Join(db.InnerJoin, "workspaces w", "w.user_id = u.id"). + Where("u.active = ").Placeholder(true) + }, + wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?", + wantArgs: []interface{}{true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestOrderLimitOffset(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Order by", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").OrderBy("name ASC") + }, + wantSQL: "SELECT * FROM users ORDER BY name ASC", + wantArgs: []interface{}{}, + }, + { + name: "Order by multiple columns", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC") + }, + wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC", + wantArgs: []interface{}{}, + }, + { + name: "Limit", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Limit(10) + }, + wantSQL: "SELECT * FROM users LIMIT 10", + wantArgs: []interface{}{}, + }, + { + name: "Limit and offset", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*").From("users").Limit(10).Offset(20) + }, + wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20", + wantArgs: []interface{}{}, + }, + { + name: "Complete query with all clauses", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("*"). + From("users"). + Where("active = ").Placeholder(true). + OrderBy("name ASC"). + Limit(10). + Offset(20) + }, + wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20", + wantArgs: []interface{}{true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestInsertUpdateDelete(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Insert SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John", "john@example.com") + }, + wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)", + wantArgs: []interface{}{"John", "john@example.com"}, + }, + { + name: "Insert Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John", "john@example.com") + }, + wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)", + wantArgs: []interface{}{"John", "john@example.com"}, + }, + { + name: "Update SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("John"). + Set("email").Placeholder("john@example.com"). + Where("id = ").Placeholder(1) + }, + wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?", + wantArgs: []interface{}{"John", "john@example.com", 1}, + }, + { + name: "Update Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("John"). + Set("email").Placeholder("john@example.com"). + Where("id = ").Placeholder(1) + }, + wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3", + wantArgs: []interface{}{"John", "john@example.com", 1}, + }, + { + name: "Delete SQLite", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Delete().From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "DELETE FROM users WHERE id = ?", + wantArgs: []interface{}{1}, + }, + { + name: "Delete Postgres", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Delete().From("users").Where("id = ").Placeholder(1) + }, + wantSQL: "DELETE FROM users WHERE id = $1", + wantArgs: []interface{}{1}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestHavingClause(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Simple having", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "COUNT(*) as count"). + From("employees"). + GroupBy("department"). + Having("count > ").Placeholder(5) + }, + wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?", + wantArgs: []interface{}{5}, + }, + { + name: "Having with multiple conditions", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "AVG(salary) as avg_salary"). + From("employees"). + GroupBy("department"). + Having("avg_salary > ").Placeholder(50000). + And("COUNT(*) > ").Placeholder(3) + }, + wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?", + wantArgs: []interface{}{50000, 3}, + }, + { + name: "Having with postgres placeholders", + dbType: db.DBTypePostgres, + buildFn: func(q *db.Query) *db.Query { + return q.Select("department", "COUNT(*) as count"). + From("employees"). + GroupBy("department"). + Having("count > ").Placeholder(5) + }, + wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1", + wantArgs: []interface{}{5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} + +func TestComplexQueries(t *testing.T) { + tests := []struct { + name string + dbType db.DBType + buildFn func(*db.Query) *db.Query + wantSQL string + wantArgs []interface{} + }{ + { + name: "Complex select with join and where", + dbType: db.DBTypeSQLite, + buildFn: func(q *db.Query) *db.Query { + return q.Select("u.id", "u.name", "COUNT(w.id) as workspace_count"). + From("users u"). + Join(db.LeftJoin, "workspaces w", "w.user_id = u.id"). + Where("u.active = ").Placeholder(true). + GroupBy("u.id", "u.name"). + Having("COUNT(w.id) > ").Placeholder(0). + OrderBy("workspace_count DESC"). + Limit(10) + }, + wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10", + wantArgs: []interface{}{true, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := db.NewQuery(tt.dbType) + q = tt.buildFn(q) + + gotSQL := q.String() + gotArgs := q.Args() + + if gotSQL != tt.wantSQL { + t.Errorf("Query.String() = %q, want %q", gotSQL, tt.wantSQL) + } + + if !reflect.DeepEqual(gotArgs, tt.wantArgs) { + t.Errorf("Query.Args() = %v, want %v", gotArgs, tt.wantArgs) + } + }) + } +} diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go index 5141d8b..97e0422 100644 --- a/server/internal/db/scanner.go +++ b/server/internal/db/scanner.go @@ -4,6 +4,8 @@ import ( "database/sql" "fmt" "reflect" + "regexp" + "strings" ) // Scanner provides methods for scanning rows into structs @@ -23,7 +25,24 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner { // QueryRow executes a query and scans the result into a struct func (s *Scanner) QueryRow(dest interface{}, q *Query) error { row := s.db.QueryRow(q.String(), q.Args()...) - return scanStruct(row, dest) + + // Handle primitive types + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + elem := v.Elem() + switch elem.Kind() { + case reflect.Struct: + return scanStruct(row, dest) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: + return row.Scan(dest) + default: + return fmt.Errorf("unsupported dest type: %T", dest) + } } // Query executes a query and scans multiple results into a slice of structs @@ -41,7 +60,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error { func scanStruct(row *sql.Row, dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer to a struct") + return fmt.Errorf("dest must be a pointer") } v = v.Elem() if v.Kind() != reflect.Struct { @@ -64,8 +83,9 @@ func scanStruct(row *sql.Row, dest interface{}) error { func scanStructs(rows *sql.Rows, dest interface{}) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer to a slice") + return fmt.Errorf("dest must be a pointer") } + sliceVal := v.Elem() if sliceVal.Kind() != reflect.Slice { return fmt.Errorf("dest must be a pointer to a slice") @@ -93,3 +113,176 @@ func scanStructs(rows *sql.Rows, dest interface{}) error { return rows.Err() } + +// ScannerEx is an extended version of Scanner with more features +type ScannerEx struct { + db *sql.DB + dbType DBType +} + +// NewScannerEx creates a new ScannerEx instance +func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx { + return &ScannerEx{ + db: db, + dbType: dbType, + } +} + +// QueryRow executes a query and scans the result into a struct +func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error { + row := s.db.QueryRow(q.String(), q.Args()...) + + // Get column names + // Note: This is a workaround since sql.Row doesn't expose column names. + // In a real implementation, you'd likely need to execute the query to get columns first. + // For simplicity, we'll infer them from the struct tags. + + return scanStructTags(row, dest) +} + +// Query executes a query and scans multiple results into a slice of structs +func (s *ScannerEx) Query(dest interface{}, q *Query) error { + rows, err := s.db.Query(q.String(), q.Args()...) + if err != nil { + return err + } + defer rows.Close() + + return scanStructsTags(rows, dest) +} + +// getFieldMap builds a map of db column names to struct fields using struct tags +func getFieldMap(t reflect.Type) map[string]int { + fieldMap := make(map[string]int) + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + // Check db tag first + tag := field.Tag.Get("db") + if tag != "" && tag != "-" { + fieldMap[tag] = i + continue + } + + // Check json tag next + tag = field.Tag.Get("json") + if tag != "" && tag != "-" { + // Handle json tag options like omitempty + parts := strings.Split(tag, ",") + fieldMap[parts[0]] = i + continue + } + + // Default to field name with snake_case conversion + fieldMap[toSnakeCase(field.Name)] = i + } + + return fieldMap +} + +var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`) + +// toSnakeCase converts a camelCase string to snake_case +func toSnakeCase(s string) string { + return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}")) +} + +// scanStructTags scans a single row into a struct using field tags +func scanStructTags(row *sql.Row, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + v = v.Elem() + if v.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to a struct") + } + + fields := make([]interface{}, 0, v.NumField()) + + for i := 0; i < v.NumField(); i++ { + field := v.Field(i) + if field.CanSet() { + fields = append(fields, field.Addr().Interface()) + } + } + + return row.Scan(fields...) +} + +// scanStructsTags scans multiple rows into a slice of structs using field tags +func scanStructsTags(rows *sql.Rows, dest interface{}) error { + v := reflect.ValueOf(dest) + if v.Kind() != reflect.Ptr { + return fmt.Errorf("dest must be a pointer") + } + + sliceVal := v.Elem() + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("dest must be a pointer to a slice") + } + + elemType := sliceVal.Type().Elem() + isPtr := elemType.Kind() == reflect.Ptr + if isPtr { + elemType = elemType.Elem() + } + + if elemType.Kind() != reflect.Struct { + return fmt.Errorf("dest must be a pointer to a slice of structs") + } + + // Get column names + columns, err := rows.Columns() + if err != nil { + return err + } + + // Build field mapping + fieldMap := getFieldMap(elemType) + + // Prepare values slice for each scan + values := make([]interface{}, len(columns)) + scanFields := make([]interface{}, len(columns)) + for i := range values { + scanFields[i] = &values[i] + } + + for rows.Next() { + // Create a new struct instance + newElem := reflect.New(elemType).Elem() + + // Scan row into values + if err := rows.Scan(scanFields...); err != nil { + return err + } + + // Map values to struct fields + for i, colName := range columns { + if fieldIndex, ok := fieldMap[colName]; ok { + field := newElem.Field(fieldIndex) + if field.CanSet() { + val := reflect.ValueOf(values[i]) + if val.Elem().Kind() == reflect.Interface { + val = val.Elem() + } + if val.Kind() == reflect.Ptr && !val.IsNil() { + field.Set(val.Elem()) + } else if !val.IsNil() { + field.Set(val) + } + } + } + } + + // Append to result slice + if isPtr { + sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) + } else { + sliceVal.Set(reflect.Append(sliceVal, newElem)) + } + } + + return rows.Err() +} diff --git a/server/internal/db/scanner_test.go b/server/internal/db/scanner_test.go new file mode 100644 index 0000000..a1fbc28 --- /dev/null +++ b/server/internal/db/scanner_test.go @@ -0,0 +1,298 @@ +package db_test + +import ( + "database/sql" + "testing" + "time" + + "lemma/internal/db" +) + +func TestScannerQueryRow(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + // Create a test table + _, err = testDB.TestDB().Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + active BOOLEAN NOT NULL + ) + `) + if err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + type User struct { + ID int + Email string + CreatedAt time.Time + Active bool + } + + // Insert test data + now := time.Now().UTC().Truncate(time.Second) + _, err = testDB.TestDB().Exec( + "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", + 1, "test@example.com", now, true, + ) + if err != nil { + t.Fatalf("Failed to insert test data: %v", err) + } + + // Test query row success + t.Run("QueryRow success", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id = "). + Placeholder(1) + + var user User + err := scanner.QueryRow(&user, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if user.ID != 1 { + t.Errorf("Expected ID 1, got %d", user.ID) + } + if user.Email != "test@example.com" { + t.Errorf("Expected Email test@example.com, got %s", user.Email) + } + if !user.CreatedAt.Equal(now) { + t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt) + } + if !user.Active { + t.Errorf("Expected Active true, got %v", user.Active) + } + }) + + // Test query row no results + t.Run("QueryRow no results", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id = "). + Placeholder(999) + + var user User + err := scanner.QueryRow(&user, q) + + if err != sql.ErrNoRows { + t.Errorf("Expected ErrNoRows, got %v", err) + } + }) + + // Test scanning a single value + t.Run("QueryRow single value", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("COUNT(*)").From("users") + + var count int + err := scanner.QueryRow(&count, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if count != 1 { + t.Errorf("Expected count 1, got %d", count) + } + }) +} + +func TestScannerQuery(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + // Create a test table + _, err = testDB.TestDB().Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + created_at TIMESTAMP NOT NULL, + active BOOLEAN NOT NULL + ) + `) + if err != nil { + t.Fatalf("Failed to create test table: %v", err) + } + + type User struct { + ID int + Email string + CreatedAt time.Time + Active bool + } + + // Insert test data + now := time.Now().UTC().Truncate(time.Second) + testUsers := []User{ + {ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true}, + {ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false}, + {ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true}, + } + + for _, user := range testUsers { + _, err = testDB.TestDB().Exec( + "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", + user.ID, user.Email, user.CreatedAt, user.Active, + ) + if err != nil { + t.Fatalf("Failed to insert test data: %v", err) + } + } + + // Test query multiple rows + t.Run("Query multiple rows", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + OrderBy("id ASC") + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != len(testUsers) { + t.Errorf("Expected %d users, got %d", len(testUsers), len(users)) + } + + for i, u := range users { + if u.ID != testUsers[i].ID { + t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID) + } + if u.Email != testUsers[i].Email { + t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email) + } + if !u.CreatedAt.Equal(testUsers[i].CreatedAt) { + t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt) + } + if u.Active != testUsers[i].Active { + t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active) + } + } + }) + + // Test query with filter + t.Run("Query with filter", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("active = "). + Placeholder(true). + OrderBy("id ASC") + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != 2 { + t.Errorf("Expected 2 users, got %d", len(users)) + } + + for _, u := range users { + if !u.Active { + t.Errorf("Expected only active users, got inactive user: %+v", u) + } + } + }) + + // Test query empty result + t.Run("Query empty result", func(t *testing.T) { + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("id", "email", "created_at", "active"). + From("users"). + Where("id > "). + Placeholder(100) + + var users []User + err := scanner.Query(&users, q) + + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(users) != 0 { + t.Errorf("Expected 0 users, got %d", len(users)) + } + }) +} + +func TestScanErrors(t *testing.T) { + mockSecrets := &mockSecretsService{} + testDB, err := db.NewTestDB(mockSecrets) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer testDB.Close() + + scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) + q := db.NewQuery(db.DBTypeSQLite) + q.Select("1") + + // Test non-pointer + t.Run("QueryRow non-pointer", func(t *testing.T) { + var user struct{} + err := scanner.QueryRow(user, q) // Passing non-pointer + + if err == nil { + t.Error("Expected error for non-pointer, got nil") + } + }) + + // Test pointer to non-slice for Query + t.Run("Query pointer to non-slice", func(t *testing.T) { + var user struct{} + err := scanner.Query(&user, q) // Passing pointer to struct, not slice + + if err == nil { + t.Error("Expected error for non-slice pointer, got nil") + } + }) + + // Test non-pointer for Query + t.Run("Query non-pointer", func(t *testing.T) { + var users []struct{} + err := scanner.Query(users, q) // Passing non-pointer + + if err == nil { + t.Error("Expected error for non-pointer, got nil") + } + }) +} + +// Mock secrets service for testing +type mockSecretsService struct{} + +func (m *mockSecretsService) Encrypt(plaintext string) (string, error) { + return plaintext, nil +} + +func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) { + return ciphertext, nil +} From d3ffcfbb53748255dd654c28aabf79a6607facee Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 24 Feb 2025 21:42:39 +0100 Subject: [PATCH 07/39] Replace interface{} with any --- server/internal/app/config_test.go | 8 +- server/internal/auth/jwt.go | 2 +- server/internal/db/query.go | 10 +- server/internal/db/query_test.go | 92 +++++++++---------- server/internal/db/scanner.go | 26 +++--- server/internal/handlers/admin_handlers.go | 2 +- .../file_handlers_integration_test.go | 2 +- .../handlers/git_handlers_integration_test.go | 2 +- server/internal/handlers/handlers.go | 2 +- server/internal/handlers/integration_test.go | 4 +- server/internal/storage/filesystem_test.go | 2 +- 11 files changed, 76 insertions(+), 76 deletions(-) diff --git a/server/internal/app/config_test.go b/server/internal/app/config_test.go index 17bef80..3205a57 100644 --- a/server/internal/app/config_test.go +++ b/server/internal/app/config_test.go @@ -15,8 +15,8 @@ func TestDefaultConfig(t *testing.T) { tests := []struct { name string - got interface{} - expected interface{} + got any + expected any }{ {"DBPath", cfg.DBURL, "sqlite://lemma.db"}, {"WorkDir", cfg.WorkDir, "./data"}, @@ -119,8 +119,8 @@ func TestLoad(t *testing.T) { tests := []struct { name string - got interface{} - expected interface{} + got any + expected any }{ {"IsDevelopment", cfg.IsDevelopment, true}, {"DBURL", cfg.DBURL, "/custom/db/path.db"}, diff --git a/server/internal/auth/jwt.go b/server/internal/auth/jwt.go index eefedcb..76fe9b5 100644 --- a/server/internal/auth/jwt.go +++ b/server/internal/auth/jwt.go @@ -112,7 +112,7 @@ func (s *jwtService) generateToken(userID int, role string, sessionID string, to func (s *jwtService) ValidateToken(tokenString string) (*Claims, error) { log := getJWTLogger() - token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (any, error) { // Validate the signing method if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) diff --git a/server/internal/db/query.go b/server/internal/db/query.go index af8fd7c..e9526a4 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -16,7 +16,7 @@ const ( // Query represents a SQL query with its parameters type Query struct { builder strings.Builder - args []interface{} + args []any dbType DBType pos int // tracks the current placeholder position hasSelect bool @@ -35,7 +35,7 @@ type Query struct { func NewQuery(dbType DBType) *Query { return &Query{ dbType: dbType, - args: make([]interface{}, 0), + args: make([]any, 0), } } @@ -234,7 +234,7 @@ func (q *Query) Write(s string) *Query { } // Placeholder adds a placeholder for a single argument -func (q *Query) Placeholder(arg interface{}) *Query { +func (q *Query) Placeholder(arg any) *Query { q.pos++ q.args = append(q.args, arg) @@ -265,7 +265,7 @@ func (q *Query) Placeholders(n int) *Query { } // AddArgs adds arguments to the query -func (q *Query) AddArgs(args ...interface{}) *Query { +func (q *Query) AddArgs(args ...any) *Query { q.args = append(q.args, args...) return q } @@ -276,6 +276,6 @@ func (q *Query) String() string { } // Args returns the query arguments -func (q *Query) Args() []interface{} { +func (q *Query) Args() []any { return q.args } diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go index 6664936..6c29924 100644 --- a/server/internal/db/query_test.go +++ b/server/internal/db/query_test.go @@ -56,7 +56,7 @@ func TestBasicQueryBuilding(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Simple select SQLite", @@ -65,7 +65,7 @@ func TestBasicQueryBuilding(t *testing.T) { return q.Select("id", "name").From("users") }, wantSQL: "SELECT id, name FROM users", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Simple select Postgres", @@ -74,7 +74,7 @@ func TestBasicQueryBuilding(t *testing.T) { return q.Select("id", "name").From("users") }, wantSQL: "SELECT id, name FROM users", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Select with where SQLite", @@ -83,7 +83,7 @@ func TestBasicQueryBuilding(t *testing.T) { return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) }, wantSQL: "SELECT id, name FROM users WHERE id = ?", - wantArgs: []interface{}{1}, + wantArgs: []any{1}, }, { name: "Select with where Postgres", @@ -92,7 +92,7 @@ func TestBasicQueryBuilding(t *testing.T) { return q.Select("id", "name").From("users").Where("id = ").Placeholder(1) }, wantSQL: "SELECT id, name FROM users WHERE id = $1", - wantArgs: []interface{}{1}, + wantArgs: []any{1}, }, { name: "Multiple where conditions SQLite", @@ -103,7 +103,7 @@ func TestBasicQueryBuilding(t *testing.T) { And("role = ").Placeholder("admin") }, wantSQL: "SELECT * FROM users WHERE active = ? AND role = ?", - wantArgs: []interface{}{true, "admin"}, + wantArgs: []any{true, "admin"}, }, { name: "Multiple where conditions Postgres", @@ -114,7 +114,7 @@ func TestBasicQueryBuilding(t *testing.T) { And("role = ").Placeholder("admin") }, wantSQL: "SELECT * FROM users WHERE active = $1 AND role = $2", - wantArgs: []interface{}{true, "admin"}, + wantArgs: []any{true, "admin"}, }, } @@ -143,7 +143,7 @@ func TestPlaceholders(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Single placeholder SQLite", @@ -152,7 +152,7 @@ func TestPlaceholders(t *testing.T) { return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) }, wantSQL: "SELECT * FROM users WHERE id = ?", - wantArgs: []interface{}{42}, + wantArgs: []any{42}, }, { name: "Single placeholder Postgres", @@ -161,7 +161,7 @@ func TestPlaceholders(t *testing.T) { return q.Write("SELECT * FROM users WHERE id = ").Placeholder(42) }, wantSQL: "SELECT * FROM users WHERE id = $1", - wantArgs: []interface{}{42}, + wantArgs: []any{42}, }, { name: "Multiple placeholders SQLite", @@ -173,7 +173,7 @@ func TestPlaceholders(t *testing.T) { Placeholder("John") }, wantSQL: "SELECT * FROM users WHERE id = ? AND name = ?", - wantArgs: []interface{}{42, "John"}, + wantArgs: []any{42, "John"}, }, { name: "Multiple placeholders Postgres", @@ -185,7 +185,7 @@ func TestPlaceholders(t *testing.T) { Placeholder("John") }, wantSQL: "SELECT * FROM users WHERE id = $1 AND name = $2", - wantArgs: []interface{}{42, "John"}, + wantArgs: []any{42, "John"}, }, { name: "Placeholders for IN SQLite", @@ -197,7 +197,7 @@ func TestPlaceholders(t *testing.T) { AddArgs(1, 2, 3) }, wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", - wantArgs: []interface{}{1, 2, 3}, + wantArgs: []any{1, 2, 3}, }, { name: "Placeholders for IN Postgres", @@ -209,7 +209,7 @@ func TestPlaceholders(t *testing.T) { AddArgs(1, 2, 3) }, wantSQL: "SELECT * FROM users WHERE id IN ($1, $2, $3)", - wantArgs: []interface{}{1, 2, 3}, + wantArgs: []any{1, 2, 3}, }, } @@ -238,7 +238,7 @@ func TestWhereClauseBuilding(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Simple where", @@ -247,7 +247,7 @@ func TestWhereClauseBuilding(t *testing.T) { return q.Select("*").From("users").Where("id = ").Placeholder(1) }, wantSQL: "SELECT * FROM users WHERE id = ?", - wantArgs: []interface{}{1}, + wantArgs: []any{1}, }, { name: "Where with And", @@ -258,7 +258,7 @@ func TestWhereClauseBuilding(t *testing.T) { And("active = ").Placeholder(true) }, wantSQL: "SELECT * FROM users WHERE id = ? AND active = ?", - wantArgs: []interface{}{1, true}, + wantArgs: []any{1, true}, }, { name: "Where with Or", @@ -269,7 +269,7 @@ func TestWhereClauseBuilding(t *testing.T) { Or("id = ").Placeholder(2) }, wantSQL: "SELECT * FROM users WHERE id = ? OR id = ?", - wantArgs: []interface{}{1, 2}, + wantArgs: []any{1, 2}, }, { name: "Where with parentheses", @@ -283,7 +283,7 @@ func TestWhereClauseBuilding(t *testing.T) { Write(")") }, wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", - wantArgs: []interface{}{true, 1, 2}, + wantArgs: []any{true, 1, 2}, }, { name: "Where with StartGroup and EndGroup", @@ -297,7 +297,7 @@ func TestWhereClauseBuilding(t *testing.T) { Write(")") }, wantSQL: "SELECT * FROM users WHERE active = ? AND (id = ? OR id = ?)", - wantArgs: []interface{}{true, 1, 2}, + wantArgs: []any{true, 1, 2}, }, { name: "Where with nested groups", @@ -311,7 +311,7 @@ func TestWhereClauseBuilding(t *testing.T) { And("created_at > ").Placeholder("2020-01-01") }, wantSQL: "SELECT * FROM users WHERE (active = ? OR role = ?) AND created_at > ?", - wantArgs: []interface{}{true, "admin", "2020-01-01"}, + wantArgs: []any{true, "admin", "2020-01-01"}, }, { name: "WhereIn", @@ -322,7 +322,7 @@ func TestWhereClauseBuilding(t *testing.T) { AddArgs(1, 2, 3) }, wantSQL: "SELECT * FROM users WHERE id IN (?, ?, ?)", - wantArgs: []interface{}{1, 2, 3}, + wantArgs: []any{1, 2, 3}, }, } @@ -351,7 +351,7 @@ func TestJoinClauseBuilding(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Inner join", @@ -362,7 +362,7 @@ func TestJoinClauseBuilding(t *testing.T) { Join(db.InnerJoin, "workspaces w", "w.user_id = u.id") }, wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Left join", @@ -373,7 +373,7 @@ func TestJoinClauseBuilding(t *testing.T) { Join(db.LeftJoin, "workspaces w", "w.user_id = u.id") }, wantSQL: "SELECT u.*, w.name FROM users u LEFT JOIN workspaces w ON w.user_id = u.id", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Multiple joins", @@ -385,7 +385,7 @@ func TestJoinClauseBuilding(t *testing.T) { Join(db.LeftJoin, "settings s", "s.user_id = u.id") }, wantSQL: "SELECT u.*, w.name, s.role FROM users u INNER JOIN workspaces w ON w.user_id = u.id LEFT JOIN settings s ON s.user_id = u.id", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Join with where", @@ -397,7 +397,7 @@ func TestJoinClauseBuilding(t *testing.T) { Where("u.active = ").Placeholder(true) }, wantSQL: "SELECT u.*, w.name FROM users u INNER JOIN workspaces w ON w.user_id = u.id WHERE u.active = ?", - wantArgs: []interface{}{true}, + wantArgs: []any{true}, }, } @@ -426,7 +426,7 @@ func TestOrderLimitOffset(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Order by", @@ -435,7 +435,7 @@ func TestOrderLimitOffset(t *testing.T) { return q.Select("*").From("users").OrderBy("name ASC") }, wantSQL: "SELECT * FROM users ORDER BY name ASC", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Order by multiple columns", @@ -444,7 +444,7 @@ func TestOrderLimitOffset(t *testing.T) { return q.Select("*").From("users").OrderBy("name ASC", "created_at DESC") }, wantSQL: "SELECT * FROM users ORDER BY name ASC, created_at DESC", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Limit", @@ -453,7 +453,7 @@ func TestOrderLimitOffset(t *testing.T) { return q.Select("*").From("users").Limit(10) }, wantSQL: "SELECT * FROM users LIMIT 10", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Limit and offset", @@ -462,7 +462,7 @@ func TestOrderLimitOffset(t *testing.T) { return q.Select("*").From("users").Limit(10).Offset(20) }, wantSQL: "SELECT * FROM users LIMIT 10 OFFSET 20", - wantArgs: []interface{}{}, + wantArgs: []any{}, }, { name: "Complete query with all clauses", @@ -476,7 +476,7 @@ func TestOrderLimitOffset(t *testing.T) { Offset(20) }, wantSQL: "SELECT * FROM users WHERE active = ? ORDER BY name ASC LIMIT 10 OFFSET 20", - wantArgs: []interface{}{true}, + wantArgs: []any{true}, }, } @@ -505,7 +505,7 @@ func TestInsertUpdateDelete(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Insert SQLite", @@ -516,7 +516,7 @@ func TestInsertUpdateDelete(t *testing.T) { AddArgs("John", "john@example.com") }, wantSQL: "INSERT INTO users (name, email) VALUES (?, ?)", - wantArgs: []interface{}{"John", "john@example.com"}, + wantArgs: []any{"John", "john@example.com"}, }, { name: "Insert Postgres", @@ -527,7 +527,7 @@ func TestInsertUpdateDelete(t *testing.T) { AddArgs("John", "john@example.com") }, wantSQL: "INSERT INTO users (name, email) VALUES ($1, $2)", - wantArgs: []interface{}{"John", "john@example.com"}, + wantArgs: []any{"John", "john@example.com"}, }, { name: "Update SQLite", @@ -539,7 +539,7 @@ func TestInsertUpdateDelete(t *testing.T) { Where("id = ").Placeholder(1) }, wantSQL: "UPDATE users SET name = ?, email = ? WHERE id = ?", - wantArgs: []interface{}{"John", "john@example.com", 1}, + wantArgs: []any{"John", "john@example.com", 1}, }, { name: "Update Postgres", @@ -551,7 +551,7 @@ func TestInsertUpdateDelete(t *testing.T) { Where("id = ").Placeholder(1) }, wantSQL: "UPDATE users SET name = $1, email = $2 WHERE id = $3", - wantArgs: []interface{}{"John", "john@example.com", 1}, + wantArgs: []any{"John", "john@example.com", 1}, }, { name: "Delete SQLite", @@ -560,7 +560,7 @@ func TestInsertUpdateDelete(t *testing.T) { return q.Delete().From("users").Where("id = ").Placeholder(1) }, wantSQL: "DELETE FROM users WHERE id = ?", - wantArgs: []interface{}{1}, + wantArgs: []any{1}, }, { name: "Delete Postgres", @@ -569,7 +569,7 @@ func TestInsertUpdateDelete(t *testing.T) { return q.Delete().From("users").Where("id = ").Placeholder(1) }, wantSQL: "DELETE FROM users WHERE id = $1", - wantArgs: []interface{}{1}, + wantArgs: []any{1}, }, } @@ -598,7 +598,7 @@ func TestHavingClause(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Simple having", @@ -610,7 +610,7 @@ func TestHavingClause(t *testing.T) { Having("count > ").Placeholder(5) }, wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > ?", - wantArgs: []interface{}{5}, + wantArgs: []any{5}, }, { name: "Having with multiple conditions", @@ -623,7 +623,7 @@ func TestHavingClause(t *testing.T) { And("COUNT(*) > ").Placeholder(3) }, wantSQL: "SELECT department, AVG(salary) as avg_salary FROM employees GROUP BY department HAVING avg_salary > ? AND COUNT(*) > ?", - wantArgs: []interface{}{50000, 3}, + wantArgs: []any{50000, 3}, }, { name: "Having with postgres placeholders", @@ -635,7 +635,7 @@ func TestHavingClause(t *testing.T) { Having("count > ").Placeholder(5) }, wantSQL: "SELECT department, COUNT(*) as count FROM employees GROUP BY department HAVING count > $1", - wantArgs: []interface{}{5}, + wantArgs: []any{5}, }, } @@ -664,7 +664,7 @@ func TestComplexQueries(t *testing.T) { dbType db.DBType buildFn func(*db.Query) *db.Query wantSQL string - wantArgs []interface{} + wantArgs []any }{ { name: "Complex select with join and where", @@ -680,7 +680,7 @@ func TestComplexQueries(t *testing.T) { Limit(10) }, wantSQL: "SELECT u.id, u.name, COUNT(w.id) as workspace_count FROM users u LEFT JOIN workspaces w ON w.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(w.id) > ? ORDER BY workspace_count DESC LIMIT 10", - wantArgs: []interface{}{true, 0}, + wantArgs: []any{true, 0}, }, } diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go index 97e0422..f50c296 100644 --- a/server/internal/db/scanner.go +++ b/server/internal/db/scanner.go @@ -23,7 +23,7 @@ func NewScanner(db *sql.DB, dbType DBType) *Scanner { } // QueryRow executes a query and scans the result into a struct -func (s *Scanner) QueryRow(dest interface{}, q *Query) error { +func (s *Scanner) QueryRow(dest any, q *Query) error { row := s.db.QueryRow(q.String(), q.Args()...) // Handle primitive types @@ -46,7 +46,7 @@ func (s *Scanner) QueryRow(dest interface{}, q *Query) error { } // Query executes a query and scans multiple results into a slice of structs -func (s *Scanner) Query(dest interface{}, q *Query) error { +func (s *Scanner) Query(dest any, q *Query) error { rows, err := s.db.Query(q.String(), q.Args()...) if err != nil { return err @@ -57,7 +57,7 @@ func (s *Scanner) Query(dest interface{}, q *Query) error { } // scanStruct scans a single row into a struct -func scanStruct(row *sql.Row, dest interface{}) error { +func scanStruct(row *sql.Row, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return fmt.Errorf("dest must be a pointer") @@ -67,7 +67,7 @@ func scanStruct(row *sql.Row, dest interface{}) error { return fmt.Errorf("dest must be a pointer to a struct") } - fields := make([]interface{}, 0, v.NumField()) + fields := make([]any, 0, v.NumField()) for i := 0; i < v.NumField(); i++ { field := v.Field(i) @@ -80,7 +80,7 @@ func scanStruct(row *sql.Row, dest interface{}) error { } // scanStructs scans multiple rows into a slice of structs -func scanStructs(rows *sql.Rows, dest interface{}) error { +func scanStructs(rows *sql.Rows, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return fmt.Errorf("dest must be a pointer") @@ -95,7 +95,7 @@ func scanStructs(rows *sql.Rows, dest interface{}) error { for rows.Next() { newElem := reflect.New(elemType).Elem() - fields := make([]interface{}, 0, newElem.NumField()) + fields := make([]any, 0, newElem.NumField()) for i := 0; i < newElem.NumField(); i++ { field := newElem.Field(i) @@ -129,7 +129,7 @@ func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx { } // QueryRow executes a query and scans the result into a struct -func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error { +func (s *ScannerEx) QueryRow(dest any, q *Query) error { row := s.db.QueryRow(q.String(), q.Args()...) // Get column names @@ -141,7 +141,7 @@ func (s *ScannerEx) QueryRow(dest interface{}, q *Query) error { } // Query executes a query and scans multiple results into a slice of structs -func (s *ScannerEx) Query(dest interface{}, q *Query) error { +func (s *ScannerEx) Query(dest any, q *Query) error { rows, err := s.db.Query(q.String(), q.Args()...) if err != nil { return err @@ -189,7 +189,7 @@ func toSnakeCase(s string) string { } // scanStructTags scans a single row into a struct using field tags -func scanStructTags(row *sql.Row, dest interface{}) error { +func scanStructTags(row *sql.Row, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return fmt.Errorf("dest must be a pointer") @@ -199,7 +199,7 @@ func scanStructTags(row *sql.Row, dest interface{}) error { return fmt.Errorf("dest must be a pointer to a struct") } - fields := make([]interface{}, 0, v.NumField()) + fields := make([]any, 0, v.NumField()) for i := 0; i < v.NumField(); i++ { field := v.Field(i) @@ -212,7 +212,7 @@ func scanStructTags(row *sql.Row, dest interface{}) error { } // scanStructsTags scans multiple rows into a slice of structs using field tags -func scanStructsTags(rows *sql.Rows, dest interface{}) error { +func scanStructsTags(rows *sql.Rows, dest any) error { v := reflect.ValueOf(dest) if v.Kind() != reflect.Ptr { return fmt.Errorf("dest must be a pointer") @@ -243,8 +243,8 @@ func scanStructsTags(rows *sql.Rows, dest interface{}) error { fieldMap := getFieldMap(elemType) // Prepare values slice for each scan - values := make([]interface{}, len(columns)) - scanFields := make([]interface{}, len(columns)) + values := make([]any, len(columns)) + scanFields := make([]any, len(columns)) for i := range values { scanFields[i] = &values[i] } diff --git a/server/internal/handlers/admin_handlers.go b/server/internal/handlers/admin_handlers.go index 082edaf..47c9d04 100644 --- a/server/internal/handlers/admin_handlers.go +++ b/server/internal/handlers/admin_handlers.go @@ -308,7 +308,7 @@ func (h *Handler) AdminUpdateUser() http.HandlerFunc { } // Track what's being updated for logging - updates := make(map[string]interface{}) + updates := make(map[string]any) if req.Email != "" { user.Email = req.Email diff --git a/server/internal/handlers/file_handlers_integration_test.go b/server/internal/handlers/file_handlers_integration_test.go index 384d2c3..c9a35a4 100644 --- a/server/internal/handlers/file_handlers_integration_test.go +++ b/server/internal/handlers/file_handlers_integration_test.go @@ -192,7 +192,7 @@ func TestFileHandlers_Integration(t *testing.T) { name string method string path string - body interface{} + body any }{ {"list files", http.MethodGet, baseURL, nil}, {"get file", http.MethodGet, baseURL + "/test.md", nil}, diff --git a/server/internal/handlers/git_handlers_integration_test.go b/server/internal/handlers/git_handlers_integration_test.go index 87f21ec..f458b86 100644 --- a/server/internal/handlers/git_handlers_integration_test.go +++ b/server/internal/handlers/git_handlers_integration_test.go @@ -123,7 +123,7 @@ func TestGitHandlers_Integration(t *testing.T) { name string method string path string - body interface{} + body any }{ { name: "commit without token", diff --git a/server/internal/handlers/handlers.go b/server/internal/handlers/handlers.go index 357f007..5e37571 100644 --- a/server/internal/handlers/handlers.go +++ b/server/internal/handlers/handlers.go @@ -37,7 +37,7 @@ func NewHandler(db db.Database, s storage.Manager) *Handler { } // respondJSON is a helper to send JSON responses -func respondJSON(w http.ResponseWriter, data interface{}) { +func respondJSON(w http.ResponseWriter, data any) { w.Header().Set("Content-Type", "application/json") if err := json.NewEncoder(w).Encode(data); err != nil { respondError(w, "Failed to encode response", http.StatusInternalServerError) diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index fb66585..0bb8bd3 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -195,7 +195,7 @@ func (h *testHarness) createTestUser(t *testing.T, email, password string, role } } -func (h *testHarness) newRequest(t *testing.T, method, path string, body interface{}) *http.Request { +func (h *testHarness) newRequest(t *testing.T, method, path string, body any) *http.Request { t.Helper() var reqBody []byte @@ -246,7 +246,7 @@ func (h *testHarness) addCSRFCookie(t *testing.T, req *http.Request) string { } // makeRequest is the main helper for making JSON requests -func (h *testHarness) makeRequest(t *testing.T, method, path string, body interface{}, testUser *testUser) *httptest.ResponseRecorder { +func (h *testHarness) makeRequest(t *testing.T, method, path string, body any, testUser *testUser) *httptest.ResponseRecorder { t.Helper() req := h.newRequest(t, method, path, body) diff --git a/server/internal/storage/filesystem_test.go b/server/internal/storage/filesystem_test.go index a35c692..e717a16 100644 --- a/server/internal/storage/filesystem_test.go +++ b/server/internal/storage/filesystem_test.go @@ -37,7 +37,7 @@ func (m MockDirInfo) Size() int64 { return m.size } func (m MockDirInfo) Mode() fs.FileMode { return m.mode } func (m MockDirInfo) ModTime() time.Time { return m.modTime } func (m MockDirInfo) IsDir() bool { return m.isDir } -func (m MockDirInfo) Sys() interface{} { return nil } +func (m MockDirInfo) Sys() any { return nil } type mockFS struct { // Record operations for verification From 27b81ef43301b3f459d6960dad96db002aa48886 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 21:36:26 +0100 Subject: [PATCH 08/39] Update sessions to query builder --- server/internal/db/sessions.go | 53 ++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index f5b8f81..7df51c6 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -10,11 +10,11 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { - _, err := db.Exec(` - INSERT INTO sessions (id, user_id, refresh_token, expires_at, created_at) - VALUES (?, ?, ?, ?, ?)`, - session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt, - ) + query := NewQuery(db.dbType). + Insert("sessions", "id", "user_id", "refresh_token", "expires_at", "created_at"). + Values(5). + AddArgs(session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt) + _, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to store session: %w", err) } @@ -25,12 +25,14 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE refresh_token = ? AND expires_at > ?`, - refreshToken, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + query := NewQuery(db.dbType). + Select("id", "user_id", "refresh_token", "expires_at", "created_at"). + From("sessions"). + Where("refresh_token = "). + Placeholder(refreshToken). + And("expires_at >"). + Placeholder(time.Now()) + err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found or expired") @@ -45,12 +47,14 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - err := db.QueryRow(` - SELECT id, user_id, refresh_token, expires_at, created_at - FROM sessions - WHERE id = ? AND expires_at > ?`, - sessionID, time.Now(), - ).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + query := NewQuery(db.dbType). + Select("id", "user_id", "refresh_token", "expires_at", "created_at"). + From("sessions"). + Where("id = "). + Placeholder(sessionID). + And("expires_at >"). + Placeholder(time.Now()) + err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found") @@ -64,7 +68,13 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { // DeleteSession removes a session from the database func (db *database) DeleteSession(sessionID string) error { - result, err := db.Exec("DELETE FROM sessions WHERE id = ?", sessionID) + query := NewQuery(db.dbType). + Delete(). + From("sessions"). + Where("id = "). + Placeholder(sessionID) + + result, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to delete session: %w", err) } @@ -84,7 +94,12 @@ func (db *database) DeleteSession(sessionID string) error { // CleanExpiredSessions removes all expired sessions from the database func (db *database) CleanExpiredSessions() error { log := getLogger().WithGroup("sessions") - result, err := db.Exec("DELETE FROM sessions WHERE expires_at <= ?", time.Now()) + query := NewQuery(db.dbType). + Delete(). + From("sessions"). + Where("expires_at <="). + Placeholder(time.Now()) + result, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to clean expired sessions: %w", err) } From 7d05c8aaccf992b6d7a1f1426097134e078b9d96 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 21:36:42 +0100 Subject: [PATCH 09/39] Update system to query builder --- server/internal/db/system.go | 39 +++++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/server/internal/db/system.go b/server/internal/db/system.go index 447d421..d81aef3 100644 --- a/server/internal/db/system.go +++ b/server/internal/db/system.go @@ -49,7 +49,12 @@ func (db *database) EnsureJWTSecret() (string, error) { // GetSystemSetting retrieves a system setting by key func (db *database) GetSystemSetting(key string) (string, error) { var value string - err := db.QueryRow("SELECT value FROM system_settings WHERE key = ?", key).Scan(&value) + query := NewQuery(db.dbType). + Select("value"). + From("system_settings"). + Where("key = "). + Placeholder(key) + err := db.QueryRow(query.String(), query.args...).Scan(&value) if err != nil { return "", err } @@ -59,11 +64,14 @@ func (db *database) GetSystemSetting(key string) (string, error) { // SetSystemSetting stores or updates a system setting func (db *database) SetSystemSetting(key, value string) error { - _, err := db.Exec(` - INSERT INTO system_settings (key, value) - VALUES (?, ?) - ON CONFLICT(key) DO UPDATE SET value = ?`, - key, value, value) + query := NewQuery(db.dbType). + Insert("system_settings", "key", "value"). + Values(2). + AddArgs(key, value). + Write("ON CONFLICT(key) DO UPDATE SET value = "). + Placeholder(value) + + _, err := db.Exec(query.String(), query.args...) if err != nil { return fmt.Errorf("failed to store system setting: %w", err) @@ -92,22 +100,29 @@ func (db *database) GetSystemStats() (*UserStats, error) { stats := &UserStats{} // Get total users - err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&stats.TotalUsers) + query := NewQuery(db.dbType). + Select("COUNT(*)"). + From("users") + err := db.QueryRow(query.String()).Scan(&stats.TotalUsers) if err != nil { return nil, fmt.Errorf("failed to get total users count: %w", err) } // Get total workspaces - err = db.QueryRow("SELECT COUNT(*) FROM workspaces").Scan(&stats.TotalWorkspaces) + query = NewQuery(db.dbType). + Select("COUNT(*)"). + From("workspaces") + err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces) if err != nil { return nil, fmt.Errorf("failed to get total workspaces count: %w", err) } // Get active users (users with activity in last 30 days) - err = db.QueryRow(` - SELECT COUNT(DISTINCT user_id) - FROM sessions - WHERE created_at > datetime('now', '-30 days')`). + query = NewQuery(db.dbType). + Select("COUNT(DISTINCT user_id)"). + From("sessions"). + Where("created_at > datetime('now', '-30 days')") + err = db.QueryRow(query.String()). Scan(&stats.ActiveUsers) if err != nil { return nil, fmt.Errorf("failed to get active users count: %w", err) From 802f192dc0f9eaa54979d156a377d168a4b04238 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 21:36:55 +0100 Subject: [PATCH 10/39] Implement returning clause --- server/internal/db/query.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/server/internal/db/query.go b/server/internal/db/query.go index e9526a4..a552a8e 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -227,6 +227,18 @@ func (q *Query) EndGroup() *Query { return q } +// Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+) +func (q *Query) Returning(columns ...string) *Query { + q.Write(" RETURNING ") + if len(columns) == 1 && columns[0] == "*" { + q.Write("*") + } else { + q.Write(strings.Join(columns, ", ")) + } + + return q +} + // Write adds a string to the query func (q *Query) Write(s string) *Query { q.builder.WriteString(s) From 9da51aeb5e3b0222a06b6261de04232e39bfaab0 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 21:37:06 +0100 Subject: [PATCH 11/39] Update users to query builder --- server/internal/db/users.go | 182 +++++++++++++++++++++--------------- 1 file changed, 107 insertions(+), 75 deletions(-) diff --git a/server/internal/db/users.go b/server/internal/db/users.go index d8fd3c7..5260a19 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -17,26 +17,18 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } defer tx.Rollback() - result, err := tx.Exec(` - INSERT INTO users (email, display_name, password_hash, role) - VALUES (?, ?, ?, ?)`, - user.Email, user.DisplayName, user.PasswordHash, user.Role) + query := NewQuery(db.dbType). + Insert("users", "email", "display_name", "password_hash", "role"). + Values(4). + AddArgs(user.Email, user.DisplayName, user.PasswordHash, user.Role). + Returning("id", "created_at") + + err = tx.QueryRow(query.String(), query.Args()...). + Scan(&user.ID, &user.CreatedAt) if err != nil { return nil, fmt.Errorf("failed to insert user: %w", err) } - userID, err := result.LastInsertId() - if err != nil { - return nil, fmt.Errorf("failed to get last insert ID: %w", err) - } - user.ID = int(userID) - - // Retrieve the created_at timestamp - err = tx.QueryRow("SELECT created_at FROM users WHERE id = ?", user.ID).Scan(&user.CreatedAt) - if err != nil { - return nil, fmt.Errorf("failed to get created timestamp: %w", err) - } - // Create default workspace with default settings defaultWorkspace := &models.Workspace{ UserID: user.ID, @@ -51,7 +43,13 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } // Update user's last workspace ID - _, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", defaultWorkspace.ID, user.ID) + query = NewQuery(db.dbType). + Update("users"). + Set("last_workspace_id"). + Placeholder(defaultWorkspace.ID). + Where("id = "). + Placeholder(user.ID) + _, err = tx.Exec(query.String(), query.Args()...) if err != nil { return nil, fmt.Errorf("failed to update last workspace ID: %w", err) } @@ -70,44 +68,43 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { // Helper function to create a workspace in a transaction func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { log := getLogger().WithGroup("users") - result, err := tx.Exec(` - INSERT INTO workspaces ( - user_id, name, - theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - workspace.UserID, workspace.Name, - workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, - workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, - workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, - workspace.GitCommitName, workspace.GitCommitEmail, - ) + + insertQuery := NewQuery(db.dbType). + Insert("workspaces", + "user_id", "name", + "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email"). + Values(13). + AddArgs( + workspace.UserID, workspace.Name, + workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, + workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, + workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, + workspace.GitCommitName, workspace.GitCommitEmail). + Returning("id") + + err := tx.QueryRow(insertQuery.String(), insertQuery.Args()...).Scan(&workspace.ID) if err != nil { return fmt.Errorf("failed to insert workspace: %w", err) } - id, err := result.LastInsertId() - if err != nil { - return fmt.Errorf("failed to get workspace ID: %w", err) - } - workspace.ID = int(id) - log.Debug("created user workspace", "workspace_id", workspace.ID, "user_id", workspace.UserID) return nil } +// GetUserByID retrieves a user by its ID func (db *database) GetUserByID(id int) (*models.User, error) { + query := NewQuery(db.dbType). + Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). + From("users"). + Where("id = ").Placeholder(id) + user := &models.User{} - err := db.QueryRow(` - SELECT - id, email, display_name, password_hash, role, created_at, - last_workspace_id - FROM users - WHERE id = ?`, id). + err := db.QueryRow(query.String(), query.Args()...). Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt, &user.LastWorkspaceID) @@ -120,14 +117,15 @@ func (db *database) GetUserByID(id int) (*models.User, error) { return user, nil } +// GetUserByEmail retrieves a user by its email func (db *database) GetUserByEmail(email string) (*models.User, error) { + query := NewQuery(db.dbType). + Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). + From("users"). + Where("email = ").Placeholder(email) + user := &models.User{} - err := db.QueryRow(` - SELECT - id, email, display_name, password_hash, role, created_at, - last_workspace_id - FROM users - WHERE email = ?`, email). + err := db.QueryRow(query.String(), query.Args()...). Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, &user.Role, &user.CreatedAt, &user.LastWorkspaceID) @@ -141,14 +139,18 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) { return user, nil } +// UpdateUser updates an existing user record in the database func (db *database) UpdateUser(user *models.User) error { - result, err := db.Exec(` - UPDATE users - SET email = ?, display_name = ?, password_hash = ?, role = ?, last_workspace_id = ? - WHERE id = ?`, - user.Email, user.DisplayName, user.PasswordHash, user.Role, - user.LastWorkspaceID, user.ID) + query := NewQuery(db.dbType). + Update("users"). + Set("email").Placeholder(user.Email). + Set("display_name").Placeholder(user.DisplayName). + Set("password_hash").Placeholder(user.PasswordHash). + Set("role").Placeholder(user.Role). + Set("last_workspace_id").Placeholder(user.LastWorkspaceID). + Where("id = ").Placeholder(user.ID) + result, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to update user: %w", err) } @@ -165,13 +167,14 @@ func (db *database) UpdateUser(user *models.User) error { return nil } +// GetAllUsers retrieves all users from the database func (db *database) GetAllUsers() ([]*models.User, error) { - rows, err := db.Query(` - SELECT - id, email, display_name, role, created_at, - last_workspace_id - FROM users - ORDER BY id ASC`) + query := NewQuery(db.dbType). + Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). + From("users"). + OrderBy("id ASC") + + rows, err := db.Query(query.String(), query.Args()...) if err != nil { return nil, fmt.Errorf("failed to query users: %w", err) } @@ -200,15 +203,26 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error } defer tx.Rollback() + // Find workspace ID from name + workspaceQuery := NewQuery(db.dbType). + Select("id"). + From("workspaces"). + Where("user_id = ").Placeholder(userID). + And("name = ").Placeholder(workspaceName) + var workspaceID int - err = tx.QueryRow("SELECT id FROM workspaces WHERE user_id = ? AND name = ?", - userID, workspaceName).Scan(&workspaceID) + err = tx.QueryRow(workspaceQuery.String(), workspaceQuery.Args()...).Scan(&workspaceID) if err != nil { return fmt.Errorf("failed to find workspace: %w", err) } - _, err = tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", - workspaceID, userID) + // Update user's last workspace + updateQuery := NewQuery(db.dbType). + Update("users"). + Set("last_workspace_id").Placeholder(workspaceID). + Where("id = ").Placeholder(userID) + + _, err = tx.Exec(updateQuery.String(), updateQuery.Args()...) if err != nil { return fmt.Errorf("failed to update last workspace: %w", err) } @@ -221,6 +235,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error return nil } +// DeleteUser deletes a user and all their workspaces func (db *database) DeleteUser(id int) error { log := getLogger().WithGroup("users") log.Debug("deleting user", "user_id", id) @@ -233,13 +248,24 @@ func (db *database) DeleteUser(id int) error { // Delete all user's workspaces first log.Debug("deleting user workspaces", "user_id", id) - _, err = tx.Exec("DELETE FROM workspaces WHERE user_id = ?", id) + + deleteWorkspacesQuery := NewQuery(db.dbType). + Delete(). + From("workspaces"). + Where("user_id = ").Placeholder(id) + + _, err = tx.Exec(deleteWorkspacesQuery.String(), deleteWorkspacesQuery.Args()...) if err != nil { return fmt.Errorf("failed to delete workspaces: %w", err) } // Delete the user - _, err = tx.Exec("DELETE FROM users WHERE id = ?", id) + deleteUserQuery := NewQuery(db.dbType). + Delete(). + From("users"). + Where("id = ").Placeholder(id) + + _, err = tx.Exec(deleteUserQuery.String(), deleteUserQuery.Args()...) if err != nil { return fmt.Errorf("failed to delete user: %w", err) } @@ -253,15 +279,16 @@ func (db *database) DeleteUser(id int) error { return nil } +// GetLastWorkspaceName retrieves the name of the last workspace accessed by a user func (db *database) GetLastWorkspaceName(userID int) (string, error) { + query := NewQuery(db.dbType). + Select("w.name"). + From("workspaces w"). + Join(InnerJoin, "users u", "u.last_workspace_id = w.id"). + Where("u.id = ").Placeholder(userID) + var workspaceName string - err := db.QueryRow(` - SELECT - w.name - FROM workspaces w - JOIN users u ON u.last_workspace_id = w.id - WHERE u.id = ?`, userID). - Scan(&workspaceName) + err := db.QueryRow(query.String(), query.Args()...).Scan(&workspaceName) if err == sql.ErrNoRows { return "", fmt.Errorf("no last workspace found") @@ -275,8 +302,13 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) { // CountAdminUsers returns the number of admin users in the system func (db *database) CountAdminUsers() (int, error) { + query := NewQuery(db.dbType). + Select("COUNT(*)"). + From("users"). + Where("role = ").Placeholder(models.RoleAdmin) + var count int - err := db.QueryRow("SELECT COUNT(*) FROM users WHERE role = 'admin'").Scan(&count) + err := db.QueryRow(query.String(), query.Args()...).Scan(&count) if err != nil { return 0, fmt.Errorf("failed to count admin users: %w", err) } From 3b7deaa10773db867573a26c2bd86ccb15393210 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 22:11:59 +0100 Subject: [PATCH 12/39] Update workspaces to query builder --- server/internal/db/workspaces.go | 289 +++++++++++++++++-------------- 1 file changed, 161 insertions(+), 128 deletions(-) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 1d7fcc8..06bff0c 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -25,51 +25,53 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { return fmt.Errorf("failed to encrypt token: %w", err) } - result, err := db.Exec(` - INSERT INTO workspaces ( - user_id, name, theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, - workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, - workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail, - ) + query := NewQuery(db.dbType). + Insert("workspaces", + "user_id", "name", "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email"). + Values(13). + AddArgs( + workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, + workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, + workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail). + Returning("id", "created_at") + + err = db.QueryRow(query.String(), query.Args()...). + Scan(&workspace.ID, &workspace.CreatedAt) if err != nil { return fmt.Errorf("failed to insert workspace: %w", err) } - id, err := result.LastInsertId() - if err != nil { - return fmt.Errorf("failed to get workspace ID: %w", err) - } - workspace.ID = int(id) - return nil } // GetWorkspaceByID retrieves a workspace by its ID func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { + query := NewQuery(db.dbType). + Select( + "id", "user_id", "name", "created_at", + "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email", + "last_opened_file_path"). + From("workspaces"). + Where("id = ").Placeholder(id) + workspace := &models.Workspace{} var encryptedToken string - err := db.QueryRow(` - SELECT - id, user_id, name, created_at, - theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - FROM workspaces - WHERE id = ?`, - id, - ).Scan( + var lastOpenedFile sql.NullString + + err := db.QueryRow(query.String(), query.Args()...).Scan( &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail, + &lastOpenedFile, ) if err == sql.ErrNoRows { @@ -79,6 +81,10 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { return nil, fmt.Errorf("failed to fetch workspace: %w", err) } + if lastOpenedFile.Valid { + workspace.LastOpenedFilePath = lastOpenedFile.String + } + // Decrypt token workspace.GitToken, err = db.decryptToken(encryptedToken) if err != nil { @@ -90,25 +96,29 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { // GetWorkspaceByName retrieves a workspace by its name and user ID func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { + query := NewQuery(db.dbType). + Select( + "id", "user_id", "name", "created_at", + "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email", + "last_opened_file_path"). + From("workspaces"). + Where("user_id = ").Placeholder(userID). + And("name = ").Placeholder(workspaceName) + workspace := &models.Workspace{} var encryptedToken string + var lastOpenedFile sql.NullString - err := db.QueryRow(` - SELECT - id, user_id, name, created_at, - theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - FROM workspaces - WHERE user_id = ? AND name = ?`, - userID, workspaceName, - ).Scan( + err := db.QueryRow(query.String(), query.Args()...).Scan( &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail, + &lastOpenedFile, ) if err == sql.ErrNoRows { @@ -118,6 +128,10 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model return nil, fmt.Errorf("failed to fetch workspace: %w", err) } + if lastOpenedFile.Valid { + workspace.LastOpenedFilePath = lastOpenedFile.String + } + // Decrypt token workspace.GitToken, err = db.decryptToken(encryptedToken) if err != nil { @@ -135,37 +149,24 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { return fmt.Errorf("failed to encrypt token: %w", err) } - _, err = db.Exec(` - UPDATE workspaces - SET - name = ?, - theme = ?, - auto_save = ?, - show_hidden_files = ?, - git_enabled = ?, - git_url = ?, - git_user = ?, - git_token = ?, - git_auto_commit = ?, - git_commit_msg_template = ?, - git_commit_name = ?, - git_commit_email = ? - WHERE id = ? AND user_id = ?`, - workspace.Name, - workspace.Theme, - workspace.AutoSave, - workspace.ShowHiddenFiles, - workspace.GitEnabled, - workspace.GitURL, - workspace.GitUser, - encryptedToken, - workspace.GitAutoCommit, - workspace.GitCommitMsgTemplate, - workspace.GitCommitName, - workspace.GitCommitEmail, - workspace.ID, - workspace.UserID, - ) + query := NewQuery(db.dbType). + Update("workspaces"). + Set("name").Placeholder(workspace.Name). + Set("theme").Placeholder(workspace.Theme). + Set("auto_save").Placeholder(workspace.AutoSave). + Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). + Set("git_enabled").Placeholder(workspace.GitEnabled). + Set("git_url").Placeholder(workspace.GitURL). + Set("git_user").Placeholder(workspace.GitUser). + Set("git_token").Placeholder(encryptedToken). + Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). + Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). + Set("git_commit_name").Placeholder(workspace.GitCommitName). + Set("git_commit_email").Placeholder(workspace.GitCommitEmail). + Where("id = ").Placeholder(workspace.ID). + And("user_id = ").Placeholder(workspace.UserID) + + _, err = db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to update workspace: %w", err) } @@ -175,17 +176,18 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // GetWorkspacesByUserID retrieves all workspaces for a user func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { - rows, err := db.Query(` - SELECT - id, user_id, name, created_at, - theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - FROM workspaces - WHERE user_id = ?`, - userID, - ) + query := NewQuery(db.dbType). + Select( + "id", "user_id", "name", "created_at", + "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email", + "last_opened_file_path"). + From("workspaces"). + Where("user_id = ").Placeholder(userID) + + rows, err := db.Query(query.String(), query.Args()...) if err != nil { return nil, fmt.Errorf("failed to query workspaces: %w", err) } @@ -195,17 +197,23 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro for rows.Next() { workspace := &models.Workspace{} var encryptedToken string + var lastOpenedFile sql.NullString err := rows.Scan( &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail, + &lastOpenedFile, ) if err != nil { return nil, fmt.Errorf("failed to scan workspace row: %w", err) } + if lastOpenedFile.Valid { + workspace.LastOpenedFilePath = lastOpenedFile.String + } + // Decrypt token workspace.GitToken, err = db.decryptToken(encryptedToken) if err != nil { @@ -224,34 +232,28 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro // UpdateWorkspaceSettings updates only the settings portion of a workspace func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { - _, err := db.Exec(` - UPDATE workspaces - SET - theme = ?, - auto_save = ?, - show_hidden_files = ?, - git_enabled = ?, - git_url = ?, - git_user = ?, - git_token = ?, - git_auto_commit = ?, - git_commit_msg_template = ?, - git_commit_name = ?, - git_commit_email = ? - WHERE id = ?`, - workspace.Theme, - workspace.AutoSave, - workspace.ShowHiddenFiles, - workspace.GitEnabled, - workspace.GitURL, - workspace.GitUser, - workspace.GitToken, - workspace.GitAutoCommit, - workspace.GitCommitMsgTemplate, - workspace.GitCommitName, - workspace.GitCommitEmail, - workspace.ID, - ) + // Encrypt token before storing + encryptedToken, err := db.encryptToken(workspace.GitToken) + if err != nil { + return fmt.Errorf("failed to encrypt token: %w", err) + } + + query := NewQuery(db.dbType). + Update("workspaces"). + Set("theme").Placeholder(workspace.Theme). + Set("auto_save").Placeholder(workspace.AutoSave). + Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). + Set("git_enabled").Placeholder(workspace.GitEnabled). + Set("git_url").Placeholder(workspace.GitURL). + Set("git_user").Placeholder(workspace.GitUser). + Set("git_token").Placeholder(encryptedToken). + Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). + Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). + Set("git_commit_name").Placeholder(workspace.GitCommitName). + Set("git_commit_email").Placeholder(workspace.GitCommitEmail). + Where("id = ").Placeholder(workspace.ID) + + _, err = db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to update workspace settings: %w", err) } @@ -263,7 +265,12 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) DeleteWorkspace(id int) error { log := getLogger().WithGroup("workspaces") - _, err := db.Exec("DELETE FROM workspaces WHERE id = ?", id) + query := NewQuery(db.dbType). + Delete(). + From("workspaces"). + Where("id = ").Placeholder(id) + + _, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to delete workspace: %w", err) } @@ -275,7 +282,13 @@ func (db *database) DeleteWorkspace(id int) error { // DeleteWorkspaceTx removes a workspace record from the database within a transaction func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { log := getLogger().WithGroup("workspaces") - result, err := tx.Exec("DELETE FROM workspaces WHERE id = ?", id) + + query := NewQuery(db.dbType). + Delete(). + From("workspaces"). + Where("id = ").Placeholder(id) + + result, err := tx.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to delete workspace in transaction: %w", err) } @@ -285,15 +298,18 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { return fmt.Errorf("failed to get rows affected in transaction: %w", err) } - log.Debug("workspace deleted", - "workspace_id", id) + log.Debug("workspace deleted", "workspace_id", id) return nil } // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { - result, err := tx.Exec("UPDATE users SET last_workspace_id = ? WHERE id = ?", - workspaceID, userID) + query := NewQuery(db.dbType). + Update("users"). + Set("last_workspace_id").Placeholder(workspaceID). + Where("id = ").Placeholder(userID) + + result, err := tx.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to update last workspace in transaction: %w", err) } @@ -308,8 +324,12 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e // UpdateLastOpenedFile updates the last opened file path for a workspace func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { - _, err := db.Exec("UPDATE workspaces SET last_opened_file_path = ? WHERE id = ?", - filePath, workspaceID) + query := NewQuery(db.dbType). + Update("workspaces"). + Set("last_opened_file_path").Placeholder(filePath). + Where("id = ").Placeholder(workspaceID) + + _, err := db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to update last opened file: %w", err) } @@ -319,9 +339,13 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error // GetLastOpenedFile retrieves the last opened file path for a workspace func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { + query := NewQuery(db.dbType). + Select("last_opened_file_path"). + From("workspaces"). + Where("id = ").Placeholder(workspaceID) + var filePath sql.NullString - err := db.QueryRow("SELECT last_opened_file_path FROM workspaces WHERE id = ?", - workspaceID).Scan(&filePath) + err := db.QueryRow(query.String(), query.Args()...).Scan(&filePath) if err == sql.ErrNoRows { return "", fmt.Errorf("workspace not found") @@ -339,15 +363,17 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { // GetAllWorkspaces retrieves all workspaces in the database func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { - rows, err := db.Query(` - SELECT - id, user_id, name, created_at, - theme, auto_save, show_hidden_files, - git_enabled, git_url, git_user, git_token, - git_auto_commit, git_commit_msg_template, - git_commit_name, git_commit_email - FROM workspaces`, - ) + query := NewQuery(db.dbType). + Select( + "id", "user_id", "name", "created_at", + "theme", "auto_save", "show_hidden_files", + "git_enabled", "git_url", "git_user", "git_token", + "git_auto_commit", "git_commit_msg_template", + "git_commit_name", "git_commit_email", + "last_opened_file_path"). + From("workspaces") + + rows, err := db.Query(query.String(), query.Args()...) if err != nil { return nil, fmt.Errorf("failed to query workspaces: %w", err) } @@ -357,17 +383,24 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { for rows.Next() { workspace := &models.Workspace{} var encryptedToken string + var lastOpenedFile sql.NullString + err := rows.Scan( &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, &workspace.GitCommitName, &workspace.GitCommitEmail, + &lastOpenedFile, ) if err != nil { return nil, fmt.Errorf("failed to scan workspace row: %w", err) } + if lastOpenedFile.Valid { + workspace.LastOpenedFilePath = lastOpenedFile.String + } + // Decrypt token workspace.GitToken, err = db.decryptToken(encryptedToken) if err != nil { From 96fc490c1d91bd934ba6802115a9fff38d71c54e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 22:29:06 +0100 Subject: [PATCH 13/39] Update migrations for postgres --- server/internal/db/migrations.go | 16 ++++- .../postgres/001_initial_schema.down.sql | 9 +++ .../postgres/001_initial_schema.up.sql | 61 +++++++++++++++++++ .../{ => sqlite}/001_initial_schema.down.sql | 1 + .../{ => sqlite}/001_initial_schema.up.sql | 3 +- server/internal/db/migrations_test.go | 2 +- 6 files changed, 88 insertions(+), 4 deletions(-) create mode 100644 server/internal/db/migrations/postgres/001_initial_schema.down.sql create mode 100644 server/internal/db/migrations/postgres/001_initial_schema.up.sql rename server/internal/db/migrations/{ => sqlite}/001_initial_schema.down.sql (86%) rename server/internal/db/migrations/{ => sqlite}/001_initial_schema.up.sql (97%) diff --git a/server/internal/db/migrations.go b/server/internal/db/migrations.go index efa3769..a01f30b 100644 --- a/server/internal/db/migrations.go +++ b/server/internal/db/migrations.go @@ -10,7 +10,7 @@ import ( "github.com/golang-migrate/migrate/v4/source/iofs" ) -//go:embed migrations/*.sql +//go:embed migrations/sqlite/*.sql migrations/postgres/*.sql var migrationsFS embed.FS // Migrate applies all database migrations @@ -18,7 +18,19 @@ func (db *database) Migrate() error { log := getLogger().WithGroup("migrations") log.Info("starting database migration") - sourceInstance, err := iofs.New(migrationsFS, "migrations") + var migrationPath string + switch db.dbType { + case DBTypePostgres: + migrationPath = "migrations/postgres" + case DBTypeSQLite: + migrationPath = "migrations/sqlite" + default: + return fmt.Errorf("unsupported database driver: %s", db.dbType) + } + + log.Debug("using migration path", "path", migrationPath) + + sourceInstance, err := iofs.New(migrationsFS, migrationPath) if err != nil { return fmt.Errorf("failed to create source instance: %w", err) } diff --git a/server/internal/db/migrations/postgres/001_initial_schema.down.sql b/server/internal/db/migrations/postgres/001_initial_schema.down.sql new file mode 100644 index 0000000..e95e055 --- /dev/null +++ b/server/internal/db/migrations/postgres/001_initial_schema.down.sql @@ -0,0 +1,9 @@ +-- 001_initial_schema.down.sql (PostgreSQL version) +DROP INDEX IF EXISTS idx_sessions_refresh_token; +DROP INDEX IF EXISTS idx_sessions_expires_at; +DROP INDEX IF EXISTS idx_sessions_user_id; +DROP INDEX IF EXISTS idx_workspaces_user_id; +DROP TABLE IF EXISTS sessions; +DROP TABLE IF EXISTS workspaces; +DROP TABLE IF EXISTS system_settings; +DROP TABLE IF EXISTS users; \ No newline at end of file diff --git a/server/internal/db/migrations/postgres/001_initial_schema.up.sql b/server/internal/db/migrations/postgres/001_initial_schema.up.sql new file mode 100644 index 0000000..288aec0 --- /dev/null +++ b/server/internal/db/migrations/postgres/001_initial_schema.up.sql @@ -0,0 +1,61 @@ +-- 001_initial_schema.up.sql (PostgreSQL version) + +-- Create users table +CREATE TABLE IF NOT EXISTS users ( + id SERIAL PRIMARY KEY, + email TEXT NOT NULL UNIQUE, + display_name TEXT, + password_hash TEXT NOT NULL, + role TEXT NOT NULL CHECK(role IN ('admin', 'editor', 'viewer')), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_workspace_id INTEGER +); + +-- Create workspaces table with integrated settings +CREATE TABLE IF NOT EXISTS workspaces ( + id SERIAL PRIMARY KEY, + user_id INTEGER NOT NULL, + name TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + last_opened_file_path TEXT, + -- Settings fields + theme TEXT NOT NULL DEFAULT 'light' CHECK(theme IN ('light', 'dark')), + auto_save BOOLEAN NOT NULL DEFAULT FALSE, + git_enabled BOOLEAN NOT NULL DEFAULT FALSE, + git_url TEXT, + git_user TEXT, + git_token TEXT, + git_auto_commit BOOLEAN NOT NULL DEFAULT FALSE, + git_commit_msg_template TEXT DEFAULT '${action} ${filename}', + git_commit_name TEXT, + git_commit_email TEXT, + show_hidden_files BOOLEAN NOT NULL DEFAULT FALSE, + created_by INTEGER REFERENCES users(id), + updated_by INTEGER REFERENCES users(id), + updated_at TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +); + +-- Create sessions table for authentication +CREATE TABLE IF NOT EXISTS sessions ( + id TEXT PRIMARY KEY, + user_id INTEGER NOT NULL, + refresh_token TEXT NOT NULL, + expires_at TIMESTAMP NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES users (id) ON DELETE CASCADE +); + +-- Create system_settings table for application settings +CREATE TABLE IF NOT EXISTS system_settings ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +-- Create indexes for performance +CREATE INDEX idx_sessions_user_id ON sessions(user_id); +CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); +CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token); +CREATE INDEX idx_workspaces_user_id ON workspaces(user_id); \ No newline at end of file diff --git a/server/internal/db/migrations/001_initial_schema.down.sql b/server/internal/db/migrations/sqlite/001_initial_schema.down.sql similarity index 86% rename from server/internal/db/migrations/001_initial_schema.down.sql rename to server/internal/db/migrations/sqlite/001_initial_schema.down.sql index f32272a..ba0f7de 100644 --- a/server/internal/db/migrations/001_initial_schema.down.sql +++ b/server/internal/db/migrations/sqlite/001_initial_schema.down.sql @@ -2,6 +2,7 @@ DROP INDEX IF EXISTS idx_sessions_refresh_token; DROP INDEX IF EXISTS idx_sessions_expires_at; DROP INDEX IF EXISTS idx_sessions_user_id; +DROP INDEX IF EXISTS idx_workspaces_user_id; DROP TABLE IF EXISTS sessions; DROP TABLE IF EXISTS workspaces; DROP TABLE IF EXISTS system_settings; diff --git a/server/internal/db/migrations/001_initial_schema.up.sql b/server/internal/db/migrations/sqlite/001_initial_schema.up.sql similarity index 97% rename from server/internal/db/migrations/001_initial_schema.up.sql rename to server/internal/db/migrations/sqlite/001_initial_schema.up.sql index 8c13e9b..b632442 100644 --- a/server/internal/db/migrations/001_initial_schema.up.sql +++ b/server/internal/db/migrations/sqlite/001_initial_schema.up.sql @@ -56,4 +56,5 @@ CREATE TABLE IF NOT EXISTS system_settings ( -- Create indexes for performance CREATE INDEX idx_sessions_user_id ON sessions(user_id); CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); -CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token); \ No newline at end of file +CREATE INDEX idx_sessions_refresh_token ON sessions(refresh_token); +CREATE INDEX idx_workspaces_user_id ON workspaces(user_id); \ No newline at end of file diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index ecce24c..80c9661 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -27,7 +27,6 @@ func TestMigrate(t *testing.T) { "workspaces", "sessions", "system_settings", - // Note: golang-migrate uses its own migrations table "schema_migrations", } @@ -45,6 +44,7 @@ func TestMigrate(t *testing.T) { {"sessions", "idx_sessions_user_id"}, {"sessions", "idx_sessions_expires_at"}, {"sessions", "idx_sessions_refresh_token"}, + {"workspaces", "idx_workspaces_user_id"}, } for _, idx := range indexes { if !indexExists(t, database, idx.table, idx.name) { From a80b48956a3746815647dcade550f1b88ca142f4 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Tue, 25 Feb 2025 22:36:22 +0100 Subject: [PATCH 14/39] Add returning tests --- server/internal/db/query_test.go | 152 +++++++++++++++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go index 6c29924..4910328 100644 --- a/server/internal/db/query_test.go +++ b/server/internal/db/query_test.go @@ -658,6 +658,158 @@ func TestHavingClause(t *testing.T) { } } +func TestQueryReturning(t *testing.T) { + testCases := []struct { + name string + dbType db.DBType + buildQuery func(q *db.Query) *db.Query + expectedSQL string + expectedArgs []any + }{ + { + name: "SQLite INSERT with RETURNING single column", + dbType: db.DBTypeSQLite, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("id") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + { + name: "PostgreSQL INSERT with RETURNING single column", + dbType: db.DBTypePostgres, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("id") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + { + name: "SQLite INSERT with RETURNING multiple columns", + dbType: db.DBTypeSQLite, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("id", "created_at") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING id, created_at", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + { + name: "PostgreSQL INSERT with RETURNING multiple columns", + dbType: db.DBTypePostgres, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("id", "created_at") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING id, created_at", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + { + name: "SQLite UPDATE with RETURNING", + dbType: db.DBTypeSQLite, + buildQuery: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("Jane Doe"). + Where("id = ").Placeholder(1). + Returning("name", "updated_at") + }, + expectedSQL: "UPDATE users SET name = ? WHERE id = ? RETURNING name, updated_at", + expectedArgs: []any{"Jane Doe", 1}, + }, + { + name: "PostgreSQL UPDATE with RETURNING", + dbType: db.DBTypePostgres, + buildQuery: func(q *db.Query) *db.Query { + return q.Update("users"). + Set("name").Placeholder("Jane Doe"). + Where("id = ").Placeholder(1). + Returning("name", "updated_at") + }, + expectedSQL: "UPDATE users SET name = $1 WHERE id = $2 RETURNING name, updated_at", + expectedArgs: []any{"Jane Doe", 1}, + }, + { + name: "SQLite DELETE with RETURNING", + dbType: db.DBTypeSQLite, + buildQuery: func(q *db.Query) *db.Query { + return q.Delete(). + From("users"). + Where("id = ").Placeholder(1). + Returning("id", "name", "email") + }, + expectedSQL: "DELETE FROM users WHERE id = ? RETURNING id, name, email", + expectedArgs: []any{1}, + }, + { + name: "PostgreSQL DELETE with RETURNING", + dbType: db.DBTypePostgres, + buildQuery: func(q *db.Query) *db.Query { + return q.Delete(). + From("users"). + Where("id = ").Placeholder(1). + Returning("id", "name", "email") + }, + expectedSQL: "DELETE FROM users WHERE id = $1 RETURNING id, name, email", + expectedArgs: []any{1}, + }, + { + name: "SQLite INSERT with RETURNING *", + dbType: db.DBTypeSQLite, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("*") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES (?, ?) RETURNING *", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + { + name: "PostgreSQL INSERT with RETURNING *", + dbType: db.DBTypePostgres, + buildQuery: func(q *db.Query) *db.Query { + return q.Insert("users", "name", "email"). + Values(2). + AddArgs("John Doe", "john@example.com"). + Returning("*") + }, + expectedSQL: "INSERT INTO users (name, email) VALUES ($1, $2) RETURNING *", + expectedArgs: []any{"John Doe", "john@example.com"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + query := db.NewQuery(tc.dbType) + result := tc.buildQuery(query) + + if result.String() != tc.expectedSQL { + t.Errorf("Expected SQL: %s, got: %s", tc.expectedSQL, result.String()) + } + + if len(result.Args()) != len(tc.expectedArgs) { + t.Errorf("Expected %d args, got %d", len(tc.expectedArgs), len(result.Args())) + } + + for i, arg := range result.Args() { + if arg != tc.expectedArgs[i] { + t.Errorf("Expected arg %d to be %v, got %v", i, tc.expectedArgs[i], arg) + } + } + }) + } +} + func TestComplexQueries(t *testing.T) { tests := []struct { name string From c0de3538dcb7ebfeb6dd617ec4743e5b68ee1dd1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 27 Feb 2025 21:16:43 +0100 Subject: [PATCH 15/39] Implement insert struct --- server/internal/db/scanner.go | 288 ---------------------------- server/internal/db/scanner_test.go | 298 ----------------------------- server/internal/db/struct_query.go | 99 ++++++++++ 3 files changed, 99 insertions(+), 586 deletions(-) delete mode 100644 server/internal/db/scanner.go delete mode 100644 server/internal/db/scanner_test.go create mode 100644 server/internal/db/struct_query.go diff --git a/server/internal/db/scanner.go b/server/internal/db/scanner.go deleted file mode 100644 index f50c296..0000000 --- a/server/internal/db/scanner.go +++ /dev/null @@ -1,288 +0,0 @@ -package db - -import ( - "database/sql" - "fmt" - "reflect" - "regexp" - "strings" -) - -// Scanner provides methods for scanning rows into structs -type Scanner struct { - db *sql.DB - dbType DBType -} - -// NewScanner creates a new Scanner instance -func NewScanner(db *sql.DB, dbType DBType) *Scanner { - return &Scanner{ - db: db, - dbType: dbType, - } -} - -// QueryRow executes a query and scans the result into a struct -func (s *Scanner) QueryRow(dest any, q *Query) error { - row := s.db.QueryRow(q.String(), q.Args()...) - - // Handle primitive types - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - elem := v.Elem() - switch elem.Kind() { - case reflect.Struct: - return scanStruct(row, dest) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, - reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, - reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: - return row.Scan(dest) - default: - return fmt.Errorf("unsupported dest type: %T", dest) - } -} - -// Query executes a query and scans multiple results into a slice of structs -func (s *Scanner) Query(dest any, q *Query) error { - rows, err := s.db.Query(q.String(), q.Args()...) - if err != nil { - return err - } - defer rows.Close() - - return scanStructs(rows, dest) -} - -// scanStruct scans a single row into a struct -func scanStruct(row *sql.Row, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - v = v.Elem() - if v.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a struct") - } - - fields := make([]any, 0, v.NumField()) - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - return row.Scan(fields...) -} - -// scanStructs scans multiple rows into a slice of structs -func scanStructs(rows *sql.Rows, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - sliceVal := v.Elem() - if sliceVal.Kind() != reflect.Slice { - return fmt.Errorf("dest must be a pointer to a slice") - } - - elemType := sliceVal.Type().Elem() - - for rows.Next() { - newElem := reflect.New(elemType).Elem() - fields := make([]any, 0, newElem.NumField()) - - for i := 0; i < newElem.NumField(); i++ { - field := newElem.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - if err := rows.Scan(fields...); err != nil { - return err - } - - sliceVal.Set(reflect.Append(sliceVal, newElem)) - } - - return rows.Err() -} - -// ScannerEx is an extended version of Scanner with more features -type ScannerEx struct { - db *sql.DB - dbType DBType -} - -// NewScannerEx creates a new ScannerEx instance -func NewScannerEx(db *sql.DB, dbType DBType) *ScannerEx { - return &ScannerEx{ - db: db, - dbType: dbType, - } -} - -// QueryRow executes a query and scans the result into a struct -func (s *ScannerEx) QueryRow(dest any, q *Query) error { - row := s.db.QueryRow(q.String(), q.Args()...) - - // Get column names - // Note: This is a workaround since sql.Row doesn't expose column names. - // In a real implementation, you'd likely need to execute the query to get columns first. - // For simplicity, we'll infer them from the struct tags. - - return scanStructTags(row, dest) -} - -// Query executes a query and scans multiple results into a slice of structs -func (s *ScannerEx) Query(dest any, q *Query) error { - rows, err := s.db.Query(q.String(), q.Args()...) - if err != nil { - return err - } - defer rows.Close() - - return scanStructsTags(rows, dest) -} - -// getFieldMap builds a map of db column names to struct fields using struct tags -func getFieldMap(t reflect.Type) map[string]int { - fieldMap := make(map[string]int) - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - - // Check db tag first - tag := field.Tag.Get("db") - if tag != "" && tag != "-" { - fieldMap[tag] = i - continue - } - - // Check json tag next - tag = field.Tag.Get("json") - if tag != "" && tag != "-" { - // Handle json tag options like omitempty - parts := strings.Split(tag, ",") - fieldMap[parts[0]] = i - continue - } - - // Default to field name with snake_case conversion - fieldMap[toSnakeCase(field.Name)] = i - } - - return fieldMap -} - -var camelRegex = regexp.MustCompile(`([a-z0-9])([A-Z])`) - -// toSnakeCase converts a camelCase string to snake_case -func toSnakeCase(s string) string { - return strings.ToLower(camelRegex.ReplaceAllString(s, "${1}_${2}")) -} - -// scanStructTags scans a single row into a struct using field tags -func scanStructTags(row *sql.Row, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - v = v.Elem() - if v.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a struct") - } - - fields := make([]any, 0, v.NumField()) - - for i := 0; i < v.NumField(); i++ { - field := v.Field(i) - if field.CanSet() { - fields = append(fields, field.Addr().Interface()) - } - } - - return row.Scan(fields...) -} - -// scanStructsTags scans multiple rows into a slice of structs using field tags -func scanStructsTags(rows *sql.Rows, dest any) error { - v := reflect.ValueOf(dest) - if v.Kind() != reflect.Ptr { - return fmt.Errorf("dest must be a pointer") - } - - sliceVal := v.Elem() - if sliceVal.Kind() != reflect.Slice { - return fmt.Errorf("dest must be a pointer to a slice") - } - - elemType := sliceVal.Type().Elem() - isPtr := elemType.Kind() == reflect.Ptr - if isPtr { - elemType = elemType.Elem() - } - - if elemType.Kind() != reflect.Struct { - return fmt.Errorf("dest must be a pointer to a slice of structs") - } - - // Get column names - columns, err := rows.Columns() - if err != nil { - return err - } - - // Build field mapping - fieldMap := getFieldMap(elemType) - - // Prepare values slice for each scan - values := make([]any, len(columns)) - scanFields := make([]any, len(columns)) - for i := range values { - scanFields[i] = &values[i] - } - - for rows.Next() { - // Create a new struct instance - newElem := reflect.New(elemType).Elem() - - // Scan row into values - if err := rows.Scan(scanFields...); err != nil { - return err - } - - // Map values to struct fields - for i, colName := range columns { - if fieldIndex, ok := fieldMap[colName]; ok { - field := newElem.Field(fieldIndex) - if field.CanSet() { - val := reflect.ValueOf(values[i]) - if val.Elem().Kind() == reflect.Interface { - val = val.Elem() - } - if val.Kind() == reflect.Ptr && !val.IsNil() { - field.Set(val.Elem()) - } else if !val.IsNil() { - field.Set(val) - } - } - } - } - - // Append to result slice - if isPtr { - sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) - } else { - sliceVal.Set(reflect.Append(sliceVal, newElem)) - } - } - - return rows.Err() -} diff --git a/server/internal/db/scanner_test.go b/server/internal/db/scanner_test.go deleted file mode 100644 index a1fbc28..0000000 --- a/server/internal/db/scanner_test.go +++ /dev/null @@ -1,298 +0,0 @@ -package db_test - -import ( - "database/sql" - "testing" - "time" - - "lemma/internal/db" -) - -func TestScannerQueryRow(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - // Create a test table - _, err = testDB.TestDB().Exec(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - email TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - active BOOLEAN NOT NULL - ) - `) - if err != nil { - t.Fatalf("Failed to create test table: %v", err) - } - - type User struct { - ID int - Email string - CreatedAt time.Time - Active bool - } - - // Insert test data - now := time.Now().UTC().Truncate(time.Second) - _, err = testDB.TestDB().Exec( - "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", - 1, "test@example.com", now, true, - ) - if err != nil { - t.Fatalf("Failed to insert test data: %v", err) - } - - // Test query row success - t.Run("QueryRow success", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id = "). - Placeholder(1) - - var user User - err := scanner.QueryRow(&user, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if user.ID != 1 { - t.Errorf("Expected ID 1, got %d", user.ID) - } - if user.Email != "test@example.com" { - t.Errorf("Expected Email test@example.com, got %s", user.Email) - } - if !user.CreatedAt.Equal(now) { - t.Errorf("Expected CreatedAt %v, got %v", now, user.CreatedAt) - } - if !user.Active { - t.Errorf("Expected Active true, got %v", user.Active) - } - }) - - // Test query row no results - t.Run("QueryRow no results", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id = "). - Placeholder(999) - - var user User - err := scanner.QueryRow(&user, q) - - if err != sql.ErrNoRows { - t.Errorf("Expected ErrNoRows, got %v", err) - } - }) - - // Test scanning a single value - t.Run("QueryRow single value", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("COUNT(*)").From("users") - - var count int - err := scanner.QueryRow(&count, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if count != 1 { - t.Errorf("Expected count 1, got %d", count) - } - }) -} - -func TestScannerQuery(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - // Create a test table - _, err = testDB.TestDB().Exec(` - CREATE TABLE users ( - id INTEGER PRIMARY KEY, - email TEXT NOT NULL, - created_at TIMESTAMP NOT NULL, - active BOOLEAN NOT NULL - ) - `) - if err != nil { - t.Fatalf("Failed to create test table: %v", err) - } - - type User struct { - ID int - Email string - CreatedAt time.Time - Active bool - } - - // Insert test data - now := time.Now().UTC().Truncate(time.Second) - testUsers := []User{ - {ID: 1, Email: "user1@example.com", CreatedAt: now, Active: true}, - {ID: 2, Email: "user2@example.com", CreatedAt: now, Active: false}, - {ID: 3, Email: "user3@example.com", CreatedAt: now, Active: true}, - } - - for _, user := range testUsers { - _, err = testDB.TestDB().Exec( - "INSERT INTO users (id, email, created_at, active) VALUES (?, ?, ?, ?)", - user.ID, user.Email, user.CreatedAt, user.Active, - ) - if err != nil { - t.Fatalf("Failed to insert test data: %v", err) - } - } - - // Test query multiple rows - t.Run("Query multiple rows", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - OrderBy("id ASC") - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != len(testUsers) { - t.Errorf("Expected %d users, got %d", len(testUsers), len(users)) - } - - for i, u := range users { - if u.ID != testUsers[i].ID { - t.Errorf("Expected user[%d].ID %d, got %d", i, testUsers[i].ID, u.ID) - } - if u.Email != testUsers[i].Email { - t.Errorf("Expected user[%d].Email %s, got %s", i, testUsers[i].Email, u.Email) - } - if !u.CreatedAt.Equal(testUsers[i].CreatedAt) { - t.Errorf("Expected user[%d].CreatedAt %v, got %v", i, testUsers[i].CreatedAt, u.CreatedAt) - } - if u.Active != testUsers[i].Active { - t.Errorf("Expected user[%d].Active %v, got %v", i, testUsers[i].Active, u.Active) - } - } - }) - - // Test query with filter - t.Run("Query with filter", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("active = "). - Placeholder(true). - OrderBy("id ASC") - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != 2 { - t.Errorf("Expected 2 users, got %d", len(users)) - } - - for _, u := range users { - if !u.Active { - t.Errorf("Expected only active users, got inactive user: %+v", u) - } - } - }) - - // Test query empty result - t.Run("Query empty result", func(t *testing.T) { - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("id", "email", "created_at", "active"). - From("users"). - Where("id > "). - Placeholder(100) - - var users []User - err := scanner.Query(&users, q) - - if err != nil { - t.Errorf("Expected no error, got %v", err) - } - - if len(users) != 0 { - t.Errorf("Expected 0 users, got %d", len(users)) - } - }) -} - -func TestScanErrors(t *testing.T) { - mockSecrets := &mockSecretsService{} - testDB, err := db.NewTestDB(mockSecrets) - if err != nil { - t.Fatalf("Failed to create test database: %v", err) - } - defer testDB.Close() - - scanner := db.NewScanner(testDB.TestDB(), db.DBTypeSQLite) - q := db.NewQuery(db.DBTypeSQLite) - q.Select("1") - - // Test non-pointer - t.Run("QueryRow non-pointer", func(t *testing.T) { - var user struct{} - err := scanner.QueryRow(user, q) // Passing non-pointer - - if err == nil { - t.Error("Expected error for non-pointer, got nil") - } - }) - - // Test pointer to non-slice for Query - t.Run("Query pointer to non-slice", func(t *testing.T) { - var user struct{} - err := scanner.Query(&user, q) // Passing pointer to struct, not slice - - if err == nil { - t.Error("Expected error for non-slice pointer, got nil") - } - }) - - // Test non-pointer for Query - t.Run("Query non-pointer", func(t *testing.T) { - var users []struct{} - err := scanner.Query(users, q) // Passing non-pointer - - if err == nil { - t.Error("Expected error for non-pointer, got nil") - } - }) -} - -// Mock secrets service for testing -type mockSecretsService struct{} - -func (m *mockSecretsService) Encrypt(plaintext string) (string, error) { - return plaintext, nil -} - -func (m *mockSecretsService) Decrypt(ciphertext string) (string, error) { - return ciphertext, nil -} diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go new file mode 100644 index 0000000..b465b52 --- /dev/null +++ b/server/internal/db/struct_query.go @@ -0,0 +1,99 @@ +package db + +import ( + "fmt" + "reflect" + "strings" + "unicode" +) + +type DBField struct { + Name string + Value any + Type reflect.Type +} + +func StructTagsToFields(s any) ([]DBField, error) { + v := reflect.ValueOf(s) + + if v.Kind() == reflect.Ptr { + if v.IsNil() { + return nil, fmt.Errorf("nil pointer provided") + } + + v = v.Elem() + } + + if v.Kind() != reflect.Struct { + return nil, fmt.Errorf("provided value is %s, expected struct", v.Kind()) + } + + t := v.Type() + fields := make([]DBField, 0, t.NumField()) + + for i := range t.NumField() { + f := t.Field(i) + + if !f.IsExported() { + continue + } + + tag := f.Tag.Get("db") + if tag == "-" { + continue + } + + if tag == "" { + tag = toSnakeCase(f.Name) + } + + if strings.Contains(tag, "omitempty") && reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) { + continue + } + + fields = append(fields, DBField{ + Name: tag, + Value: v.Field(i).Interface(), + Type: f.Type, + }) + } + return fields, nil +} + +func toSnakeCase(s string) string { + var res string + + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + res += "_" + } + res += string(unicode.ToLower(r)) + } else { + res += string(r) + } + } + return res +} + +func (q *Query) InsertStruct(s any, table string) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + columns := make([]string, 0, len(fields)) + values := make([]any, 0, len(fields)) + + for _, f := range fields { + columns = append(columns, f.Name) + values = append(values, f.Value) + } + + if len(columns) == 0 { + return nil, fmt.Errorf("no columns to insert") + } + + q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) + return q, nil +} From e89b4a0e146049714c050bed4c9605f17a42d195 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 27 Feb 2025 21:44:32 +0100 Subject: [PATCH 16/39] Use InsertStruct --- server/internal/db/sessions.go | 11 +++++---- server/internal/db/struct_query.go | 35 ++++++++++++++++++++------ server/internal/db/users.go | 38 +++++++++++++---------------- server/internal/db/workspaces.go | 21 +++++++--------- server/internal/models/session.go | 10 ++++---- server/internal/models/user.go | 14 +++++------ server/internal/models/workspace.go | 32 ++++++++++++------------ 7 files changed, 87 insertions(+), 74 deletions(-) diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index 7df51c6..66c148b 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -10,11 +10,12 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { - query := NewQuery(db.dbType). - Insert("sessions", "id", "user_id", "refresh_token", "expires_at", "created_at"). - Values(5). - AddArgs(session.ID, session.UserID, session.RefreshToken, session.ExpiresAt, session.CreatedAt) - _, err := db.Exec(query.String(), query.Args()...) + query, err := NewQuery(db.dbType). + InsertStruct(session, "sessions") + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } + _, err = db.Exec(query.String(), query.Args()...) if err != nil { return fmt.Errorf("failed to store session: %w", err) } diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index b465b52..c18b6d5 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -8,9 +8,10 @@ import ( ) type DBField struct { - Name string - Value any - Type reflect.Type + Name string + Value any + Type reflect.Type + useDefault bool } func StructTagsToFields(s any) ([]DBField, error) { @@ -47,14 +48,28 @@ func StructTagsToFields(s any) ([]DBField, error) { tag = toSnakeCase(f.Name) } - if strings.Contains(tag, "omitempty") && reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) { - continue + useDefault := false + if strings.Contains(tag, ",") { + parts := strings.Split(tag, ",") + tag = parts[0] + + for _, opt := range parts[1:] { + switch opt { + case "omitempty": + if reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) { + continue + } + case "default": + useDefault = true + } + } } fields = append(fields, DBField{ - Name: tag, - Value: v.Field(i).Interface(), - Type: f.Type, + Name: tag, + Value: v.Field(i).Interface(), + Type: f.Type, + useDefault: useDefault, }) } return fields, nil @@ -86,6 +101,10 @@ func (q *Query) InsertStruct(s any, table string) (*Query, error) { values := make([]any, 0, len(fields)) for _, f := range fields { + if f.useDefault { + continue + } + columns = append(columns, f.Name) values = append(values, f.Value) } diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 5260a19..b326949 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -17,11 +17,14 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } defer tx.Rollback() - query := NewQuery(db.dbType). - Insert("users", "email", "display_name", "password_hash", "role"). - Values(4). - AddArgs(user.Email, user.DisplayName, user.PasswordHash, user.Role). - Returning("id", "created_at") + query, err := NewQuery(db.dbType). + InsertStruct(user, "users") + + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + + query.Returning("id", "created_at") err = tx.QueryRow(query.String(), query.Args()...). Scan(&user.ID, &user.CreatedAt) @@ -69,23 +72,16 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { log := getLogger().WithGroup("users") - insertQuery := NewQuery(db.dbType). - Insert("workspaces", - "user_id", "name", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email"). - Values(13). - AddArgs( - workspace.UserID, workspace.Name, - workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, - workspace.GitEnabled, workspace.GitURL, workspace.GitUser, workspace.GitToken, - workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, - workspace.GitCommitName, workspace.GitCommitEmail). - Returning("id") + insertQuery, err := NewQuery(db.dbType). + InsertStruct(workspace, "workspaces") - err := tx.QueryRow(insertQuery.String(), insertQuery.Args()...).Scan(&workspace.ID) + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } + + insertQuery.Returning("id") + + err = tx.QueryRow(insertQuery.String(), insertQuery.Args()...).Scan(&workspace.ID) if err != nil { return fmt.Errorf("failed to insert workspace: %w", err) } diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 06bff0c..dc07911 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -24,19 +24,16 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { if err != nil { return fmt.Errorf("failed to encrypt token: %w", err) } + workspace.GitToken = encryptedToken - query := NewQuery(db.dbType). - Insert("workspaces", - "user_id", "name", "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email"). - Values(13). - AddArgs( - workspace.UserID, workspace.Name, workspace.Theme, workspace.AutoSave, workspace.ShowHiddenFiles, - workspace.GitEnabled, workspace.GitURL, workspace.GitUser, encryptedToken, - workspace.GitAutoCommit, workspace.GitCommitMsgTemplate, workspace.GitCommitName, workspace.GitCommitEmail). - Returning("id", "created_at") + query, err := NewQuery(db.dbType). + InsertStruct(workspace, "workspaces") + + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } + + query.Returning("id", "created_at") err = db.QueryRow(query.String(), query.Args()...). Scan(&workspace.ID, &workspace.CreatedAt) diff --git a/server/internal/models/session.go b/server/internal/models/session.go index d0b8119..305dd85 100644 --- a/server/internal/models/session.go +++ b/server/internal/models/session.go @@ -5,9 +5,9 @@ import "time" // Session represents a user session in the database type Session struct { - ID string // Unique session identifier - UserID int // ID of the user this session belongs to - RefreshToken string // The refresh token associated with this session - ExpiresAt time.Time // When this session expires - CreatedAt time.Time // When this session was created + ID string `db:"id,default"` // Unique session identifier + UserID int `db:"user_id"` // ID of the user this session belongs to + RefreshToken string `db:"refresh_token"` // The refresh token associated with this session + ExpiresAt time.Time `db:"expires_at"` // When this session expires + CreatedAt time.Time `db:"created_at,default"` // When this session was created } diff --git a/server/internal/models/user.go b/server/internal/models/user.go index 3832cda..11d6903 100644 --- a/server/internal/models/user.go +++ b/server/internal/models/user.go @@ -20,13 +20,13 @@ const ( // User represents a user in the system type User struct { - ID int `json:"id" validate:"required,min=1"` - Email string `json:"email" validate:"required,email"` - DisplayName string `json:"displayName"` - PasswordHash string `json:"-"` - Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"` - CreatedAt time.Time `json:"createdAt"` - LastWorkspaceID int `json:"lastWorkspaceId"` + ID int `json:"id" db:"id,default" validate:"required,min=1"` + Email string `json:"email" db:"email" validate:"required,email"` + DisplayName string `json:"displayName" db:"display_name"` + PasswordHash string `json:"-" db:"password_hash"` + Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"` + CreatedAt time.Time `json:"createdAt" db:"created_at,default"` + LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"` } // Validate validates the user struct diff --git a/server/internal/models/workspace.go b/server/internal/models/workspace.go index 6589957..886847b 100644 --- a/server/internal/models/workspace.go +++ b/server/internal/models/workspace.go @@ -6,24 +6,24 @@ import ( // Workspace represents a user's workspace in the system type Workspace struct { - ID int `json:"id" validate:"required,min=1"` - UserID int `json:"userId" validate:"required,min=1"` - Name string `json:"name" validate:"required"` - CreatedAt time.Time `json:"createdAt"` - LastOpenedFilePath string `json:"lastOpenedFilePath"` + ID int `json:"id" db:"id,default" validate:"required,min=1"` + UserID int `json:"userId" db:"user_id" validate:"required,min=1"` + Name string `json:"name" db:"name" validate:"required"` + CreatedAt time.Time `json:"createdAt" db:"created_at,default"` + LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"` // Integrated settings - Theme string `json:"theme" validate:"oneof=light dark"` - AutoSave bool `json:"autoSave"` - ShowHiddenFiles bool `json:"showHiddenFiles"` - GitEnabled bool `json:"gitEnabled"` - GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"` - GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"` - GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"` - GitAutoCommit bool `json:"gitAutoCommit"` - GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"` - GitCommitName string `json:"gitCommitName"` - GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"` + Theme string `json:"theme" db:"theme" validate:"oneof=light dark"` + AutoSave bool `json:"autoSave" db:"auto_save"` + ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"` + GitEnabled bool `json:"gitEnabled" db:"git_enabled"` + GitURL string `json:"gitUrl" db:"git_url" validate:"required_if=GitEnabled true"` + GitUser string `json:"gitUser" db:"git_user" validate:"required_if=GitEnabled true"` + GitToken string `json:"gitToken" db:"git_token" validate:"required_if=GitEnabled true"` + GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"` + GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"` + GitCommitName string `json:"gitCommitName" db:"git_commit_name"` + GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"` } // Validate validates the workspace struct From 3ce92322f4ecb458bbe74f2171fc1d5b4839eadb Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 1 Mar 2025 21:59:04 +0100 Subject: [PATCH 17/39] Encrypt git token in insertstruct --- server/internal/db/sessions.go | 2 +- server/internal/db/struct_query.go | 17 +++++++++++++---- server/internal/db/users.go | 4 ++-- server/internal/db/workspaces.go | 9 +-------- server/internal/models/session.go | 2 +- server/internal/models/workspace.go | 6 +++--- 6 files changed, 21 insertions(+), 19 deletions(-) diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index 66c148b..bdf55db 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -11,7 +11,7 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { query, err := NewQuery(db.dbType). - InsertStruct(session, "sessions") + InsertStruct(session, "sessions", db.secretsService) if err != nil { return fmt.Errorf("failed to create query: %w", err) } diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index c18b6d5..68fc7a3 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "lemma/internal/secrets" "reflect" "strings" "unicode" @@ -14,7 +15,7 @@ type DBField struct { useDefault bool } -func StructTagsToFields(s any) ([]DBField, error) { +func StructTagsToFields(s any, secretsService secrets.Service) ([]DBField, error) { v := reflect.ValueOf(s) if v.Kind() == reflect.Ptr { @@ -49,6 +50,8 @@ func StructTagsToFields(s any) ([]DBField, error) { } useDefault := false + value := v.Field(i).Interface() + if strings.Contains(tag, ",") { parts := strings.Split(tag, ",") tag = parts[0] @@ -61,13 +64,19 @@ func StructTagsToFields(s any) ([]DBField, error) { } case "default": useDefault = true + case "encrypted": + val, err := secretsService.Encrypt(value.(string)) + if err != nil { + return nil, fmt.Errorf("failed to encrypt field %s: %w", f.Name, err) + } + value = val } } } fields = append(fields, DBField{ Name: tag, - Value: v.Field(i).Interface(), + Value: value, Type: f.Type, useDefault: useDefault, }) @@ -91,8 +100,8 @@ func toSnakeCase(s string) string { return res } -func (q *Query) InsertStruct(s any, table string) (*Query, error) { - fields, err := StructTagsToFields(s) +func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { + fields, err := StructTagsToFields(s, secretsService) if err != nil { return nil, err } diff --git a/server/internal/db/users.go b/server/internal/db/users.go index b326949..ed792d6 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -18,7 +18,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { defer tx.Rollback() query, err := NewQuery(db.dbType). - InsertStruct(user, "users") + InsertStruct(user, "users", db.secretsService) if err != nil { return nil, fmt.Errorf("failed to create query: %w", err) @@ -73,7 +73,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e log := getLogger().WithGroup("users") insertQuery, err := NewQuery(db.dbType). - InsertStruct(workspace, "workspaces") + InsertStruct(workspace, "workspaces", db.secretsService) if err != nil { return fmt.Errorf("failed to create query: %w", err) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index dc07911..120033f 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -19,15 +19,8 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { workspace.SetDefaultSettings() } - // Encrypt token if present - encryptedToken, err := db.encryptToken(workspace.GitToken) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) - } - workspace.GitToken = encryptedToken - query, err := NewQuery(db.dbType). - InsertStruct(workspace, "workspaces") + InsertStruct(workspace, "workspaces", db.secretsService) if err != nil { return fmt.Errorf("failed to create query: %w", err) diff --git a/server/internal/models/session.go b/server/internal/models/session.go index 305dd85..60652a7 100644 --- a/server/internal/models/session.go +++ b/server/internal/models/session.go @@ -5,7 +5,7 @@ import "time" // Session represents a user session in the database type Session struct { - ID string `db:"id,default"` // Unique session identifier + ID string `db:"id"` // Unique session identifier UserID int `db:"user_id"` // ID of the user this session belongs to RefreshToken string `db:"refresh_token"` // The refresh token associated with this session ExpiresAt time.Time `db:"expires_at"` // When this session expires diff --git a/server/internal/models/workspace.go b/server/internal/models/workspace.go index 886847b..083a14e 100644 --- a/server/internal/models/workspace.go +++ b/server/internal/models/workspace.go @@ -17,9 +17,9 @@ type Workspace struct { AutoSave bool `json:"autoSave" db:"auto_save"` ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"` GitEnabled bool `json:"gitEnabled" db:"git_enabled"` - GitURL string `json:"gitUrl" db:"git_url" validate:"required_if=GitEnabled true"` - GitUser string `json:"gitUser" db:"git_user" validate:"required_if=GitEnabled true"` - GitToken string `json:"gitToken" db:"git_token" validate:"required_if=GitEnabled true"` + GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"` + GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"` + GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"` GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"` GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"` GitCommitName string `json:"gitCommitName" db:"git_commit_name"` From 204dacd15e81d0669c5a7c73bb7a0d0198aff896 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sat, 1 Mar 2025 22:26:50 +0100 Subject: [PATCH 18/39] Encrypt field in query --- server/internal/db/struct_query.go | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index 68fc7a3..d6c8cde 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -13,9 +13,10 @@ type DBField struct { Value any Type reflect.Type useDefault bool + encrypted bool } -func StructTagsToFields(s any, secretsService secrets.Service) ([]DBField, error) { +func StructTagsToFields(s any) ([]DBField, error) { v := reflect.ValueOf(s) if v.Kind() == reflect.Ptr { @@ -50,7 +51,7 @@ func StructTagsToFields(s any, secretsService secrets.Service) ([]DBField, error } useDefault := false - value := v.Field(i).Interface() + encrypted := false if strings.Contains(tag, ",") { parts := strings.Split(tag, ",") @@ -65,20 +66,17 @@ func StructTagsToFields(s any, secretsService secrets.Service) ([]DBField, error case "default": useDefault = true case "encrypted": - val, err := secretsService.Encrypt(value.(string)) - if err != nil { - return nil, fmt.Errorf("failed to encrypt field %s: %w", f.Name, err) - } - value = val + encrypted = true } } } fields = append(fields, DBField{ Name: tag, - Value: value, + Value: v.Field(i).Interface(), Type: f.Type, useDefault: useDefault, + encrypted: encrypted, }) } return fields, nil @@ -101,7 +99,7 @@ func toSnakeCase(s string) string { } func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { - fields, err := StructTagsToFields(s, secretsService) + fields, err := StructTagsToFields(s) if err != nil { return nil, err } @@ -110,12 +108,22 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service values := make([]any, 0, len(fields)) for _, f := range fields { + value := f.Value + if f.useDefault { continue } + if f.encrypted { + encValue, err := secretsService.Encrypt(value.(string)) + if err != nil { + return nil, err + } + value = encValue + } + columns = append(columns, f.Name) - values = append(values, f.Value) + values = append(values, value) } if len(columns) == 0 { From ccac439465e425ff990d43dd0455e4a2a2d12014 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 2 Mar 2025 18:40:12 +0100 Subject: [PATCH 19/39] Implement update struct --- server/internal/db/db.go | 27 +---------- server/internal/db/query.go | 37 +++++++------- server/internal/db/query_test.go | 20 ++++---- server/internal/db/sessions.go | 12 ++--- server/internal/db/struct_query.go | 45 +++++++++++++++-- server/internal/db/system.go | 10 ++-- server/internal/db/users.go | 30 ++++++------ server/internal/db/workspaces.go | 77 +++++++++++------------------- 8 files changed, 127 insertions(+), 131 deletions(-) diff --git a/server/internal/db/db.go b/server/internal/db/db.go index ff80725..e171158 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -195,29 +195,6 @@ func (db *database) Close() error { return nil } -// Helper methods for token encryption/decryption -func (db *database) encryptToken(token string) (string, error) { - if token == "" { - return "", nil - } - - encrypted, err := db.secretsService.Encrypt(token) - if err != nil { - return "", fmt.Errorf("failed to encrypt token: %w", err) - } - - return encrypted, nil -} - -func (db *database) decryptToken(token string) (string, error) { - if token == "" { - return "", nil - } - - decrypted, err := db.secretsService.Decrypt(token) - if err != nil { - return "", fmt.Errorf("failed to decrypt token: %w", err) - } - - return decrypted, nil +func (db *database) NewQuery() *Query { + return NewQuery(db.dbType, db.secretsService) } diff --git a/server/internal/db/query.go b/server/internal/db/query.go index a552a8e..192e848 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -2,6 +2,7 @@ package db import ( "fmt" + "lemma/internal/secrets" "strings" ) @@ -15,27 +16,29 @@ const ( // Query represents a SQL query with its parameters type Query struct { - builder strings.Builder - args []any - dbType DBType - pos int // tracks the current placeholder position - hasSelect bool - hasFrom bool - hasWhere bool - hasOrderBy bool - hasGroupBy bool - hasHaving bool - hasLimit bool - hasOffset bool - isInParens bool - parensDepth int + builder strings.Builder + args []any + dbType DBType + secretsService secrets.Service + pos int // tracks the current placeholder position + hasSelect bool + hasFrom bool + hasWhere bool + hasOrderBy bool + hasGroupBy bool + hasHaving bool + hasLimit bool + hasOffset bool + isInParens bool + parensDepth int } // NewQuery creates a new Query instance -func NewQuery(dbType DBType) *Query { +func NewQuery(dbType DBType, secretsService secrets.Service) *Query { return &Query{ - dbType: dbType, - args: make([]any, 0), + dbType: dbType, + secretsService: secretsService, + args: make([]any, 0), } } diff --git a/server/internal/db/query_test.go b/server/internal/db/query_test.go index 4910328..907b633 100644 --- a/server/internal/db/query_test.go +++ b/server/internal/db/query_test.go @@ -24,7 +24,7 @@ func TestNewQuery(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) // Test that a new query is empty if q.String() != "" { @@ -120,7 +120,7 @@ func TestBasicQueryBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -215,7 +215,7 @@ func TestPlaceholders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -328,7 +328,7 @@ func TestWhereClauseBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -403,7 +403,7 @@ func TestJoinClauseBuilding(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -482,7 +482,7 @@ func TestOrderLimitOffset(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -575,7 +575,7 @@ func TestInsertUpdateDelete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -641,7 +641,7 @@ func TestHavingClause(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() @@ -790,7 +790,7 @@ func TestQueryReturning(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - query := db.NewQuery(tc.dbType) + query := db.NewQuery(tc.dbType, &mockSecrets{}) result := tc.buildQuery(query) if result.String() != tc.expectedSQL { @@ -838,7 +838,7 @@ func TestComplexQueries(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - q := db.NewQuery(tt.dbType) + q := db.NewQuery(tt.dbType, &mockSecrets{}) q = tt.buildFn(q) gotSQL := q.String() diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index bdf55db..fdfcc7d 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -10,8 +10,8 @@ import ( // CreateSession inserts a new session record into the database func (db *database) CreateSession(session *models.Session) error { - query, err := NewQuery(db.dbType). - InsertStruct(session, "sessions", db.secretsService) + query, err := db.NewQuery(). + InsertStruct(session, "sessions") if err != nil { return fmt.Errorf("failed to create query: %w", err) } @@ -26,7 +26,7 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "user_id", "refresh_token", "expires_at", "created_at"). From("sessions"). Where("refresh_token = "). @@ -48,7 +48,7 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "user_id", "refresh_token", "expires_at", "created_at"). From("sessions"). Where("id = "). @@ -69,7 +69,7 @@ func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { // DeleteSession removes a session from the database func (db *database) DeleteSession(sessionID string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("sessions"). Where("id = "). @@ -95,7 +95,7 @@ func (db *database) DeleteSession(sessionID string) error { // CleanExpiredSessions removes all expired sessions from the database func (db *database) CleanExpiredSessions() error { log := getLogger().WithGroup("sessions") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("sessions"). Where("expires_at <="). diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index d6c8cde..efa6321 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -2,7 +2,6 @@ package db import ( "fmt" - "lemma/internal/secrets" "reflect" "strings" "unicode" @@ -98,7 +97,7 @@ func toSnakeCase(s string) string { return res } -func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service) (*Query, error) { +func (q *Query) InsertStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { return nil, err @@ -115,7 +114,7 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service } if f.encrypted { - encValue, err := secretsService.Encrypt(value.(string)) + encValue, err := q.secretsService.Encrypt(value.(string)) if err != nil { return nil, err } @@ -133,3 +132,43 @@ func (q *Query) InsertStruct(s any, table string, secretsService secrets.Service q.Insert(table, columns...).Values(len(columns)).AddArgs(values...) return q, nil } + +func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + if len(where) != len(args) { + return nil, fmt.Errorf("number of where clauses does not match number of arguments") + } + + q = q.Update("users") + + for _, f := range fields { + value := f.Value + + if f.useDefault { + continue + } + + if f.encrypted { + encValue, err := q.secretsService.Encrypt(value.(string)) + if err != nil { + return nil, err + } + value = encValue + } + + q = q.Set(f.Name).Placeholder(value) + } + + for i, w := range where { + if i != 0 && i < len(args) { + q = q.And(w) + } + q = q.Where(w).Placeholder(args[i]) + } + + return q, nil +} diff --git a/server/internal/db/system.go b/server/internal/db/system.go index d81aef3..a193ee7 100644 --- a/server/internal/db/system.go +++ b/server/internal/db/system.go @@ -49,7 +49,7 @@ func (db *database) EnsureJWTSecret() (string, error) { // GetSystemSetting retrieves a system setting by key func (db *database) GetSystemSetting(key string) (string, error) { var value string - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("value"). From("system_settings"). Where("key = "). @@ -64,7 +64,7 @@ func (db *database) GetSystemSetting(key string) (string, error) { // SetSystemSetting stores or updates a system setting func (db *database) SetSystemSetting(key, value string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Insert("system_settings", "key", "value"). Values(2). AddArgs(key, value). @@ -100,7 +100,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { stats := &UserStats{} // Get total users - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("COUNT(*)"). From("users") err := db.QueryRow(query.String()).Scan(&stats.TotalUsers) @@ -109,7 +109,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { } // Get total workspaces - query = NewQuery(db.dbType). + query = db.NewQuery(). Select("COUNT(*)"). From("workspaces") err = db.QueryRow(query.String()).Scan(&stats.TotalWorkspaces) @@ -118,7 +118,7 @@ func (db *database) GetSystemStats() (*UserStats, error) { } // Get active users (users with activity in last 30 days) - query = NewQuery(db.dbType). + query = db.NewQuery(). Select("COUNT(DISTINCT user_id)"). From("sessions"). Where("created_at > datetime('now', '-30 days')") diff --git a/server/internal/db/users.go b/server/internal/db/users.go index ed792d6..5d0cc55 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -17,8 +17,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } defer tx.Rollback() - query, err := NewQuery(db.dbType). - InsertStruct(user, "users", db.secretsService) + query, err := db.NewQuery(). + InsertStruct(user, "users") if err != nil { return nil, fmt.Errorf("failed to create query: %w", err) @@ -46,7 +46,7 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { } // Update user's last workspace ID - query = NewQuery(db.dbType). + query = db.NewQuery(). Update("users"). Set("last_workspace_id"). Placeholder(defaultWorkspace.ID). @@ -72,8 +72,8 @@ func (db *database) CreateUser(user *models.User) (*models.User, error) { func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) error { log := getLogger().WithGroup("users") - insertQuery, err := NewQuery(db.dbType). - InsertStruct(workspace, "workspaces", db.secretsService) + insertQuery, err := db.NewQuery(). + InsertStruct(workspace, "workspaces") if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -94,7 +94,7 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e // GetUserByID retrieves a user by its ID func (db *database) GetUserByID(id int) (*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). From("users"). Where("id = ").Placeholder(id) @@ -115,7 +115,7 @@ func (db *database) GetUserByID(id int) (*models.User, error) { // GetUserByEmail retrieves a user by its email func (db *database) GetUserByEmail(email string) (*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). From("users"). Where("email = ").Placeholder(email) @@ -137,7 +137,7 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) { // UpdateUser updates an existing user record in the database func (db *database) UpdateUser(user *models.User) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("users"). Set("email").Placeholder(user.Email). Set("display_name").Placeholder(user.DisplayName). @@ -165,7 +165,7 @@ func (db *database) UpdateUser(user *models.User) error { // GetAllUsers retrieves all users from the database func (db *database) GetAllUsers() ([]*models.User, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). From("users"). OrderBy("id ASC") @@ -200,7 +200,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error defer tx.Rollback() // Find workspace ID from name - workspaceQuery := NewQuery(db.dbType). + workspaceQuery := db.NewQuery(). Select("id"). From("workspaces"). Where("user_id = ").Placeholder(userID). @@ -213,7 +213,7 @@ func (db *database) UpdateLastWorkspace(userID int, workspaceName string) error } // Update user's last workspace - updateQuery := NewQuery(db.dbType). + updateQuery := db.NewQuery(). Update("users"). Set("last_workspace_id").Placeholder(workspaceID). Where("id = ").Placeholder(userID) @@ -245,7 +245,7 @@ func (db *database) DeleteUser(id int) error { // Delete all user's workspaces first log.Debug("deleting user workspaces", "user_id", id) - deleteWorkspacesQuery := NewQuery(db.dbType). + deleteWorkspacesQuery := db.NewQuery(). Delete(). From("workspaces"). Where("user_id = ").Placeholder(id) @@ -256,7 +256,7 @@ func (db *database) DeleteUser(id int) error { } // Delete the user - deleteUserQuery := NewQuery(db.dbType). + deleteUserQuery := db.NewQuery(). Delete(). From("users"). Where("id = ").Placeholder(id) @@ -277,7 +277,7 @@ func (db *database) DeleteUser(id int) error { // GetLastWorkspaceName retrieves the name of the last workspace accessed by a user func (db *database) GetLastWorkspaceName(userID int) (string, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("w.name"). From("workspaces w"). Join(InnerJoin, "users u", "u.last_workspace_id = w.id"). @@ -298,7 +298,7 @@ func (db *database) GetLastWorkspaceName(userID int) (string, error) { // CountAdminUsers returns the number of admin users in the system func (db *database) CountAdminUsers() (int, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("COUNT(*)"). From("users"). Where("role = ").Placeholder(models.RoleAdmin) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 120033f..64ebd7e 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -19,7 +19,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { workspace.SetDefaultSettings() } - query, err := NewQuery(db.dbType). + query, err := db.NewQuery(). InsertStruct(workspace, "workspaces", db.secretsService) if err != nil { @@ -39,7 +39,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { // GetWorkspaceByID retrieves a workspace by its ID func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -86,7 +86,7 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { // GetWorkspaceByName retrieves a workspace by its name and user ID func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -133,28 +133,14 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model // UpdateWorkspace updates a workspace record in the database func (db *database) UpdateWorkspace(workspace *models.Workspace) error { - // Encrypt token before storing - encryptedToken, err := db.encryptToken(workspace.GitToken) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) - } - query := NewQuery(db.dbType). - Update("workspaces"). - Set("name").Placeholder(workspace.Name). - Set("theme").Placeholder(workspace.Theme). - Set("auto_save").Placeholder(workspace.AutoSave). - Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). - Set("git_enabled").Placeholder(workspace.GitEnabled). - Set("git_url").Placeholder(workspace.GitURL). - Set("git_user").Placeholder(workspace.GitUser). - Set("git_token").Placeholder(encryptedToken). - Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). - Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). - Set("git_commit_name").Placeholder(workspace.GitCommitName). - Set("git_commit_email").Placeholder(workspace.GitCommitEmail). - Where("id = ").Placeholder(workspace.ID). - And("user_id = ").Placeholder(workspace.UserID) + query := db.NewQuery() + query, err := query. + UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID}) + + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } _, err = db.Exec(query.String(), query.Args()...) if err != nil { @@ -166,7 +152,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // GetWorkspacesByUserID retrieves all workspaces for a user func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", @@ -222,26 +208,17 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro // UpdateWorkspaceSettings updates only the settings portion of a workspace func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { - // Encrypt token before storing - encryptedToken, err := db.encryptToken(workspace.GitToken) - if err != nil { - return fmt.Errorf("failed to encrypt token: %w", err) - } - query := NewQuery(db.dbType). - Update("workspaces"). - Set("theme").Placeholder(workspace.Theme). - Set("auto_save").Placeholder(workspace.AutoSave). - Set("show_hidden_files").Placeholder(workspace.ShowHiddenFiles). - Set("git_enabled").Placeholder(workspace.GitEnabled). - Set("git_url").Placeholder(workspace.GitURL). - Set("git_user").Placeholder(workspace.GitUser). - Set("git_token").Placeholder(encryptedToken). - Set("git_auto_commit").Placeholder(workspace.GitAutoCommit). - Set("git_commit_msg_template").Placeholder(workspace.GitCommitMsgTemplate). - Set("git_commit_name").Placeholder(workspace.GitCommitName). - Set("git_commit_email").Placeholder(workspace.GitCommitEmail). - Where("id = ").Placeholder(workspace.ID) + where := []string{"id ="} + args := []interface{}{workspace.ID} + + query := db.NewQuery() + query, err := query. + UpdateStruct(workspace, "workspaces", where, args) + + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } _, err = db.Exec(query.String(), query.Args()...) if err != nil { @@ -255,7 +232,7 @@ func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { func (db *database) DeleteWorkspace(id int) error { log := getLogger().WithGroup("workspaces") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("workspaces"). Where("id = ").Placeholder(id) @@ -273,7 +250,7 @@ func (db *database) DeleteWorkspace(id int) error { func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { log := getLogger().WithGroup("workspaces") - query := NewQuery(db.dbType). + query := db.NewQuery(). Delete(). From("workspaces"). Where("id = ").Placeholder(id) @@ -294,7 +271,7 @@ func (db *database) DeleteWorkspaceTx(tx *sql.Tx, id int) error { // UpdateLastWorkspaceTx sets the last workspace for a user in a transaction func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("users"). Set("last_workspace_id").Placeholder(workspaceID). Where("id = ").Placeholder(userID) @@ -314,7 +291,7 @@ func (db *database) UpdateLastWorkspaceTx(tx *sql.Tx, userID, workspaceID int) e // UpdateLastOpenedFile updates the last opened file path for a workspace func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error { - query := NewQuery(db.dbType). + query := db.NewQuery(). Update("workspaces"). Set("last_opened_file_path").Placeholder(filePath). Where("id = ").Placeholder(workspaceID) @@ -329,7 +306,7 @@ func (db *database) UpdateLastOpenedFile(workspaceID int, filePath string) error // GetLastOpenedFile retrieves the last opened file path for a workspace func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select("last_opened_file_path"). From("workspaces"). Where("id = ").Placeholder(workspaceID) @@ -353,7 +330,7 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { // GetAllWorkspaces retrieves all workspaces in the database func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { - query := NewQuery(db.dbType). + query := db.NewQuery(). Select( "id", "user_id", "name", "created_at", "theme", "auto_save", "show_hidden_files", From 5fd9755f120293d90f352dbd654c40e66fda4500 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Sun, 2 Mar 2025 21:54:10 +0100 Subject: [PATCH 20/39] Implement scan struct --- server/internal/db/struct_query.go | 84 ++++++++++++++++++++++++++---- 1 file changed, 74 insertions(+), 10 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index efa6321..3484307 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -1,6 +1,7 @@ package db import ( + "database/sql" "fmt" "reflect" "strings" @@ -8,11 +9,12 @@ import ( ) type DBField struct { - Name string - Value any - Type reflect.Type - useDefault bool - encrypted bool + Name string + Value any + Type reflect.Type + OriginalName string + useDefault bool + encrypted bool } func StructTagsToFields(s any) ([]DBField, error) { @@ -71,11 +73,12 @@ func StructTagsToFields(s any) ([]DBField, error) { } fields = append(fields, DBField{ - Name: tag, - Value: v.Field(i).Interface(), - Type: f.Type, - useDefault: useDefault, - encrypted: encrypted, + Name: tag, + Value: v.Field(i).Interface(), + Type: f.Type, + OriginalName: f.Name, + useDefault: useDefault, + encrypted: encrypted, }) } return fields, nil @@ -172,3 +175,64 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* return q, nil } + +func (db *database) ScanStruct(row *sql.Row, dest any) error { + // Get the fields of the destination struct + fields, err := StructTagsToFields(dest) + if err != nil { + return fmt.Errorf("failed to extract struct fields: %w", err) + } + + // Create a slice of pointers to hold the scan destinations + scanDest := make([]interface{}, len(fields)) + destVal := reflect.ValueOf(dest).Elem() + + var fieldsToDecrypt []string + nullStringIndexes := make(map[int]reflect.Value) + + for i, field := range fields { + // Find the field in the struct + structField := destVal.FieldByName(field.OriginalName) + if !structField.IsValid() { + return fmt.Errorf("struct field %s not found", field.OriginalName) + } + + if field.encrypted { + fieldsToDecrypt = append(fieldsToDecrypt, field.OriginalName) + } + + if structField.Kind() == reflect.String { + nullStringIndexes[i] = structField + + var ns sql.NullString + scanDest[i] = &ns + } else { + scanDest[i] = structField.Addr().Interface() + } + } + + // Scan the row into the destination pointers + if err := row.Scan(scanDest...); err != nil { + return err + } + + // Set null strings to nil if they are null + for i, field := range nullStringIndexes { + ns := scanDest[i].(*sql.NullString) + if ns.Valid { + field.SetString(ns.String) + } + } + + // Decrypt encrypted fields + for _, fieldName := range fieldsToDecrypt { + field := destVal.FieldByName(fieldName) + decValue, err := db.secretsService.Decrypt(field.Interface().(string)) + if err != nil { + return err + } + field.SetString(decValue) + } + + return nil +} From 829b359e82b99ea660353d651a9f53f0214bc1e1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 3 Mar 2025 21:36:04 +0100 Subject: [PATCH 21/39] Implement select struct and scan struct --- server/internal/db/struct_query.go | 134 ++++++++++++++++++--- server/internal/db/workspaces.go | 187 ++++++----------------------- 2 files changed, 158 insertions(+), 163 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index 3484307..ec86764 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "reflect" + "sort" "strings" "unicode" ) @@ -81,6 +82,11 @@ func StructTagsToFields(s any) ([]DBField, error) { encrypted: encrypted, }) } + + sort.Slice(fields, func(i, j int) bool { + return fields[i].Name < fields[j].Name + }) + return fields, nil } @@ -176,17 +182,34 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* return q, nil } -func (db *database) ScanStruct(row *sql.Row, dest any) error { - // Get the fields of the destination struct - fields, err := StructTagsToFields(dest) +func (q *Query) SelectStruct(s any, table string) (*Query, error) { + fields, err := StructTagsToFields(s) + if err != nil { + return nil, err + } + + columns := make([]string, 0, len(fields)) + for _, f := range fields { + columns = append(columns, f.Name) + } + + q = q.Select(columns...).From(table) + return q, nil +} + +// Scanner is an interface that both sql.Row and sql.Rows satisfy +type Scanner interface { + Scan(dest ...interface{}) error +} + +// scanStructInstance is an internal function that handles the scanning logic for a single instance +func (db *database) scanStructInstance(destVal reflect.Value, scanner Scanner) error { + fields, err := StructTagsToFields(destVal.Interface()) if err != nil { return fmt.Errorf("failed to extract struct fields: %w", err) } - // Create a slice of pointers to hold the scan destinations scanDest := make([]interface{}, len(fields)) - destVal := reflect.ValueOf(dest).Elem() - var fieldsToDecrypt []string nullStringIndexes := make(map[int]reflect.Value) @@ -202,8 +225,8 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { } if structField.Kind() == reflect.String { + // Handle null strings separately nullStringIndexes[i] = structField - var ns sql.NullString scanDest[i] = &ns } else { @@ -211,12 +234,12 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { } } - // Scan the row into the destination pointers - if err := row.Scan(scanDest...); err != nil { + // Scan using the scanner interface + if err := scanner.Scan(scanDest...); err != nil { return err } - // Set null strings to nil if they are null + // Set null strings to their values if they are valid for i, field := range nullStringIndexes { ns := scanDest[i].(*sql.NullString) if ns.Valid { @@ -227,11 +250,94 @@ func (db *database) ScanStruct(row *sql.Row, dest any) error { // Decrypt encrypted fields for _, fieldName := range fieldsToDecrypt { field := destVal.FieldByName(fieldName) - decValue, err := db.secretsService.Decrypt(field.Interface().(string)) - if err != nil { - return err + if !field.IsZero() { + decValue, err := db.secretsService.Decrypt(field.Interface().(string)) + if err != nil { + return err + } + field.SetString(decValue) } - field.SetString(decValue) + } + + return nil +} + +// ScanStruct scans a single row into a struct +func (db *database) ScanStruct(row *sql.Row, dest any) error { + if row == nil { + return fmt.Errorf("row cannot be nil") + } + + if row.Err() != nil { + return row.Err() + } + + // Get the destination value + destVal := reflect.ValueOf(dest) + if destVal.Kind() != reflect.Ptr || destVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer to a struct, got %T", dest) + } + + destVal = destVal.Elem() + if destVal.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a pointer to a struct, got pointer to %s", destVal.Kind()) + } + + return db.scanStructInstance(destVal, row) +} + +// ScanStructs scans multiple rows into a slice of structs +func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error { + if rows == nil { + return fmt.Errorf("rows cannot be nil") + } + defer rows.Close() + + // Get the slice value and element type + sliceVal := reflect.ValueOf(destSlice) + if sliceVal.Kind() != reflect.Ptr || sliceVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer to a slice, got %T", destSlice) + } + + sliceVal = sliceVal.Elem() + if sliceVal.Kind() != reflect.Slice { + return fmt.Errorf("destination must be a pointer to a slice, got pointer to %s", sliceVal.Kind()) + } + + // Get the element type of the slice + elemType := sliceVal.Type().Elem() + + // Check if we have a direct struct type or a pointer to struct + isPtr := elemType.Kind() == reflect.Ptr + structType := elemType + if isPtr { + structType = elemType.Elem() + } + if structType.Kind() != reflect.Struct { + return fmt.Errorf("slice element type must be a struct or pointer to struct, got %s", elemType.String()) + } + + // Process each row + for rows.Next() { + // Create a new instance of the struct for each row + newElem := reflect.New(structType).Elem() + + // Scan this row into the new element + if err := db.scanStructInstance(newElem, rows); err != nil { + return err + } + + // Add the new element to the result slice + if isPtr { + sliceVal.Set(reflect.Append(sliceVal, newElem.Addr())) + } else { + sliceVal.Set(reflect.Append(sliceVal, newElem)) + } + } + + // Check for errors from iterating over rows + if err := rows.Err(); err != nil { + return err } return nil diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 64ebd7e..1783ad3 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -20,7 +20,7 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { } query, err := db.NewQuery(). - InsertStruct(workspace, "workspaces", db.secretsService) + InsertStruct(workspace, "workspaces") if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -39,30 +39,16 @@ func (db *database) CreateWorkspace(workspace *models.Workspace) error { // GetWorkspaceByID retrieves a workspace by its ID func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("id = ").Placeholder(id) - workspace := &models.Workspace{} - var encryptedToken string + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(id) - var lastOpenedFile sql.NullString - - err := db.QueryRow(query.String(), query.Args()...).Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, workspace) if err == sql.ErrNoRows { return nil, fmt.Errorf("workspace not found") @@ -71,45 +57,22 @@ func (db *database) GetWorkspaceByID(id int) (*models.Workspace, error) { return nil, fmt.Errorf("failed to fetch workspace: %w", err) } - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - return workspace, nil } // GetWorkspaceByName retrieves a workspace by its name and user ID func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("user_id = ").Placeholder(userID). + workspace := &models.Workspace{} + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("user_id = ").Placeholder(userID). And("name = ").Placeholder(workspaceName) - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - - err := db.QueryRow(query.String(), query.Args()...).Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, workspace) if err == sql.ErrNoRows { return nil, fmt.Errorf("workspace not found") @@ -118,16 +81,6 @@ func (db *database) GetWorkspaceByName(userID int, workspaceName string) (*model return nil, fmt.Errorf("failed to fetch workspace: %w", err) } - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - return workspace, nil } @@ -136,7 +89,7 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []interface{}{workspace.ID, workspace.UserID}) + UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []any{workspace.ID, workspace.UserID}) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -152,16 +105,13 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { // GetWorkspacesByUserID retrieves all workspaces for a user func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces"). - Where("user_id = ").Placeholder(userID) + workspace := &models.Workspace{} + query := db.NewQuery() + query, err := query.SelectStruct(workspace, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("user_id = ").Placeholder(userID) rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -170,37 +120,9 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro defer rows.Close() var workspaces []*models.Workspace - for rows.Next() { - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - err := rows.Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan workspace row: %w", err) - } - - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - - workspaces = append(workspaces, workspace) - } - - if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating workspace rows: %w", err) + err = db.ScanStructs(rows, workspaces) + if err != nil { + return nil, fmt.Errorf("failed to scan workspaces: %w", err) } return workspaces, nil @@ -210,7 +132,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { where := []string{"id ="} - args := []interface{}{workspace.ID} + args := []any{workspace.ID} query := db.NewQuery() query, err := query. @@ -330,15 +252,11 @@ func (db *database) GetLastOpenedFile(workspaceID int) (string, error) { // GetAllWorkspaces retrieves all workspaces in the database func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { - query := db.NewQuery(). - Select( - "id", "user_id", "name", "created_at", - "theme", "auto_save", "show_hidden_files", - "git_enabled", "git_url", "git_user", "git_token", - "git_auto_commit", "git_commit_msg_template", - "git_commit_name", "git_commit_email", - "last_opened_file_path"). - From("workspaces") + query := db.NewQuery() + query, err := query.SelectStruct(&models.Workspace{}, "workspaces") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -347,38 +265,9 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { defer rows.Close() var workspaces []*models.Workspace - for rows.Next() { - workspace := &models.Workspace{} - var encryptedToken string - var lastOpenedFile sql.NullString - - err := rows.Scan( - &workspace.ID, &workspace.UserID, &workspace.Name, &workspace.CreatedAt, - &workspace.Theme, &workspace.AutoSave, &workspace.ShowHiddenFiles, - &workspace.GitEnabled, &workspace.GitURL, &workspace.GitUser, &encryptedToken, - &workspace.GitAutoCommit, &workspace.GitCommitMsgTemplate, - &workspace.GitCommitName, &workspace.GitCommitEmail, - &lastOpenedFile, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan workspace row: %w", err) - } - - if lastOpenedFile.Valid { - workspace.LastOpenedFilePath = lastOpenedFile.String - } - - // Decrypt token - workspace.GitToken, err = db.decryptToken(encryptedToken) - if err != nil { - return nil, fmt.Errorf("failed to decrypt token: %w", err) - } - - workspaces = append(workspaces, workspace) - } - - if err = rows.Err(); err != nil { - return nil, fmt.Errorf("error iterating workspace rows: %w", err) + err = db.ScanStructs(rows, workspaces) + if err != nil { + return nil, fmt.Errorf("failed to scan workspaces: %w", err) } return workspaces, nil From 0f97927219b4fc4f23a955c9eaf04cdcf5ec2327 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Mon, 3 Mar 2025 22:04:38 +0100 Subject: [PATCH 22/39] Rework UpdateStruct --- server/internal/db/struct_query.go | 16 ++-------------- server/internal/db/workspaces.go | 13 ++++++------- 2 files changed, 8 insertions(+), 21 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index ec86764..a4d8ff7 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -142,17 +142,13 @@ func (q *Query) InsertStruct(s any, table string) (*Query, error) { return q, nil } -func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (*Query, error) { +func (q *Query) UpdateStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { return nil, err } - if len(where) != len(args) { - return nil, fmt.Errorf("number of where clauses does not match number of arguments") - } - - q = q.Update("users") + q = q.Update(table) for _, f := range fields { value := f.Value @@ -172,13 +168,6 @@ func (q *Query) UpdateStruct(s any, table string, where []string, args []any) (* q = q.Set(f.Name).Placeholder(value) } - for i, w := range where { - if i != 0 && i < len(args) { - q = q.And(w) - } - q = q.Where(w).Placeholder(args[i]) - } - return q, nil } @@ -291,7 +280,6 @@ func (db *database) ScanStructs(rows *sql.Rows, destSlice any) error { if rows == nil { return fmt.Errorf("rows cannot be nil") } - defer rows.Close() // Get the slice value and element type sliceVal := reflect.ValueOf(destSlice) diff --git a/server/internal/db/workspaces.go b/server/internal/db/workspaces.go index 1783ad3..348e841 100644 --- a/server/internal/db/workspaces.go +++ b/server/internal/db/workspaces.go @@ -89,7 +89,8 @@ func (db *database) UpdateWorkspace(workspace *models.Workspace) error { query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", []string{"id =", "user_id ="}, []any{workspace.ID, workspace.UserID}) + UpdateStruct(workspace, "workspaces") + query = query.Where("id =").Placeholder(workspace.ID).And("user_id =").Placeholder(workspace.UserID) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -120,7 +121,7 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro defer rows.Close() var workspaces []*models.Workspace - err = db.ScanStructs(rows, workspaces) + err = db.ScanStructs(rows, &workspaces) if err != nil { return nil, fmt.Errorf("failed to scan workspaces: %w", err) } @@ -131,12 +132,10 @@ func (db *database) GetWorkspacesByUserID(userID int) ([]*models.Workspace, erro // UpdateWorkspaceSettings updates only the settings portion of a workspace func (db *database) UpdateWorkspaceSettings(workspace *models.Workspace) error { - where := []string{"id ="} - args := []any{workspace.ID} - query := db.NewQuery() query, err := query. - UpdateStruct(workspace, "workspaces", where, args) + UpdateStruct(workspace, "workspaces") + query = query.Where("id =").Placeholder(workspace.ID) if err != nil { return fmt.Errorf("failed to create query: %w", err) @@ -265,7 +264,7 @@ func (db *database) GetAllWorkspaces() ([]*models.Workspace, error) { defer rows.Close() var workspaces []*models.Workspace - err = db.ScanStructs(rows, workspaces) + err = db.ScanStructs(rows, &workspaces) if err != nil { return nil, fmt.Errorf("failed to scan workspaces: %w", err) } From 976425d660d18eba7e669410cc8d82e6b18d0386 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:07:05 +0100 Subject: [PATCH 23/39] Use ScanStruct in sessions --- server/internal/db/sessions.go | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/server/internal/db/sessions.go b/server/internal/db/sessions.go index fdfcc7d..4d69115 100644 --- a/server/internal/db/sessions.go +++ b/server/internal/db/sessions.go @@ -26,15 +26,18 @@ func (db *database) CreateSession(session *models.Session) error { // GetSessionByRefreshToken retrieves a session by its refresh token func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Session, error) { session := &models.Session{} - query := db.NewQuery(). - Select("id", "user_id", "refresh_token", "expires_at", "created_at"). - From("sessions"). - Where("refresh_token = "). + query := db.NewQuery() + query, err := query.SelectStruct(session, "sessions") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("refresh_token = "). Placeholder(refreshToken). And("expires_at >"). Placeholder(time.Now()) - err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, session) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found or expired") } @@ -48,15 +51,18 @@ func (db *database) GetSessionByRefreshToken(refreshToken string) (*models.Sessi // GetSessionByID retrieves a session by its ID func (db *database) GetSessionByID(sessionID string) (*models.Session, error) { session := &models.Session{} - query := db.NewQuery(). - Select("id", "user_id", "refresh_token", "expires_at", "created_at"). - From("sessions"). - Where("id = "). + query := db.NewQuery() + query, err := query.SelectStruct(session, "sessions") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = "). Placeholder(sessionID). And("expires_at >"). Placeholder(time.Now()) - err := db.QueryRow(query.String(), query.Args()...).Scan(&session.ID, &session.UserID, &session.RefreshToken, &session.ExpiresAt, &session.CreatedAt) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, session) if err == sql.ErrNoRows { return nil, fmt.Errorf("session not found") } From 3aa8c838e8d4fb4eea03ffff756611b91ff88faf Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:20:57 +0100 Subject: [PATCH 24/39] Use struct queries in users --- server/internal/db/users.go | 72 +++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/server/internal/db/users.go b/server/internal/db/users.go index 5d0cc55..f6b00e2 100644 --- a/server/internal/db/users.go +++ b/server/internal/db/users.go @@ -94,16 +94,16 @@ func (db *database) createWorkspaceTx(tx *sql.Tx, workspace *models.Workspace) e // GetUserByID retrieves a user by its ID func (db *database) GetUserByID(id int) (*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). - From("users"). - Where("id = ").Placeholder(id) - user := &models.User{} - err := db.QueryRow(query.String(), query.Args()...). - Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, - &user.Role, &user.CreatedAt, &user.LastWorkspaceID) + query := db.NewQuery() + query, err := query.SelectStruct(user, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(id) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, user) if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") } @@ -115,15 +115,16 @@ func (db *database) GetUserByID(id int) (*models.User, error) { // GetUserByEmail retrieves a user by its email func (db *database) GetUserByEmail(email string) (*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "password_hash", "role", "created_at", "last_workspace_id"). - From("users"). - Where("email = ").Placeholder(email) - user := &models.User{} - err := db.QueryRow(query.String(), query.Args()...). - Scan(&user.ID, &user.Email, &user.DisplayName, &user.PasswordHash, - &user.Role, &user.CreatedAt, &user.LastWorkspaceID) + query := db.NewQuery() + query, err := query.SelectStruct(user, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + + query = query.Where("email = ").Placeholder(email) + row := db.QueryRow(query.String(), query.Args()...) + err = db.ScanStruct(row, user) if err == sql.ErrNoRows { return nil, fmt.Errorf("user not found") @@ -137,14 +138,12 @@ func (db *database) GetUserByEmail(email string) (*models.User, error) { // UpdateUser updates an existing user record in the database func (db *database) UpdateUser(user *models.User) error { - query := db.NewQuery(). - Update("users"). - Set("email").Placeholder(user.Email). - Set("display_name").Placeholder(user.DisplayName). - Set("password_hash").Placeholder(user.PasswordHash). - Set("role").Placeholder(user.Role). - Set("last_workspace_id").Placeholder(user.LastWorkspaceID). - Where("id = ").Placeholder(user.ID) + query := db.NewQuery() + query, err := query.UpdateStruct(user, "users") + if err != nil { + return fmt.Errorf("failed to create query: %w", err) + } + query = query.Where("id = ").Placeholder(user.ID) result, err := db.Exec(query.String(), query.Args()...) if err != nil { @@ -165,10 +164,12 @@ func (db *database) UpdateUser(user *models.User) error { // GetAllUsers retrieves all users from the database func (db *database) GetAllUsers() ([]*models.User, error) { - query := db.NewQuery(). - Select("id", "email", "display_name", "role", "created_at", "last_workspace_id"). - From("users"). - OrderBy("id ASC") + query := db.NewQuery() + query, err := query.SelectStruct(&models.User{}, "users") + if err != nil { + return nil, fmt.Errorf("failed to create query: %w", err) + } + query = query.OrderBy("id ASC") rows, err := db.Query(query.String(), query.Args()...) if err != nil { @@ -176,17 +177,10 @@ func (db *database) GetAllUsers() ([]*models.User, error) { } defer rows.Close() - var users []*models.User - for rows.Next() { - user := &models.User{} - err := rows.Scan( - &user.ID, &user.Email, &user.DisplayName, &user.Role, - &user.CreatedAt, &user.LastWorkspaceID, - ) - if err != nil { - return nil, fmt.Errorf("failed to scan user row: %w", err) - } - users = append(users, user) + users := []*models.User{} + err = db.ScanStructs(rows, &users) + if err != nil { + return nil, fmt.Errorf("failed to scan users: %w", err) } return users, nil From 52aa406c6d5ee13361d06c72b9c3b0e59e77ded1 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:30:40 +0100 Subject: [PATCH 25/39] Add docs comments to struct query --- server/internal/db/struct_query.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index a4d8ff7..d3d2ddd 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -18,6 +18,7 @@ type DBField struct { encrypted bool } +// StructTagsToFields converts a struct to a slice of DBField instances func StructTagsToFields(s any) ([]DBField, error) { v := reflect.ValueOf(s) @@ -106,6 +107,7 @@ func toSnakeCase(s string) string { return res } +// InsertStruct creates an INSERT query from a struct func (q *Query) InsertStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { @@ -142,6 +144,7 @@ func (q *Query) InsertStruct(s any, table string) (*Query, error) { return q, nil } +// UpdateStruct creates an UPDATE query from a struct func (q *Query) UpdateStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { @@ -171,6 +174,7 @@ func (q *Query) UpdateStruct(s any, table string) (*Query, error) { return q, nil } +// SelectStruct creates a SELECT query from a struct func (q *Query) SelectStruct(s any, table string) (*Query, error) { fields, err := StructTagsToFields(s) if err != nil { @@ -188,7 +192,7 @@ func (q *Query) SelectStruct(s any, table string) (*Query, error) { // Scanner is an interface that both sql.Row and sql.Rows satisfy type Scanner interface { - Scan(dest ...interface{}) error + Scan(dest ...any) error } // scanStructInstance is an internal function that handles the scanning logic for a single instance @@ -198,7 +202,7 @@ func (db *database) scanStructInstance(destVal reflect.Value, scanner Scanner) e return fmt.Errorf("failed to extract struct fields: %w", err) } - scanDest := make([]interface{}, len(fields)) + scanDest := make([]any, len(fields)) var fieldsToDecrypt []string nullStringIndexes := make(map[int]reflect.Value) From 904d4ce106b8b9ab5080b8467156ce2732ed709a Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 21:31:58 +0100 Subject: [PATCH 26/39] Update documentation --- server/documentation.md | 195 ++++++++++++++++++++++++++++++++-------- 1 file changed, 160 insertions(+), 35 deletions(-) diff --git a/server/documentation.md b/server/documentation.md index 21729a6..c6451df 100644 --- a/server/documentation.md +++ b/server/documentation.md @@ -57,14 +57,20 @@ package app // import "lemma/internal/app" Package app provides application-level functionality for initializing and running the server +FUNCTIONS + +func ParseDBURL(dbURL string) (db.DBType, string, error) + ParseDBURL parses a database URL and returns the driver name and data source + + TYPES type Config struct { - DBPath string + DBURL string + DBType db.DBType WorkDir string StaticPath string Port string - RootURL string Domain string CORSOrigins []string AdminEmail string @@ -281,6 +287,24 @@ const ( TYPES +type DBField struct { + Name string + Value any + Type reflect.Type + OriginalName string + + // Has unexported fields. +} + +func StructTagsToFields(s any) ([]DBField, error) + StructTagsToFields converts a struct to a slice of DBField instances + +type DBType string + +const ( + DBTypeSQLite DBType = "sqlite3" + DBTypePostgres DBType = "postgres" +) type Database interface { UserStore WorkspaceStore @@ -292,14 +316,115 @@ type Database interface { } Database defines the methods for interacting with the database -func Init(dbPath string, secretsService secrets.Service) (Database, error) +func Init(dbType DBType, dbURL string, secretsService secrets.Service) (Database, error) Init initializes the database connection -type Migration struct { - Version int - SQL string +type JoinType string + +const ( + InnerJoin JoinType = "INNER JOIN" + LeftJoin JoinType = "LEFT JOIN" + RightJoin JoinType = "RIGHT JOIN" +) +type Query struct { + // Has unexported fields. } - Migration represents a database migration + Query represents a SQL query with its parameters + +func NewQuery(dbType DBType, secretsService secrets.Service) *Query + NewQuery creates a new Query instance + +func (q *Query) AddArgs(args ...any) *Query + AddArgs adds arguments to the query + +func (q *Query) And(condition string) *Query + And adds an AND condition + +func (q *Query) Args() []any + Args returns the query arguments + +func (q *Query) Delete() *Query + Delete starts a DELETE statement + +func (q *Query) EndGroup() *Query + EndGroup ends a parenthetical group + +func (q *Query) From(table string) *Query + From adds a FROM clause + +func (q *Query) GroupBy(columns ...string) *Query + GroupBy adds a GROUP BY clause + +func (q *Query) Having(condition string) *Query + Having adds a HAVING clause for filtering groups + +func (q *Query) Insert(table string, columns ...string) *Query + Insert starts an INSERT statement + +func (q *Query) InsertStruct(s any, table string) (*Query, error) + InsertStruct creates an INSERT query from a struct + +func (q *Query) Join(joinType JoinType, table, condition string) *Query + Join adds a JOIN clause + +func (q *Query) Limit(limit int) *Query + Limit adds a LIMIT clause + +func (q *Query) Offset(offset int) *Query + Offset adds an OFFSET clause + +func (q *Query) Or(condition string) *Query + Or adds an OR condition + +func (q *Query) OrderBy(columns ...string) *Query + OrderBy adds an ORDER BY clause + +func (q *Query) Placeholder(arg any) *Query + Placeholder adds a placeholder for a single argument + +func (q *Query) Placeholders(n int) *Query + Placeholders adds n placeholders separated by commas + +func (q *Query) Returning(columns ...string) *Query + Returning adds a RETURNING clause for both PostgreSQL and SQLite (3.35.0+) + +func (q *Query) Select(columns ...string) *Query + Select adds a SELECT clause + +func (q *Query) SelectStruct(s any, table string) (*Query, error) + SelectStruct creates a SELECT query from a struct + +func (q *Query) Set(column string) *Query + Set adds a SET clause for updates + +func (q *Query) StartGroup() *Query + StartGroup starts a parenthetical group + +func (q *Query) String() string + String returns the formatted query string + +func (q *Query) Update(table string) *Query + Update starts an UPDATE statement + +func (q *Query) UpdateStruct(s any, table string) (*Query, error) + UpdateStruct creates an UPDATE query from a struct + +func (q *Query) Values(count int) *Query + Values adds a VALUES clause + +func (q *Query) Where(condition string) *Query + Where adds a WHERE clause + +func (q *Query) WhereIn(column string, count int) *Query + WhereIn adds a WHERE IN clause + +func (q *Query) Write(s string) *Query + Write adds a string to the query + +type Scanner interface { + Scan(dest ...any) error +} + Scanner is an interface that both sql.Row and sql.Rows satisfy type SessionStore interface { CreateSession(session *models.Session) error @@ -897,22 +1022,22 @@ and serialize data in the application. TYPES type Session struct { - ID string // Unique session identifier - UserID int // ID of the user this session belongs to - RefreshToken string // The refresh token associated with this session - ExpiresAt time.Time // When this session expires - CreatedAt time.Time // When this session was created + ID string `db:"id"` // Unique session identifier + UserID int `db:"user_id"` // ID of the user this session belongs to + RefreshToken string `db:"refresh_token"` // The refresh token associated with this session + ExpiresAt time.Time `db:"expires_at"` // When this session expires + CreatedAt time.Time `db:"created_at,default"` // When this session was created } Session represents a user session in the database type User struct { - ID int `json:"id" validate:"required,min=1"` - Email string `json:"email" validate:"required,email"` - DisplayName string `json:"displayName"` - PasswordHash string `json:"-"` - Role UserRole `json:"role" validate:"required,oneof=admin editor viewer"` - CreatedAt time.Time `json:"createdAt"` - LastWorkspaceID int `json:"lastWorkspaceId"` + ID int `json:"id" db:"id,default" validate:"required,min=1"` + Email string `json:"email" db:"email" validate:"required,email"` + DisplayName string `json:"displayName" db:"display_name"` + PasswordHash string `json:"-" db:"password_hash"` + Role UserRole `json:"role" db:"role" validate:"required,oneof=admin editor viewer"` + CreatedAt time.Time `json:"createdAt" db:"created_at,default"` + LastWorkspaceID int `json:"lastWorkspaceId" db:"last_workspace_id"` } User represents a user in the system @@ -930,24 +1055,24 @@ const ( User roles type Workspace struct { - ID int `json:"id" validate:"required,min=1"` - UserID int `json:"userId" validate:"required,min=1"` - Name string `json:"name" validate:"required"` - CreatedAt time.Time `json:"createdAt"` - LastOpenedFilePath string `json:"lastOpenedFilePath"` + ID int `json:"id" db:"id,default" validate:"required,min=1"` + UserID int `json:"userId" db:"user_id" validate:"required,min=1"` + Name string `json:"name" db:"name" validate:"required"` + CreatedAt time.Time `json:"createdAt" db:"created_at,default"` + LastOpenedFilePath string `json:"lastOpenedFilePath" db:"last_opened_file_path"` // Integrated settings - Theme string `json:"theme" validate:"oneof=light dark"` - AutoSave bool `json:"autoSave"` - ShowHiddenFiles bool `json:"showHiddenFiles"` - GitEnabled bool `json:"gitEnabled"` - GitURL string `json:"gitUrl" validate:"required_if=GitEnabled true"` - GitUser string `json:"gitUser" validate:"required_if=GitEnabled true"` - GitToken string `json:"gitToken" validate:"required_if=GitEnabled true"` - GitAutoCommit bool `json:"gitAutoCommit"` - GitCommitMsgTemplate string `json:"gitCommitMsgTemplate"` - GitCommitName string `json:"gitCommitName"` - GitCommitEmail string `json:"gitCommitEmail" validate:"omitempty,required_if=GitEnabled true,email"` + Theme string `json:"theme" db:"theme" validate:"oneof=light dark"` + AutoSave bool `json:"autoSave" db:"auto_save"` + ShowHiddenFiles bool `json:"showHiddenFiles" db:"show_hidden_files"` + GitEnabled bool `json:"gitEnabled" db:"git_enabled"` + GitURL string `json:"gitUrl" db:"git_url,ommitempty" validate:"required_if=GitEnabled true"` + GitUser string `json:"gitUser" db:"git_user,ommitempty" validate:"required_if=GitEnabled true"` + GitToken string `json:"gitToken" db:"git_token,ommitempty,encrypted" validate:"required_if=GitEnabled true"` + GitAutoCommit bool `json:"gitAutoCommit" db:"git_auto_commit"` + GitCommitMsgTemplate string `json:"gitCommitMsgTemplate" db:"git_commit_msg_template"` + GitCommitName string `json:"gitCommitName" db:"git_commit_name"` + GitCommitEmail string `json:"gitCommitEmail" db:"git_commit_email" validate:"omitempty,required_if=GitEnabled true,email"` } Workspace represents a user's workspace in the system From 7e9aab01cb3135dcffcc54568a7744f95f8794dd Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 22:03:32 +0100 Subject: [PATCH 27/39] Fix ommit empty db tag --- server/internal/db/struct_query.go | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/server/internal/db/struct_query.go b/server/internal/db/struct_query.go index d3d2ddd..9230fdb 100644 --- a/server/internal/db/struct_query.go +++ b/server/internal/db/struct_query.go @@ -55,6 +55,7 @@ func StructTagsToFields(s any) ([]DBField, error) { useDefault := false encrypted := false + ommit := false if strings.Contains(tag, ",") { parts := strings.Split(tag, ",") @@ -64,7 +65,7 @@ func StructTagsToFields(s any) ([]DBField, error) { switch opt { case "omitempty": if reflect.DeepEqual(v.Field(i).Interface(), reflect.Zero(f.Type).Interface()) { - continue + ommit = true } case "default": useDefault = true @@ -74,6 +75,10 @@ func StructTagsToFields(s any) ([]DBField, error) { } } + if ommit { + continue + } + fields = append(fields, DBField{ Name: tag, Value: v.Field(i).Interface(), From 4766a166df9880b69620c31f95e0abb576257400 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Wed, 5 Mar 2025 22:36:27 +0100 Subject: [PATCH 28/39] Add struct_query tests --- server/internal/db/db.go | 7 + server/internal/db/struct_query_test.go | 507 ++++++++++++++++++++++++ 2 files changed, 514 insertions(+) create mode 100644 server/internal/db/struct_query_test.go diff --git a/server/internal/db/db.go b/server/internal/db/db.go index e171158..e4e873b 100644 --- a/server/internal/db/db.go +++ b/server/internal/db/db.go @@ -76,12 +76,18 @@ type SystemStore interface { SetSystemSetting(key, value string) error } +type StructScanner interface { + ScanStruct(row *sql.Row, dest interface{}) error + ScanStructs(rows *sql.Rows, dest interface{}) error +} + // Database defines the methods for interacting with the database type Database interface { UserStore WorkspaceStore SessionStore SystemStore + StructScanner Begin() (*sql.Tx, error) Close() error Migrate() error @@ -101,6 +107,7 @@ var ( // Sub-interfaces _ WorkspaceReader = (*database)(nil) _ WorkspaceWriter = (*database)(nil) + _ StructScanner = (*database)(nil) ) var logger logging.Logger diff --git a/server/internal/db/struct_query_test.go b/server/internal/db/struct_query_test.go new file mode 100644 index 0000000..7c28799 --- /dev/null +++ b/server/internal/db/struct_query_test.go @@ -0,0 +1,507 @@ +package db_test + +import ( + "reflect" + "testing" + "time" + + "lemma/internal/db" + "lemma/internal/models" + _ "lemma/internal/testenv" +) + +// TestStructTagsToFields tests the exported StructTagsToFields function +func TestStructTagsToFields(t *testing.T) { + type testStruct struct { + ID int `db:"id"` + Name string `db:"custom_name"` + CreatedAt time.Time `db:"created_at,default"` + Skip string `db:"-"` + Empty string `db:"empty,omitempty"` + Secret string `db:"secret,encrypted"` + NoTag string + } + + tests := []struct { + name string + input interface{} + wantFields int + wantErr bool + }{ + { + name: "valid struct", + input: testStruct{ + ID: 1, + Name: "Test", + CreatedAt: time.Now(), + Skip: "skip me", + Secret: "secret value", + NoTag: "no tag", + }, + wantFields: 5, // ID, Name, CreatedAt, Secret, NoTag (Empty is omitted) + wantErr: false, + }, + { + name: "nil pointer", + input: (*testStruct)(nil), + wantFields: 0, + wantErr: true, + }, + { + name: "non-struct", + input: "not a struct", + wantFields: 0, + wantErr: true, + }, + { + name: "struct pointer", + input: &testStruct{ + ID: 2, + Name: "Test Pointer", + }, + wantFields: 5, // Same fields as above + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fields, err := db.StructTagsToFields(tt.input) + + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if len(fields) != tt.wantFields { + t.Errorf("Expected %d fields, got %d", tt.wantFields, len(fields)) + } + + // Check specific field handling for valid struct test + if tt.name == "valid struct" { + // Find fields by name + var idField, nameField, createdAtField, secretField, emptyField, noTagField *db.DBField + for i := range fields { + f := &fields[i] + switch f.Name { + case "id": + idField = f + case "custom_name": + nameField = f + case "created_at": + createdAtField = f + case "secret": + secretField = f + case "empty": + emptyField = f + case "no_tag": + noTagField = f + } + } + + // Check fields exist + if idField == nil { + t.Error("ID field not found") + } + if nameField == nil { + t.Error("Name field not found") + } + if createdAtField == nil { + t.Error("CreatedAt field not found") + } + if secretField == nil { + t.Error("Secret field not found") + } + if noTagField == nil { + t.Error("NoTag field not found") + } + if emptyField != nil { + t.Error("Empty field should be omitted") + } + + // Check original names + if idField != nil && idField.OriginalName != "ID" { + t.Errorf("Expected OriginalName 'ID', got '%s'", idField.OriginalName) + } + if nameField != nil && nameField.OriginalName != "Name" { + t.Errorf("Expected OriginalName 'Name', got '%s'", nameField.OriginalName) + } + } + }) + } +} + +// TestStructQueries tests the struct-based query methods using the test database +func TestStructQueries(t *testing.T) { + // Setup test database + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Define test data + user := &models.User{ + Email: "structquery@example.com", + DisplayName: "Struct Query Test", + PasswordHash: "hashed_password", + Role: models.RoleEditor, + } + + t.Run("InsertStructQuery", func(t *testing.T) { + // Insert user with struct query + createdUser, err := database.CreateUser(user) + if err != nil { + t.Fatalf("Failed to create user with struct query: %v", err) + } + + // Verify user was created with proper values + if createdUser.ID == 0 { + t.Error("Expected non-zero user ID") + } + + if createdUser.Email != user.Email { + t.Errorf("Email = %v, want %v", createdUser.Email, user.Email) + } + + if createdUser.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", createdUser.DisplayName, user.DisplayName) + } + + if createdUser.Role != user.Role { + t.Errorf("Role = %v, want %v", createdUser.Role, user.Role) + } + + // We will use this user for the next test cases + user = createdUser + }) + + t.Run("SelectStructQuery", func(t *testing.T) { + // Get the created user + fetchedUser, err := database.GetUserByID(user.ID) + if err != nil { + t.Fatalf("Failed to get user with struct query: %v", err) + } + + // Verify fetched user matches the original + if fetchedUser.ID != user.ID { + t.Errorf("ID = %v, want %v", fetchedUser.ID, user.ID) + } + + if fetchedUser.Email != user.Email { + t.Errorf("Email = %v, want %v", fetchedUser.Email, user.Email) + } + + if fetchedUser.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", fetchedUser.DisplayName, user.DisplayName) + } + + if fetchedUser.Role != user.Role { + t.Errorf("Role = %v, want %v", fetchedUser.Role, user.Role) + } + }) + + t.Run("UpdateStructQuery", func(t *testing.T) { + // Update the user + user.DisplayName = "Updated Display Name" + user.Role = models.RoleAdmin + + err := database.UpdateUser(user) + if err != nil { + t.Fatalf("Failed to update user with struct query: %v", err) + } + + // Verify update worked + updatedUser, err := database.GetUserByID(user.ID) + if err != nil { + t.Fatalf("Failed to get updated user: %v", err) + } + + if updatedUser.DisplayName != "Updated Display Name" { + t.Errorf("DisplayName = %v, want %v", updatedUser.DisplayName, "Updated Display Name") + } + + if updatedUser.Role != models.RoleAdmin { + t.Errorf("Role = %v, want %v", updatedUser.Role, models.RoleAdmin) + } + }) + + t.Run("ScanStructs", func(t *testing.T) { + // Create another user to test multiple rows + secondUser := &models.User{ + Email: "structquery2@example.com", + DisplayName: "Struct Query Test 2", + PasswordHash: "hashed_password2", + Role: models.RoleViewer, + } + + createdUser2, err := database.CreateUser(secondUser) + if err != nil { + t.Fatalf("Failed to create second user: %v", err) + } + + // Get all users + users, err := database.GetAllUsers() + if err != nil { + t.Fatalf("Failed to get all users: %v", err) + } + + // Verify we have at least the two users we created + if len(users) < 2 { + t.Errorf("Expected at least 2 users, got %d", len(users)) + } + + // Check if both our users are in the result + foundUser1 := false + foundUser2 := false + + for _, u := range users { + if u.ID == user.ID { + foundUser1 = true + if u.DisplayName != user.DisplayName { + t.Errorf("DisplayName = %v, want %v", u.DisplayName, user.DisplayName) + } + } + if u.ID == createdUser2.ID { + foundUser2 = true + if u.DisplayName != secondUser.DisplayName { + t.Errorf("DisplayName = %v, want %v", u.DisplayName, secondUser.DisplayName) + } + } + } + + if !foundUser1 { + t.Errorf("First user (ID: %d) not found in results", user.ID) + } + if !foundUser2 { + t.Errorf("Second user (ID: %d) not found in results", createdUser2.ID) + } + }) + + t.Run("ScanStruct with null values", func(t *testing.T) { + // Test handling of NULL values by creating a workspace with null values + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Null Test Workspace", + // Leave all optional fields as zero values + } + workspace.SetDefaultSettings() // This will set default values + + err := database.CreateWorkspace(workspace) + if err != nil { + t.Fatalf("Failed to create test workspace: %v", err) + } + + // Clear the GitToken to test NULL handling + testDB := database.TestDB() + _, err = testDB.Exec("UPDATE workspaces SET git_token = NULL WHERE id = ?", workspace.ID) + if err != nil { + t.Fatalf("Failed to set git_token to NULL: %v", err) + } + + // Fetch the workspace with NULL field + fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID) + if err != nil { + t.Fatalf("Failed to get workspace with NULL field: %v", err) + } + + // Verify the NULL field is empty + if fetchedWorkspace.GitToken != "" { + t.Errorf("Expected empty GitToken, got '%s'", fetchedWorkspace.GitToken) + } + }) + + t.Run("ScanStructErrors", func(t *testing.T) { + // Test error handling in ScanStruct + testDB := database.TestDB() + + // Attempt to scan too many columns into a struct with fewer fields + row := testDB.QueryRow("SELECT 1, 2, 3") + var singleField struct { + One int `db:"one"` + } + + err := database.ScanStruct(row, &singleField) + if err == nil { + t.Error("Expected error when scanning too many columns, got nil") + } + + // Test scanning into a non-struct + var notAStruct int + row = testDB.QueryRow("SELECT 1") + err = database.ScanStruct(row, ¬AStruct) + if err == nil { + t.Error("Expected error when scanning into non-struct, got nil") + } + + // Test scanning into nil + var nilPtr *struct{} + row = testDB.QueryRow("SELECT 1") + err = database.ScanStruct(row, nilPtr) + if err == nil { + t.Error("Expected error when scanning into nil pointer, got nil") + } + }) +} + +// TestScanStructsErrors tests error handling for ScanStructs +func TestScanStructsErrors(t *testing.T) { + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + testDB := database.TestDB() + + t.Run("ScanStructsWithNilRows", func(t *testing.T) { + var users []*models.User + err := database.ScanStructs(nil, &users) + if err == nil { + t.Error("Expected error with nil rows, got nil") + } + }) + + t.Run("ScanStructsWithNilDest", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var nilSlice *[]*models.User + err = database.ScanStructs(rows, nilSlice) + if err == nil { + t.Error("Expected error with nil destination, got nil") + } + }) + + t.Run("ScanStructsWithNonSlice", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var nonSlice int + err = database.ScanStructs(rows, &nonSlice) + if err == nil { + t.Error("Expected error with non-slice destination, got nil") + } + }) + + t.Run("ScanStructsWithNonStructSlice", func(t *testing.T) { + rows, err := testDB.Query("SELECT 1") + if err != nil { + t.Fatalf("Failed to execute query: %v", err) + } + defer rows.Close() + + var intSlice []int + err = database.ScanStructs(rows, &intSlice) + if err == nil { + t.Error("Expected error with non-struct slice, got nil") + } + }) +} + +// TestEncryptedFields tests handling of encrypted fields +func TestEncryptedFields(t *testing.T) { + database, err := db.NewTestDB(&mockSecrets{}) + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + t.Fatalf("Failed to run migrations: %v", err) + } + + // Create user with workspace that has encrypted token + user, err := database.CreateUser(&models.User{ + Email: "encrypted@example.com", + DisplayName: "Encryption Test", + PasswordHash: "hash", + Role: models.RoleEditor, + }) + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + // Create workspace with encrypted field + workspace := &models.Workspace{ + UserID: user.ID, + Name: "Encryption Test", + Theme: "dark", + GitEnabled: true, + GitURL: "https://github.com/user/repo", + GitUser: "username", + GitToken: "secret-token", // This field is encrypted + GitCommitName: "Test User", + GitCommitEmail: "test@example.com", + } + + if err := database.CreateWorkspace(workspace); err != nil { + t.Fatalf("Failed to create test workspace: %v", err) + } + + // Verify our mock secrets service passed the token through unmodified + // In a real application, the token would be encrypted in the database + testDB := database.TestDB() + var rawToken string + err = testDB.QueryRow("SELECT git_token FROM workspaces WHERE id = ?", workspace.ID).Scan(&rawToken) + if err != nil { + t.Fatalf("Failed to query raw token: %v", err) + } + + // With the mock secrets service, encryption is a no-op so the token is stored as-is + if rawToken != "secret-token" { + t.Errorf("Expected raw token 'secret-token', got '%s'", rawToken) + } + + // Verify the fetched workspace has the correct token + fetchedWorkspace, err := database.GetWorkspaceByID(workspace.ID) + if err != nil { + t.Fatalf("Failed to get workspace: %v", err) + } + + if fetchedWorkspace.GitToken != "secret-token" { + t.Errorf("Expected GitToken 'secret-token', got '%s'", fetchedWorkspace.GitToken) + } +} + +// Helper function to compare slices of DBFields +func compareDBFields(t *testing.T, got, want []db.DBField) { + t.Helper() + + if len(got) != len(want) { + t.Errorf("Got %d fields, want %d", len(got), len(want)) + return + } + + for i := range got { + if got[i].Name != want[i].Name { + t.Errorf("Field %d name: got %s, want %s", i, got[i].Name, want[i].Name) + } + if got[i].OriginalName != want[i].OriginalName { + t.Errorf("Field %d original name: got %s, want %s", i, got[i].OriginalName, want[i].OriginalName) + } + if !reflect.DeepEqual(got[i].Value, want[i].Value) { + t.Errorf("Field %d value: got %v, want %v", i, got[i].Value, want[i].Value) + } + } +} From 629baa9952cf867bbc99436d30c74c283488c0a4 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 6 Mar 2025 19:23:24 +0100 Subject: [PATCH 29/39] Add test postgres db connection --- server/internal/db/migrations_test.go | 2 +- server/internal/db/sessions_test.go | 2 +- server/internal/db/struct_query_test.go | 6 +- server/internal/db/system_test.go | 2 +- server/internal/db/testdb.go | 72 ++++++++++++++++++-- server/internal/db/users_test.go | 2 +- server/internal/db/workspaces_test.go | 2 +- server/internal/handlers/integration_test.go | 2 +- 8 files changed, 77 insertions(+), 13 deletions(-) diff --git a/server/internal/db/migrations_test.go b/server/internal/db/migrations_test.go index 80c9661..4998db2 100644 --- a/server/internal/db/migrations_test.go +++ b/server/internal/db/migrations_test.go @@ -9,7 +9,7 @@ import ( ) func TestMigrate(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to initialize database: %v", err) } diff --git a/server/internal/db/sessions_test.go b/server/internal/db/sessions_test.go index 67b2526..208d1ae 100644 --- a/server/internal/db/sessions_test.go +++ b/server/internal/db/sessions_test.go @@ -13,7 +13,7 @@ import ( ) func TestSessionOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/struct_query_test.go b/server/internal/db/struct_query_test.go index 7c28799..71e2c64 100644 --- a/server/internal/db/struct_query_test.go +++ b/server/internal/db/struct_query_test.go @@ -140,7 +140,7 @@ func TestStructTagsToFields(t *testing.T) { // TestStructQueries tests the struct-based query methods using the test database func TestStructQueries(t *testing.T) { // Setup test database - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } @@ -356,7 +356,7 @@ func TestStructQueries(t *testing.T) { // TestScanStructsErrors tests error handling for ScanStructs func TestScanStructsErrors(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } @@ -421,7 +421,7 @@ func TestScanStructsErrors(t *testing.T) { // TestEncryptedFields tests handling of encrypted fields func TestEncryptedFields(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("Failed to create test database: %v", err) } diff --git a/server/internal/db/system_test.go b/server/internal/db/system_test.go index 8e65023..9c54dd9 100644 --- a/server/internal/db/system_test.go +++ b/server/internal/db/system_test.go @@ -15,7 +15,7 @@ import ( ) func TestSystemOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/testdb.go b/server/internal/db/testdb.go index 14e13e1..0cdd2b8 100644 --- a/server/internal/db/testdb.go +++ b/server/internal/db/testdb.go @@ -4,7 +4,10 @@ package db import ( "database/sql" + "fmt" "lemma/internal/secrets" + "log" + "time" ) type TestDatabase interface { @@ -12,19 +15,80 @@ type TestDatabase interface { TestDB() *sql.DB } -func NewTestDB(secretsService secrets.Service) (TestDatabase, error) { +func NewTestSQLiteDB(secretsService secrets.Service) (TestDatabase, error) { db, err := Init(DBTypeSQLite, ":memory:", secretsService) if err != nil { return nil, err } - return &testDatabase{db.(*database)}, nil + return &testSQLiteDatabase{db.(*database)}, nil } -type testDatabase struct { +type testSQLiteDatabase struct { *database } -func (td *testDatabase) TestDB() *sql.DB { +func (td *testSQLiteDatabase) TestDB() *sql.DB { return td.DB } + +// NewPostgresTestDB creates a test database using PostgreSQL +func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, error) { + if dbURL == "" { + return nil, fmt.Errorf("postgres URL cannot be empty") + } + + db, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to open postgres database: %w", err) + } + + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to ping postgres database: %w", err) + } + + // Create a unique schema name for this test run to avoid conflicts + schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) + _, err = db.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to create schema: %w", err) + } + + // Set search path to use our schema + _, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName)) + if err != nil { + db.Close() + return nil, fmt.Errorf("failed to set search path: %w", err) + } + + // Create database instance + database := &postgresTestDatabase{ + database: &database{DB: db, secretsService: secretsSvc, dbType: DBTypePostgres}, + schemaName: schemaName, + } + + return database, nil +} + +// postgresTestDatabase extends the regular postgres database to add test-specific cleanup +type postgresTestDatabase struct { + *database + schemaName string +} + +// Close closes the database connection and drops the test schema +func (db *postgresTestDatabase) Close() error { + _, err := db.TestDB().Exec(fmt.Sprintf("DROP SCHEMA %s CASCADE", db.schemaName)) + if err != nil { + log.Printf("Failed to drop schema %s: %v", db.schemaName, err) + } + + return db.TestDB().Close() +} + +// TestDB returns the underlying *sql.DB instance +func (db *postgresTestDatabase) TestDB() *sql.DB { + return db.DB +} diff --git a/server/internal/db/users_test.go b/server/internal/db/users_test.go index 5709ab4..8dfe4f0 100644 --- a/server/internal/db/users_test.go +++ b/server/internal/db/users_test.go @@ -10,7 +10,7 @@ import ( ) func TestUserOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/db/workspaces_test.go b/server/internal/db/workspaces_test.go index 3d6fd9f..fda1c98 100644 --- a/server/internal/db/workspaces_test.go +++ b/server/internal/db/workspaces_test.go @@ -10,7 +10,7 @@ import ( ) func TestWorkspaceOperations(t *testing.T) { - database, err := db.NewTestDB(&mockSecrets{}) + database, err := db.NewTestSQLiteDB(&mockSecrets{}) if err != nil { t.Fatalf("failed to create test database: %v", err) } diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 0bb8bd3..43e2f64 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -61,7 +61,7 @@ func setupTestHarness(t *testing.T) *testHarness { t.Fatalf("Failed to initialize secrets service: %v", err) } - database, err := db.NewTestDB(secretsSvc) + database, err := db.NewTestSQLiteDB(secretsSvc) if err != nil { t.Fatalf("Failed to initialize test database: %v", err) } From f55d2644c3492d9cb6a8243c48ea067f420a1347 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 6 Mar 2025 21:39:56 +0100 Subject: [PATCH 30/39] Run integration tests with both dbs --- .../admin_handlers_integration_test.go | 30 ++++++----- .../auth_handlers_integration_test.go | 6 ++- .../file_handlers_integration_test.go | 6 ++- .../handlers/git_handlers_integration_test.go | 6 ++- server/internal/handlers/integration_test.go | 50 +++++++++++++++++-- .../user_handlers_integration_test.go | 6 ++- .../workspace_handlers_integration_test.go | 6 ++- 7 files changed, 88 insertions(+), 22 deletions(-) diff --git a/server/internal/handlers/admin_handlers_integration_test.go b/server/internal/handlers/admin_handlers_integration_test.go index ddb96f9..f53d659 100644 --- a/server/internal/handlers/admin_handlers_integration_test.go +++ b/server/internal/handlers/admin_handlers_integration_test.go @@ -15,21 +15,12 @@ import ( "github.com/stretchr/testify/require" ) -// Helper function to check if a user exists in a slice of users -func containsUser(users []*models.User, searchUser *models.User) bool { - for _, u := range users { - if u.ID == searchUser.ID && - u.Email == searchUser.Email && - u.DisplayName == searchUser.DisplayName && - u.Role == searchUser.Role { - return true - } - } - return false +func TestAdminHandlers_Integration(t *testing.T) { + runWithDatabases(t, testAdminHandlers) } -func TestAdminHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) +func testAdminHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) t.Run("user management", func(t *testing.T) { @@ -241,3 +232,16 @@ func TestAdminHandlers_Integration(t *testing.T) { assert.Equal(t, http.StatusForbidden, rr.Code) }) } + +// Helper function to check if a user exists in a slice of users +func containsUser(users []*models.User, searchUser *models.User) bool { + for _, u := range users { + if u.ID == searchUser.ID && + u.Email == searchUser.Email && + u.DisplayName == searchUser.DisplayName && + u.Role == searchUser.Role { + return true + } + } + return false +} diff --git a/server/internal/handlers/auth_handlers_integration_test.go b/server/internal/handlers/auth_handlers_integration_test.go index 45a2049..1f00d0f 100644 --- a/server/internal/handlers/auth_handlers_integration_test.go +++ b/server/internal/handlers/auth_handlers_integration_test.go @@ -19,7 +19,11 @@ import ( ) func TestAuthHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) + runWithDatabases(t, testAuthHandlers) +} + +func testAuthHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) t.Run("login", func(t *testing.T) { diff --git a/server/internal/handlers/file_handlers_integration_test.go b/server/internal/handlers/file_handlers_integration_test.go index c9a35a4..962b8f9 100644 --- a/server/internal/handlers/file_handlers_integration_test.go +++ b/server/internal/handlers/file_handlers_integration_test.go @@ -18,7 +18,11 @@ import ( ) func TestFileHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) + runWithDatabases(t, testFileHandlers) +} + +func testFileHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) t.Run("file operations", func(t *testing.T) { diff --git a/server/internal/handlers/git_handlers_integration_test.go b/server/internal/handlers/git_handlers_integration_test.go index f458b86..23e2bf3 100644 --- a/server/internal/handlers/git_handlers_integration_test.go +++ b/server/internal/handlers/git_handlers_integration_test.go @@ -16,7 +16,11 @@ import ( ) func TestGitHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) + runWithDatabases(t, testGitHandlers) +} + +func testGitHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) t.Run("git operations", func(t *testing.T) { diff --git a/server/internal/handlers/integration_test.go b/server/internal/handlers/integration_test.go index 43e2f64..7e9894c 100644 --- a/server/internal/handlers/integration_test.go +++ b/server/internal/handlers/integration_test.go @@ -45,8 +45,13 @@ type testUser struct { session *models.Session } +type DatabaseConfig struct { + Type db.DBType + URL string +} + // setupTestHarness creates a new test environment -func setupTestHarness(t *testing.T) *testHarness { +func setupTestHarness(t *testing.T, dbConfig DatabaseConfig) *testHarness { t.Helper() // Create temporary directory for test files @@ -61,9 +66,20 @@ func setupTestHarness(t *testing.T) *testHarness { t.Fatalf("Failed to initialize secrets service: %v", err) } - database, err := db.NewTestSQLiteDB(secretsSvc) - if err != nil { - t.Fatalf("Failed to initialize test database: %v", err) + var database db.TestDatabase + switch dbConfig.Type { + case db.DBTypeSQLite: + database, err = db.NewTestSQLiteDB(secretsSvc) + if err != nil { + t.Fatalf("Failed to initialize test database: %v", err) + } + case db.DBTypePostgres: + database, err = db.NewPostgresTestDB(dbConfig.URL, secretsSvc) + if err != nil { + t.Fatalf("Failed to initialize test database: %v", err) + } + default: + t.Fatalf("Unsupported database type: %s", dbConfig.Type) } if err := database.Migrate(); err != nil { @@ -156,6 +172,32 @@ func (h *testHarness) teardown(t *testing.T) { } } +// runWithDatabases runs a test function with both SQLite and PostgreSQL databases +func runWithDatabases(t *testing.T, testFn func(*testing.T, DatabaseConfig)) { + // Get PostgreSQL connection URL from environment variable + postgresURL := os.Getenv("LEMMA_TEST_POSTGRES_URL") + + // Always run with SQLite in-memory + t.Run("SQLite", func(t *testing.T) { + testFn(t, DatabaseConfig{ + Type: db.DBTypeSQLite, + URL: "sqlite://:memory:", + }) + }) + + // Run with PostgreSQL if connection URL is provided + if postgresURL != "" { + t.Run("PostgreSQL", func(t *testing.T) { + testFn(t, DatabaseConfig{ + Type: db.DBTypePostgres, + URL: postgresURL, + }) + }) + } else { + t.Log("Skipping PostgreSQL tests, LEMMA_TEST_POSTGRES_URL environment variable not set") + } +} + // createTestUser creates a test user and returns the user and access token func (h *testHarness) createTestUser(t *testing.T, email, password string, role models.UserRole) *testUser { t.Helper() diff --git a/server/internal/handlers/user_handlers_integration_test.go b/server/internal/handlers/user_handlers_integration_test.go index ff1363f..61b54f6 100644 --- a/server/internal/handlers/user_handlers_integration_test.go +++ b/server/internal/handlers/user_handlers_integration_test.go @@ -15,7 +15,11 @@ import ( ) func TestUserHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) + runWithDatabases(t, testUserHandlers) +} + +func testUserHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) currentEmail := h.RegularTestUser.userModel.Email diff --git a/server/internal/handlers/workspace_handlers_integration_test.go b/server/internal/handlers/workspace_handlers_integration_test.go index a0fc06f..cb73efa 100644 --- a/server/internal/handlers/workspace_handlers_integration_test.go +++ b/server/internal/handlers/workspace_handlers_integration_test.go @@ -15,7 +15,11 @@ import ( ) func TestWorkspaceHandlers_Integration(t *testing.T) { - h := setupTestHarness(t) + runWithDatabases(t, testWorkspaceHandlers) +} + +func testWorkspaceHandlers(t *testing.T, dbConfig DatabaseConfig) { + h := setupTestHarness(t, dbConfig) defer h.teardown(t) t.Run("list workspaces", func(t *testing.T) { From d8de67ae6ce0c158366e7a2206e9d9865c28094c Mon Sep 17 00:00:00 2001 From: LordMathis Date: Thu, 6 Mar 2025 21:40:16 +0100 Subject: [PATCH 31/39] Add test postgres docker compose --- server/docker-compose.test.yaml | 21 +++++++++++++++++++++ server/run_integration_tests.sh | 20 ++++++++++++++++++++ 2 files changed, 41 insertions(+) create mode 100644 server/docker-compose.test.yaml create mode 100755 server/run_integration_tests.sh diff --git a/server/docker-compose.test.yaml b/server/docker-compose.test.yaml new file mode 100644 index 0000000..eba3b6d --- /dev/null +++ b/server/docker-compose.test.yaml @@ -0,0 +1,21 @@ +version: "3.8" + +services: + postgres: + image: postgres:17 + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: lemma_test + ports: + - "5432:5432" + volumes: + - postgres-data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 5s + timeout: 5s + retries: 5 + +volumes: + postgres-data: diff --git a/server/run_integration_tests.sh b/server/run_integration_tests.sh new file mode 100755 index 0000000..f92dd07 --- /dev/null +++ b/server/run_integration_tests.sh @@ -0,0 +1,20 @@ +#!/bin/bash + +set -e + +COMPOSE_FILE=docker-compose.test.yaml + +if ! docker compose -f $COMPOSE_FILE ps postgres | grep -q "running"; then + docker compose -f $COMPOSE_FILE up -d + until docker compose -f $COMPOSE_FILE exec postgres pg_isready -U postgres; do + sleep 1 + done + echo "PostgreSQL is ready!" +fi + +export LEMMA_TEST_POSTGRES_URL="postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable" + +echo "Running integration tests..." +go test -v -tags=integration ./... + +docker compose -f $COMPOSE_FILE down From 32628abf0967d24fa0e56a13ec25eaa3cb925818 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:47:15 +0100 Subject: [PATCH 32/39] Add pgadmin to test compose --- server/docker-compose.test.yaml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/server/docker-compose.test.yaml b/server/docker-compose.test.yaml index eba3b6d..6304c21 100644 --- a/server/docker-compose.test.yaml +++ b/server/docker-compose.test.yaml @@ -17,5 +17,19 @@ services: timeout: 5s retries: 5 + pgadmin: + image: dpage/pgadmin4 + environment: + PGADMIN_DEFAULT_EMAIL: admin@admin.com + PGADMIN_DEFAULT_PASSWORD: admin + PGADMIN_CONFIG_SERVER_MODE: "False" + ports: + - "8080:80" + volumes: + - pgadmin-data:/var/lib/pgadmin + depends_on: + - postgres + volumes: postgres-data: + pgadmin-data: From 72b0ac08ce09e8625db268d35a16d4b78e54c07e Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:47:44 +0100 Subject: [PATCH 33/39] Add test env var to settings.json --- .vscode/settings.json | 3 +++ server/run_integration_tests.sh | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 52c38cb..f0043cc 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -15,6 +15,9 @@ "go.lintOnSave": "package", "go.formatTool": "goimports", "go.testFlags": ["-tags=test,integration"], + "go.testEnvVars": { + "LEMMA_TEST_POSTGRES_URL": "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable" + }, "[go]": { "editor.formatOnSave": true, "editor.codeActionsOnSave": { diff --git a/server/run_integration_tests.sh b/server/run_integration_tests.sh index f92dd07..63edde2 100755 --- a/server/run_integration_tests.sh +++ b/server/run_integration_tests.sh @@ -15,6 +15,6 @@ fi export LEMMA_TEST_POSTGRES_URL="postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable" echo "Running integration tests..." -go test -v -tags=integration ./... +go test -v -tags=test,integration ./... docker compose -f $COMPOSE_FILE down From f7825e5a672f0a5ba581d515c704435f9bf22351 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:48:17 +0100 Subject: [PATCH 34/39] Use search path in connection string --- server/internal/db/testdb.go | 41 ++++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/server/internal/db/testdb.go b/server/internal/db/testdb.go index 0cdd2b8..201d3fe 100644 --- a/server/internal/db/testdb.go +++ b/server/internal/db/testdb.go @@ -7,6 +7,7 @@ import ( "fmt" "lemma/internal/secrets" "log" + "strings" "time" ) @@ -38,7 +39,37 @@ func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, return nil, fmt.Errorf("postgres URL cannot be empty") } - db, err := sql.Open("postgres", dbURL) + initialDB, err := sql.Open("postgres", dbURL) + if err != nil { + return nil, fmt.Errorf("failed to open postgres database: %w", err) + } + + if err := initialDB.Ping(); err != nil { + initialDB.Close() + return nil, fmt.Errorf("failed to ping postgres database: %w", err) + } + + // Create a unique schema name for this test run to avoid conflicts + schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) + _, err = initialDB.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) + if err != nil { + initialDB.Close() + return nil, fmt.Errorf("failed to create schema: %w", err) + } + + // Close the initial connection and create a new one with the schema set + initialDB.Close() + + var newDBURL string + if strings.Contains(dbURL, "?") { + // URL already has parameters + newDBURL = fmt.Sprintf("%s&search_path=%s", dbURL, schemaName) + } else { + // URL has no parameters yet + newDBURL = fmt.Sprintf("%s?search_path=%s", dbURL, schemaName) + } + + db, err := sql.Open("postgres", newDBURL) if err != nil { return nil, fmt.Errorf("failed to open postgres database: %w", err) } @@ -48,14 +79,6 @@ func NewPostgresTestDB(dbURL string, secretsSvc secrets.Service) (TestDatabase, return nil, fmt.Errorf("failed to ping postgres database: %w", err) } - // Create a unique schema name for this test run to avoid conflicts - schemaName := fmt.Sprintf("lemma_test_%d", time.Now().UnixNano()) - _, err = db.Exec(fmt.Sprintf("CREATE SCHEMA %s", schemaName)) - if err != nil { - db.Close() - return nil, fmt.Errorf("failed to create schema: %w", err) - } - // Set search path to use our schema _, err = db.Exec(fmt.Sprintf("SET search_path TO %s", schemaName)) if err != nil { From 3eb4424e86bbc2eaeccd0e182c1ec0ce27321690 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:48:36 +0100 Subject: [PATCH 35/39] Implement time since query --- server/internal/db/query.go | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/server/internal/db/query.go b/server/internal/db/query.go index 192e848..fcec089 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -266,7 +266,7 @@ func (q *Query) Placeholder(arg any) *Query { func (q *Query) Placeholders(n int) *Query { placeholders := make([]string, n) - for i := 0; i < n; i++ { + for i := range n { q.pos++ if q.dbType == DBTypePostgres { placeholders[i] = fmt.Sprintf("$%d", q.pos) @@ -279,6 +279,14 @@ func (q *Query) Placeholders(n int) *Query { return q } +func (q *Query) TimeSince(days int) string { + if q.dbType == DBTypePostgres { + return fmt.Sprintf("NOW() - INTERVAL '%d days'", days) + } + + return fmt.Sprintf("datetime('now', '-%d days')", days) +} + // AddArgs adds arguments to the query func (q *Query) AddArgs(args ...any) *Query { q.args = append(q.args, args...) From 94ea2d0d78b33dfd89fdae82f86906d0bdadd3dd Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:53:03 +0100 Subject: [PATCH 36/39] Fix time since --- server/internal/db/query.go | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/internal/db/query.go b/server/internal/db/query.go index fcec089..ca00088 100644 --- a/server/internal/db/query.go +++ b/server/internal/db/query.go @@ -279,12 +279,14 @@ func (q *Query) Placeholders(n int) *Query { return q } -func (q *Query) TimeSince(days int) string { +func (q *Query) TimeSince(days int) *Query { if q.dbType == DBTypePostgres { - return fmt.Sprintf("NOW() - INTERVAL '%d days'", days) + q.builder.WriteString(fmt.Sprintf("NOW() - INTERVAL '%d days'", days)) + } else { + q.builder.WriteString(fmt.Sprintf("datetime('now', '-%d days')", days)) } - return fmt.Sprintf("datetime('now', '-%d days')", days) + return q } // AddArgs adds arguments to the query From d97f5a01784c4f61abe9ad735d175c0b3dfdb691 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 19:53:19 +0100 Subject: [PATCH 37/39] Use time since --- server/internal/db/system.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/internal/db/system.go b/server/internal/db/system.go index a193ee7..ae51fb8 100644 --- a/server/internal/db/system.go +++ b/server/internal/db/system.go @@ -121,7 +121,8 @@ func (db *database) GetSystemStats() (*UserStats, error) { query = db.NewQuery(). Select("COUNT(DISTINCT user_id)"). From("sessions"). - Where("created_at > datetime('now', '-30 days')") + Where("created_at >"). + TimeSince(30) err = db.QueryRow(query.String()). Scan(&stats.ActiveUsers) if err != nil { From d0f6f275264422a3d562614cfd43b81027dc4ced Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 22:31:00 +0100 Subject: [PATCH 38/39] Add PostgreSQL service to GitHub Actions workflow for integration tests --- .github/workflows/go-test.yml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/.github/workflows/go-test.yml b/.github/workflows/go-test.yml index 33bbbb6..aee2d5b 100644 --- a/.github/workflows/go-test.yml +++ b/.github/workflows/go-test.yml @@ -13,6 +13,21 @@ jobs: name: Run Tests runs-on: ubuntu-latest + services: + postgres: + image: postgres:17 + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: lemma_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + defaults: run: working-directory: ./server @@ -29,6 +44,8 @@ jobs: - name: Run Tests run: go test -tags=test,integration ./... -v + env: + LEMMA_TEST_POSTGRES_URL: "postgres://postgres:postgres@localhost:5432/lemma_test?sslmode=disable" - name: Run Tests with Race Detector run: go test -tags=test,integration -race ./... -v From 0be1bbf9a7fe690b1b0b70df5788290a33c7c1d9 Mon Sep 17 00:00:00 2001 From: LordMathis Date: Fri, 7 Mar 2025 22:37:20 +0100 Subject: [PATCH 39/39] Update README with postgres url info --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 2bb3e08..4d8e87b 100644 --- a/README.md +++ b/README.md @@ -33,8 +33,8 @@ Lemma can be configured using environment variables. Here are the available conf ### Optional Environment Variables - `LEMMA_ENV`: Set to "development" to enable development mode -- `LEMMA_DB_PATH`: Path to the SQLite database file (default: "./lemma.db") -- `LEMMA_WORKDIR`: Working directory for application data (default: "./data") +- `LEMMA_DB_URL`: URL (Connection string) to the database. Supported databases are sqlite and postgres a (default: "./lemma.db") +- `LEMMA_WORKDIR`: Working directory for application data (default: "sqlite://lemma.db") - `LEMMA_STATIC_PATH`: Path to static files (default: "../app/dist") - `LEMMA_PORT`: Port to run the server on (default: "8080") - `LEMMA_DOMAIN`: Domain name where the application is hosted for cookie authentication