diff --git a/go.mod b/go.mod index d1bd409390a..eed0b50765c 100644 --- a/go.mod +++ b/go.mod @@ -14,7 +14,7 @@ require ( github.com/corona10/goimagehash v1.1.0 github.com/disintegration/imaging v1.6.2 github.com/dop251/goja v0.0.0-20231027120936-b396bb4c349d - github.com/doug-martin/goqu/v9 v9.18.0 + github.com/doug-martin/goqu/v9 v9.19.1-0.20231214054827-21b6e6d1cb1b github.com/go-chi/chi/v5 v5.0.12 github.com/go-chi/cors v1.2.1 github.com/go-chi/httplog v0.3.1 @@ -26,6 +26,7 @@ require ( github.com/gorilla/sessions v1.2.1 github.com/gorilla/websocket v1.5.0 github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/jackc/pgx/v5 v5.3.1 github.com/jinzhu/copier v0.4.0 github.com/jmoiron/sqlx v1.4.0 github.com/json-iterator/go v1.1.12 @@ -82,7 +83,10 @@ require ( github.com/hashicorp/go-multierror v1.1.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/josharian/intern v1.0.0 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/mattn/go-colorable v0.1.13 // indirect diff --git a/go.sum b/go.sum index 15516f7ff98..13cffd3f297 100644 --- a/go.sum +++ b/go.sum @@ -53,11 +53,15 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/99designs/gqlgen v0.17.55 h1:3vzrNWYyzSZjGDFo68e5j9sSauLxfKvLp+6ioRokVtM= github.com/99designs/gqlgen v0.17.55/go.mod h1:3Bq768f8hgVPGZxL8aY9MaYmbxa6llPM/qu1IGH1EJo= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= +github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= github.com/DataDog/datadog-go v3.2.0+incompatible/go.mod h1:LButxg5PwREeZtORoXG3tL4fMGNddJ+vMq1mwgfaqoQ= +github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= +github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/PuerkitoBio/goquery v1.9.3 h1:mpJr/ikUA9/GNJB/DBZcGeFDXUtosHRyRrwh7KGdTG0= github.com/PuerkitoBio/goquery v1.9.3/go.mod h1:1ndLHPdTz+DyQPICCWYlYQMPl0oXZj0G6D4LCYA6u4U= @@ -166,22 +170,31 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54 h1:SG7nF6SRlWhcT7cNTs5R6Hk4V2lcmLz2NsG2VnInyNo= github.com/dgryski/trifles v0.0.0-20230903005119-f50d829f2e54/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA= +github.com/dhui/dktest v0.3.16 h1:i6gq2YQEtcrjKbeJpBkWjE8MmLZPYllcjOFbTZuPDnw= +github.com/dhui/dktest v0.3.16/go.mod h1:gYaA3LRmM8Z4vJl2MA0THIigJoZrwOansEOsp+kqxp0= github.com/disintegration/imaging v1.6.2 h1:w1LecBlG2Lnp8B3jk5zSuNqd7b4DXhcjwek1ei82L+c= github.com/disintegration/imaging v1.6.2/go.mod h1:44/5580QXChDfwIclfc/PCwrr44amcmDAg8hxG0Ewe4= github.com/dlclark/regexp2 v1.4.1-0.20201116162257-a2a8dda75c91/go.mod h1:2pZnwuY/m+8K6iRw6wQdMtk+rH5tNGR1i55kozfMjCc= github.com/dlclark/regexp2 v1.7.0 h1:7lJfhqlPssTb1WQx4yvTHN0uElPEv52sbaECrAQxjAo= github.com/dlclark/regexp2 v1.7.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/docker/distribution v2.8.2+incompatible h1:T3de5rq0dB1j30rp0sA2rER+m322EBzniBPB6ZIzuh8= +github.com/docker/distribution v2.8.2+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= +github.com/docker/docker v20.10.24+incompatible h1:Ugvxm7a8+Gz6vqQYQQ2W7GYq5EUPaAiuPgIfVyI3dYE= +github.com/docker/docker v20.10.24+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= +github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= +github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dop251/goja v0.0.0-20211022113120-dc8c55024d06/go.mod h1:R9ET47fwRVRPZnOGvHxxhuZcbrMCuiqOz3Rlrh4KSnk= github.com/dop251/goja v0.0.0-20231027120936-b396bb4c349d h1:wi6jN5LVt/ljaBG4ue79Ekzb12QfJ52L9Q98tl8SWhw= github.com/dop251/goja v0.0.0-20231027120936-b396bb4c349d/go.mod h1:QMWlm50DNe14hD7t24KEqZuUdC9sOTy8W6XbCU1mlw4= github.com/dop251/goja_nodejs v0.0.0-20210225215109-d91c329300e7/go.mod h1:hn7BA7c8pLvoGndExHudxTDKZ84Pyvv+90pbBjbTz0Y= github.com/dop251/goja_nodejs v0.0.0-20211022123610-8dd9abb0616d/go.mod h1:DngW8aVqWbuLRMHItjPUyqdj+HWPvnQe8V8y1nDpIbM= -github.com/doug-martin/goqu/v9 v9.18.0 h1:/6bcuEtAe6nsSMVK/M+fOiXUNfyFF3yYtE07DBPFMYY= -github.com/doug-martin/goqu/v9 v9.18.0/go.mod h1:nf0Wc2/hV3gYK9LiyqIrzBEVGlI8qW3GuDCEobC4wBQ= +github.com/doug-martin/goqu/v9 v9.19.1-0.20231214054827-21b6e6d1cb1b h1:WaCes6lOJCbIDgABfA8gB1ADMQo6+ftGEkj+oIB+vm4= +github.com/doug-martin/goqu/v9 v9.19.1-0.20231214054827-21b6e6d1cb1b/go.mod h1:1MqhYk2p5QFEUT9ZzH+M02Jv8BbOYlvzupULdHl7Mjs= github.com/dustin/go-humanize v0.0.0-20180421182945-02af3965c54e/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= @@ -227,7 +240,6 @@ github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyLFfuisuzeidWPMPWmECqU= github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -245,12 +257,12 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gofrs/uuid/v5 v5.1.0 h1:S5rqVKIigghZTCBKPCw0Y+bXkn26K3TB5mvQq2Ix8dk= github.com/gofrs/uuid/v5 v5.1.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo= github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-migrate/migrate/v4 v4.16.2 h1:8coYbMKUyInrFk1lfGfRovTLAW7PhWp8qQDT2iKfuoA= github.com/golang-migrate/migrate/v4 v4.16.2/go.mod h1:pfcJX4nPHaVdc5nmdCikFBWtm+UBpiZjRNNsyBbp0/o= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -404,6 +416,12 @@ github.com/ianlancetaylor/demangle v0.0.0-20220319035150-800ac71e25c2/go.mod h1: github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.3.1 h1:Fcr8QJ1ZeLi5zsPZqQeUZhNhxfkkKBOgJuYkJHoBOtU= +github.com/jackc/pgx/v5 v5.3.1/go.mod h1:t3JDKnCBlYIc0ewLF0Q7B8MXmoIaBOZj/ic7iHozM/8= github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= @@ -450,7 +468,6 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80 h1:6Yzfa6GP0rIo/kULo2bwGEkFvCePZ3qHDDTC3/J9Swo= github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= -github.com/lib/pq v1.10.1/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69Aj6K7nkY= @@ -478,7 +495,6 @@ github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= @@ -503,6 +519,8 @@ github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= +github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -510,6 +528,8 @@ github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lN github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/mschoch/smat v0.0.0-20160514031455-90eadee771ae/go.mod h1:qAyveg+e4CE+eKJXWVjKXM4ck2QobLqTDytGJbLLhJg= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -521,6 +541,10 @@ github.com/npillmayer/nestext v0.1.3/go.mod h1:h2lrijH8jpicr25dFY+oAJLyzlya6jhnu github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d h1:VhgPp6v9qf9Agr/56bj7Y/xa04UccTW04VP0Qed4vnQ= github.com/nu7hatch/gouuid v0.0.0-20131221200532-179d4d0c4d8d/go.mod h1:YUTz3bUH2ZwIWBy3CJBeOBEugqcmXREj14T+iG/4k4U= github.com/oklog/run v1.0.0/go.mod h1:dlhp/R75TPv97u0XWUtDeV/lRKWPKSdTuV0TZvrmrQA= +github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= +github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= +github.com/opencontainers/image-spec v1.0.2 h1:9yCKha/T5XdGtO0q9Q9a6T5NUCsTn/DrBg0D7ufOcFM= +github.com/opencontainers/image-spec v1.0.2/go.mod h1:BtxoFyWECRxE4U/7sNtV5W15zMzWCbyJoFRP3s7yZA0= github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde h1:x0TT0RDC7UhAVbbWWBzr41ElhJx5tXPWkIHA2HWPRuw= github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= github.com/pascaldekloe/goe v0.0.0-20180627143212-57f6aae5913c/go.mod h1:lzWF7FIEvWOWxwDKqyGYQf6ZUaNfKdP144TG7ZOy1lc= @@ -699,7 +723,6 @@ go.uber.org/zap v1.17.0/go.mod h1:MXVU+bhUf/A7Xi2HNOnopQOrmycQ5Ih87HtOu4q5SSo= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20181029021203-45a5f77698d3/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/internal/api/resolver.go b/internal/api/resolver.go index ab6eead7e5e..35d4ca065d8 100644 --- a/internal/api/resolver.go +++ b/internal/api/resolver.go @@ -319,12 +319,11 @@ func (r *queryResolver) Latestversion(ctx context.Context) (*LatestVersion, erro func (r *mutationResolver) ExecSQL(ctx context.Context, sql string, args []interface{}) (*SQLExecResult, error) { var rowsAffected *int64 - var lastInsertID *int64 db := manager.GetInstance().Database if err := r.withTxn(ctx, func(ctx context.Context) error { var err error - rowsAffected, lastInsertID, err = db.ExecSQL(ctx, sql, args) + rowsAffected, err = db.ExecSQL(ctx, sql, args) return err }); err != nil { return nil, err @@ -332,7 +331,6 @@ func (r *mutationResolver) ExecSQL(ctx context.Context, sql string, args []inter return &SQLExecResult{ RowsAffected: rowsAffected, - LastInsertID: lastInsertID, }, nil } diff --git a/internal/api/resolver_mutation_configure.go b/internal/api/resolver_mutation_configure.go index 34b627b3c66..f18243aed7f 100644 --- a/internal/api/resolver_mutation_configure.go +++ b/internal/api/resolver_mutation_configure.go @@ -127,8 +127,8 @@ func (r *mutationResolver) ConfigureGeneral(ctx context.Context, input ConfigGen return nil } - existingDBPath := c.GetDatabasePath() - if input.DatabasePath != nil && existingDBPath != *input.DatabasePath { + existingDBUrl := c.GetDatabaseUrl() + if input.DatabasePath != nil && existingDBUrl != *input.DatabasePath { if err := checkConfigOverride(config.Database); err != nil { return makeConfigGeneralResult(), err } diff --git a/internal/api/resolver_mutation_migrate.go b/internal/api/resolver_mutation_migrate.go index 083d307e9fd..d684da8e335 100644 --- a/internal/api/resolver_mutation_migrate.go +++ b/internal/api/resolver_mutation_migrate.go @@ -30,7 +30,7 @@ func (r *mutationResolver) MigrateBlobs(ctx context.Context, input MigrateBlobsI mgr := manager.GetInstance() t := &task.MigrateBlobsJob{ TxnManager: mgr.Database, - BlobStore: mgr.Database.Blobs, + BlobStore: mgr.Database.GetRepo().Blobs, Vacuumer: mgr.Database, DeleteOld: utils.IsTrue(input.DeleteOld), } diff --git a/internal/api/resolver_query_configuration.go b/internal/api/resolver_query_configuration.go index 3328e4a356b..e28128fca9e 100644 --- a/internal/api/resolver_query_configuration.go +++ b/internal/api/resolver_query_configuration.go @@ -81,7 +81,7 @@ func makeConfigGeneralResult() *ConfigGeneralResult { return &ConfigGeneralResult{ Stashes: config.GetStashPaths(), - DatabasePath: config.GetDatabasePath(), + DatabasePath: config.GetDatabaseUrl(), BackupDirectoryPath: config.GetBackupDirectoryPath(), GeneratedPath: config.GetGeneratedPath(), MetadataPath: config.GetMetadataPath(), diff --git a/internal/autotag/integration_test.go b/internal/autotag/integration_test.go index e74cb30aa66..efb010463e3 100644 --- a/internal/autotag/integration_test.go +++ b/internal/autotag/integration_test.go @@ -33,20 +33,28 @@ var existingStudioID int const expectedMatchTitle = "expected match" -var db *sqlite.Database +var db sqlite.DBInterface var r models.Repository func testTeardown(databaseFile string) { - err := db.Close() + err := db.Remove() if err != nil { panic(err) } +} - err = os.Remove(databaseFile) - if err != nil { - panic(err) +func getNewDB(databaseFile string) sqlite.DBInterface { + dbUrl, valid := os.LookupEnv("PGSQL_TEST") + if valid { + fmt.Printf("Postgres backend for tests detected\n") + db = sqlite.NewPostgresDatabase(sqlite.NewDatabase(), dbUrl, true) + } else { + fmt.Printf("SQLite backend for tests detected\n") + db = sqlite.NewSQLiteDatabase(sqlite.NewDatabase(), databaseFile, true) } + + return db } func runTests(m *testing.M) int { @@ -58,10 +66,12 @@ func runTests(m *testing.M) int { f.Close() databaseFile := f.Name() - db = sqlite.NewDatabase() - if err := db.Open(databaseFile); err != nil { + db = getNewDB(databaseFile) + + if err := db.Open(); err != nil { panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) } + db.TestMode() r = db.Repository() diff --git a/internal/manager/config/config.go b/internal/manager/config/config.go index 6a568c1da0b..49cef271f47 100644 --- a/internal/manager/config/config.go +++ b/internal/manager/config/config.go @@ -691,7 +691,7 @@ func (i *Config) GetMetadataPath() string { return i.getString(Metadata) } -func (i *Config) GetDatabasePath() string { +func (i *Config) GetDatabaseUrl() string { return i.getString(Database) } @@ -703,7 +703,7 @@ func (i *Config) GetBackupDirectoryPathOrDefault() string { ret := i.GetBackupDirectoryPath() if ret == "" { // #4915 - default to the same directory as the database - return filepath.Dir(i.GetDatabasePath()) + return filepath.Dir(i.GetDatabaseUrl()) } return ret diff --git a/internal/manager/config/config_concurrency_test.go b/internal/manager/config/config_concurrency_test.go index fd9b067c7e7..3c5313113dc 100644 --- a/internal/manager/config/config_concurrency_test.go +++ b/internal/manager/config/config_concurrency_test.go @@ -35,7 +35,7 @@ func TestConcurrentConfigAccess(t *testing.T) { i.SetInterface(Cache, i.GetCachePath()) i.SetInterface(Generated, i.GetGeneratedPath()) i.SetInterface(Metadata, i.GetMetadataPath()) - i.SetInterface(Database, i.GetDatabasePath()) + i.SetInterface(Database, i.GetDatabaseUrl()) // these must be set as strings since the original values are also strings // setting them as []byte will cause the returned string to be corrupted diff --git a/internal/manager/init.go b/internal/manager/init.go index dd1640ed368..1834b2df4b6 100644 --- a/internal/manager/init.go +++ b/internal/manager/init.go @@ -35,7 +35,8 @@ import ( func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { ctx := context.TODO() - db := sqlite.NewDatabase() + var db *sqlite.Database = sqlite.NewDatabase() + repo := db.Repository() // start with empty paths @@ -46,30 +47,32 @@ func Initialize(cfg *config.Config, l *log.Logger) (*Manager, error) { pluginCache := plugin.NewCache(cfg) + dbRepo := db.GetRepo() + sceneService := &scene.Service{ - File: db.File, - Repository: db.Scene, - MarkerRepository: db.SceneMarker, + File: dbRepo.File, + Repository: dbRepo.Scene, + MarkerRepository: dbRepo.SceneMarker, PluginCache: pluginCache, Paths: mgrPaths, Config: cfg, } imageService := &image.Service{ - File: db.File, - Repository: db.Image, + File: dbRepo.File, + Repository: dbRepo.Image, } galleryService := &gallery.Service{ - Repository: db.Gallery, - ImageFinder: db.Image, + Repository: dbRepo.Gallery, + ImageFinder: dbRepo.Image, ImageService: imageService, - File: db.File, - Folder: db.Folder, + File: dbRepo.File, + Folder: dbRepo.Folder, } groupService := &group.Service{ - Repository: db.Group, + Repository: dbRepo.Group, } sceneServer := &SceneServer{ @@ -183,11 +186,31 @@ func initJobManager(cfg *config.Config) *job.Manager { return ret } +// Initializes the specific DB type +func (s *Manager) RefreshDB() { + cfg := s.Config + + var odb *sqlite.Database = s.Database.Pointer() + + dbUrl := cfg.GetDatabaseUrl() + upperUrl := strings.ToUpper(dbUrl) + switch { + case strings.HasPrefix(upperUrl, string(sqlite.PostgresBackend)+":"): + s.Database = sqlite.NewPostgresDatabase(odb, dbUrl, true) + case strings.HasPrefix(upperUrl, string(sqlite.SqliteBackend)+":"): + s.Database = sqlite.NewSQLiteDatabase(odb, dbUrl[len(sqlite.SqliteBackend)+1:], true) + default: + // Assume it's the path to a SQLite database - for backwards compat + s.Database = sqlite.NewSQLiteDatabase(odb, dbUrl, true) + } +} + // postInit initialises the paths, caches and database after the initial // configuration has been set. Should only be called if the configuration // is valid. func (s *Manager) postInit(ctx context.Context) error { s.RefreshConfig() + s.RefreshDB() s.SessionStore = session.NewStore(s.Config) s.PluginCache.RegisterSessionStore(s.SessionStore) @@ -227,7 +250,7 @@ func (s *Manager) postInit(ctx context.Context) error { }) } - if err := s.Database.Open(s.Config.GetDatabasePath()); err != nil { + if err := s.Database.Open(); err != nil { var migrationNeededErr *sqlite.MigrationNeededError if errors.As(err, &migrationNeededErr) { logger.Warn(err) diff --git a/internal/manager/manager.go b/internal/manager/manager.go index 4827a3e3d92..9a55f3d0185 100644 --- a/internal/manager/manager.go +++ b/internal/manager/manager.go @@ -6,6 +6,7 @@ import ( "context" "errors" "fmt" + "math" "net/http" "os" "path/filepath" @@ -60,7 +61,7 @@ type Manager struct { DLNAService *dlna.Service - Database *sqlite.Database + Database sqlite.DBInterface Repository models.Repository SceneService SceneService @@ -383,7 +384,12 @@ func (s *Manager) GetSystemStatus() *SystemStatus { database := s.Database dbSchema := int(database.Version()) dbPath := database.DatabasePath() - appSchema := int(database.AppSchemaVersion()) + appSchema := math.MaxInt32 + + // Database is not initialized and cannot be used yet + if !s.Config.IsNewSystem() { + appSchema = int(database.AppSchemaVersion()) + } status := SystemStatusEnumOk if s.Config.IsNewSystem() { diff --git a/internal/manager/task/migrate.go b/internal/manager/task/migrate.go index 37062329e48..609512b2f4b 100644 --- a/internal/manager/task/migrate.go +++ b/internal/manager/task/migrate.go @@ -20,7 +20,7 @@ type migrateJobConfig interface { type MigrateJob struct { BackupPath string Config migrateJobConfig - Database *sqlite.Database + Database sqlite.DBInterface } type databaseSchemaInfo struct { diff --git a/pkg/models/find_filter.go b/pkg/models/find_filter.go index 9934a9ea9c4..4f9b0248c78 100644 --- a/pkg/models/find_filter.go +++ b/pkg/models/find_filter.go @@ -127,3 +127,16 @@ func BatchFindFilter(batchSize int) *FindFilterType { Page: &page, } } + +// EnsureFindFilterSorted sets an ordering if one is missing +// Database ordering is unstable otherwise +func EnsureFindFilterSorted(findFilter *FindFilterType) { + if findFilter == nil { + findFilter = &FindFilterType{} + } + + if findFilter.Sort == nil || *findFilter.Sort == "" { + idStr := "id" + findFilter.Sort = &idStr + } +} diff --git a/pkg/models/relationships.go b/pkg/models/relationships.go index 5495f858b17..a899490ec52 100644 --- a/pkg/models/relationships.go +++ b/pkg/models/relationships.go @@ -2,6 +2,7 @@ package models import ( "context" + "slices" "github.com/stashapp/stash/pkg/sliceutil" ) @@ -86,6 +87,10 @@ func (r RelatedIDs) Loaded() bool { return r.list != nil } +func (r RelatedIDs) Sort() { + slices.Sort(r.list) +} + func (r RelatedIDs) mustLoaded() { if !r.Loaded() { panic("list has not been loaded") diff --git a/pkg/scene/update_test.go b/pkg/scene/update_test.go index f72c9640394..c91f57d9f7d 100644 --- a/pkg/scene/update_test.go +++ b/pkg/scene/update_test.go @@ -93,6 +93,11 @@ func TestUpdater_IsEmpty(t *testing.T) { } } +func getUUID(_ string) string { + // TODO: Encode input string + return "00000000-0000-0000-0000-000000000000" +} + func TestUpdater_Update(t *testing.T) { const ( sceneID = iota + 1 @@ -107,7 +112,7 @@ func TestUpdater_Update(t *testing.T) { performerIDs := []int{performerID} tagIDs := []int{tagID} - stashID := "stashID" + stashID := getUUID("stashID") endpoint := "endpoint" title := "title" @@ -235,7 +240,7 @@ func TestUpdateSet_UpdateInput(t *testing.T) { performerIDStrs := intslice.IntSliceToStringSlice(performerIDs) tagIDs := []int{tagID} tagIDStrs := intslice.IntSliceToStringSlice(tagIDs) - stashID := "stashID" + stashID := getUUID("stashID") endpoint := "endpoint" updatedAt := time.Now() stashIDs := []models.StashID{ diff --git a/pkg/sqlite/anonymise.go b/pkg/sqlite/anonymise.go index 519489abfc6..1a02444fa57 100644 --- a/pkg/sqlite/anonymise.go +++ b/pkg/sqlite/anonymise.go @@ -24,20 +24,19 @@ const ( ) type Anonymiser struct { - *Database + *SQLiteDB + sourceDb DBInterface } -func NewAnonymiser(db *Database, outPath string) (*Anonymiser, error) { - if _, err := db.writeDB.Exec(fmt.Sprintf(`VACUUM INTO "%s"`, outPath)); err != nil { - return nil, fmt.Errorf("vacuuming into %s: %w", outPath, err) - } +var anon_dialect = goqu.Dialect("sqlite3") - newDB := NewDatabase() - if err := newDB.Open(outPath); err != nil { +func NewAnonymiser(db DBInterface, outPath string) (*Anonymiser, error) { + newDB := NewSQLiteDatabase(NewDatabase(), outPath, false) + if err := newDB.Open(); err != nil { return nil, fmt.Errorf("opening %s: %w", outPath, err) } - return &Anonymiser{Database: newDB}, nil + return &Anonymiser{SQLiteDB: newDB, sourceDb: db}, nil } func (db *Anonymiser) Anonymise(ctx context.Context) error { @@ -45,6 +44,7 @@ func (db *Anonymiser) Anonymise(ctx context.Context) error { defer db.Close() return utils.Do([]func() error{ + func() error { return db.fetch(ctx) }, func() error { return db.deleteBlobs() }, func() error { return db.deleteStashIDs() }, func() error { return db.clearOHistory() }, @@ -74,6 +74,127 @@ func (db *Anonymiser) Anonymise(ctx context.Context) error { return nil } +func (db *Anonymiser) fetch(ctx context.Context) error { + const disableForeignKeys = true + const writable = true + const batchSize = 5000 + + conn, err := db.open(disableForeignKeys, writable) + if err != nil { + return fmt.Errorf("failed to open db: %w", err) + } + + for _, table := range []exp.IdentifierExpression{ + goqu.I(fileTable), + goqu.I(fingerprintTable), + goqu.I(folderTable), + goqu.I(galleryTable), + goqu.I(galleriesChaptersTable), + goqu.I(galleriesFilesTable), + goqu.I(galleriesImagesTable), + goqu.I(galleriesTagsTable), + goqu.I(galleriesURLsTable), + goqu.I(groupURLsTable), + goqu.I(groupTable), + goqu.I(groupRelationsTable), + goqu.I(groupsScenesTable), + goqu.I(groupsTagsTable), + goqu.I(imageFileTable), + goqu.I(imagesURLsTable), + goqu.I(imageTable), + goqu.I(imagesFilesTable), + goqu.I(imagesTagsTable), + goqu.I(performersAliasesTable), + goqu.I("performer_stash_ids"), + goqu.I(performerURLsTable), + goqu.I(performerTable), + goqu.I(performersGalleriesTable), + goqu.I(performersImagesTable), + goqu.I(performersScenesTable), + goqu.I(performersTagsTable), + goqu.I(savedFilterTable), + goqu.I(sceneMarkerTable), + goqu.I("scene_markers_tags"), + goqu.I(scenesURLsTable), + goqu.I(sceneTable), + goqu.I(scenesFilesTable), + goqu.I(scenesGalleriesTable), + goqu.I(scenesODatesTable), + goqu.I(scenesTagsTable), + goqu.I(scenesViewDatesTable), + goqu.I(studioAliasesTable), + goqu.I("studio_stash_ids"), + goqu.I(studioTable), + goqu.I(studiosTagsTable), + goqu.I(tagAliasesTable), + goqu.I(tagTable), + goqu.I(tagRelationsTable), + goqu.I(videoCaptionsTable), + goqu.I(videoFileTable), + } { + offset := 0 + for { + q := dialect.From(table).Select(table.All()).Limit(uint(batchSize)).Offset(uint(offset)) + var rowsSlice []map[string]interface{} + + // Fetch + if err := txn.WithTxn(ctx, db.sourceDb, func(ctx context.Context) error { + if err := queryFunc(ctx, q, false, func(r *sqlx.Rows) error { + for r.Next() { + row := make(map[string]interface{}) + if err := r.MapScan(row); err != nil { + return fmt.Errorf("failed structscan: %w", err) + } + rowsSlice = append(rowsSlice, row) + } + + return nil + }); err != nil { + return fmt.Errorf("querying %s: %w", table, err) + } + + return nil + }); err != nil { + return fmt.Errorf("failed fetch transaction: %w", err) + } + + if len(rowsSlice) == 0 { + break + } + + // Insert + txn, err := conn.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("begin tx: %w", err) + } + + i := anon_dialect.Insert(table).Rows(rowsSlice) + sql, args, err := i.ToSQL() + if err != nil { + return fmt.Errorf("failed tosql: %w", err) + } + + _, err = txn.ExecContext(ctx, sql, args...) + if err != nil { + return fmt.Errorf("exec `%s` [%v]: %w", sql, args, err) + } + + if err := txn.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } + + // Move to the next batch + offset += batchSize + } + } + + if err := conn.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + + return nil +} + func (db *Anonymiser) truncateColumn(tableName string, column string) error { _, err := db.writeDB.Exec("UPDATE " + tableName + " SET " + column + " = NULL") return err @@ -127,7 +248,7 @@ func (db *Anonymiser) anonymiseFolders(ctx context.Context) error { func (db *Anonymiser) anonymiseFoldersRecurse(ctx context.Context, parentFolderID int, parentPath string) error { table := folderTableMgr.table - stmt := dialect.Update(table) + stmt := anon_dialect.Update(table) if parentFolderID == 0 { stmt = stmt.Set(goqu.Record{"path": goqu.Cast(table.Col(idColumn), "VARCHAR")}).Where(table.Col("parent_folder_id").IsNull()) @@ -142,7 +263,7 @@ func (db *Anonymiser) anonymiseFoldersRecurse(ctx context.Context, parentFolderI } // now recurse to sub-folders - query := dialect.From(table).Select(table.Col(idColumn), table.Col("path")) + query := anon_dialect.From(table).Select(table.Col(idColumn), table.Col("path")) if parentFolderID == 0 { query = query.Where(table.Col("parent_folder_id").IsNull()) } else { @@ -165,7 +286,7 @@ func (db *Anonymiser) anonymiseFiles(ctx context.Context) error { logger.Infof("Anonymising files") return txn.WithTxn(ctx, db, func(ctx context.Context) error { table := fileTableMgr.table - stmt := dialect.Update(table).Set(goqu.Record{"basename": goqu.Cast(table.Col(idColumn), "VARCHAR")}) + stmt := anon_dialect.Update(table).Set(goqu.Record{"basename": goqu.Cast(table.Col(idColumn), "VARCHAR")}) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -179,7 +300,7 @@ func (db *Anonymiser) anonymiseCaptions(ctx context.Context) error { logger.Infof("Anonymising captions") return txn.WithTxn(ctx, db, func(ctx context.Context) error { table := goqu.T(videoCaptionsTable) - stmt := dialect.Update(table).Set(goqu.Record{"filename": goqu.Cast(table.Col("file_id"), "VARCHAR")}) + stmt := anon_dialect.Update(table).Set(goqu.Record{"filename": goqu.Cast(table.Col("file_id"), "VARCHAR")}) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -199,7 +320,7 @@ func (db *Anonymiser) anonymiseFingerprints(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(fileIDColumn), table.Col("type"), table.Col("fingerprint"), @@ -256,7 +377,7 @@ func (db *Anonymiser) anonymiseScenes(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("title"), table.Col("details"), @@ -293,7 +414,7 @@ func (db *Anonymiser) anonymiseScenes(ctx context.Context) error { db.obfuscateNullString(set, "details", details) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -343,7 +464,7 @@ func (db *Anonymiser) anonymiseMarkers(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("title"), ).Where(table.Col(idColumn).Gt(lastID)).Limit(1000) @@ -395,7 +516,7 @@ func (db *Anonymiser) anonymiseImages(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("title"), ).Where(table.Col(idColumn).Gt(lastID)).Limit(1000) @@ -420,7 +541,7 @@ func (db *Anonymiser) anonymiseImages(ctx context.Context) error { db.obfuscateNullString(set, "title", title) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -458,7 +579,7 @@ func (db *Anonymiser) anonymiseGalleries(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("title"), table.Col("details"), @@ -491,7 +612,7 @@ func (db *Anonymiser) anonymiseGalleries(ctx context.Context) error { db.obfuscateNullString(set, "photographer", photographer) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -529,7 +650,7 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("name"), table.Col("disambiguation"), @@ -570,7 +691,7 @@ func (db *Anonymiser) anonymisePerformers(ctx context.Context) error { db.obfuscateNullString(set, "piercings", piercings) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -612,7 +733,7 @@ func (db *Anonymiser) anonymiseStudios(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("name"), table.Col("url"), @@ -645,7 +766,7 @@ func (db *Anonymiser) anonymiseStudios(ctx context.Context) error { db.obfuscateNullString(set, "details", details) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -684,7 +805,7 @@ func (db *Anonymiser) anonymiseAliases(ctx context.Context, table exp.Identifier for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("alias"), ).Where(goqu.L("(" + idColumn + ", alias)").Gt(goqu.L("(?, ?)", lastID, lastAlias))).Limit(1000) @@ -709,7 +830,7 @@ func (db *Anonymiser) anonymiseAliases(ctx context.Context, table exp.Identifier db.obfuscateNullString(set, "alias", alias) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where( + stmt := anon_dialect.Update(table).Set(set).Where( table.Col(idColumn).Eq(id), table.Col("alias").Eq(alias), ) @@ -746,7 +867,7 @@ func (db *Anonymiser) anonymiseURLs(ctx context.Context, table exp.IdentifierExp for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("url"), ).Where(goqu.L("(" + idColumn + ", url)").Gt(goqu.L("(?, ?)", lastID, lastURL))).Limit(1000) @@ -771,7 +892,7 @@ func (db *Anonymiser) anonymiseURLs(ctx context.Context, table exp.IdentifierExp db.obfuscateNullString(set, "url", url) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where( + stmt := anon_dialect.Update(table).Set(set).Where( table.Col(idColumn).Eq(id), table.Col("url").Eq(url), ) @@ -809,7 +930,7 @@ func (db *Anonymiser) anonymiseTags(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("name"), table.Col("description"), @@ -838,7 +959,7 @@ func (db *Anonymiser) anonymiseTags(ctx context.Context) error { db.obfuscateNullString(set, "description", description) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -876,7 +997,7 @@ func (db *Anonymiser) anonymiseGroups(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("name"), table.Col("aliases"), @@ -913,7 +1034,7 @@ func (db *Anonymiser) anonymiseGroups(ctx context.Context) error { db.obfuscateNullString(set, "director", director) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -951,7 +1072,7 @@ func (db *Anonymiser) anonymiseSavedFilters(ctx context.Context) error { for gotSome := true; gotSome; { if err := txn.WithTxn(ctx, db, func(ctx context.Context) error { - query := dialect.From(table).Select( + query := anon_dialect.From(table).Select( table.Col(idColumn), table.Col("name"), ).Where(table.Col(idColumn).Gt(lastID)).Limit(1000) @@ -976,7 +1097,7 @@ func (db *Anonymiser) anonymiseSavedFilters(ctx context.Context) error { db.obfuscateNullString(set, "name", name) if len(set) > 0 { - stmt := dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(idColumn).Eq(id)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", table.GetTable(), err) @@ -1005,7 +1126,7 @@ func (db *Anonymiser) anonymiseText(ctx context.Context, table exp.IdentifierExp set := goqu.Record{} set[column] = db.obfuscateString(value, letters) - stmt := dialect.Update(table).Set(set).Where(table.Col(column).Eq(value)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(column).Eq(value)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", column, err) @@ -1018,7 +1139,7 @@ func (db *Anonymiser) anonymiseFingerprint(ctx context.Context, table exp.Identi set := goqu.Record{} set[column] = db.obfuscateString(value, hex) - stmt := dialect.Update(table).Set(set).Where(table.Col(column).Eq(value)) + stmt := anon_dialect.Update(table).Set(set).Where(table.Col(column).Eq(value)) if _, err := exec(ctx, stmt); err != nil { return fmt.Errorf("anonymising %s: %w", column, err) diff --git a/pkg/sqlite/blob.go b/pkg/sqlite/blob.go index 241b63d23cf..f278568a16b 100644 --- a/pkg/sqlite/blob.go +++ b/pkg/sqlite/blob.go @@ -10,12 +10,10 @@ import ( "github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9/exp" "github.com/jmoiron/sqlx" - "github.com/mattn/go-sqlite3" "github.com/stashapp/stash/pkg/file" "github.com/stashapp/stash/pkg/hash/md5" "github.com/stashapp/stash/pkg/logger" "github.com/stashapp/stash/pkg/sqlite/blob" - "github.com/stashapp/stash/pkg/utils" "gopkg.in/guregu/null.v4" ) @@ -70,8 +68,8 @@ func NewBlobStore(options BlobStoreOptions) *BlobStore { } type blobRow struct { - Checksum string `db:"checksum"` - Blob []byte `db:"blob"` + Checksum string `db:"checksum"` + Blob sql.Null[[]byte] `db:"blob"` } func (qb *BlobStore) table() exp.IdentifierExpression { @@ -105,9 +103,10 @@ func (qb *BlobStore) Write(ctx context.Context, data []byte) (string, error) { // only write blob to the database if UseDatabase is true // always at least write the checksum - var storedData []byte + var storedData sql.Null[[]byte] if qb.options.UseDatabase { - storedData = data + storedData.V = data + storedData.Valid = len(storedData.V) > 0 } if err := qb.write(ctx, checksum, storedData); err != nil { @@ -123,12 +122,12 @@ func (qb *BlobStore) Write(ctx context.Context, data []byte) (string, error) { return checksum, nil } -func (qb *BlobStore) write(ctx context.Context, checksum string, data []byte) error { +func (qb *BlobStore) write(ctx context.Context, checksum string, data sql.Null[[]byte]) error { table := qb.table() - q := dialect.Insert(table).Prepared(true).Rows(blobRow{ + q := dialect.Insert(table).Rows(blobRow{ Checksum: checksum, Blob: data, - }).OnConflict(goqu.DoNothing()) + }).OnConflict(goqu.DoNothing()).Prepared(true) _, err := exec(ctx, q) if err != nil { @@ -140,9 +139,9 @@ func (qb *BlobStore) write(ctx context.Context, checksum string, data []byte) er func (qb *BlobStore) update(ctx context.Context, checksum string, data []byte) error { table := qb.table() - q := dialect.Update(table).Prepared(true).Set(goqu.Record{ + q := dialect.Update(table).Set(goqu.Record{ "blob": data, - }).Where(goqu.C(blobChecksumColumn).Eq(checksum)) + }).Where(goqu.C(blobChecksumColumn).Eq(checksum)).Prepared(true) _, err := exec(ctx, q) if err != nil { @@ -168,16 +167,21 @@ func (e *ChecksumBlobNotExistError) Error() string { return fmt.Sprintf("blob for checksum %s does not exist", e.Checksum) } -func (qb *BlobStore) readSQL(ctx context.Context, querySQL string, args ...interface{}) ([]byte, string, error) { +func (qb *BlobStore) readSQL(ctx context.Context, querySQL sqler) ([]byte, string, error) { if !qb.options.UseDatabase && !qb.options.UseFilesystem { panic("no blob store configured") } + query, args, err := querySQL.ToSQL() + if err != nil { + return nil, "", fmt.Errorf("reading blob tosql: %w", err) + } + // always try to get from the database first, even if set to use filesystem var row blobRow found := false const single = true - if err := qb.queryFunc(ctx, querySQL, args, single, func(r *sqlx.Rows) error { + if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error { found = true if err := r.StructScan(&row); err != nil { return err @@ -195,8 +199,8 @@ func (qb *BlobStore) readSQL(ctx context.Context, querySQL string, args ...inter checksum := row.Checksum - if row.Blob != nil { - return row.Blob, checksum, nil + if row.Blob.Valid { + return row.Blob.V, checksum, nil } // don't use the filesystem if not configured to do so @@ -265,8 +269,8 @@ func (qb *BlobStore) Read(ctx context.Context, checksum string) ([]byte, error) } } - if ret != nil { - return ret, nil + if ret.Valid { + return ret.V, nil } // don't use the filesystem if not configured to do so @@ -280,9 +284,10 @@ func (qb *BlobStore) Read(ctx context.Context, checksum string) ([]byte, error) } } -func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) ([]byte, error) { +func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) (sql.Null[[]byte], error) { q := dialect.From(qb.table()).Select(qb.table().All()).Where(qb.tableMgr.byID(checksum)) + var empty sql.Null[[]byte] var row blobRow const single = true if err := queryFunc(ctx, q, single, func(r *sqlx.Rows) error { @@ -292,7 +297,7 @@ func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) ([]b return nil }); err != nil { - return nil, fmt.Errorf("querying %s: %w", qb.table(), err) + return empty, fmt.Errorf("querying %s: %w", qb.table(), err) } return row.Blob, nil @@ -303,7 +308,7 @@ func (qb *BlobStore) readFromDatabase(ctx context.Context, checksum string) ([]b func (qb *BlobStore) Delete(ctx context.Context, checksum string) error { // try to delete the blob from the database if err := qb.delete(ctx, checksum); err != nil { - if qb.isConstraintError(err) { + if isConstraintError(err) { // blob is still referenced - do not delete logger.Debugf("Blob %s is still referenced - not deleting", checksum) return nil @@ -324,24 +329,19 @@ func (qb *BlobStore) Delete(ctx context.Context, checksum string) error { return nil } -func (qb *BlobStore) isConstraintError(err error) bool { - var sqliteError sqlite3.Error - if errors.As(err, &sqliteError) { - return sqliteError.Code == sqlite3.ErrConstraint - } - return false -} - func (qb *BlobStore) delete(ctx context.Context, checksum string) error { table := qb.table() q := dialect.Delete(table).Where(goqu.C(blobChecksumColumn).Eq(checksum)) - _, err := exec(ctx, q) + err := withSavepoint(ctx, func(ctx context.Context) error { + _, err := exec(ctx, q) + return err + }) + if err != nil { return fmt.Errorf("deleting from %s: %w", table, err) } - return nil } @@ -353,15 +353,12 @@ type blobJoinQueryBuilder struct { } func (qb *blobJoinQueryBuilder) GetImage(ctx context.Context, id int, blobCol string) ([]byte, error) { - sqlQuery := utils.StrFormat(` -SELECT blobs.checksum, blobs.blob FROM {joinTable} INNER JOIN blobs ON {joinTable}.{joinCol} = blobs.checksum -WHERE {joinTable}.id = ? -`, utils.StrFormatMap{ - "joinTable": qb.joinTable, - "joinCol": blobCol, - }) + sqlQuery := dialect.From(qb.joinTable). + Join(goqu.I("blobs"), goqu.On(goqu.I(qb.joinTable+"."+blobCol).Eq(goqu.I("blobs.checksum")))). + Select(goqu.I("blobs.checksum"), goqu.I("blobs.blob")). + Where(goqu.Ex{"id": id}) - ret, _, err := qb.blobStore.readSQL(ctx, sqlQuery, id) + ret, _, err := qb.blobStore.readSQL(ctx, sqlQuery) return ret, err } @@ -380,8 +377,17 @@ func (qb *blobJoinQueryBuilder) UpdateImage(ctx context.Context, id int, blobCol return err } - sqlQuery := fmt.Sprintf("UPDATE %s SET %s = ? WHERE id = ?", qb.joinTable, blobCol) - if _, err := dbWrapper.Exec(ctx, sqlQuery, checksum, id); err != nil { + sqlQuery := dialect.From(qb.joinTable).Update(). + Set(goqu.Record{blobCol: checksum}). + Prepared(true). + Where(goqu.Ex{"id": id}) + + query, args, err := sqlQuery.ToSQL() + if err != nil { + return err + } + + if _, err := dbWrapper.Exec(ctx, query, args...); err != nil { return err } @@ -396,15 +402,17 @@ func (qb *blobJoinQueryBuilder) UpdateImage(ctx context.Context, id int, blobCol } func (qb *blobJoinQueryBuilder) getChecksum(ctx context.Context, id int, blobCol string) (*string, error) { - sqlQuery := utils.StrFormat(` -SELECT {joinTable}.{joinCol} FROM {joinTable} WHERE {joinTable}.id = ? -`, utils.StrFormatMap{ - "joinTable": qb.joinTable, - "joinCol": blobCol, - }) + sqlQuery := dialect.From(qb.joinTable). + Select(blobCol). + Where(goqu.Ex{"id": id}) + + query, args, err := sqlQuery.ToSQL() + if err != nil { + return nil, err + } var checksum null.String - err := qb.repository.querySimple(ctx, sqlQuery, []interface{}{id}, &checksum) + err = qb.repository.querySimple(ctx, query, args, &checksum) if err != nil { return nil, err } @@ -427,8 +435,16 @@ func (qb *blobJoinQueryBuilder) DestroyImage(ctx context.Context, id int, blobCo return nil } - updateQuery := fmt.Sprintf("UPDATE %s SET %s = NULL WHERE id = ?", qb.joinTable, blobCol) - if _, err = dbWrapper.Exec(ctx, updateQuery, id); err != nil { + updateQuery := dialect.Update(qb.joinTable). + Set(goqu.Record{blobCol: nil}). + Where(goqu.Ex{"id": id}) + + query, args, err := updateQuery.ToSQL() + if err != nil { + return err + } + + if _, err = dbWrapper.Exec(ctx, query, args...); err != nil { return err } @@ -436,12 +452,22 @@ func (qb *blobJoinQueryBuilder) DestroyImage(ctx context.Context, id int, blobCo } func (qb *blobJoinQueryBuilder) HasImage(ctx context.Context, id int, blobCol string) (bool, error) { - stmt := utils.StrFormat("SELECT COUNT(*) as count FROM (SELECT {joinCol} FROM {joinTable} WHERE id = ? AND {joinCol} IS NOT NULL LIMIT 1)", utils.StrFormatMap{ - "joinTable": qb.joinTable, - "joinCol": blobCol, - }) + ds := dialect.From(goqu.T(qb.joinTable)). + Select(goqu.C(blobCol)). + Where( + goqu.C("id").Eq(id), + goqu.C(blobCol).IsNotNull(), + ). + Limit(1) + + countDs := dialect.From(ds.As("subquery")).Select(goqu.COUNT("*").As("count")) + + sql, params, err := countDs.ToSQL() + if err != nil { + return false, err + } - c, err := qb.repository.runCountQuery(ctx, stmt, []interface{}{id}) + c, err := qb.repository.runCountQuery(ctx, sql, params) if err != nil { return false, err } diff --git a/pkg/sqlite/blob_migrate.go b/pkg/sqlite/blob_migrate.go index e121d07923a..7f42aff2c6d 100644 --- a/pkg/sqlite/blob_migrate.go +++ b/pkg/sqlite/blob_migrate.go @@ -63,7 +63,7 @@ func (qb *BlobStore) migrateBlobDatabase(ctx context.Context, checksum string, d return fmt.Errorf("reading from database: %w", err) } - if len(existing) == 0 { + if len(existing.V) == 0 { // find the blob in the filesystem blob, err := qb.fsStore.Read(ctx, checksum) if err != nil { @@ -94,14 +94,14 @@ func (qb *BlobStore) migrateBlobFilesystem(ctx context.Context, checksum string, return fmt.Errorf("reading from database: %w", err) } - if len(blob) == 0 { + if len(blob.V) == 0 { // it's possible that the blob is already present in the filesystem // just ignore return nil } // write the blob to the filesystem - if err := qb.fsStore.Write(ctx, checksum, blob); err != nil { + if err := qb.fsStore.Write(ctx, checksum, blob.V); err != nil { return fmt.Errorf("writing to filesystem: %w", err) } diff --git a/pkg/sqlite/criterion_handlers.go b/pkg/sqlite/criterion_handlers.go index 55ff31fca82..246ad5a3803 100644 --- a/pkg/sqlite/criterion_handlers.go +++ b/pkg/sqlite/criterion_handlers.go @@ -43,21 +43,21 @@ func stringCriterionHandler(c *models.StringCriterionInput, column string) crite case models.CriterionModifierExcludes: f.whereClauses = append(f.whereClauses, getStringSearchClause([]string{column}, c.Value, true)) case models.CriterionModifierEquals: - f.addWhere(column+" LIKE ?", c.Value) + f.addWhere("LOWER("+column+") LIKE LOWER(?)", c.Value) case models.CriterionModifierNotEquals: - f.addWhere(column+" NOT LIKE ?", c.Value) + f.addWhere("LOWER("+column+") NOT LIKE LOWER(?)", c.Value) case models.CriterionModifierMatchesRegex: if _, err := regexp.Compile(c.Value); err != nil { f.setError(err) return } - f.addWhere(fmt.Sprintf("(%s IS NOT NULL AND %[1]s regexp ?)", column), c.Value) + f.addWhere(fmt.Sprintf("(%s IS NOT NULL AND regexp(?, %[1]s))", column), c.Value) case models.CriterionModifierNotMatchesRegex: if _, err := regexp.Compile(c.Value); err != nil { f.setError(err) return } - f.addWhere(fmt.Sprintf("(%s IS NULL OR %[1]s NOT regexp ?)", column), c.Value) + f.addWhere(fmt.Sprintf("(%s IS NULL OR NOT regexp(?, %[1]s))", column), c.Value) case models.CriterionModifierIsNull: f.addWhere("(" + column + " IS NULL OR TRIM(" + column + ") = '')") case models.CriterionModifierNotNull: @@ -70,6 +70,49 @@ func stringCriterionHandler(c *models.StringCriterionInput, column string) crite } } +func uuidCriterionHandler(c *models.StringCriterionInput, column string) criterionHandlerFunc { + if dbWrapper.dbType == SqliteBackend { + return stringCriterionHandler(c, column) + } + + return func(ctx context.Context, f *filterBuilder) { + columnCast := "CAST(" + column + " AS TEXT)" + + if c != nil { + if modifier := c.Modifier; c.Modifier.IsValid() { + switch modifier { + case models.CriterionModifierIncludes: + f.whereClauses = append(f.whereClauses, getStringSearchClause([]string{columnCast}, c.Value, false)) + case models.CriterionModifierExcludes: + f.whereClauses = append(f.whereClauses, getStringSearchClause([]string{columnCast}, c.Value, true)) + case models.CriterionModifierEquals: + f.addWhere("LOWER("+columnCast+") LIKE LOWER(?)", c.Value) + case models.CriterionModifierNotEquals: + f.addWhere("LOWER("+columnCast+") NOT LIKE LOWER(?)", c.Value) + case models.CriterionModifierMatchesRegex: + if _, err := regexp.Compile(c.Value); err != nil { + f.setError(err) + return + } + f.addWhere(fmt.Sprintf("(%s IS NOT NULL AND regexp(?, %s))", column, columnCast), c.Value) + case models.CriterionModifierNotMatchesRegex: + if _, err := regexp.Compile(c.Value); err != nil { + f.setError(err) + return + } + f.addWhere(fmt.Sprintf("(%s IS NULL OR NOT regexp(?, %s))", column, columnCast), c.Value) + case models.CriterionModifierIsNull: + f.addWhere("(" + column + " IS NULL)") + case models.CriterionModifierNotNull: + f.addWhere("(" + column + " IS NOT NULL)") + default: + panic("unsupported string filter modifier") + } + } + } + } +} + func enumCriterionHandler(modifier models.CriterionModifier, values []string, column string) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { if modifier.IsValid() { @@ -122,14 +165,14 @@ func pathCriterionHandler(c *models.StringCriterionInput, pathColumn string, bas return } filepathColumn := fmt.Sprintf("%s || '%s' || %s", pathColumn, string(filepath.Separator), basenameColumn) - f.addWhere(fmt.Sprintf("%s IS NOT NULL AND %s IS NOT NULL AND %s regexp ?", pathColumn, basenameColumn, filepathColumn), c.Value) + f.addWhere(fmt.Sprintf("%s IS NOT NULL AND %s IS NOT NULL AND regexp(?, %s)", pathColumn, basenameColumn, filepathColumn), c.Value) case models.CriterionModifierNotMatchesRegex: if _, err := regexp.Compile(c.Value); err != nil { f.setError(err) return } filepathColumn := fmt.Sprintf("%s || '%s' || %s", pathColumn, string(filepath.Separator), basenameColumn) - f.addWhere(fmt.Sprintf("%s IS NULL OR %s IS NULL OR %s NOT regexp ?", pathColumn, basenameColumn, filepathColumn), c.Value) + f.addWhere(fmt.Sprintf("%s IS NULL OR %s IS NULL OR NOT regexp(?, %s)", pathColumn, basenameColumn, filepathColumn), c.Value) case models.CriterionModifierIsNull: f.addWhere(fmt.Sprintf("%s IS NULL OR TRIM(%[1]s) = '' OR %s IS NULL OR TRIM(%[2]s) = ''", pathColumn, basenameColumn)) case models.CriterionModifierNotNull: @@ -148,7 +191,7 @@ func getPathSearchClause(pathColumn, basenameColumn, p string, addWildcards, not } filepathColumn := fmt.Sprintf("%s || '%s' || %s", pathColumn, string(filepath.Separator), basenameColumn) - ret := makeClause(fmt.Sprintf("%s LIKE ?", filepathColumn), p) + ret := makeClause(fmt.Sprintf("LOWER(%s) LIKE LOWER(?)", filepathColumn), p) if not { ret = ret.not() @@ -225,12 +268,7 @@ func boolCriterionHandler(c *bool, column string, addJoinFn func(f *filterBuilde if addJoinFn != nil { addJoinFn(f) } - var v string - if *c { - v = "1" - } else { - v = "0" - } + v := getDBBoolean(*c) f.addWhere(column + " = " + v) } @@ -272,7 +310,16 @@ func (h *timestampCriterionHandler) handle(ctx context.Context, f *filterBuilder func yearFilterCriterionHandler(year *models.IntCriterionInput, col string) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { if year != nil && year.Modifier.IsValid() { - clause, args := getIntCriterionWhereClause("cast(strftime('%Y', "+col+") as int)", *year) + var clause string + var args []interface{} + + switch dbWrapper.dbType { + case PostgresBackend: + clause, args = getIntCriterionWhereClause("TO_CHAR("+col+", 'YYYY')::int", *year) + case SqliteBackend: + clause, args = getIntCriterionWhereClause("cast(strftime('%Y', "+col+") as int)", *year) + } + f.addWhere(clause, args...) } } @@ -288,7 +335,7 @@ func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, hei mn := resolution.Value.GetMinResolution() mx := resolution.Value.GetMaxResolution() - widthHeight := fmt.Sprintf("MIN(%s, %s)", widthColumn, heightColumn) + widthHeight := fmt.Sprintf("%s(%s, %s)", getDBMinFunc(), widthColumn, heightColumn) switch resolution.Modifier { case models.CriterionModifierEquals: @@ -417,7 +464,7 @@ func (m *joinedMultiCriterionHandlerBuilder) handler(c *models.MultiCriterionInp "primaryFK": m.primaryFK, "primaryTable": m.primaryTable, }) - havingClause = fmt.Sprintf("count(distinct %s.%s) IS %d", joinAlias, m.foreignFK, len(criterion.Value)) + havingClause = fmt.Sprintf("count(distinct %s.%s) = %d", joinAlias, m.foreignFK, len(criterion.Value)) args = append(args, len(criterion.Value)) case models.CriterionModifierNotEquals: f.setError(fmt.Errorf("not equals modifier is not supported for multi criterion input")) @@ -425,7 +472,7 @@ func (m *joinedMultiCriterionHandlerBuilder) handler(c *models.MultiCriterionInp // includes all of the provided ids m.addJoinTable(f) whereClause = fmt.Sprintf("%s.%s IN %s", joinAlias, m.foreignFK, getInBinding(len(criterion.Value))) - havingClause = fmt.Sprintf("count(distinct %s.%s) IS %d", joinAlias, m.foreignFK, len(criterion.Value)) + havingClause = fmt.Sprintf("count(distinct %s.%s) = %d", joinAlias, m.foreignFK, len(criterion.Value)) } f.addWhere(whereClause, args...) @@ -542,7 +589,7 @@ func (m *stringListCriterionHandlerBuilder) handler(criterion *models.StringCrit // excludes all of the provided values // need to use actual join table name for this // .id NOT IN (select . from where . in ) - whereClause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{primaryFK} from {joinTable} where {joinTable}.{stringColumn} LIKE ?)", + whereClause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{primaryFK} from {joinTable} where LOWER({joinTable}.{stringColumn}) LIKE LOWER(?))", utils.StrFormatMap{ "primaryTable": m.primaryTable, "joinTable": m.joinTable, @@ -601,7 +648,7 @@ type hierarchicalMultiCriterionHandlerBuilder struct { relationsTable string } -func getHierarchicalValues(ctx context.Context, values []string, table, relationsTable, parentFK string, childFK string, depth *int) (string, error) { +func getHierarchicalValues(ctx context.Context, values []string, table, relationsTable, parentFK string, childFK string, depth *int, parenthesis bool) (string, error) { var args []interface{} if parentFK == "" { @@ -632,7 +679,11 @@ func getHierarchicalValues(ctx context.Context, values []string, table, relation } if valid { - return "VALUES" + strings.Join(valuesClauses, ","), nil + values := "VALUES" + strings.Join(valuesClauses, ",") + if parenthesis { + values = "(" + values + ")" + getDBValuesFix() + } + return values, nil } } @@ -679,7 +730,7 @@ WHERE id in {inBinding} {unionClause}) `, withClauseMap) - query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || GROUP_CONCAT('(' || root_id || ', ' || item_id || ')') AS val FROM items", withClause) + query := fmt.Sprintf("WITH RECURSIVE %s SELECT 'VALUES' || "+getDBGroupConcat("'(' || root_id || ', ' || item_id || ')'")+" AS val FROM items", withClause) var valuesClause sql.NullString err := dbWrapper.Get(ctx, &valuesClause, query, args...) @@ -695,6 +746,10 @@ WHERE id in {inBinding} valuesClause.String = "VALUES" + strings.Join(values, ",") } + if parenthesis { + valuesClause.String = "(" + valuesClause.String + ")" + getDBValuesFix() + } + return valuesClause.String, nil } @@ -704,7 +759,7 @@ func addHierarchicalConditionClauses(f *filterBuilder, criterion models.Hierarch f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) case models.CriterionModifierIncludesAll: f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) - f.addHaving(fmt.Sprintf("count(distinct %s.%s) IS %d", table, idColumn, len(criterion.Value))) + f.addHaving(fmt.Sprintf("count(distinct %s.%s) = %d", table, idColumn, len(criterion.Value))) case models.CriterionModifierExcludes: f.addWhere(fmt.Sprintf("%s.%s IS NULL", table, idColumn)) } @@ -748,7 +803,7 @@ func (m *hierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hierarchica } if len(criterion.Value) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, true) if err != nil { f.setError(err) return @@ -756,21 +811,21 @@ func (m *hierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hierarchica switch criterion.Modifier { case models.CriterionModifierIncludes: - f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM (%s))", m.primaryTable, m.foreignFK, valuesClause)) + f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM %s)", m.primaryTable, m.foreignFK, valuesClause)) case models.CriterionModifierIncludesAll: - f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM (%s))", m.primaryTable, m.foreignFK, valuesClause)) - f.addHaving(fmt.Sprintf("count(distinct %s.%s) IS %d", m.primaryTable, m.foreignFK, len(criterion.Value))) + f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM %s)", m.primaryTable, m.foreignFK, valuesClause)) + f.addHaving(fmt.Sprintf("count(distinct %s.%s) = %d", m.primaryTable, m.foreignFK, len(criterion.Value))) } } if len(criterion.Excludes) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, true) if err != nil { f.setError(err) return } - f.addWhere(fmt.Sprintf("%s.%s NOT IN (SELECT column2 FROM (%s)) OR %[1]s.%[2]s IS NULL", m.primaryTable, m.foreignFK, valuesClause)) + f.addWhere(fmt.Sprintf("%s.%s NOT IN (SELECT column2 FROM %s) OR %[1]s.%[2]s IS NULL", m.primaryTable, m.foreignFK, valuesClause)) } } } @@ -801,7 +856,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) addHierarchicalConditio case models.CriterionModifierEquals: // includes only the provided ids f.addWhere(fmt.Sprintf("%s.%s IS NOT NULL", table, idColumn)) - f.addHaving(fmt.Sprintf("count(distinct %s.%s) IS %d", table, idColumn, len(criterion.Value))) + f.addHaving(fmt.Sprintf("count(distinct %s.%s) = %d", table, idColumn, len(criterion.Value))) f.addWhere(utils.StrFormat("(SELECT COUNT(*) FROM {joinTable} s WHERE s.{primaryFK} = {primaryTable}.{primaryKey}) = ?", utils.StrFormatMap{ "joinTable": m.joinTable, "primaryFK": m.primaryFK, @@ -859,7 +914,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hiera } if len(criterion.Value) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, false) if err != nil { f.setError(err) return @@ -881,7 +936,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hiera } if len(criterion.Excludes) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, false) if err != nil { f.setError(err) return @@ -959,7 +1014,7 @@ func (h *joinedPerformerTagsHandler) handle(ctx context.Context, f *filterBuilde } if len(criterion.Value) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Value, tagTable, "tags_relations", "", "", criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Value, tagTable, "tags_relations", "", "", criterion.Depth, false) if err != nil { f.setError(err) return @@ -977,13 +1032,13 @@ INNER JOIN (`+valuesClause+`) t ON t.column2 = pt.tag_id } if len(criterion.Excludes) > 0 { - valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, tagTable, "tags_relations", "", "", criterion.Depth) + valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, tagTable, "tags_relations", "", "", criterion.Depth, true) if err != nil { f.setError(err) return } - clause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{joinPrimaryKey} FROM {joinTable} INNER JOIN performers_tags ON {joinTable}.performer_id = performers_tags.performer_id WHERE performers_tags.tag_id IN (SELECT column2 FROM (%s)))", strFormatMap) + clause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{joinPrimaryKey} FROM {joinTable} INNER JOIN performers_tags ON {joinTable}.performer_id = performers_tags.performer_id WHERE performers_tags.tag_id IN (SELECT column2 FROM %s))", strFormatMap) f.addWhere(fmt.Sprintf(clause, valuesClause)) } } @@ -1019,7 +1074,7 @@ func (h *stashIDCriterionHandler) handle(ctx context.Context, f *filterBuilder) v = *h.c.StashID } - stringCriterionHandler(&models.StringCriterionInput{ + uuidCriterionHandler(&models.StringCriterionInput{ Value: v, Modifier: h.c.Modifier, }, t+".stash_id")(ctx, f) diff --git a/pkg/sqlite/database.go b/pkg/sqlite/database.go index d2c0a8191e5..2eff0011c47 100644 --- a/pkg/sqlite/database.go +++ b/pkg/sqlite/database.go @@ -6,14 +6,16 @@ import ( "embed" "errors" "fmt" - "os" "path/filepath" + "strconv" "time" + "github.com/jackc/pgx/v5/pgconn" "github.com/jmoiron/sqlx" + "github.com/mattn/go-sqlite3" - "github.com/stashapp/stash/pkg/fsutil" "github.com/stashapp/stash/pkg/logger" + "github.com/stashapp/stash/pkg/models" ) const ( @@ -36,7 +38,7 @@ const ( var appSchemaVersion uint = 70 -//go:embed migrations/*.sql +//go:embed migrations/*.sql migrationsPostgres/*.sql var migrationsBox embed.FS var ( @@ -81,19 +83,73 @@ type storeRepository struct { Group *GroupStore } +type DatabaseType string + +const ( + PostgresBackend DatabaseType = "POSTGRESQL" + SqliteBackend DatabaseType = "SQLITE" +) + +type DBInterface interface { + Analyze(ctx context.Context) error + Anonymise(outPath string) error + AnonymousDatabasePath(backupDirectoryPath string) string + AppSchemaVersion() uint + Backup(backupPath string) (err error) + Begin(ctx context.Context, writable bool) (context.Context, error) + Close() error + Commit(ctx context.Context) error + DatabaseBackupPath(backupDirectoryPath string) string + DatabasePath() string + DatabaseType() DatabaseType + ExecSQL(ctx context.Context, query string, args []interface{}) (*int64, error) + getDatabaseSchemaVersion() (uint, error) + GetReadDB() *sqlx.DB + GetRepo() *storeRepository + GetWriteDB() *sqlx.DB + initialise() error + IsLocked(err error) bool + lock() + needsMigration() bool + Open() error + open(disableForeignKeys bool, writable bool) (conn *sqlx.DB, err error) + openReadDB() error + openWriteDB() error + Optimise(ctx context.Context) error + QuerySQL(ctx context.Context, query string, args []interface{}) ([]string, [][]interface{}, error) + Ready() error + ReInitialise() error + Remove() error + Repository() models.Repository + Reset() error + RestoreFromBackup(backupPath string) error + Rollback(ctx context.Context) error + RunAllMigrations() error + SetBlobStoreOptions(options BlobStoreOptions) + SetSchemaVersion(version uint) + txnComplete(ctx context.Context) + unlock() + Vacuum(ctx context.Context) error + Version() uint + WithDatabase(ctx context.Context) (context.Context, error) + TestMode() + Pointer() *Database +} + type Database struct { *storeRepository + DBInterface - readDB *sqlx.DB - writeDB *sqlx.DB - dbPath string + readDB *sqlx.DB + writeDB *sqlx.DB + dbConfig string schemaVersion uint lockChan chan struct{} } -func NewDatabase() *Database { +func newStoreRepo() *storeRepository { fileStore := NewFileStore() folderStore := NewFolderStore() galleryStore := NewGalleryStore(fileStore, folderStore) @@ -119,18 +175,101 @@ func NewDatabase() *Database { SavedFilter: NewSavedFilterStore(), } - ret := &Database{ - storeRepository: r, + return r +} + +func NewDatabase() *Database { + return &Database{ + storeRepository: newStoreRepo(), lockChan: make(chan struct{}, 1), } +} - return ret +func getDBBoolean(val bool) string { + switch dbWrapper.dbType { + case SqliteBackend: + if val { + return "1" + } else { + return "0" + } + default: + return strconv.FormatBool(val) + } +} + +func getDBValuesFix() (val string) { + if dbWrapper.dbType == PostgresBackend { + val = " AS v(column1, column2)" + } + + return val +} + +func getDBMinFunc() string { + switch dbWrapper.dbType { + case PostgresBackend: + return "LEAST" + default: + return "MIN" + } +} + +func getDBGroupConcat(columnName string) string { + switch dbWrapper.dbType { + case PostgresBackend: + return "STRING_AGG(" + columnName + "::TEXT, ',')" + default: + return "GROUP_CONCAT(" + columnName + ")" + } +} + +func getDBRowId() string { + switch dbWrapper.dbType { + case PostgresBackend: + return "ctid" + default: + return "rowid" + } +} + +func isConstraintError(err error) bool { + switch dbWrapper.dbType { + case PostgresBackend: + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + // Class 23 — Integrity Constraint Violation + return pgErr.Code[:2] == "23" + } + case SqliteBackend: + var sqliteError sqlite3.Error + if errors.As(err, &sqliteError) { + return sqliteError.Code == sqlite3.ErrConstraint + } + } + return false +} + +func (db *Database) Pointer() *Database { + return db +} + +func (db *Database) SetSchemaVersion(version uint) { + db.schemaVersion = version +} + +func (db *Database) GetRepo() *storeRepository { + return db.storeRepository } func (db *Database) SetBlobStoreOptions(options BlobStoreOptions) { *db.Blobs = *NewBlobStore(options) } +func (db *Database) DatabasePath() string { + return "" +} + // Ready returns an error if the database is not ready to begin transactions. func (db *Database) Ready() error { if db.readDB == nil || db.writeDB == nil { @@ -144,12 +283,10 @@ func (db *Database) Ready() error { // performs a full migration to the latest schema version. Otherwise, any // necessary migrations must be run separately using RunMigrations. // Returns true if the database is new. -func (db *Database) Open(dbPath string) error { +func (db *Database) Open() error { db.lock() defer db.unlock() - db.dbPath = dbPath - databaseSchemaVersion, err := db.getDatabaseSchemaVersion() if err != nil { return fmt.Errorf("getting database schema version: %w", err) @@ -165,10 +302,10 @@ func (db *Database) Open(dbPath string) error { return fmt.Errorf("error running initial schema migrations: %w", err) } } else { - if databaseSchemaVersion > appSchemaVersion { + if databaseSchemaVersion > db.AppSchemaVersion() { return &MismatchedSchemaVersionError{ CurrentSchemaVersion: databaseSchemaVersion, - RequiredSchemaVersion: appSchemaVersion, + RequiredSchemaVersion: db.AppSchemaVersion(), } } @@ -176,7 +313,7 @@ func (db *Database) Open(dbPath string) error { if db.needsMigration() { return &MigrationNeededError{ CurrentSchemaVersion: databaseSchemaVersion, - RequiredSchemaVersion: appSchemaVersion, + RequiredSchemaVersion: db.AppSchemaVersion(), } } } @@ -196,22 +333,6 @@ func (db *Database) Open(dbPath string) error { return nil } -// lock locks the database for writing. This method will block until the lock is acquired. -func (db *Database) lock() { - db.lockChan <- struct{}{} -} - -// unlock unlocks the database -func (db *Database) unlock() { - // will block the caller if the lock is not held, so check first - select { - case <-db.lockChan: - return - default: - panic("database is not locked") - } -} - func (db *Database) Close() error { db.lock() defer db.unlock() @@ -234,33 +355,6 @@ func (db *Database) Close() error { return nil } -func (db *Database) open(disableForeignKeys bool, writable bool) (*sqlx.DB, error) { - // https://github.com/mattn/go-sqlite3 - url := "file:" + db.dbPath + "?_journal=WAL&_sync=NORMAL&_busy_timeout=50" - if !disableForeignKeys { - url += "&_fk=true" - } - - if writable { - url += "&_txlock=immediate" - } else { - url += "&mode=ro" - } - - // #5155 - set the cache size if the environment variable is set - // default is -2000 which is 2MB - if cacheSize := os.Getenv(cacheSizeEnv); cacheSize != "" { - url += "&_cache_size=" + cacheSize - } - - conn, err := sqlx.Open(sqlite3Driver, url) - if err != nil { - return nil, fmt.Errorf("db.Open(): %w", err) - } - - return conn, nil -} - func (db *Database) initialise() error { if err := db.openReadDB(); err != nil { return fmt.Errorf("opening read database: %w", err) @@ -272,128 +366,34 @@ func (db *Database) initialise() error { return nil } -func (db *Database) openReadDB() error { - const ( - disableForeignKeys = false - writable = false - ) - var err error - db.readDB, err = db.open(disableForeignKeys, writable) - db.readDB.SetMaxOpenConns(maxReadConnections) - db.readDB.SetMaxIdleConns(maxReadConnections) - db.readDB.SetConnMaxIdleTime(dbConnTimeout) - return err -} - -func (db *Database) openWriteDB() error { - const ( - disableForeignKeys = false - writable = true - ) - var err error - db.writeDB, err = db.open(disableForeignKeys, writable) - db.writeDB.SetMaxOpenConns(maxWriteConnections) - db.writeDB.SetMaxIdleConns(maxWriteConnections) - db.writeDB.SetConnMaxIdleTime(dbConnTimeout) - return err -} - -func (db *Database) Remove() error { - databasePath := db.dbPath - err := db.Close() - - if err != nil { - return fmt.Errorf("error closing database: %w", err) - } +func (db *Database) Anonymise(outPath string) error { + anon, err := NewAnonymiser(db, outPath) - err = os.Remove(databasePath) if err != nil { - return fmt.Errorf("error removing database: %w", err) + return err } - // remove the -shm, -wal files ( if they exist ) - walFiles := []string{databasePath + "-shm", databasePath + "-wal"} - for _, wf := range walFiles { - if exists, _ := fsutil.FileExists(wf); exists { - err = os.Remove(wf) - if err != nil { - return fmt.Errorf("error removing database: %w", err) - } - } - } + return anon.Anonymise(context.Background()) +} - return nil +func (db *Database) Version() uint { + return db.schemaVersion } func (db *Database) Reset() error { - databasePath := db.dbPath if err := db.Remove(); err != nil { return err } - if err := db.Open(databasePath); err != nil { + if err := db.Open(); err != nil { return fmt.Errorf("[reset DB] unable to initialize: %w", err) } return nil } -// Backup the database. If db is nil, then uses the existing database -// connection. -func (db *Database) Backup(backupPath string) (err error) { - thisDB := db.writeDB - if thisDB == nil { - thisDB, err = sqlx.Connect(sqlite3Driver, "file:"+db.dbPath+"?_fk=true") - if err != nil { - return fmt.Errorf("open database %s failed: %w", db.dbPath, err) - } - defer thisDB.Close() - } - - logger.Infof("Backing up database into: %s", backupPath) - _, err = thisDB.Exec(`VACUUM INTO "` + backupPath + `"`) - if err != nil { - return fmt.Errorf("vacuum failed: %w", err) - } - - return nil -} - -func (db *Database) Anonymise(outPath string) error { - anon, err := NewAnonymiser(db, outPath) - - if err != nil { - return err - } - - return anon.Anonymise(context.Background()) -} - -func (db *Database) RestoreFromBackup(backupPath string) error { - logger.Infof("Restoring backup database %s into %s", backupPath, db.dbPath) - return os.Rename(backupPath, db.dbPath) -} - -func (db *Database) AppSchemaVersion() uint { - return appSchemaVersion -} - -func (db *Database) DatabasePath() string { - return db.dbPath -} - -func (db *Database) DatabaseBackupPath(backupDirectoryPath string) string { - fn := fmt.Sprintf("%s.%d.%s", filepath.Base(db.dbPath), db.schemaVersion, time.Now().Format("20060102_150405")) - - if backupDirectoryPath != "" { - return filepath.Join(backupDirectoryPath, fn) - } - - return fn -} - func (db *Database) AnonymousDatabasePath(backupDirectoryPath string) string { - fn := fmt.Sprintf("%s.anonymous.%d.%s", filepath.Base(db.dbPath), db.schemaVersion, time.Now().Format("20060102_150405")) + fn := fmt.Sprintf("%s.anonymous.%d.%s", filepath.Base(db.DatabasePath()), db.schemaVersion, time.Now().Format("20060102_150405")) if backupDirectoryPath != "" { return filepath.Join(backupDirectoryPath, fn) @@ -402,10 +402,6 @@ func (db *Database) AnonymousDatabasePath(backupDirectoryPath string) string { return fn } -func (db *Database) Version() uint { - return db.schemaVersion -} - func (db *Database) Optimise(ctx context.Context) error { logger.Info("Optimising database") @@ -434,12 +430,20 @@ func (db *Database) Analyze(ctx context.Context) error { return err } -func (db *Database) ExecSQL(ctx context.Context, query string, args []interface{}) (*int64, *int64, error) { +func (db *Database) GetWriteDB() *sqlx.DB { + return db.writeDB +} + +func (db *Database) GetReadDB() *sqlx.DB { + return db.readDB +} + +func (db *Database) ExecSQL(ctx context.Context, query string, args []interface{}) (*int64, error) { wrapper := dbWrapperType{} result, err := wrapper.Exec(ctx, query, args...) if err != nil { - return nil, nil, err + return nil, err } var rowsAffected *int64 @@ -448,13 +452,7 @@ func (db *Database) ExecSQL(ctx context.Context, query string, args []interface{ rowsAffected = &ra } - var lastInsertId *int64 - li, err := result.LastInsertId() - if err == nil { - lastInsertId = &li - } - - return rowsAffected, lastInsertId, nil + return rowsAffected, nil } func (db *Database) QuerySQL(ctx context.Context, query string, args []interface{}) ([]string, [][]interface{}, error) { diff --git a/pkg/sqlite/database_postgres.go b/pkg/sqlite/database_postgres.go new file mode 100644 index 00000000000..83f79607bf5 --- /dev/null +++ b/pkg/sqlite/database_postgres.go @@ -0,0 +1,142 @@ +package sqlite + +import ( + "fmt" + + "github.com/doug-martin/goqu/v9" + _ "github.com/doug-martin/goqu/v9/dialect/postgres" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/logger" +) + +type PostgresDB struct { + *Database +} + +const ( + // TODO: Optimize for these + maxPGWriteConnections = 5 + maxPGReadConnections = 15 +) + +func NewPostgresDatabase(odb *Database, dbConnector string, init bool) *PostgresDB { + db := &PostgresDB{ + Database: odb, + } + + db.dbConfig = dbConnector + db.DBInterface = db + + if init { + dialect = goqu.Dialect("postgres") + dbWrapper.dbType = PostgresBackend + } + + return db +} + +// Does nothing +func (db *PostgresDB) lock() {} +func (db *PostgresDB) unlock() {} + +func (db *PostgresDB) openReadDB() error { + const ( + disableForeignKeys = false + writable = false + ) + var err error + db.readDB, err = db.open(disableForeignKeys, writable) + db.readDB.SetMaxOpenConns(maxPGReadConnections) + db.readDB.SetMaxIdleConns(maxPGReadConnections) + db.readDB.SetConnMaxIdleTime(dbConnTimeout) + return err +} + +func (db *PostgresDB) openWriteDB() error { + const ( + disableForeignKeys = false + writable = true + ) + var err error + db.writeDB, err = db.open(disableForeignKeys, writable) + db.writeDB.SetMaxOpenConns(maxPGWriteConnections) + db.writeDB.SetMaxIdleConns(maxPGWriteConnections) + db.writeDB.SetConnMaxIdleTime(dbConnTimeout) + return err +} + +// Ensure single connection for testing to avoid race conditions +func (db *PostgresDB) TestMode() { + db.readDB.Close() + db.readDB = db.writeDB +} + +func (db *PostgresDB) DatabaseType() DatabaseType { + return PostgresBackend +} + +func (db *PostgresDB) AppSchemaVersion() uint { + return uint(0 - (66 - int(appSchemaVersion))) +} + +func (db *PostgresDB) DatabaseConnector() string { + return db.dbConfig +} + +func (db *PostgresDB) open(disableForeignKeys bool, writable bool) (conn *sqlx.DB, err error) { + conn, err = sqlx.Open("pgx", db.DatabaseConnector()) + + if err != nil { + return nil, fmt.Errorf("db.Open(): %w", err) + } + + if disableForeignKeys { + _, err = conn.Exec("SET session_replication_role = replica;") + + if err != nil { + return nil, fmt.Errorf("conn.Exec(): %w", err) + } + } + if !writable { + _, err = conn.Exec("SET SESSION CHARACTERISTICS AS TRANSACTION READ ONLY;") + + if err != nil { + return nil, fmt.Errorf("conn.Exec(): %w", err) + } + } + + return conn, nil +} + +func (db *PostgresDB) Remove() (err error) { + _, err = db.writeDB.Exec(` +DO $$ +DECLARE + r record; +BEGIN + FOR r IN SELECT quote_ident(tablename) AS tablename, quote_ident(schemaname) AS schemaname FROM pg_tables WHERE schemaname = 'public' + LOOP + RAISE INFO 'Dropping table %.%', r.schemaname, r.tablename; + EXECUTE format('DROP TABLE IF EXISTS %I.%I CASCADE', r.schemaname, r.tablename); + END LOOP; +END$$; +`) + + return err +} + +func (db *PostgresDB) Backup(backupPath string) (err error) { + logger.Warn("Postgres backend detected, ignoring Backup request") + return nil +} + +func (db *PostgresDB) RestoreFromBackup(backupPath string) (err error) { + logger.Warn("Postgres backend detected, ignoring RestoreFromBackup request") + return nil +} + +func (db *PostgresDB) DatabaseBackupPath(backupDirectoryPath string) string { + logger.Warn("Postgres backend detected, ignoring DatabaseBackupPath request") + return "" +} diff --git a/pkg/sqlite/database_sqlite.go b/pkg/sqlite/database_sqlite.go new file mode 100644 index 00000000000..943093117fe --- /dev/null +++ b/pkg/sqlite/database_sqlite.go @@ -0,0 +1,181 @@ +package sqlite + +import ( + "fmt" + "os" + "path/filepath" + "time" + + "github.com/doug-martin/goqu/v9" + "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/fsutil" + "github.com/stashapp/stash/pkg/logger" +) + +type SQLiteDB struct { + *Database +} + +func NewSQLiteDatabase(odb *Database, dbPath string, init bool) *SQLiteDB { + db := &SQLiteDB{ + Database: odb, + } + + db.dbConfig = dbPath + db.DBInterface = db + + if init { + dialect = goqu.Dialect("sqlite3") + dbWrapper.dbType = SqliteBackend + } + + return db +} + +// Does nothing +func (db *SQLiteDB) TestMode() {} + +// lock locks the database for writing. This method will block until the lock is acquired. +func (db *SQLiteDB) lock() { + db.lockChan <- struct{}{} +} + +// unlock unlocks the database +func (db *SQLiteDB) unlock() { + // will block the caller if the lock is not held, so check first + select { + case <-db.lockChan: + return + default: + panic("database is not locked") + } +} + +func (db *SQLiteDB) openReadDB() error { + const ( + disableForeignKeys = false + writable = false + ) + var err error + db.readDB, err = db.open(disableForeignKeys, writable) + db.readDB.SetMaxOpenConns(maxReadConnections) + db.readDB.SetMaxIdleConns(maxReadConnections) + db.readDB.SetConnMaxIdleTime(dbConnTimeout) + return err +} + +func (db *SQLiteDB) openWriteDB() error { + const ( + disableForeignKeys = false + writable = true + ) + var err error + db.writeDB, err = db.open(disableForeignKeys, writable) + db.writeDB.SetMaxOpenConns(maxWriteConnections) + db.writeDB.SetMaxIdleConns(maxWriteConnections) + db.writeDB.SetConnMaxIdleTime(dbConnTimeout) + return err +} + +func (db *SQLiteDB) AppSchemaVersion() uint { + return appSchemaVersion +} + +func (db *SQLiteDB) DatabaseType() DatabaseType { + return SqliteBackend +} + +func (db *SQLiteDB) DatabasePath() string { + return db.dbConfig +} + +func (db *SQLiteDB) open(disableForeignKeys bool, writable bool) (conn *sqlx.DB, err error) { + // https://github.com/mattn/go-sqlite3 + url := "file:" + db.DatabasePath() + "?_journal=WAL&_sync=NORMAL&_busy_timeout=50" + if !disableForeignKeys { + url += "&_fk=true" + } + + if writable { + url += "&_txlock=immediate" + } else { + url += "&mode=ro" + } + + // #5155 - set the cache size if the environment variable is set + // default is -2000 which is 2MB + if cacheSize := os.Getenv(cacheSizeEnv); cacheSize != "" { + url += "&_cache_size=" + cacheSize + } + + conn, err = sqlx.Open(sqlite3Driver, url) + + if err != nil { + return nil, fmt.Errorf("db.Open(): %w", err) + } + + return conn, nil +} + +func (db *SQLiteDB) Remove() error { + databasePath := db.DatabasePath() + err := db.Close() + + if err != nil { + return fmt.Errorf("error closing database: %w", err) + } + + err = os.Remove(databasePath) + if err != nil { + return fmt.Errorf("error removing database: %w", err) + } + + // remove the -shm, -wal files ( if they exist ) + walFiles := []string{databasePath + "-shm", databasePath + "-wal"} + for _, wf := range walFiles { + if exists, _ := fsutil.FileExists(wf); exists { + err = os.Remove(wf) + if err != nil { + return fmt.Errorf("error removing database: %w", err) + } + } + } + + return nil +} + +// Backup the database. If db is nil, then uses the existing database +// connection. +func (db *SQLiteDB) Backup(backupPath string) (err error) { + thisDB := db.writeDB + if thisDB == nil { + thisDB, err = sqlx.Connect(sqlite3Driver, "file:"+db.DatabasePath()+"?_fk=true") + if err != nil { + return fmt.Errorf("open database %s failed: %w", db.DatabasePath(), err) + } + defer thisDB.Close() + } + + logger.Infof("Backing up database into: %s", backupPath) + _, err = thisDB.Exec(`VACUUM INTO "` + backupPath + `"`) + if err != nil { + return fmt.Errorf("vacuum failed: %w", err) + } + + return nil +} + +func (db *SQLiteDB) RestoreFromBackup(backupPath string) error { + logger.Infof("Restoring backup database %s into %s", backupPath, db.DatabasePath()) + return os.Rename(backupPath, db.DatabasePath()) +} + +func (db *SQLiteDB) DatabaseBackupPath(backupDirectoryPath string) string { + fn := fmt.Sprintf("%s.%d.%s", filepath.Base(db.DatabasePath()), db.schemaVersion, time.Now().Format("20060102_150405")) + + if backupDirectoryPath != "" { + return filepath.Join(backupDirectoryPath, fn) + } + + return fn +} diff --git a/pkg/sqlite/file.go b/pkg/sqlite/file.go index 6bf6e32b51f..7a3a1ca2e4a 100644 --- a/pkg/sqlite/file.go +++ b/pkg/sqlite/file.go @@ -549,6 +549,8 @@ func (qb *FileStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]mode return err } + f.fingerprintQueryRow.correct() + rows = append(rows, f) return nil }); err != nil { @@ -622,8 +624,8 @@ func (qb *FileStore) FindAllByPath(ctx context.Context, p string) ([]models.File if strings.Contains(basename, "%") || strings.Contains(dirName, "%") { q = q.Where( - folderTable.Col("path").Like(dirName), - table.Col("basename").Like(basename), + folderTable.Col("path").ILike(dirName), + table.Col("basename").ILike(basename), ) } else { q = q.Where( @@ -647,7 +649,7 @@ func (qb *FileStore) allInPaths(q *goqu.SelectDataset, p []string) *goqu.SelectD for _, pp := range p { ppWildcard := pp + string(filepath.Separator) + "%" - conds = append(conds, folderTable.Col("path").Eq(pp), folderTable.Col("path").Like(ppWildcard)) + conds = append(conds, folderTable.Col("path").Eq(pp), folderTable.Col("path").ILike(ppWildcard)) } return q.Where( @@ -762,9 +764,9 @@ func (qb *FileStore) IsPrimary(ctx context.Context, fileID models.FileID) (bool, var sq *goqu.SelectDataset for _, t := range joinTables { - qq := dialect.From(t).Select(t.Col(fileIDColumn)).Where( + qq := dialect.From(t).Select(t.Col(fileIDColumn)).Prepared(true).Where( t.Col(fileIDColumn).Eq(fileID), - t.Col("primary").Eq(1), + t.Col("primary").IsTrue(), ) if sq == nil { @@ -774,7 +776,7 @@ func (qb *FileStore) IsPrimary(ctx context.Context, fileID models.FileID) (bool, } } - q := dialect.Select(goqu.COUNT("*").As("count")).Prepared(true).From( + q := dialect.Select(goqu.COUNT("*").As("count")).From( sq, ) @@ -926,9 +928,7 @@ var fileSortOptions = sortOptions{ } func (qb *FileStore) setQuerySort(query *queryBuilder, findFilter *models.FindFilterType) error { - if findFilter == nil || findFilter.Sort == nil || *findFilter.Sort == "" { - return nil - } + models.EnsureFindFilterSorted(findFilter) sort := findFilter.GetSort("path") // CVE-2024-32231 - ensure sort is in the list of allowed sorts @@ -941,8 +941,11 @@ func (qb *FileStore) setQuerySort(query *queryBuilder, findFilter *models.FindFi case "path": // special handling for path query.sortAndPagination += fmt.Sprintf(" ORDER BY folders.path %s, files.basename %[1]s", direction) + query.addGroupBy("folders.path", "files.basename") default: - query.sortAndPagination += getSort(sort, direction, "files") + add, agg := getSort(sort, direction, "files") + query.sortAndPagination += add + query.addGroupBy(agg...) } return nil diff --git a/pkg/sqlite/file_test.go b/pkg/sqlite/file_test.go index 766ffcc70b7..728667800b1 100644 --- a/pkg/sqlite/file_test.go +++ b/pkg/sqlite/file_test.go @@ -192,7 +192,7 @@ func Test_fileFileStore_Create(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -419,7 +419,7 @@ func Test_fileStore_Update(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -491,7 +491,7 @@ func Test_fileStore_Find(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -546,7 +546,7 @@ func Test_FileStore_FindByPath(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -598,7 +598,7 @@ func TestFileStore_FindByFingerprint(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -647,7 +647,7 @@ func TestFileStore_IsPrimary(t *testing.T) { }, } - qb := db.File + qb := db.GetRepo().File for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { diff --git a/pkg/sqlite/filter_internal_test.go b/pkg/sqlite/filter_internal_test.go index f416b661cbc..9444bdd4af3 100644 --- a/pkg/sqlite/filter_internal_test.go +++ b/pkg/sqlite/filter_internal_test.go @@ -471,7 +471,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%[1]s LIKE ? OR %[1]s LIKE ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(LOWER(%[1]s) LIKE LOWER(?) OR LOWER(%[1]s) LIKE LOWER(?))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 2) assert.Equal("%two%", f.whereClauses[0].args[0]) assert.Equal("%words%", f.whereClauses[0].args[1]) @@ -483,7 +483,7 @@ func TestStringCriterionHandlerIncludes(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%[1]s LIKE ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(LOWER(%[1]s) LIKE LOWER(?))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal("%two words%", f.whereClauses[0].args[0]) } @@ -502,7 +502,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%[1]s NOT LIKE ? AND %[1]s NOT LIKE ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(LOWER(%[1]s) NOT LIKE LOWER(?) AND LOWER(%[1]s) NOT LIKE LOWER(?))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 2) assert.Equal("%two%", f.whereClauses[0].args[0]) assert.Equal("%words%", f.whereClauses[0].args[1]) @@ -514,7 +514,7 @@ func TestStringCriterionHandlerExcludes(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%[1]s NOT LIKE ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(LOWER(%[1]s) NOT LIKE LOWER(?))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal("%two words%", f.whereClauses[0].args[0]) } @@ -532,7 +532,7 @@ func TestStringCriterionHandlerEquals(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("%[1]s LIKE ?", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("LOWER(%[1]s) LIKE LOWER(?)", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal(value1, f.whereClauses[0].args[0]) } @@ -550,7 +550,7 @@ func TestStringCriterionHandlerNotEquals(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("%[1]s NOT LIKE ?", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("LOWER(%[1]s) NOT LIKE LOWER(?)", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal(value1, f.whereClauses[0].args[0]) } @@ -569,7 +569,7 @@ func TestStringCriterionHandlerMatchesRegex(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%s IS NOT NULL AND %[1]s regexp ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(%s IS NOT NULL AND regexp(?, %[1]s))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal(validValue, f.whereClauses[0].args[0]) @@ -597,7 +597,7 @@ func TestStringCriterionHandlerNotMatchesRegex(t *testing.T) { }, column)) assert.Len(f.whereClauses, 1) - assert.Equal(fmt.Sprintf("(%s IS NULL OR %[1]s NOT regexp ?)", column), f.whereClauses[0].sql) + assert.Equal(fmt.Sprintf("(%s IS NULL OR NOT regexp(?, %[1]s))", column), f.whereClauses[0].sql) assert.Len(f.whereClauses[0].args, 1) assert.Equal(validValue, f.whereClauses[0].args[0]) diff --git a/pkg/sqlite/fingerprint.go b/pkg/sqlite/fingerprint.go index d65f6bab55a..c3bfaa6ddcb 100644 --- a/pkg/sqlite/fingerprint.go +++ b/pkg/sqlite/fingerprint.go @@ -3,6 +3,8 @@ package sqlite import ( "context" "fmt" + "strconv" + "strings" "github.com/doug-martin/goqu/v9" "github.com/doug-martin/goqu/v9/exp" @@ -23,6 +25,18 @@ func (r fingerprintQueryRow) valid() bool { return r.Type.Valid } +func (r *fingerprintQueryRow) correct() { + if !r.Type.Valid || strings.ToLower(r.Type.String) != "phash" { + return + } + + if val, ok := r.Fingerprint.(string); ok { + if i, err := strconv.ParseInt(val, 10, 64); err == nil { + r.Fingerprint = i + } + } +} + func (r *fingerprintQueryRow) resolve() models.Fingerprint { return models.Fingerprint{ Type: r.Type.String, diff --git a/pkg/sqlite/folder.go b/pkg/sqlite/folder.go index 4cf632d49e6..8d04d283a9f 100644 --- a/pkg/sqlite/folder.go +++ b/pkg/sqlite/folder.go @@ -254,7 +254,7 @@ func (qb *FolderStore) allInPaths(q *goqu.SelectDataset, p []string) *goqu.Selec for _, pp := range p { ppWildcard := pp + string(filepath.Separator) + "%" - conds = append(conds, table.Col("path").Eq(pp), table.Col("path").Like(ppWildcard)) + conds = append(conds, table.Col("path").Eq(pp), table.Col("path").ILike(ppWildcard)) } return q.Where( diff --git a/pkg/sqlite/folder_test.go b/pkg/sqlite/folder_test.go index 1d948d06368..c27a89e7322 100644 --- a/pkg/sqlite/folder_test.go +++ b/pkg/sqlite/folder_test.go @@ -65,7 +65,7 @@ func Test_FolderStore_Create(t *testing.T) { }, } - qb := db.Folder + qb := db.GetRepo().Folder for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -165,7 +165,7 @@ func Test_FolderStore_Update(t *testing.T) { }, } - qb := db.Folder + qb := db.GetRepo().Folder for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -224,7 +224,7 @@ func Test_FolderStore_FindByPath(t *testing.T) { }, } - qb := db.Folder + qb := db.GetRepo().Folder for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { diff --git a/pkg/sqlite/gallery.go b/pkg/sqlite/gallery.go index 90e46bebcba..faf607a41fe 100644 --- a/pkg/sqlite/gallery.go +++ b/pkg/sqlite/gallery.go @@ -209,7 +209,7 @@ func (qb *GalleryStore) selectDataset() *goqu.SelectDataset { galleriesFilesJoinTable, goqu.On( galleriesFilesJoinTable.Col(galleryIDColumn).Eq(table.Col(idColumn)), - galleriesFilesJoinTable.Col("primary").Eq(1), + galleriesFilesJoinTable.Col("primary").IsTrue(), ), ).LeftJoin( files, @@ -785,10 +785,7 @@ var gallerySortOptions = sortOptions{ } func (qb *GalleryStore) setGallerySort(query *queryBuilder, findFilter *models.FindFilterType) error { - if findFilter == nil || findFilter.Sort == nil || *findFilter.Sort == "" { - return nil - } - + models.EnsureFindFilterSorted(findFilter) sort := findFilter.GetSort("path") direction := findFilter.GetDirection() @@ -838,20 +835,27 @@ func (qb *GalleryStore) setGallerySort(query *queryBuilder, findFilter *models.F addFileTable() addFolderTable() query.sortAndPagination += fmt.Sprintf(" ORDER BY COALESCE(folders.path, '') || COALESCE(file_folder.path, '') || COALESCE(files.basename, '') COLLATE NATURAL_CI %s", direction) + query.addGroupBy("folders.path", "file_folder.path", "files.basename") case "file_mod_time": sort = "mod_time" addFileTable() - query.sortAndPagination += getSort(sort, direction, fileTable) + add, agg := getSort(sort, direction, fileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "title": addFileTable() addFolderTable() query.sortAndPagination += " ORDER BY COALESCE(galleries.title, files.basename, basename(COALESCE(folders.path, ''))) COLLATE NATURAL_CI " + direction + ", file_folder.path COLLATE NATURAL_CI " + direction + query.addGroupBy("galleries.title", "files.basename", "folders.path", "file_folder.path") default: - query.sortAndPagination += getSort(sort, direction, "galleries") + add, agg := getSort(sort, direction, "galleries") + query.sortAndPagination += add + query.addGroupBy(agg...) } // Whatever the sorting, always use title/id as a final sort - query.sortAndPagination += ", COALESCE(galleries.title, galleries.id) COLLATE NATURAL_CI ASC" + query.sortAndPagination += ", COALESCE(galleries.title, CAST(galleries.id as text)) COLLATE NATURAL_CI ASC" + query.addGroupBy("galleries.title", "galleries.id") return nil } diff --git a/pkg/sqlite/gallery_chapter_test.go b/pkg/sqlite/gallery_chapter_test.go index 4c71ae6b5a4..13ec2aa3fec 100644 --- a/pkg/sqlite/gallery_chapter_test.go +++ b/pkg/sqlite/gallery_chapter_test.go @@ -12,7 +12,7 @@ import ( func TestChapterFindByGalleryID(t *testing.T) { withTxn(func(ctx context.Context) error { - mqb := db.GalleryChapter + mqb := db.GetRepo().GalleryChapter galleryID := galleryIDs[galleryIdxWithChapters] chapters, err := mqb.FindByGalleryID(ctx, galleryID) diff --git a/pkg/sqlite/gallery_filter.go b/pkg/sqlite/gallery_filter.go index 18718c511d9..84a9e7c2f1e 100644 --- a/pkg/sqlite/gallery_filter.go +++ b/pkg/sqlite/gallery_filter.go @@ -201,13 +201,13 @@ func (qb *galleryFilterHandler) pathCriterionHandler(c *models.StringCriterionIn case models.CriterionModifierEquals: addWildcards = false clause := getPathSearchClause(pathColumn, basenameColumn, c.Value, addWildcards, not) - clause2 := makeClause(folderPathColumn+" LIKE ?", c.Value) + clause2 := makeClause("LOWER("+folderPathColumn+") LIKE LOWER(?)", c.Value) f.whereClauses = append(f.whereClauses, orClauses(clause, clause2)) case models.CriterionModifierNotEquals: addWildcards = false not = true clause := getPathSearchClause(pathColumn, basenameColumn, c.Value, addWildcards, not) - clause2 := makeClause(folderPathColumn+" NOT LIKE ?", c.Value) + clause2 := makeClause("LOWER("+folderPathColumn+") NOT LIKE LOWER(?)", c.Value) f.whereClauses = append(f.whereClauses, orClauses(clause, clause2)) case models.CriterionModifierMatchesRegex: if _, err := regexp.Compile(c.Value); err != nil { @@ -215,8 +215,8 @@ func (qb *galleryFilterHandler) pathCriterionHandler(c *models.StringCriterionIn return } filepathColumn := fmt.Sprintf("%s || '%s' || %s", pathColumn, string(filepath.Separator), basenameColumn) - clause := makeClause(fmt.Sprintf("%s IS NOT NULL AND %s IS NOT NULL AND %s regexp ?", pathColumn, basenameColumn, filepathColumn), c.Value) - clause2 := makeClause(fmt.Sprintf("%s IS NOT NULL AND %[1]s regexp ?", folderPathColumn), c.Value) + clause := makeClause(fmt.Sprintf("%s IS NOT NULL AND %s IS NOT NULL AND regexp(?, %s)", pathColumn, basenameColumn, filepathColumn), c.Value) + clause2 := makeClause(fmt.Sprintf("%s IS NOT NULL AND regexp(?, %[1]s)", folderPathColumn), c.Value) f.whereClauses = append(f.whereClauses, orClauses(clause, clause2)) case models.CriterionModifierNotMatchesRegex: if _, err := regexp.Compile(c.Value); err != nil { @@ -224,8 +224,8 @@ func (qb *galleryFilterHandler) pathCriterionHandler(c *models.StringCriterionIn return } filepathColumn := fmt.Sprintf("%s || '%s' || %s", pathColumn, string(filepath.Separator), basenameColumn) - f.addWhere(fmt.Sprintf("%s IS NULL OR %s IS NULL OR %s NOT regexp ?", pathColumn, basenameColumn, filepathColumn), c.Value) - f.addWhere(fmt.Sprintf("%s IS NULL OR %[1]s NOT regexp ?", folderPathColumn), c.Value) + f.addWhere(fmt.Sprintf("%s IS NULL OR %s IS NULL OR NOT regexp(?, %s)", pathColumn, basenameColumn, filepathColumn), c.Value) + f.addWhere(fmt.Sprintf("%s IS NULL OR NOT regexp(?, %[1]s)", folderPathColumn), c.Value) case models.CriterionModifierIsNull: f.addWhere(fmt.Sprintf("%s IS NULL OR TRIM(%[1]s) = '' OR %s IS NULL OR TRIM(%[2]s) = ''", pathColumn, basenameColumn)) f.addWhere(fmt.Sprintf("%s IS NULL OR TRIM(%[1]s) = ''", folderPathColumn)) @@ -267,12 +267,16 @@ func (qb *galleryFilterHandler) missingCriterionHandler(isMissing *string) crite galleryRepository.performers.join(f, "performers_join", "galleries.id") f.addWhere("performers_join.gallery_id IS NULL") case "date": - f.addWhere("galleries.date IS NULL OR galleries.date IS \"\"") + q := "galleries.date IS NULL" + if dbWrapper.dbType == SqliteBackend { + q += ` OR galleries.date IS ""` + } + f.addWhere(q) case "tags": galleryRepository.tags.join(f, "tags_join", "galleries.id") f.addWhere("tags_join.gallery_id IS NULL") default: - f.addWhere("(galleries." + *isMissing + " IS NULL OR TRIM(galleries." + *isMissing + ") = '')") + f.addWhere("(galleries." + *isMissing + " IS NULL OR TRIM(CAST(galleries." + *isMissing + " AS TEXT)) = '')") } } } @@ -378,12 +382,12 @@ func (qb *galleryFilterHandler) performerFavoriteCriterionHandler(performerfavor if *performerfavorite { // contains at least one favorite f.addLeftJoin("performers", "", "performers.id = performers_galleries.performer_id") - f.addWhere("performers.favorite = 1") + f.addWhere("performers.favorite = " + getDBBoolean(true)) } else { // contains zero favorites f.addLeftJoin(`(SELECT performers_galleries.gallery_id as id FROM performers_galleries JOIN performers ON performers.id = performers_galleries.performer_id -GROUP BY performers_galleries.gallery_id HAVING SUM(performers.favorite) = 0)`, "nofaves", "galleries.id = nofaves.id") +GROUP BY performers_galleries.gallery_id HAVING SUM(performers.favorite) = `+getDBBoolean(false)+")", "nofaves", "galleries.id = nofaves.id") f.addWhere("performers_galleries.gallery_id IS NULL OR nofaves.id IS NOT NULL") } } @@ -399,7 +403,15 @@ func (qb *galleryFilterHandler) performerAgeCriterionHandler(performerAge *model f.addWhere("galleries.date != '' AND performers.birthdate != ''") f.addWhere("galleries.date IS NOT NULL AND performers.birthdate IS NOT NULL") - ageCalc := "cast(strftime('%Y.%m%d', galleries.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + var ageCalc string + + switch dbWrapper.dbType { + case PostgresBackend: + ageCalc = "EXTRACT(YEAR FROM AGE(galleries.date, performers.birthdate))" + case SqliteBackend: + ageCalc = "cast(strftime('%Y.%m%d', galleries.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + } + whereClause, args := getIntWhereClause(ageCalc, performerAge.Modifier, performerAge.Value, performerAge.Value2) f.addWhere(whereClause, args...) } @@ -417,7 +429,7 @@ func (qb *galleryFilterHandler) averageResolutionCriterionHandler(resolution *mo mn := resolution.Value.GetMinResolution() mx := resolution.Value.GetMaxResolution() - const widthHeight = "avg(MIN(image_files.width, image_files.height))" + var widthHeight = "avg(" + getDBMinFunc() + "(image_files.width, image_files.height))" switch resolution.Modifier { case models.CriterionModifierEquals: diff --git a/pkg/sqlite/gallery_test.go b/pkg/sqlite/gallery_test.go index be1edb687ae..b94af647128 100644 --- a/pkg/sqlite/gallery_test.go +++ b/pkg/sqlite/gallery_test.go @@ -6,6 +6,7 @@ package sqlite_test import ( "context" "math" + "sort" "strconv" "testing" "time" @@ -18,27 +19,27 @@ var invalidID = -1 func loadGalleryRelationships(ctx context.Context, expected models.Gallery, actual *models.Gallery) error { if expected.URLs.Loaded() { - if err := actual.LoadURLs(ctx, db.Gallery); err != nil { + if err := actual.LoadURLs(ctx, db.GetRepo().Gallery); err != nil { return err } } if expected.SceneIDs.Loaded() { - if err := actual.LoadSceneIDs(ctx, db.Gallery); err != nil { + if err := actual.LoadSceneIDs(ctx, db.GetRepo().Gallery); err != nil { return err } } if expected.TagIDs.Loaded() { - if err := actual.LoadTagIDs(ctx, db.Gallery); err != nil { + if err := actual.LoadTagIDs(ctx, db.GetRepo().Gallery); err != nil { return err } } if expected.PerformerIDs.Loaded() { - if err := actual.LoadPerformerIDs(ctx, db.Gallery); err != nil { + if err := actual.LoadPerformerIDs(ctx, db.GetRepo().Gallery); err != nil { return err } } if expected.Files.Loaded() { - if err := actual.LoadFiles(ctx, db.Gallery); err != nil { + if err := actual.LoadFiles(ctx, db.GetRepo().Gallery); err != nil { return err } } @@ -54,6 +55,19 @@ func loadGalleryRelationships(ctx context.Context, expected models.Gallery, actu return nil } +func sortGallery(copy *models.Gallery) { + // Ordering is not ensured + copy.SceneIDs.Sort() + copy.PerformerIDs.Sort() + copy.TagIDs.Sort() +} + +func sortByID[T any](list []T, getID func(T) int) { + sort.Slice(list, func(i, j int) bool { + return getID(list[i]) < getID(list[j]) + }) +} + func Test_galleryQueryBuilder_Create(t *testing.T) { var ( title = "title" @@ -148,7 +162,7 @@ func Test_galleryQueryBuilder_Create(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -180,6 +194,10 @@ func Test_galleryQueryBuilder_Create(t *testing.T) { return } + // Ordering is not ensured + sortGallery(©) + sortGallery(&s) + assert.Equal(copy, s) // ensure can find the scene @@ -198,6 +216,9 @@ func Test_galleryQueryBuilder_Create(t *testing.T) { return } + sortGallery(©) + sortGallery(found) + assert.Equal(copy, *found) return @@ -353,7 +374,7 @@ func Test_galleryQueryBuilder_Update(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -380,6 +401,10 @@ func Test_galleryQueryBuilder_Update(t *testing.T) { return } + // Ordering is not ensured + sortGallery(©) + sortGallery(s) + assert.Equal(copy, *s) return @@ -510,7 +535,7 @@ func Test_galleryQueryBuilder_UpdatePartial(t *testing.T) { }, } for _, tt := range tests { - qb := db.Gallery + qb := db.GetRepo().Gallery runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -779,7 +804,7 @@ func Test_galleryQueryBuilder_UpdatePartialRelationships(t *testing.T) { } for _, tt := range tests { - qb := db.Gallery + qb := db.GetRepo().Gallery runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -809,6 +834,11 @@ func Test_galleryQueryBuilder_UpdatePartialRelationships(t *testing.T) { return } + // Ordering is not ensured + sortGallery(s) + sortGallery(got) + sortGallery(&tt.want) + // only compare fields that were in the partial if tt.partial.PerformerIDs != nil { assert.ElementsMatch(tt.want.PerformerIDs.List(), got.PerformerIDs.List()) @@ -844,7 +874,7 @@ func Test_galleryQueryBuilder_Destroy(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -908,7 +938,7 @@ func Test_galleryQueryBuilder_Find(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -971,7 +1001,7 @@ func Test_galleryQueryBuilder_FindMany(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1029,7 +1059,7 @@ func Test_galleryQueryBuilder_FindByChecksum(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1092,7 +1122,7 @@ func Test_galleryQueryBuilder_FindByChecksums(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1108,6 +1138,8 @@ func Test_galleryQueryBuilder_FindByChecksums(t *testing.T) { return } + sortByID(tt.want, func(g *models.Gallery) int { return g.ID }) + sortByID(got, func(g *models.Gallery) int { return g.ID }) assert.Equal(tt.want, got) }) } @@ -1150,7 +1182,7 @@ func Test_galleryQueryBuilder_FindByPath(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1192,7 +1224,7 @@ func Test_galleryQueryBuilder_FindBySceneID(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1208,6 +1240,8 @@ func Test_galleryQueryBuilder_FindBySceneID(t *testing.T) { return } + sortByID(tt.want, func(g *models.Gallery) int { return g.ID }) + sortByID(got, func(g *models.Gallery) int { return g.ID }) assert.Equal(tt.want, got) }) } @@ -1237,7 +1271,7 @@ func Test_galleryQueryBuilder_FindByImageID(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1253,6 +1287,8 @@ func Test_galleryQueryBuilder_FindByImageID(t *testing.T) { return } + sortByID(tt.want, func(g *models.Gallery) int { return g.ID }) + sortByID(got, func(g *models.Gallery) int { return g.ID }) assert.Equal(tt.want, got) }) } @@ -1279,7 +1315,7 @@ func Test_galleryQueryBuilder_CountByImageID(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1325,7 +1361,7 @@ func Test_galleryStore_FindByFileID(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1369,7 +1405,7 @@ func Test_galleryStore_FindByFolderID(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1409,7 +1445,7 @@ func TestGalleryQueryQ(t *testing.T) { } func galleryQueryQ(ctx context.Context, t *testing.T, q string, expectedGalleryIdx int) { - qb := db.Gallery + qb := db.GetRepo().Gallery filter := models.FindFilterType{ Q: &q, @@ -1484,7 +1520,7 @@ func TestGalleryQueryPath(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1511,7 +1547,7 @@ func verifyGalleriesPath(ctx context.Context, t *testing.T, pathCriterion models Path: &pathCriterion, } - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleries, _, err := sqb.Query(ctx, &galleryFilter, nil) if err != nil { t.Errorf("Error querying gallery: %s", err.Error()) @@ -1545,7 +1581,7 @@ func TestGalleryQueryPathOr(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) @@ -1581,7 +1617,7 @@ func TestGalleryQueryPathAndRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) @@ -1621,7 +1657,7 @@ func TestGalleryQueryPathNotRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleries := queryGallery(ctx, t, sqb, &galleryFilter, nil) @@ -1654,7 +1690,7 @@ func TestGalleryIllegalQuery(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery _, _, err := sqb.Query(ctx, galleryFilter, nil) assert.NotNil(err) @@ -1720,7 +1756,7 @@ func TestGalleryQueryURL(t *testing.T) { func verifyGalleryQuery(t *testing.T, filter models.GalleryFilterType, verifyFn func(s *models.Gallery)) { withTxn(func(ctx context.Context) error { t.Helper() - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleries := queryGallery(ctx, t, sqb, &filter, nil) @@ -1768,7 +1804,7 @@ func TestGalleryQueryRating100(t *testing.T) { func verifyGalleriesRating100(t *testing.T, ratingCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleryFilter := models.GalleryFilterType{ Rating100: &ratingCriterion, } @@ -1788,7 +1824,7 @@ func verifyGalleriesRating100(t *testing.T, ratingCriterion models.IntCriterionI func TestGalleryQueryIsMissingScene(t *testing.T) { withTxn(func(ctx context.Context) error { - qb := db.Gallery + qb := db.GetRepo().Gallery isMissing := "scenes" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -1832,7 +1868,7 @@ func queryGallery(ctx context.Context, t *testing.T, sqb models.GalleryReader, g func TestGalleryQueryIsMissingStudio(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery isMissing := "studio" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -1861,7 +1897,7 @@ func TestGalleryQueryIsMissingStudio(t *testing.T) { func TestGalleryQueryIsMissingPerformers(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery isMissing := "performers" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -1892,7 +1928,7 @@ func TestGalleryQueryIsMissingPerformers(t *testing.T) { func TestGalleryQueryIsMissingTags(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery isMissing := "tags" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -1918,7 +1954,7 @@ func TestGalleryQueryIsMissingTags(t *testing.T) { func TestGalleryQueryIsMissingDate(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery isMissing := "date" galleryFilter := models.GalleryFilterType{ IsMissing: &isMissing, @@ -2051,7 +2087,7 @@ func TestGalleryQueryPerformers(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, _, err := db.Gallery.Query(ctx, &models.GalleryFilterType{ + results, _, err := db.GetRepo().Gallery.Query(ctx, &models.GalleryFilterType{ Performers: &tt.filter, }, nil) if (err != nil) != tt.wantErr { @@ -2187,7 +2223,7 @@ func TestGalleryQueryTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, _, err := db.Gallery.Query(ctx, &models.GalleryFilterType{ + results, _, err := db.GetRepo().Gallery.Query(ctx, &models.GalleryFilterType{ Tags: &tt.filter, }, nil) if (err != nil) != tt.wantErr { @@ -2280,7 +2316,7 @@ func TestGalleryQueryStudio(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2306,7 +2342,7 @@ func TestGalleryQueryStudio(t *testing.T) { func TestGalleryQueryStudioDepth(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -2539,7 +2575,7 @@ func TestGalleryQueryPerformerTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, _, err := db.Gallery.Query(ctx, tt.filter, tt.findFilter) + results, _, err := db.GetRepo().Gallery.Query(ctx, tt.filter, tt.findFilter) if (err != nil) != tt.wantErr { t.Errorf("ImageStore.Query() error = %v, wantErr %v", err, tt.wantErr) return @@ -2581,7 +2617,7 @@ func TestGalleryQueryTagCount(t *testing.T) { func verifyGalleriesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleryFilter := models.GalleryFilterType{ TagCount: &tagCountCriterion, } @@ -2622,7 +2658,7 @@ func TestGalleryQueryPerformerCount(t *testing.T) { func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleryFilter := models.GalleryFilterType{ PerformerCount: &performerCountCriterion, } @@ -2645,7 +2681,7 @@ func verifyGalleriesPerformerCount(t *testing.T, performerCountCriterion models. func TestGalleryQueryAverageResolution(t *testing.T) { withTxn(func(ctx context.Context) error { - qb := db.Gallery + qb := db.GetRepo().Gallery resolution := models.ResolutionEnumLow galleryFilter := models.GalleryFilterType{ AverageResolution: &models.ResolutionCriterionInput{ @@ -2683,7 +2719,7 @@ func TestGalleryQueryImageCount(t *testing.T) { func verifyGalleriesImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery galleryFilter := models.GalleryFilterType{ ImageCount: &imageCountCriterion, } @@ -2694,7 +2730,7 @@ func verifyGalleriesImageCount(t *testing.T, imageCountCriterion models.IntCrite for _, gallery := range galleries { pp := 0 - result, err := db.Image.Query(ctx, models.ImageQueryOptions{ + result, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, @@ -2749,7 +2785,7 @@ func TestGalleryQuerySorting(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2835,7 +2871,7 @@ func TestGalleryStore_AddImages(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2914,7 +2950,7 @@ func TestGalleryStore_RemoveImages(t *testing.T) { }, } - qb := db.Gallery + qb := db.GetRepo().Gallery for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2944,7 +2980,7 @@ func TestGalleryStore_RemoveImages(t *testing.T) { func TestGalleryQueryHasChapters(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery hasChapters := "true" galleryFilter := models.GalleryFilterType{ HasChapters: &hasChapters, @@ -2975,25 +3011,25 @@ func TestGalleryQueryHasChapters(t *testing.T) { func TestGallerySetAndResetCover(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Gallery + sqb := db.GetRepo().Gallery imagePath2 := getFilePath(folderIdxWithImageFiles, getImageBasename(imageIdx2WithGallery)) - result, err := db.Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) + result, err := db.GetRepo().Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) assert.Nil(t, err) assert.Nil(t, result) err = sqb.SetCover(ctx, galleryIDs[galleryIdxWithTwoImages], imageIDs[imageIdx2WithGallery]) assert.Nil(t, err) - result, err = db.Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) + result, err = db.GetRepo().Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) assert.Nil(t, err) assert.Equal(t, result.Path, imagePath2) err = sqb.ResetCover(ctx, galleryIDs[galleryIdxWithTwoImages]) assert.Nil(t, err) - result, err = db.Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) + result, err = db.GetRepo().Image.CoverByGalleryID(ctx, galleryIDs[galleryIdxWithTwoImages]) assert.Nil(t, err) assert.Nil(t, result) diff --git a/pkg/sqlite/group.go b/pkg/sqlite/group.go index 5a214f81853..0325a7b5b6d 100644 --- a/pkg/sqlite/group.go +++ b/pkg/sqlite/group.go @@ -513,23 +513,30 @@ func (qb *GroupStore) setGroupSort(query *queryBuilder, findFilter *models.FindF case "sub_group_order": // sub_group_order is a special sort that sorts by the order_index of the subgroups if query.hasJoin("groups_parents") { - query.sortAndPagination += getSort("order_index", direction, "groups_parents") + add, agg := getSort("order_index", direction, "groups_parents") + query.sortAndPagination += add + query.addGroupBy(agg...) } else { // this will give unexpected results if the query is not filtered by a parent group and // the group has multiple parents and order indexes query.join(groupRelationsTable, "", "groups.id = groups_relations.sub_id") - query.sortAndPagination += getSort("order_index", direction, groupRelationsTable) + add, agg := getSort("order_index", direction, groupRelationsTable) + query.sortAndPagination += add + query.addGroupBy(agg...) } case "tag_count": query.sortAndPagination += getCountSort(groupTable, groupsTagsTable, groupIDColumn, direction) case "scenes_count": // generic getSort won't work for this query.sortAndPagination += getCountSort(groupTable, groupsScenesTable, groupIDColumn, direction) default: - query.sortAndPagination += getSort(sort, direction, "groups") + add, agg := getSort(sort, direction, "groups") + query.sortAndPagination += add + query.addGroupBy(agg...) } // Whatever the sorting, always use name/id as a final sort - query.sortAndPagination += ", COALESCE(groups.name, groups.id) COLLATE NATURAL_CI ASC" + query.sortAndPagination += ", COALESCE(groups.name, CAST(groups.id as text)) COLLATE NATURAL_CI ASC" + query.addGroupBy("groups.name", "groups.id") return nil } diff --git a/pkg/sqlite/group_filter.go b/pkg/sqlite/group_filter.go index dcb7bcdfc94..11bc51bd49d 100644 --- a/pkg/sqlite/group_filter.go +++ b/pkg/sqlite/group_filter.go @@ -110,7 +110,7 @@ func (qb *groupFilterHandler) missingCriterionHandler(isMissing *string) criteri f.addLeftJoin("groups_scenes", "", "groups_scenes.group_id = groups.id") f.addWhere("groups_scenes.scene_id IS NULL") default: - f.addWhere("(groups." + *isMissing + " IS NULL OR TRIM(groups." + *isMissing + ") = '')") + f.addWhere("(groups." + *isMissing + " IS NULL OR TRIM(CAST(groups." + *isMissing + " AS TEXT)) = '')") } } } diff --git a/pkg/sqlite/group_test.go b/pkg/sqlite/group_test.go index d4a177e86cf..556eaf800c4 100644 --- a/pkg/sqlite/group_test.go +++ b/pkg/sqlite/group_test.go @@ -21,22 +21,22 @@ import ( func loadGroupRelationships(ctx context.Context, expected models.Group, actual *models.Group) error { if expected.URLs.Loaded() { - if err := actual.LoadURLs(ctx, db.Group); err != nil { + if err := actual.LoadURLs(ctx, db.GetRepo().Group); err != nil { return err } } if expected.TagIDs.Loaded() { - if err := actual.LoadTagIDs(ctx, db.Group); err != nil { + if err := actual.LoadTagIDs(ctx, db.GetRepo().Group); err != nil { return err } } if expected.ContainingGroups.Loaded() { - if err := actual.LoadContainingGroupIDs(ctx, db.Group); err != nil { + if err := actual.LoadContainingGroupIDs(ctx, db.GetRepo().Group); err != nil { return err } } if expected.SubGroups.Loaded() { - if err := actual.LoadSubGroupIDs(ctx, db.Group); err != nil { + if err := actual.LoadSubGroupIDs(ctx, db.GetRepo().Group); err != nil { return err } } @@ -115,7 +115,7 @@ func Test_GroupStore_Create(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -277,7 +277,7 @@ func Test_groupQueryBuilder_Update(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -527,7 +527,7 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { }, } for _, tt := range tests { - qb := db.Group + qb := db.GetRepo().Group runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -568,7 +568,7 @@ func Test_groupQueryBuilder_UpdatePartial(t *testing.T) { func TestGroupFindByName(t *testing.T) { withTxn(func(ctx context.Context) error { - mqb := db.Group + mqb := db.GetRepo().Group name := groupNames[groupIdxWithScene] // find a group by name @@ -601,7 +601,7 @@ func TestGroupFindByNames(t *testing.T) { withTxn(func(ctx context.Context) error { var names []string - mqb := db.Group + mqb := db.GetRepo().Group names = append(names, groupNames[groupIdxWithScene]) // find groups by names @@ -675,7 +675,7 @@ func TestGroupQuery(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, _, err := db.Group.Query(ctx, tt.filter, tt.findFilter) + results, _, err := db.GetRepo().Group.Query(ctx, tt.filter, tt.findFilter) if (err != nil) != tt.wantErr { t.Errorf("GroupQueryBuilder.Query() error = %v, wantErr %v", err, tt.wantErr) return @@ -697,7 +697,7 @@ func TestGroupQuery(t *testing.T) { func TestGroupQueryStudio(t *testing.T) { withTxn(func(ctx context.Context) error { - mqb := db.Group + mqb := db.GetRepo().Group studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithGroup]), @@ -788,7 +788,7 @@ func TestGroupQueryURL(t *testing.T) { func TestGroupQueryURLExcludes(t *testing.T) { withRollbackTxn(func(ctx context.Context) error { - mqb := db.Group + mqb := db.GetRepo().Group // create group with two URLs group := models.Group{ @@ -839,7 +839,7 @@ func TestGroupQueryURLExcludes(t *testing.T) { func verifyGroupQuery(t *testing.T, filter models.GroupFilterType, verifyFn func(s *models.Group)) { withTxn(func(ctx context.Context) error { t.Helper() - sqb := db.Group + sqb := db.GetRepo().Group groups := queryGroups(ctx, t, &filter, nil) @@ -861,7 +861,7 @@ func verifyGroupQuery(t *testing.T, filter models.GroupFilterType, verifyFn func } func queryGroups(ctx context.Context, t *testing.T, groupFilter *models.GroupFilterType, findFilter *models.FindFilterType) []*models.Group { - sqb := db.Group + sqb := db.GetRepo().Group groups, _, err := sqb.Query(ctx, groupFilter, findFilter) if err != nil { t.Errorf("Error querying group: %s", err.Error()) @@ -946,7 +946,7 @@ func TestGroupQueryTagCount(t *testing.T) { func verifyGroupsTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Group + sqb := db.GetRepo().Group groupFilter := models.GroupFilterType{ TagCount: &tagCountCriterion, } @@ -1011,12 +1011,12 @@ func TestGroupQuerySortOrderIndex(t *testing.T) { withTxn(func(ctx context.Context) error { // just ensure there are no errors - _, _, err := db.Group.Query(ctx, &groupFilter, &findFilter) + _, _, err := db.GetRepo().Group.Query(ctx, &groupFilter, &findFilter) if err != nil { t.Errorf("Error querying group: %s", err.Error()) } - _, _, err = db.Group.Query(ctx, nil, &findFilter) + _, _, err = db.GetRepo().Group.Query(ctx, nil, &findFilter) if err != nil { t.Errorf("Error querying group: %s", err.Error()) } @@ -1027,7 +1027,7 @@ func TestGroupQuerySortOrderIndex(t *testing.T) { func TestGroupUpdateFrontImage(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Group + qb := db.GetRepo().Group // create group to test against const name = "TestGroupUpdateGroupImages" @@ -1047,7 +1047,7 @@ func TestGroupUpdateFrontImage(t *testing.T) { func TestGroupUpdateBackImage(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Group + qb := db.GetRepo().Group // create group to test against const name = "TestGroupUpdateGroupImages" @@ -1142,7 +1142,7 @@ func TestGroupQueryContainingGroups(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) @@ -1255,7 +1255,7 @@ func TestGroupQuerySubGroups(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { valueIDs := indexesToIDs(groupIDs, tt.c.valueIdxs) @@ -1331,7 +1331,7 @@ func TestGroupQueryContainingGroupCount(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) @@ -1402,7 +1402,7 @@ func TestGroupQuerySubGroupCount(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { expectedIDs := indexesToIDs(groupIDs, tt.expectedIdxs) @@ -1460,7 +1460,7 @@ func TestGroupFindInAncestors(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { ancestorIDs := indexesToIDs(groupIDs, tt.ancestorIdxs) @@ -1556,7 +1556,7 @@ func TestGroupReorderSubGroups(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1666,7 +1666,7 @@ func TestGroupAddSubGroups(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1781,7 +1781,7 @@ func TestGroupRemoveSubGroups(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1869,7 +1869,7 @@ func TestGroupFindSubGroupIDs(t *testing.T) { }, } - qb := db.Group + qb := db.GetRepo().Group for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { diff --git a/pkg/sqlite/image.go b/pkg/sqlite/image.go index 4bc28fad820..7031bf628bd 100644 --- a/pkg/sqlite/image.go +++ b/pkg/sqlite/image.go @@ -211,7 +211,7 @@ func (qb *ImageStore) selectDataset() *goqu.SelectDataset { imagesFilesJoinTable, goqu.On( imagesFilesJoinTable.Col(imageIDColumn).Eq(table.Col(idColumn)), - imagesFilesJoinTable.Col("primary").Eq(1), + imagesFilesJoinTable.Col("primary").IsTrue(), ), ).LeftJoin( files, @@ -493,7 +493,7 @@ func (qb *ImageStore) CoverByGalleryID(ctx context.Context, galleryID int) (*mod Select(table.Col(idColumn)). Where(goqu.And( galleriesImagesJoinTable.Col("gallery_id").Eq(galleryID), - galleriesImagesJoinTable.Col("cover").Eq(true), + galleriesImagesJoinTable.Col("cover").IsTrue(), )) q := qb.selectDataset().Prepared(true).Where( @@ -608,7 +608,7 @@ func (qb *ImageStore) FindByChecksum(ctx context.Context, checksum string) ([]*m var defaultGalleryOrder = []exp.OrderedExpression{ goqu.L("COALESCE(folders.path, '') || COALESCE(files.basename, '') COLLATE NATURAL_CI").Asc(), - goqu.L("COALESCE(images.title, images.id) COLLATE NATURAL_CI").Asc(), + goqu.L("COALESCE(images.title, cast(images.id as text)) COLLATE NATURAL_CI").Asc(), } func (qb *ImageStore) FindByGalleryID(ctx context.Context, galleryID int) ([]*models.Image, error) { @@ -869,6 +869,7 @@ func (qb *ImageStore) queryGroupedFields(ctx context.Context, options models.Ima ) query.addColumn("COALESCE(image_files.width, 0) * COALESCE(image_files.height, 0) as megapixels") aggregateQuery.addColumn("COALESCE(SUM(temp.megapixels), 0) / 1000000 as megapixels") + query.addGroupBy("image_files.width", "image_files.height") } if options.TotalSize { @@ -884,6 +885,7 @@ func (qb *ImageStore) queryGroupedFields(ctx context.Context, options models.Ima ) query.addColumn("COALESCE(files.size, 0) as size") aggregateQuery.addColumn("SUM(temp.size) as size") + query.addGroupBy("files.size") } const includeSortPagination = false @@ -932,68 +934,74 @@ var imageSortOptions = sortOptions{ } func (qb *ImageStore) setImageSortAndPagination(q *queryBuilder, findFilter *models.FindFilterType) error { + models.EnsureFindFilterSorted(findFilter) sortClause := "" - if findFilter != nil && findFilter.Sort != nil && *findFilter.Sort != "" { - sort := findFilter.GetSort("title") - direction := findFilter.GetDirection() + sort := findFilter.GetSort("title") + direction := findFilter.GetDirection() - // CVE-2024-32231 - ensure sort is in the list of allowed sorts - if err := imageSortOptions.validateSort(sort); err != nil { - return err - } - - // translate sort field - if sort == "file_mod_time" { - sort = "mod_time" - } + // CVE-2024-32231 - ensure sort is in the list of allowed sorts + if err := imageSortOptions.validateSort(sort); err != nil { + return err + } - addFilesJoin := func() { - q.addJoins( - join{ - table: imagesFilesTable, - onClause: "images_files.image_id = images.id", - }, - join{ - table: fileTable, - onClause: "images_files.file_id = files.id", - }, - ) - } + // translate sort field + if sort == "file_mod_time" { + sort = "mod_time" + } - addFolderJoin := func() { - q.addJoins(join{ - table: folderTable, - onClause: "files.parent_folder_id = folders.id", - }) - } + addFilesJoin := func() { + q.addJoins( + join{ + table: imagesFilesTable, + onClause: "images_files.image_id = images.id", + }, + join{ + table: fileTable, + onClause: "images_files.file_id = files.id", + }, + ) + } - switch sort { - case "path": - addFilesJoin() - addFolderJoin() - sortClause = " ORDER BY COALESCE(folders.path, '') || COALESCE(files.basename, '') COLLATE NATURAL_CI " + direction - case "file_count": - sortClause = getCountSort(imageTable, imagesFilesTable, imageIDColumn, direction) - case "tag_count": - sortClause = getCountSort(imageTable, imagesTagsTable, imageIDColumn, direction) - case "performer_count": - sortClause = getCountSort(imageTable, performersImagesTable, imageIDColumn, direction) - case "mod_time", "filesize": - addFilesJoin() - sortClause = getSort(sort, direction, "files") - case "title": - addFilesJoin() - addFolderJoin() - sortClause = " ORDER BY COALESCE(images.title, files.basename) COLLATE NATURAL_CI " + direction + ", folders.path COLLATE NATURAL_CI " + direction - default: - sortClause = getSort(sort, direction, "images") - } + addFolderJoin := func() { + q.addJoins(join{ + table: folderTable, + onClause: "files.parent_folder_id = folders.id", + }) + } - // Whatever the sorting, always use title/id as a final sort - sortClause += ", COALESCE(images.title, images.id) COLLATE NATURAL_CI ASC" + switch sort { + case "path": + addFilesJoin() + addFolderJoin() + sortClause = " ORDER BY COALESCE(folders.path, '') || COALESCE(files.basename, '') COLLATE NATURAL_CI " + direction + q.addGroupBy("folders.path", "files.basename") + case "file_count": + sortClause = getCountSort(imageTable, imagesFilesTable, imageIDColumn, direction) + case "tag_count": + sortClause = getCountSort(imageTable, imagesTagsTable, imageIDColumn, direction) + case "performer_count": + sortClause = getCountSort(imageTable, performersImagesTable, imageIDColumn, direction) + case "mod_time", "filesize": + addFilesJoin() + add, agg := getSort(sort, direction, "files") + sortClause = add + q.addGroupBy(agg...) + case "title": + addFilesJoin() + addFolderJoin() + sortClause = " ORDER BY COALESCE(images.title, files.basename) COLLATE NATURAL_CI " + direction + ", folders.path COLLATE NATURAL_CI " + direction + q.addGroupBy("images.title", "files.basename", "folders.path") + default: + add, agg := getSort(sort, direction, "images") + sortClause = add + q.addGroupBy(agg...) } + // Whatever the sorting, always use title/id as a final sort + sortClause += ", COALESCE(images.title, CAST(images.id as text)) COLLATE NATURAL_CI ASC" + q.addGroupBy("images.title", "images.id") + q.sortAndPagination = sortClause + getPagination(findFilter) return nil diff --git a/pkg/sqlite/image_filter.go b/pkg/sqlite/image_filter.go index 8f2d5d6b90a..e440e39fa13 100644 --- a/pkg/sqlite/image_filter.go +++ b/pkg/sqlite/image_filter.go @@ -152,7 +152,7 @@ func (qb *imageFilterHandler) missingCriterionHandler(isMissing *string) criteri imageRepository.tags.join(f, "tags_join", "images.id") f.addWhere("tags_join.image_id IS NULL") default: - f.addWhere("(images." + *isMissing + " IS NULL OR TRIM(images." + *isMissing + ") = '')") + f.addWhere("(images." + *isMissing + " IS NULL OR TRIM(CAST(images." + *isMissing + " AS TEXT)) = '')") } } } @@ -254,12 +254,12 @@ func (qb *imageFilterHandler) performerFavoriteCriterionHandler(performerfavorit if *performerfavorite { // contains at least one favorite f.addLeftJoin("performers", "", "performers.id = performers_images.performer_id") - f.addWhere("performers.favorite = 1") + f.addWhere("performers.favorite = " + getDBBoolean(true)) } else { // contains zero favorites f.addLeftJoin(`(SELECT performers_images.image_id as id FROM performers_images JOIN performers ON performers.id = performers_images.performer_id -GROUP BY performers_images.image_id HAVING SUM(performers.favorite) = 0)`, "nofaves", "images.id = nofaves.id") +GROUP BY performers_images.image_id HAVING SUM(performers.favorite) = `+getDBBoolean(false)+")", "nofaves", "images.id = nofaves.id") f.addWhere("performers_images.image_id IS NULL OR nofaves.id IS NOT NULL") } } @@ -275,7 +275,15 @@ func (qb *imageFilterHandler) performerAgeCriterionHandler(performerAge *models. f.addWhere("images.date != '' AND performers.birthdate != ''") f.addWhere("images.date IS NOT NULL AND performers.birthdate IS NOT NULL") - ageCalc := "cast(strftime('%Y.%m%d', images.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + var ageCalc string + + switch dbWrapper.dbType { + case PostgresBackend: + ageCalc = "EXTRACT(YEAR FROM AGE(images.date, performers.birthdate))" + case SqliteBackend: + ageCalc = "cast(strftime('%Y.%m%d', images.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + } + whereClause, args := getIntWhereClause(ageCalc, performerAge.Modifier, performerAge.Value, performerAge.Value2) f.addWhere(whereClause, args...) } diff --git a/pkg/sqlite/image_test.go b/pkg/sqlite/image_test.go index aa4ed3b99ad..a823f8fc146 100644 --- a/pkg/sqlite/image_test.go +++ b/pkg/sqlite/image_test.go @@ -16,27 +16,27 @@ import ( func loadImageRelationships(ctx context.Context, expected models.Image, actual *models.Image) error { if expected.URLs.Loaded() { - if err := actual.LoadURLs(ctx, db.Image); err != nil { + if err := actual.LoadURLs(ctx, db.GetRepo().Image); err != nil { return err } } if expected.GalleryIDs.Loaded() { - if err := actual.LoadGalleryIDs(ctx, db.Image); err != nil { + if err := actual.LoadGalleryIDs(ctx, db.GetRepo().Image); err != nil { return err } } if expected.TagIDs.Loaded() { - if err := actual.LoadTagIDs(ctx, db.Image); err != nil { + if err := actual.LoadTagIDs(ctx, db.GetRepo().Image); err != nil { return err } } if expected.PerformerIDs.Loaded() { - if err := actual.LoadPerformerIDs(ctx, db.Image); err != nil { + if err := actual.LoadPerformerIDs(ctx, db.GetRepo().Image); err != nil { return err } } if expected.Files.Loaded() { - if err := actual.LoadFiles(ctx, db.Image); err != nil { + if err := actual.LoadFiles(ctx, db.GetRepo().Image); err != nil { return err } } @@ -153,7 +153,7 @@ func Test_imageQueryBuilder_Create(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -360,7 +360,7 @@ func Test_imageQueryBuilder_Update(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -511,7 +511,7 @@ func Test_imageQueryBuilder_UpdatePartial(t *testing.T) { }, } for _, tt := range tests { - qb := db.Image + qb := db.GetRepo().Image runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -785,7 +785,7 @@ func Test_imageQueryBuilder_UpdatePartialRelationships(t *testing.T) { } for _, tt := range tests { - qb := db.Image + qb := db.GetRepo().Image runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -853,7 +853,7 @@ func Test_imageQueryBuilder_IncrementOCounter(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -896,7 +896,7 @@ func Test_imageQueryBuilder_DecrementOCounter(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -939,7 +939,7 @@ func Test_imageQueryBuilder_ResetOCounter(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -973,7 +973,7 @@ func Test_imageQueryBuilder_Destroy(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1034,7 +1034,7 @@ func Test_imageQueryBuilder_Find(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1097,7 +1097,7 @@ func Test_imageQueryBuilder_FindMany(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1156,7 +1156,7 @@ func Test_imageQueryBuilder_FindByChecksum(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1234,7 +1234,7 @@ func Test_imageQueryBuilder_FindByFingerprints(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1276,7 +1276,7 @@ func Test_imageQueryBuilder_FindByGalleryID(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1319,7 +1319,7 @@ func Test_imageQueryBuilder_CountByGalleryID(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1365,7 +1365,7 @@ func Test_imageStore_FindByFileID(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1420,7 +1420,7 @@ func Test_imageStore_FindByFolderID(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1469,7 +1469,7 @@ func Test_imageStore_FindByZipFileID(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1503,7 +1503,7 @@ func TestImageQueryQ(t *testing.T) { q := getImageStringValue(imageIdx, titleField) - sqb := db.Image + sqb := db.GetRepo().Image imageQueryQ(ctx, t, sqb, q, imageIdx) @@ -1558,7 +1558,7 @@ func verifyImageQuery(t *testing.T, filter models.ImageFilterType, verifyFn func t.Helper() withTxn(func(ctx context.Context) error { t.Helper() - sqb := db.Image + sqb := db.GetRepo().Image images := queryImages(ctx, t, sqb, &filter, nil) @@ -1587,7 +1587,7 @@ func TestImageQueryURL(t *testing.T) { verifyFn := func(ctx context.Context, o *models.Image) { t.Helper() - if err := o.LoadURLs(ctx, db.Image); err != nil { + if err := o.LoadURLs(ctx, db.GetRepo().Image); err != nil { t.Errorf("Error loading scene URLs: %v", err) } @@ -1639,7 +1639,7 @@ func TestImageQueryPath(t *testing.T) { func verifyImagePath(t *testing.T, pathCriterion models.StringCriterionInput, expected int) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ Path: &pathCriterion, } @@ -1679,7 +1679,7 @@ func TestImageQueryPathOr(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image images := queryImages(ctx, t, sqb, &imageFilter, nil) @@ -1715,7 +1715,7 @@ func TestImageQueryPathAndRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image images := queryImages(ctx, t, sqb, &imageFilter, nil) @@ -1755,7 +1755,7 @@ func TestImageQueryPathNotRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image images := queryImages(ctx, t, sqb, &imageFilter, nil) @@ -1788,7 +1788,7 @@ func TestImageIllegalQuery(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image _, _, err := queryImagesWithCount(ctx, sqb, imageFilter, nil) assert.NotNil(err) @@ -1834,7 +1834,7 @@ func TestImageQueryRating100(t *testing.T) { func verifyImagesRating100(t *testing.T, ratingCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ Rating100: &ratingCriterion, } @@ -1873,7 +1873,7 @@ func TestImageQueryOCounter(t *testing.T) { func verifyImagesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ OCounter: &oCounterCriterion, } @@ -1902,7 +1902,7 @@ func TestImageQueryResolution(t *testing.T) { func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ Resolution: &models.ResolutionCriterionInput{ Value: resolution, @@ -1916,7 +1916,7 @@ func verifyImagesResolution(t *testing.T, resolution models.ResolutionEnum) { } for _, image := range images { - if err := image.LoadPrimaryFile(ctx, db.File); err != nil { + if err := image.LoadPrimaryFile(ctx, db.GetRepo().File); err != nil { t.Errorf("Error loading primary file: %s", err.Error()) return nil } @@ -1955,7 +1955,7 @@ func verifyImageResolution(t *testing.T, height int, resolution models.Resolutio func TestImageQueryIsMissingGalleries(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image isMissing := "galleries" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -1992,7 +1992,7 @@ func TestImageQueryIsMissingGalleries(t *testing.T) { func TestImageQueryIsMissingStudio(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image isMissing := "studio" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -2027,7 +2027,7 @@ func TestImageQueryIsMissingStudio(t *testing.T) { func TestImageQueryIsMissingPerformers(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image isMissing := "performers" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -2064,7 +2064,7 @@ func TestImageQueryIsMissingPerformers(t *testing.T) { func TestImageQueryIsMissingTags(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image isMissing := "tags" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -2096,7 +2096,7 @@ func TestImageQueryIsMissingTags(t *testing.T) { func TestImageQueryIsMissingRating(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image isMissing := "rating" imageFilter := models.ImageFilterType{ IsMissing: &isMissing, @@ -2120,7 +2120,7 @@ func TestImageQueryIsMissingRating(t *testing.T) { func TestImageQueryGallery(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image galleryCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(galleryIDs[galleryIdxWithImage]), @@ -2289,7 +2289,7 @@ func TestImageQueryPerformers(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Image.Query(ctx, models.ImageQueryOptions{ + results, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ ImageFilter: &models.ImageFilterType{ Performers: &tt.filter, }, @@ -2425,7 +2425,7 @@ func TestImageQueryTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Image.Query(ctx, models.ImageQueryOptions{ + results, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ ImageFilter: &models.ImageFilterType{ Tags: &tt.filter, }, @@ -2518,7 +2518,7 @@ func TestImageQueryStudio(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2544,7 +2544,7 @@ func TestImageQueryStudio(t *testing.T) { func TestImageQueryStudioDepth(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -2786,7 +2786,7 @@ func TestImageQueryPerformerTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Image.Query(ctx, models.ImageQueryOptions{ + results, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ ImageFilter: tt.filter, QueryOptions: models.QueryOptions{ FindFilter: tt.findFilter, @@ -2831,7 +2831,7 @@ func TestImageQueryTagCount(t *testing.T) { func verifyImagesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ TagCount: &tagCountCriterion, } @@ -2872,7 +2872,7 @@ func TestImageQueryPerformerCount(t *testing.T) { func verifyImagesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Image + sqb := db.GetRepo().Image imageFilter := models.ImageFilterType{ PerformerCount: &performerCountCriterion, } @@ -2930,7 +2930,7 @@ func TestImageQuerySorting(t *testing.T) { }, } - qb := db.Image + qb := db.GetRepo().Image for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2982,7 +2982,7 @@ func TestImageQueryPagination(t *testing.T) { PerPage: &perPage, } - sqb := db.Image + sqb := db.GetRepo().Image images, _, err := queryImagesWithCount(ctx, sqb, nil, &findFilter) if err != nil { t.Errorf("Error querying image: %s", err.Error()) diff --git a/pkg/sqlite/migrate.go b/pkg/sqlite/migrate.go index ba47544588d..c3ed895ada9 100644 --- a/pkg/sqlite/migrate.go +++ b/pkg/sqlite/migrate.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/golang-migrate/migrate/v4" + postgresmig "github.com/golang-migrate/migrate/v4/database/postgres" sqlite3mig "github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/source/iofs" "github.com/jmoiron/sqlx" @@ -12,16 +13,16 @@ import ( ) func (db *Database) needsMigration() bool { - return db.schemaVersion != appSchemaVersion + return db.schemaVersion != db.AppSchemaVersion() } type Migrator struct { - db *Database + db DBInterface conn *sqlx.DB m *migrate.Migrate } -func NewMigrator(db *Database) (*Migrator, error) { +func NewMigrator(db DBInterface) (*Migrator, error) { m := &Migrator{ db: db, } @@ -55,16 +56,25 @@ func (m *Migrator) CurrentSchemaVersion() uint { } func (m *Migrator) RequiredSchemaVersion() uint { - return appSchemaVersion + return m.db.AppSchemaVersion() } func (m *Migrator) getMigrate() (*migrate.Migrate, error) { + if m.db.DatabaseType() == PostgresBackend { + return m._getMigratePostgres() + } + + return m._getMigrateSqlite() +} + +func (m *Migrator) _getMigrateSqlite() (*migrate.Migrate, error) { migrations, err := iofs.New(migrationsBox, "migrations") if err != nil { return nil, err } driver, err := sqlite3mig.WithInstance(m.conn.DB, &sqlite3mig.Config{}) + if err != nil { return nil, err } @@ -73,7 +83,27 @@ func (m *Migrator) getMigrate() (*migrate.Migrate, error) { return migrate.NewWithInstance( "iofs", migrations, - m.db.dbPath, + m.db.DatabasePath(), + driver, + ) +} + +func (m *Migrator) _getMigratePostgres() (*migrate.Migrate, error) { + migrations, err := iofs.New(migrationsBox, "migrationsPostgres") + if err != nil { + return nil, err + } + + driver, err := postgresmig.WithInstance(m.conn.DB, &postgresmig.Config{}) + + if err != nil { + return nil, err + } + + return migrate.NewWithInstance( + "iofs", + migrations, + "postgres", driver, ) } @@ -101,7 +131,8 @@ func (m *Migrator) RunMigration(ctx context.Context, newVersion uint) error { } // update the schema version - m.db.schemaVersion, _, _ = m.m.Version() + schemaVersion, _, _ := m.m.Version() + m.db.SetSchemaVersion(schemaVersion) return nil } @@ -150,9 +181,9 @@ func (db *Database) RunAllMigrations() error { defer m.Close() databaseSchemaVersion, _, _ := m.m.Version() - stepNumber := appSchemaVersion - databaseSchemaVersion + stepNumber := db.AppSchemaVersion() - databaseSchemaVersion if stepNumber != 0 { - logger.Infof("Migrating database from version %d to %d", databaseSchemaVersion, appSchemaVersion) + logger.Infof("Migrating database from version %d to %d", databaseSchemaVersion, db.AppSchemaVersion()) // run each migration individually, and run custom migrations as needed var i uint = 1 diff --git a/pkg/sqlite/migrationsPostgres/1_initial.up.sql b/pkg/sqlite/migrationsPostgres/1_initial.up.sql new file mode 100644 index 00000000000..24191c0a096 --- /dev/null +++ b/pkg/sqlite/migrationsPostgres/1_initial.up.sql @@ -0,0 +1,480 @@ +CREATE COLLATION IF NOT EXISTS NATURAL_CI (provider = icu, locale = 'en@colNumeric=yes'); +CREATE COLLATION IF NOT EXISTS NOCASE (provider = icu, locale = 'und-u-ks-level2', deterministic = false); +CREATE TABLE blobs ( + checksum varchar(255) NOT NULL PRIMARY KEY, + blob bytea +); +CREATE TABLE tags ( + id serial not null primary key, + name text, + created_at timestamp not null, + updated_at timestamp not null, + ignore_auto_tag boolean not null default FALSE, + description text, + image_blob varchar(255) REFERENCES blobs(checksum), + favorite boolean not null default false +); +CREATE TABLE folders ( + id serial not null primary key, + path text NOT NULL, + parent_folder_id integer, + mod_time timestamp not null, + created_at timestamp not null, + updated_at timestamp not null, + foreign key(parent_folder_id) references folders(id) on delete SET NULL +); +CREATE TABLE files ( + id serial not null primary key, + basename varchar(255) NOT NULL, + zip_file_id integer, + parent_folder_id integer not null, + size bigint NOT NULL, + mod_time timestamp not null, + created_at timestamp not null, + updated_at timestamp not null, + foreign key(zip_file_id) references files(id), + foreign key(parent_folder_id) references folders(id), + CHECK (basename != '') +); +ALTER TABLE folders ADD COLUMN zip_file_id integer REFERENCES files(id); +CREATE TABLE IF NOT EXISTS performers ( + id serial not null primary key, + name text not null, + disambiguation text, + gender varchar(20), + birthdate date, + ethnicity text, + country text, + eye_color text, + height int, + measurements text, + fake_tits text, + career_length text, + tattoos text, + piercings text, + favorite boolean not null default FALSE, + created_at timestamp not null, + updated_at timestamp not null, + details text, + death_date date, + hair_color text, + weight integer, + rating smallint, + ignore_auto_tag boolean not null default FALSE, + image_blob varchar(255) REFERENCES blobs(checksum), + penis_length float, + circumcised text +); +CREATE TABLE IF NOT EXISTS studios ( + id serial not null primary key, + name text NOT NULL, + url VARCHAR(2048), + parent_id INTEGER DEFAULT NULL REFERENCES studios(id) ON DELETE SET NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL, + details TEXT, + rating smallint, + ignore_auto_tag BOOLEAN NOT NULL DEFAULT FALSE, + image_blob VARCHAR(255) REFERENCES blobs(checksum), + favorite boolean not null default FALSE, + CHECK (id != parent_id) +); +CREATE TABLE IF NOT EXISTS saved_filters ( + id serial not null primary key, + name text not null, + mode varchar(255) not null, + find_filter bytea, + object_filter bytea, + ui_options bytea +); +CREATE TABLE IF NOT EXISTS images ( + id serial not null primary key, + title text, + rating smallint, + studio_id integer, + o_counter smallint not null default 0, + organized boolean not null default FALSE, + created_at timestamp not null, + updated_at timestamp not null, + date date, + code text, + photographer text, + details text, + foreign key(studio_id) references studios(id) on delete SET NULL +); +CREATE TABLE image_urls ( + image_id integer NOT NULL, + position integer NOT NULL, + url varchar(2048) NOT NULL, + foreign key(image_id) references images(id) on delete CASCADE, + PRIMARY KEY(image_id, position, url) +); +CREATE TABLE IF NOT EXISTS galleries ( + id serial not null primary key, + folder_id integer, + title text, + date date, + details text, + studio_id integer, + rating smallint, + organized boolean not null default FALSE, + created_at timestamp not null, + updated_at timestamp not null, + code text, + photographer text, + foreign key(studio_id) references studios(id) on delete SET NULL, + foreign key(folder_id) references folders(id) on delete SET NULL +); +CREATE TABLE gallery_urls ( + gallery_id integer NOT NULL, + position integer NOT NULL, + url varchar(2048) NOT NULL, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + PRIMARY KEY(gallery_id, position, url) +); +CREATE TABLE IF NOT EXISTS scenes ( + id serial not null primary key, + title text, + details text, + date date, + rating smallint, + studio_id integer, + organized boolean not null default FALSE, + created_at timestamp not null, + updated_at timestamp not null, + code text, + director text, + resume_time float not null default 0, + play_duration float not null default 0, + cover_blob varchar(255) REFERENCES blobs(checksum), + foreign key(studio_id) references studios(id) on delete SET NULL +); +CREATE TABLE IF NOT EXISTS groups ( + id serial not null primary key, + name text not null, + aliases text, + duration integer, + date date, + rating smallint, + studio_id integer REFERENCES studios(id) ON DELETE SET NULL, + director text, + "description" text, + created_at timestamp not null, + updated_at timestamp not null, + front_image_blob varchar(255) REFERENCES blobs(checksum), + back_image_blob varchar(255) REFERENCES blobs(checksum) +); +CREATE TABLE IF NOT EXISTS group_urls ( + "group_id" integer NOT NULL, + position integer NOT NULL, + url varchar(2048) NOT NULL, + foreign key("group_id") references "groups"(id) on delete CASCADE, + PRIMARY KEY("group_id", position, url) +); +CREATE TABLE IF NOT EXISTS groups_tags ( + "group_id" integer NOT NULL, + tag_id integer NOT NULL, + foreign key("group_id") references "groups"(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY("group_id", tag_id) +); +CREATE TABLE performer_urls ( + performer_id integer NOT NULL, + position integer NOT NULL, + url varchar(2048) NOT NULL, + foreign key(performer_id) references performers(id) on delete CASCADE, + PRIMARY KEY(performer_id, position, url) +); +CREATE TABLE studios_tags ( + studio_id integer NOT NULL, + tag_id integer NOT NULL, + foreign key(studio_id) references studios(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(studio_id, tag_id) +); +CREATE TABLE IF NOT EXISTS scenes_view_dates ( + scene_id integer not null, + view_date timestamp not null, + foreign key(scene_id) references scenes(id) on delete CASCADE +); +CREATE TABLE IF NOT EXISTS scenes_o_dates ( + scene_id integer not null, + o_date timestamp not null, + foreign key(scene_id) references scenes(id) on delete CASCADE +); +CREATE TABLE performer_stash_ids ( + performer_id integer, + endpoint varchar(2048), + stash_id uuid, + foreign key(performer_id) references performers(id) on delete CASCADE +); +CREATE TABLE studio_stash_ids ( + studio_id integer, + endpoint varchar(2048), + stash_id uuid, + foreign key(studio_id) references studios(id) on delete CASCADE +); +CREATE TABLE tags_relations ( + parent_id integer, + child_id integer, + primary key (parent_id, child_id), + foreign key (parent_id) references tags(id) on delete cascade, + foreign key (child_id) references tags(id) on delete cascade +); +CREATE TABLE files_fingerprints ( + file_id integer NOT NULL, + type varchar(255) NOT NULL, + fingerprint text NOT NULL, + foreign key(file_id) references files(id) on delete CASCADE, + PRIMARY KEY (file_id, type, fingerprint) +); +CREATE TABLE video_files ( + file_id integer NOT NULL primary key, + duration float NOT NULL, + video_codec varchar(255) NOT NULL, + format varchar(255) NOT NULL, + audio_codec varchar(255) NOT NULL, + width smallint NOT NULL, + height smallint NOT NULL, + frame_rate float NOT NULL, + bit_rate integer NOT NULL, + interactive boolean not null default FALSE, + interactive_speed int, + foreign key(file_id) references files(id) on delete CASCADE +); +CREATE TABLE video_captions ( + file_id integer NOT NULL, + language_code varchar(255) NOT NULL, + filename varchar(255) NOT NULL, + caption_type varchar(255) NOT NULL, + primary key (file_id, language_code, caption_type), + foreign key(file_id) references video_files(file_id) on delete CASCADE +); +CREATE TABLE image_files ( + file_id integer NOT NULL primary key, + format varchar(255) NOT NULL, + width smallint NOT NULL, + height smallint NOT NULL, + foreign key(file_id) references files(id) on delete CASCADE +); +CREATE TABLE images_files ( + image_id integer NOT NULL, + file_id integer NOT NULL, + "primary" boolean NOT NULL, + foreign key(image_id) references images(id) on delete CASCADE, + foreign key(file_id) references files(id) on delete CASCADE, + PRIMARY KEY(image_id, file_id) +); +CREATE TABLE galleries_files ( + gallery_id integer NOT NULL, + file_id integer NOT NULL, + "primary" boolean NOT NULL, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + foreign key(file_id) references files(id) on delete CASCADE, + PRIMARY KEY(gallery_id, file_id) +); +CREATE TABLE scenes_files ( + scene_id integer NOT NULL, + file_id integer NOT NULL, + "primary" boolean NOT NULL, + foreign key(scene_id) references scenes(id) on delete CASCADE, + foreign key(file_id) references files(id) on delete CASCADE, + PRIMARY KEY(scene_id, file_id) +); +CREATE TABLE IF NOT EXISTS performers_scenes ( + performer_id integer, + scene_id integer, + foreign key(performer_id) references performers(id) on delete CASCADE, + foreign key(scene_id) references scenes(id) on delete CASCADE, + PRIMARY KEY (scene_id, performer_id) +); +CREATE TABLE IF NOT EXISTS scene_markers ( + id serial not null primary key, + title text NOT NULL, + seconds FLOAT NOT NULL, + primary_tag_id INTEGER NOT NULL, + scene_id INTEGER NOT NULL, + created_at timestamp NOT NULL, + updated_at timestamp NOT NULL, + FOREIGN KEY(primary_tag_id) REFERENCES tags(id), + FOREIGN KEY(scene_id) REFERENCES scenes(id) +); +CREATE TABLE IF NOT EXISTS scene_markers_tags ( + scene_marker_id integer, + tag_id integer, + foreign key(scene_marker_id) references scene_markers(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(scene_marker_id, tag_id) +); +CREATE TABLE IF NOT EXISTS scenes_tags ( + scene_id integer, + tag_id integer, + foreign key(scene_id) references scenes(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(scene_id, tag_id) +); +CREATE TABLE IF NOT EXISTS groups_scenes ( + "group_id" integer, + scene_id integer, + scene_index smallint, + foreign key("group_id") references "groups"(id) on delete cascade, + foreign key(scene_id) references scenes(id) on delete cascade, + PRIMARY KEY("group_id", scene_id) +); +CREATE TABLE IF NOT EXISTS performers_images ( + performer_id integer, + image_id integer, + foreign key(performer_id) references performers(id) on delete CASCADE, + foreign key(image_id) references images(id) on delete CASCADE, + PRIMARY KEY(image_id, performer_id) +); +CREATE TABLE IF NOT EXISTS images_tags ( + image_id integer, + tag_id integer, + foreign key(image_id) references images(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(image_id, tag_id) +); +CREATE TABLE IF NOT EXISTS scene_stash_ids ( + scene_id integer NOT NULL, + endpoint varchar(2048) NOT NULL, + stash_id uuid NOT NULL, + foreign key(scene_id) references scenes(id) on delete CASCADE, + PRIMARY KEY(scene_id, endpoint) +); +CREATE TABLE IF NOT EXISTS scenes_galleries ( + scene_id integer NOT NULL, + gallery_id integer NOT NULL, + foreign key(scene_id) references scenes(id) on delete CASCADE, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + PRIMARY KEY(scene_id, gallery_id) +); +CREATE TABLE IF NOT EXISTS galleries_images ( + gallery_id integer NOT NULL, + image_id integer NOT NULL, + cover boolean not null default FALSE, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + foreign key(image_id) references images(id) on delete CASCADE, + PRIMARY KEY(gallery_id, image_id) +); +CREATE TABLE IF NOT EXISTS performers_galleries ( + performer_id integer NOT NULL, + gallery_id integer NOT NULL, + foreign key(performer_id) references performers(id) on delete CASCADE, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + PRIMARY KEY(gallery_id, performer_id) +); +CREATE TABLE IF NOT EXISTS galleries_tags ( + gallery_id integer NOT NULL, + tag_id integer NOT NULL, + foreign key(gallery_id) references galleries(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(gallery_id, tag_id) +); +CREATE TABLE IF NOT EXISTS performers_tags ( + performer_id integer NOT NULL, + tag_id integer NOT NULL, + foreign key(performer_id) references performers(id) on delete CASCADE, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(performer_id, tag_id) +); +CREATE TABLE IF NOT EXISTS tag_aliases ( + tag_id integer NOT NULL, + alias text NOT NULL, + foreign key(tag_id) references tags(id) on delete CASCADE, + PRIMARY KEY(tag_id, alias) +); +CREATE TABLE IF NOT EXISTS studio_aliases ( + studio_id integer NOT NULL, + alias text NOT NULL, + foreign key(studio_id) references studios(id) on delete CASCADE, + PRIMARY KEY(studio_id, alias) +); +CREATE TABLE performer_aliases ( + performer_id integer NOT NULL, + alias text NOT NULL, + foreign key(performer_id) references performers(id) on delete CASCADE, + PRIMARY KEY(performer_id, alias) +); +CREATE TABLE galleries_chapters ( + id serial not null primary key, + title text not null, + image_index integer not null, + gallery_id integer not null, + created_at timestamp not null, + updated_at timestamp not null, + foreign key(gallery_id) references galleries(id) on delete CASCADE +); +CREATE TABLE scene_urls ( + scene_id integer NOT NULL, + position integer NOT NULL, + url varchar(2048) NOT NULL, + foreign key(scene_id) references scenes(id) on delete CASCADE, + PRIMARY KEY(scene_id, position, url) +); +CREATE TABLE groups_relations ( + containing_id integer not null, + sub_id integer not null, + order_index integer not null, + description text, + primary key (containing_id, sub_id), + foreign key (containing_id) references groups(id) on delete cascade, + foreign key (sub_id) references groups(id) on delete cascade, + check (containing_id != sub_id) +); +CREATE INDEX index_tags_on_name on tags (name); +CREATE INDEX index_folders_on_parent_folder_id on folders (parent_folder_id); +CREATE UNIQUE INDEX index_folders_on_path_unique on folders (path); +CREATE UNIQUE INDEX index_files_zip_basename_unique ON files (zip_file_id, parent_folder_id, basename) WHERE zip_file_id IS NOT NULL; +CREATE UNIQUE INDEX index_files_on_parent_folder_id_basename_unique on files (parent_folder_id, basename); +CREATE INDEX index_files_on_basename on files (basename); +CREATE INDEX index_folders_on_zip_file_id on folders (zip_file_id) WHERE zip_file_id IS NOT NULL; +CREATE INDEX index_fingerprint_type_fingerprint ON files_fingerprints (type, fingerprint); +CREATE INDEX index_images_files_on_file_id on images_files (file_id); +CREATE UNIQUE INDEX unique_index_images_files_on_primary on images_files (image_id) WHERE "primary" = TRUE; +CREATE INDEX index_galleries_files_file_id ON galleries_files (file_id); +CREATE UNIQUE INDEX unique_index_galleries_files_on_primary on galleries_files (gallery_id) WHERE "primary" = TRUE; +CREATE INDEX index_scenes_files_file_id ON scenes_files (file_id); +CREATE UNIQUE INDEX unique_index_scenes_files_on_primary on scenes_files (scene_id) WHERE "primary" = TRUE; +CREATE INDEX index_performer_stash_ids_on_performer_id ON performer_stash_ids (performer_id); +CREATE INDEX index_studio_stash_ids_on_studio_id ON studio_stash_ids (studio_id); +CREATE INDEX index_performers_scenes_on_performer_id on performers_scenes (performer_id); +CREATE INDEX index_scene_markers_tags_on_tag_id on scene_markers_tags (tag_id); +CREATE INDEX index_scenes_tags_on_tag_id on scenes_tags (tag_id); +CREATE INDEX index_movies_scenes_on_movie_id on groups_scenes (group_id); +CREATE INDEX index_performers_images_on_performer_id on performers_images (performer_id); +CREATE INDEX index_images_tags_on_tag_id on images_tags (tag_id); +CREATE INDEX index_scenes_galleries_on_gallery_id on scenes_galleries (gallery_id); +CREATE INDEX index_galleries_images_on_image_id on galleries_images (image_id); +CREATE INDEX index_performers_galleries_on_performer_id on performers_galleries (performer_id); +CREATE INDEX index_galleries_tags_on_tag_id on galleries_tags (tag_id); +CREATE INDEX index_performers_tags_on_tag_id on performers_tags (tag_id); +CREATE UNIQUE INDEX tag_aliases_alias_unique on tag_aliases (alias); +CREATE UNIQUE INDEX studio_aliases_alias_unique on studio_aliases (alias); +CREATE INDEX performer_aliases_alias on performer_aliases (alias); +CREATE INDEX index_galleries_chapters_on_gallery_id on galleries_chapters (gallery_id); +CREATE INDEX scene_urls_url on scene_urls (url); +CREATE INDEX index_scene_markers_on_primary_tag_id ON scene_markers(primary_tag_id); +CREATE INDEX index_scene_markers_on_scene_id ON scene_markers(scene_id); +CREATE UNIQUE INDEX index_studios_on_name_unique ON studios(name); +CREATE UNIQUE INDEX index_saved_filters_on_mode_name_unique on saved_filters (mode, name); +CREATE INDEX image_urls_url on image_urls (url); +CREATE INDEX index_images_on_studio_id on images (studio_id); +CREATE INDEX gallery_urls_url on gallery_urls (url); +CREATE INDEX index_galleries_on_studio_id on galleries (studio_id); +CREATE UNIQUE INDEX index_galleries_on_folder_id_unique on galleries (folder_id); +CREATE INDEX index_scenes_on_studio_id on scenes (studio_id); +CREATE INDEX performers_urls_url on performer_urls (url); +CREATE UNIQUE INDEX performers_name_disambiguation_unique on performers (name, disambiguation) WHERE disambiguation IS NOT NULL; +CREATE UNIQUE INDEX performers_name_unique on performers (name) WHERE disambiguation IS NULL; +CREATE INDEX index_studios_tags_on_tag_id on studios_tags (tag_id); +CREATE INDEX index_scenes_view_dates ON scenes_view_dates (scene_id); +CREATE INDEX index_scenes_o_dates ON scenes_o_dates (scene_id); +CREATE INDEX index_groups_on_name ON groups(name); +CREATE INDEX index_groups_on_studio_id on groups (studio_id); +CREATE INDEX group_urls_url on group_urls (url); +CREATE INDEX index_groups_tags_on_tag_id on groups_tags (tag_id); +CREATE INDEX index_groups_tags_on_movie_id on groups_tags (group_id); +CREATE UNIQUE INDEX index_galleries_images_gallery_id_cover on galleries_images (gallery_id, cover) WHERE cover = TRUE; +CREATE INDEX index_groups_relations_sub_id ON groups_relations (sub_id); +CREATE UNIQUE INDEX index_groups_relations_order_index_unique ON groups_relations (containing_id, order_index); diff --git a/pkg/sqlite/migrationsPostgres/2_image_studio_index.up.sql b/pkg/sqlite/migrationsPostgres/2_image_studio_index.up.sql new file mode 100644 index 00000000000..3c28cf1196c --- /dev/null +++ b/pkg/sqlite/migrationsPostgres/2_image_studio_index.up.sql @@ -0,0 +1,7 @@ +-- with the existing index, if no images have a studio id, then the index is +-- not used when filtering by studio id. The assumption with this change is that +-- most images don't have a studio id, so filtering by non-null studio id should +-- be faster with this index. This is a tradeoff, as filtering by null studio id +-- will be slower. +DROP INDEX index_images_on_studio_id; +CREATE INDEX index_images_on_studio_id on images (studio_id) WHERE studio_id IS NOT NULL; \ No newline at end of file diff --git a/pkg/sqlite/migrationsPostgres/3_stash_id_updated_at.up.sql b/pkg/sqlite/migrationsPostgres/3_stash_id_updated_at.up.sql new file mode 100644 index 00000000000..8bf9a8cb004 --- /dev/null +++ b/pkg/sqlite/migrationsPostgres/3_stash_id_updated_at.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE performer_stash_ids ADD COLUMN updated_at timestamp not null default '1970-01-01T00:00:00Z'; +ALTER TABLE scene_stash_ids ADD COLUMN updated_at timestamp not null default '1970-01-01T00:00:00Z'; +ALTER TABLE studio_stash_ids ADD COLUMN updated_at timestamp not null default '1970-01-01T00:00:00Z'; diff --git a/pkg/sqlite/performer.go b/pkg/sqlite/performer.go index e20dc9c4cc5..883ba403091 100644 --- a/pkg/sqlite/performer.go +++ b/pkg/sqlite/performer.go @@ -568,14 +568,14 @@ func (qb *PerformerStore) QueryForAutoTag(ctx context.Context, words []string) ( var whereClauses []exp.Expression for _, w := range words { - whereClauses = append(whereClauses, table.Col("name").Like(w+"%")) + whereClauses = append(whereClauses, table.Col("name").ILike(w+"%")) // TODO - see above - // whereClauses = append(whereClauses, performersAliasesJoinTable.Col("alias").Like(w+"%")) + // whereClauses = append(whereClauses, performersAliasesJoinTable.Col("alias").ILike(w+"%")) } sq = sq.Where( goqu.Or(whereClauses...), - table.Col("ignore_auto_tag").Eq(0), + table.Col("ignore_auto_tag").IsFalse(), ) ret, err := qb.findBySubquery(ctx, sq) @@ -613,11 +613,13 @@ func (qb *PerformerStore) makeQuery(ctx context.Context, performerFilter *models } var err error - query.sortAndPagination, err = qb.getPerformerSort(findFilter) + var agg []string + query.sortAndPagination, agg, err = qb.getPerformerSort(findFilter) if err != nil { return nil, err } query.sortAndPagination += getPagination(findFilter) + query.addGroupBy(agg...) return &query, nil } @@ -731,7 +733,7 @@ var performerSortOptions = sortOptions{ "weight", } -func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) (string, error) { +func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) (string, []string, error) { var sort string var direction string if findFilter == nil { @@ -744,9 +746,10 @@ func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) (s // CVE-2024-32231 - ensure sort is in the list of allowed sorts if err := performerSortOptions.validateSort(sort); err != nil { - return "", err + return "", nil, err } + var agg []string sortQuery := "" switch sort { case "tag_count": @@ -766,12 +769,15 @@ func (qb *PerformerStore) getPerformerSort(findFilter *models.FindFilterType) (s case "last_o_at": sortQuery += qb.sortByLastOAt(direction) default: - sortQuery += getSort(sort, direction, "performers") + var add string + add, agg = getSort(sort, direction, "performers") + sortQuery += add } // Whatever the sorting, always use name/id as a final sort - sortQuery += ", COALESCE(performers.name, performers.id) COLLATE NATURAL_CI ASC" - return sortQuery, nil + sortQuery += ", COALESCE(performers.name, CAST(performers.id as text)) COLLATE NATURAL_CI ASC" + agg = append(agg, "performers.name", "performers.id") + return sortQuery, agg, nil } func (qb *PerformerStore) GetTagIDs(ctx context.Context, id int) ([]int, error) { diff --git a/pkg/sqlite/performer_filter.go b/pkg/sqlite/performer_filter.go index 72990a7febd..dad7cd1062e 100644 --- a/pkg/sqlite/performer_filter.go +++ b/pkg/sqlite/performer_filter.go @@ -139,7 +139,7 @@ func (qb *performerFilterHandler) criterionHandler() criterionHandler { criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if filter.StashID != nil { performerRepository.stashIDs.join(f, "performer_stash_ids", "performers.id") - stringCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(ctx, f) + uuidCriterionHandler(filter.StashID, "performer_stash_ids.stash_id")(ctx, f) } }), &stashIDCriterionHandler{ @@ -226,7 +226,7 @@ func (qb *performerFilterHandler) performerIsMissingCriterionHandler(isMissing * performersAliasesTableMgr.join(f, "", "performers.id") f.addWhere("performer_aliases.alias IS NULL") default: - f.addWhere("(performers." + *isMissing + " IS NULL OR TRIM(performers." + *isMissing + ") = '')") + f.addWhere("(performers." + *isMissing + " IS NULL OR TRIM(CAST(performers." + *isMissing + " AS TEXT)) = '')") } } } @@ -235,10 +235,23 @@ func (qb *performerFilterHandler) performerIsMissingCriterionHandler(isMissing * func (qb *performerFilterHandler) performerAgeFilterCriterionHandler(age *models.IntCriterionInput) criterionHandlerFunc { return func(ctx context.Context, f *filterBuilder) { if age != nil && age.Modifier.IsValid() { - clause, args := getIntCriterionWhereClause( - "cast(IFNULL(strftime('%Y.%m%d', performers.death_date), strftime('%Y.%m%d', 'now')) - strftime('%Y.%m%d', performers.birthdate) as int)", - *age, - ) + + var clause string + var args []interface{} + + switch dbWrapper.dbType { + case PostgresBackend: + clause, args = getIntCriterionWhereClause( + "EXTRACT(YEAR FROM COALESCE(performers.death_date, CURRENT_DATE)) - EXTRACT(YEAR FROM performers.birthdate)", + *age, + ) + case SqliteBackend: + clause, args = getIntCriterionWhereClause( + "cast(IFNULL(strftime('%Y.%m%d', performers.death_date), strftime('%Y.%m%d', 'now')) - strftime('%Y.%m%d', performers.birthdate) as int)", + *age, + ) + } + f.addWhere(clause, args...) } } @@ -456,7 +469,7 @@ func (qb *performerFilterHandler) studiosCriterionHandler(studios *models.Hierar } const derivedPerformerStudioTable = "performer_studio" - valuesClause, err := getHierarchicalValues(ctx, studios.Value, studioTable, "", "parent_id", "child_id", studios.Depth) + valuesClause, err := getHierarchicalValues(ctx, studios.Value, studioTable, "", "parent_id", "child_id", studios.Depth, false) if err != nil { f.setError(err) return @@ -519,7 +532,7 @@ func (qb *performerFilterHandler) appearsWithCriterionHandler(performers *models if performers.Modifier == models.CriterionModifierIncludesAll && len(performers.Value) > 1 { templStr += ` GROUP BY {primaryTable}2.performer_id - HAVING(count(distinct {primaryTable}.performer_id) IS ` + strconv.Itoa(len(performers.Value)) + `)` + HAVING(count(distinct {primaryTable}.performer_id) = ` + strconv.Itoa(len(performers.Value)) + `)` } var unions []string diff --git a/pkg/sqlite/performer_test.go b/pkg/sqlite/performer_test.go index e0294f3e442..f522a102501 100644 --- a/pkg/sqlite/performer_test.go +++ b/pkg/sqlite/performer_test.go @@ -18,22 +18,22 @@ import ( func loadPerformerRelationships(ctx context.Context, expected models.Performer, actual *models.Performer) error { if expected.Aliases.Loaded() { - if err := actual.LoadAliases(ctx, db.Performer); err != nil { + if err := actual.LoadAliases(ctx, db.GetRepo().Performer); err != nil { return err } } if expected.URLs.Loaded() { - if err := actual.LoadURLs(ctx, db.Performer); err != nil { + if err := actual.LoadURLs(ctx, db.GetRepo().Performer); err != nil { return err } } if expected.TagIDs.Loaded() { - if err := actual.LoadTagIDs(ctx, db.Performer); err != nil { + if err := actual.LoadTagIDs(ctx, db.GetRepo().Performer); err != nil { return err } } if expected.StashIDs.Loaded() { - if err := actual.LoadStashIDs(ctx, db.Performer); err != nil { + if err := actual.LoadStashIDs(ctx, db.GetRepo().Performer); err != nil { return err } } @@ -70,8 +70,8 @@ func Test_PerformerStore_Create(t *testing.T) { favorite = true endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) @@ -137,7 +137,7 @@ func Test_PerformerStore_Create(t *testing.T) { }, } - qb := db.Performer + qb := db.GetRepo().Performer for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -217,8 +217,8 @@ func Test_PerformerStore_Update(t *testing.T) { favorite = true endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) @@ -304,7 +304,7 @@ func Test_PerformerStore_Update(t *testing.T) { }, } - qb := db.Performer + qb := db.GetRepo().Performer for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -398,8 +398,8 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) { favorite = true endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC) @@ -535,7 +535,7 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) { }, } for _, tt := range tests { - qb := db.Performer + qb := db.GetRepo().Performer runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -575,7 +575,7 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) { func TestPerformerFindBySceneID(t *testing.T) { withTxn(func(ctx context.Context) error { - pqb := db.Performer + pqb := db.GetRepo().Performer sceneID := sceneIDs[sceneIdxWithPerformer] performers, err := pqb.FindBySceneID(ctx, sceneID) @@ -606,7 +606,7 @@ func TestPerformerFindBySceneID(t *testing.T) { func TestPerformerFindByImageID(t *testing.T) { withTxn(func(ctx context.Context) error { - pqb := db.Performer + pqb := db.GetRepo().Performer imageID := imageIDs[imageIdxWithPerformer] performers, err := pqb.FindByImageID(ctx, imageID) @@ -637,7 +637,7 @@ func TestPerformerFindByImageID(t *testing.T) { func TestPerformerFindByGalleryID(t *testing.T) { withTxn(func(ctx context.Context) error { - pqb := db.Performer + pqb := db.GetRepo().Performer galleryID := galleryIDs[galleryIdxWithPerformer] performers, err := pqb.FindByGalleryID(ctx, galleryID) @@ -678,7 +678,7 @@ func TestPerformerFindByNames(t *testing.T) { withTxn(func(ctx context.Context) error { var names []string - pqb := db.Performer + pqb := db.GetRepo().Performer names = append(names, performerNames[performerIdxWithScene]) // find performers by names @@ -893,7 +893,7 @@ func TestPerformerIllegalQuery(t *testing.T) { }, } - sqb := db.Performer + sqb := db.GetRepo().Performer for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1022,7 +1022,7 @@ func TestPerformerQuery(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - performers, _, err := db.Performer.Query(ctx, tt.filter, tt.findFilter) + performers, _, err := db.GetRepo().Performer.Query(ctx, tt.filter, tt.findFilter) if (err != nil) != tt.wantErr { t.Errorf("PerformerStore.Query() error = %v, wantErr %v", err, tt.wantErr) return @@ -1105,7 +1105,7 @@ func TestPerformerQueryPenisLength(t *testing.T) { }, } - performers, _, err := db.Performer.Query(ctx, filter, nil) + performers, _, err := db.GetRepo().Performer.Query(ctx, filter, nil) if err != nil { t.Errorf("PerformerStore.Query() error = %v", err) return @@ -1145,7 +1145,7 @@ func verifyFloat(t *testing.T, value *float64, criterion models.FloatCriterionIn func TestPerformerQueryForAutoTag(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Performer + tqb := db.GetRepo().Performer name := performerNames[performerIdx1WithScene] // find a performer by name @@ -1165,7 +1165,7 @@ func TestPerformerQueryForAutoTag(t *testing.T) { func TestPerformerUpdatePerformerImage(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Performer + qb := db.GetRepo().Performer // create performer to test against const name = "TestPerformerUpdatePerformerImage" @@ -1204,7 +1204,7 @@ func TestPerformerQueryAge(t *testing.T) { func verifyPerformerAge(t *testing.T, ageCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Performer + qb := db.GetRepo().Performer performerFilter := models.PerformerFilterType{ Age: &ageCriterion, } @@ -1259,7 +1259,7 @@ func TestPerformerQueryCareerLength(t *testing.T) { func verifyPerformerCareerLength(t *testing.T, criterion models.StringCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Performer + qb := db.GetRepo().Performer performerFilter := models.PerformerFilterType{ CareerLength: &criterion, } @@ -1329,7 +1329,7 @@ func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verif performers := queryPerformers(ctx, t, &filter, nil) for _, performer := range performers { - if err := performer.LoadURLs(ctx, db.Performer); err != nil { + if err := performer.LoadURLs(ctx, db.GetRepo().Performer); err != nil { t.Errorf("Error loading url relationships: %v", err) } } @@ -1347,7 +1347,7 @@ func verifyPerformerQuery(t *testing.T, filter models.PerformerFilterType, verif func queryPerformers(ctx context.Context, t *testing.T, performerFilter *models.PerformerFilterType, findFilter *models.FindFilterType) []*models.Performer { t.Helper() - performers, _, err := db.Performer.Query(ctx, performerFilter, findFilter) + performers, _, err := db.GetRepo().Performer.Query(ctx, performerFilter, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -1429,7 +1429,7 @@ func TestPerformerQueryTagCount(t *testing.T) { func verifyPerformersTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Performer + sqb := db.GetRepo().Performer performerFilter := models.PerformerFilterType{ TagCount: &tagCountCriterion, } @@ -1478,7 +1478,7 @@ func verifyPerformersSceneCount(t *testing.T, sceneCountCriterion models.IntCrit assert.Greater(t, len(performers), 0) for _, performer := range performers { - ids, err := db.Scene.FindByPerformerID(ctx, performer.ID) + ids, err := db.GetRepo().Scene.FindByPerformerID(ctx, performer.ID) if err != nil { return err } @@ -1520,7 +1520,7 @@ func verifyPerformersImageCount(t *testing.T, imageCountCriterion models.IntCrit for _, performer := range performers { pp := 0 - result, err := db.Image.Query(ctx, models.ImageQueryOptions{ + result, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, @@ -1575,7 +1575,7 @@ func verifyPerformersGalleryCount(t *testing.T, galleryCountCriterion models.Int for _, performer := range performers { pp := 0 - _, count, err := db.Gallery.Query(ctx, &models.GalleryFilterType{ + _, count, err := db.GetRepo().Gallery.Query(ctx, &models.GalleryFilterType{ Performers: &models.MultiCriterionInput{ Value: []string{strconv.Itoa(performer.ID)}, Modifier: models.CriterionModifierIncludes, @@ -1673,7 +1673,7 @@ func TestPerformerQueryStudio(t *testing.T) { func TestPerformerStashIDs(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Performer + qb := db.GetRepo().Performer // create scene to test against const name = "TestPerformerStashIDs" @@ -1700,14 +1700,14 @@ func testPerformerStashIDs(ctx context.Context, t *testing.T, s *models.Performe assert.Len(t, s.StashIDs.List(), 0) // add stash ids - const stashIDStr = "stashID" + var stashIDStr = getUUID("stashID") const endpoint = "endpoint" stashID := models.StashID{ StashID: stashIDStr, Endpoint: endpoint, } - qb := db.Performer + qb := db.GetRepo().Performer // update stash ids and ensure was updated var err error @@ -1817,7 +1817,7 @@ func TestPerformerQueryIsMissingImage(t *testing.T) { assert.True(t, len(performers) > 0) for _, performer := range performers { - img, err := db.Performer.GetImage(ctx, performer.ID) + img, err := db.GetRepo().Performer.GetImage(ctx, performer.ID) if err != nil { t.Errorf("error getting performer image: %s", err.Error()) } @@ -1835,7 +1835,7 @@ func TestPerformerQueryIsMissingAlias(t *testing.T) { assert.True(t, len(performers) > 0) for _, performer := range performers { - a, err := db.Performer.GetAliases(ctx, performer.ID) + a, err := db.GetRepo().Performer.GetAliases(ctx, performer.ID) if err != nil { t.Errorf("error getting performer aliases: %s", err.Error()) } @@ -1856,7 +1856,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { withTxn(func(ctx context.Context) error { // just ensure it queries without error - performers, _, err := db.Performer.Query(ctx, nil, findFilter) + performers, _, err := db.GetRepo().Performer.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -1871,7 +1871,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { // sort in ascending order direction = models.SortDirectionEnumAsc - performers, _, err = db.Performer.Query(ctx, nil, findFilter) + performers, _, err = db.GetRepo().Performer.Query(ctx, nil, findFilter) if err != nil { t.Errorf("Error querying performers: %s", err.Error()) } @@ -1887,7 +1887,7 @@ func TestPerformerQuerySortScenesCount(t *testing.T) { func TestPerformerCountByTagID(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Performer + sqb := db.GetRepo().Performer count, err := sqb.CountByTagID(ctx, tagIDs[tagIdxWithPerformer]) if err != nil { @@ -1910,7 +1910,7 @@ func TestPerformerCountByTagID(t *testing.T) { func TestPerformerCount(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Performer + sqb := db.GetRepo().Performer count, err := sqb.Count(ctx) if err != nil { @@ -1925,7 +1925,7 @@ func TestPerformerCount(t *testing.T) { func TestPerformerAll(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Performer + sqb := db.GetRepo().Performer all, err := sqb.All(ctx) if err != nil { @@ -1966,7 +1966,7 @@ func TestPerformerStore_FindByStashID(t *testing.T) { { name: "non-existing", stashID: models.StashID{ - StashID: getPerformerStringValue(performerIdxWithScene, "stashid"), + StashID: getUUID(getPerformerStringValue(performerIdxWithScene, "stashid")), Endpoint: "non-existing", }, expectedIDs: []int{}, @@ -1974,7 +1974,7 @@ func TestPerformerStore_FindByStashID(t *testing.T) { }, } - qb := db.Performer + qb := db.GetRepo().Performer for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2025,7 +2025,7 @@ func TestPerformerStore_FindByStashIDStatus(t *testing.T) { }, } - qb := db.Performer + qb := db.GetRepo().Performer for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { diff --git a/pkg/sqlite/query.go b/pkg/sqlite/query.go index 9c09d8beaed..26cc0542b81 100644 --- a/pkg/sqlite/query.go +++ b/pkg/sqlite/query.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/stashapp/stash/pkg/models" + "github.com/stashapp/stash/pkg/sliceutil" ) type queryBuilder struct { @@ -14,12 +15,13 @@ type queryBuilder struct { columns []string from string - joins joins - whereClauses []string - havingClauses []string - args []interface{} - withClauses []string - recursiveWith bool + joins joins + whereClauses []string + havingClauses []string + args []interface{} + withClauses []string + recursiveWith bool + groupByClauses []string sortAndPagination string } @@ -32,6 +34,12 @@ func (qb *queryBuilder) addColumn(column string) { qb.columns = append(qb.columns, column) } +func (qb *queryBuilder) addGroupBy(columns ...string) { + if len(columns) > 0 { + qb.groupByClauses = sliceutil.AppendUniques(qb.groupByClauses, columns) + } +} + func (qb queryBuilder) toSQL(includeSortPagination bool) string { body := qb.body() @@ -44,7 +52,8 @@ func (qb queryBuilder) toSQL(includeSortPagination bool) string { withClause = "WITH " + recursive + strings.Join(qb.withClauses, ", ") + " " } - body = withClause + qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses) + body = withClause + qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses, qb.groupByClauses) + if includeSortPagination { body += qb.sortAndPagination } @@ -60,7 +69,7 @@ func (qb queryBuilder) findIDs(ctx context.Context) ([]int, error) { func (qb queryBuilder) executeFind(ctx context.Context) ([]int, int, error) { body := qb.body() - return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.recursiveWith) + return qb.repository.executeFindQuery(ctx, body, qb.args, qb.sortAndPagination, qb.whereClauses, qb.havingClauses, qb.withClauses, qb.groupByClauses, qb.recursiveWith) } func (qb queryBuilder) executeCount(ctx context.Context) (int, error) { @@ -75,7 +84,7 @@ func (qb queryBuilder) executeCount(ctx context.Context) (int, error) { withClause = "WITH " + recursive + strings.Join(qb.withClauses, ", ") + " " } - body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses) + body = qb.repository.buildQueryBody(body, qb.whereClauses, qb.havingClauses, qb.groupByClauses) countQuery := withClause + qb.repository.buildCountQuery(body) return qb.repository.runCountQuery(ctx, countQuery, qb.args) } @@ -181,7 +190,7 @@ func (qb *queryBuilder) parseQueryString(columns []string, q string) { var clauses []string for _, column := range columns { - clauses = append(clauses, column+" LIKE ?") + clauses = append(clauses, "LOWER("+column+") LIKE LOWER(?)") qb.addArg(like(t)) } @@ -190,7 +199,7 @@ func (qb *queryBuilder) parseQueryString(columns []string, q string) { for _, t := range specs.MustNot { for _, column := range columns { - qb.addWhere(coalesce(column) + " NOT LIKE ?") + qb.addWhere("LOWER(" + coalesce(column) + ") NOT LIKE LOWER(?)") qb.addArg(like(t)) } } @@ -200,7 +209,7 @@ func (qb *queryBuilder) parseQueryString(columns []string, q string) { for _, column := range columns { for _, v := range set { - clauses = append(clauses, column+" LIKE ?") + clauses = append(clauses, "LOWER("+column+") LIKE LOWER(?)") qb.addArg(like(v)) } } diff --git a/pkg/sqlite/repository.go b/pkg/sqlite/repository.go index 2035b11c2fc..1d4961df85e 100644 --- a/pkg/sqlite/repository.go +++ b/pkg/sqlite/repository.go @@ -92,6 +92,7 @@ func (r *repository) runIdsQuery(ctx context.Context, query string, args []inter for i, v := range result { vsm[i] = v.Int } + return vsm, nil } @@ -153,20 +154,25 @@ func (r *repository) querySimple(ctx context.Context, query string, args []inter return nil } -func (r *repository) buildQueryBody(body string, whereClauses []string, havingClauses []string) string { +func (r *repository) buildQueryBody(body string, whereClauses []string, havingClauses []string, groupByClauses []string) string { if len(whereClauses) > 0 { body = body + " WHERE " + strings.Join(whereClauses, " AND ") // TODO handle AND or OR } if len(havingClauses) > 0 { - body = body + " GROUP BY " + r.tableName + ".id " + groupByClauses = append(groupByClauses, r.tableName+".id") + } + if len(groupByClauses) > 0 { + body += " GROUP BY " + strings.Join(groupByClauses, ", ") + " " + } + if len(havingClauses) > 0 { body = body + " HAVING " + strings.Join(havingClauses, " AND ") // TODO handle AND or OR } return body } -func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, recursiveWith bool) ([]int, int, error) { - body = r.buildQueryBody(body, whereClauses, havingClauses) +func (r *repository) executeFindQuery(ctx context.Context, body string, args []interface{}, sortAndPagination string, whereClauses []string, havingClauses []string, withClauses []string, groupByClauses []string, recursiveWith bool) ([]int, int, error) { + body = r.buildQueryBody(body, whereClauses, havingClauses, groupByClauses) withClause := "" if len(withClauses) > 0 { @@ -447,10 +453,10 @@ func idToIndexMap(ids []int) map[int]int { func (r *filesRepository) getMany(ctx context.Context, ids []int, primaryOnly bool) ([][]models.FileID, error) { var primaryClause string if primaryOnly { - primaryClause = " AND `primary` = 1" + primaryClause = ` AND "primary" = ` + getDBBoolean(true) } - query := fmt.Sprintf("SELECT %s as id, file_id, `primary` from %s WHERE %[1]s IN %[3]s%s", r.idColumn, r.tableName, getInBinding(len(ids)), primaryClause) + query := fmt.Sprintf(`SELECT %s as id, file_id, "primary" from %s WHERE %[1]s IN %[3]s%s`, r.idColumn, r.tableName, getInBinding(len(ids)), primaryClause) idi := make([]interface{}, len(ids)) for i, id := range ids { @@ -491,7 +497,7 @@ func (r *filesRepository) getMany(ctx context.Context, ids []int, primaryOnly bo } func (r *filesRepository) get(ctx context.Context, id int) ([]models.FileID, error) { - query := fmt.Sprintf("SELECT file_id, `primary` from %s WHERE %s = ?", r.tableName, r.idColumn) + query := fmt.Sprintf(`SELECT file_id, "primary" from %s WHERE %s = ?`, r.tableName, r.idColumn) type relatedFile struct { FileID models.FileID `db:"file_id"` diff --git a/pkg/sqlite/saved_filter_test.go b/pkg/sqlite/saved_filter_test.go index 60592a923df..dc37b5a8d22 100644 --- a/pkg/sqlite/saved_filter_test.go +++ b/pkg/sqlite/saved_filter_test.go @@ -13,7 +13,7 @@ import ( func TestSavedFilterFind(t *testing.T) { withTxn(func(ctx context.Context) error { - savedFilter, err := db.SavedFilter.Find(ctx, savedFilterIDs[savedFilterIdxImage]) + savedFilter, err := db.GetRepo().SavedFilter.Find(ctx, savedFilterIDs[savedFilterIdxImage]) if err != nil { t.Errorf("Error finding saved filter: %s", err.Error()) @@ -27,7 +27,7 @@ func TestSavedFilterFind(t *testing.T) { func TestSavedFilterFindByMode(t *testing.T) { withTxn(func(ctx context.Context) error { - savedFilters, err := db.SavedFilter.FindByMode(ctx, models.FilterModeScenes) + savedFilters, err := db.GetRepo().SavedFilter.FindByMode(ctx, models.FilterModeScenes) if err != nil { t.Errorf("Error finding saved filters: %s", err.Error()) @@ -72,7 +72,7 @@ func TestSavedFilterDestroy(t *testing.T) { ObjectFilter: objectFilter, UIOptions: uiOptions, } - err := db.SavedFilter.Create(ctx, &newFilter) + err := db.GetRepo().SavedFilter.Create(ctx, &newFilter) if err == nil { id = newFilter.ID @@ -82,12 +82,12 @@ func TestSavedFilterDestroy(t *testing.T) { }) withTxn(func(ctx context.Context) error { - return db.SavedFilter.Destroy(ctx, id) + return db.GetRepo().SavedFilter.Destroy(ctx, id) }) // now try to find it withTxn(func(ctx context.Context) error { - found, err := db.SavedFilter.Find(ctx, id) + found, err := db.GetRepo().SavedFilter.Find(ctx, id) if err == nil { assert.Nil(t, found) } diff --git a/pkg/sqlite/scene.go b/pkg/sqlite/scene.go index edd36348329..1d702cf0424 100644 --- a/pkg/sqlite/scene.go +++ b/pkg/sqlite/scene.go @@ -41,7 +41,7 @@ const ( ) var findExactDuplicateQuery = ` -SELECT GROUP_CONCAT(DISTINCT scene_id) as ids +SELECT %s as ids FROM ( SELECT scenes.id as scene_id , video_files.duration as file_duration @@ -52,10 +52,10 @@ FROM ( INNER JOIN scenes_files ON (scenes.id = scenes_files.scene_id) INNER JOIN files ON (scenes_files.file_id = files.id) INNER JOIN files_fingerprints ON (scenes_files.file_id = files_fingerprints.file_id AND files_fingerprints.type = 'phash') - INNER JOIN video_files ON (files.id == video_files.file_id) -) -WHERE durationDiff <= ?1 - OR ?1 < 0 -- Always TRUE if the parameter is negative. + INNER JOIN video_files ON (files.id = video_files.file_id) +) as subq +WHERE durationDiff <= $1 + OR $1 < 0 -- Always TRUE if the parameter is negative. -- That will disable the durationDiff checking. GROUP BY phash HAVING COUNT(phash) > 1 @@ -71,7 +71,7 @@ FROM scenes INNER JOIN scenes_files ON (scenes.id = scenes_files.scene_id) INNER JOIN files ON (scenes_files.file_id = files.id) INNER JOIN files_fingerprints ON (scenes_files.file_id = files_fingerprints.file_id AND files_fingerprints.type = 'phash') -INNER JOIN video_files ON (files.id == video_files.file_id) +INNER JOIN video_files ON (files.id = video_files.file_id) ORDER BY files.size DESC; ` @@ -268,7 +268,7 @@ func (qb *SceneStore) selectDataset() *goqu.SelectDataset { scenesFilesJoinTable, goqu.On( scenesFilesJoinTable.Col(sceneIDColumn).Eq(table.Col(idColumn)), - scenesFilesJoinTable.Col("primary").Eq(1), + scenesFilesJoinTable.Col("primary").IsTrue(), ), ).LeftJoin( files, @@ -631,7 +631,7 @@ func (qb *SceneStore) FindByFileID(ctx context.Context, fileID models.FileID) ([ func (qb *SceneStore) FindByPrimaryFileID(ctx context.Context, fileID models.FileID) ([]*models.Scene, error) { sq := dialect.From(scenesFilesJoinTable).Select(scenesFilesJoinTable.Col(sceneIDColumn)).Where( scenesFilesJoinTable.Col(fileIDColumn).Eq(fileID), - scenesFilesJoinTable.Col("primary").Eq(1), + scenesFilesJoinTable.Col("primary").IsTrue(), ) ret, err := qb.findBySubquery(ctx, sq) @@ -711,8 +711,8 @@ func (qb *SceneStore) FindByPath(ctx context.Context, p string) ([]*models.Scene foldersTable, goqu.On(foldersTable.Col(idColumn).Eq(filesTable.Col("parent_folder_id"))), ).Select(scenesFilesJoinTable.Col(sceneIDColumn)).Where( - foldersTable.Col("path").Like(dir), - filesTable.Col("basename").Like(basename), + foldersTable.Col("path").ILike(dir), + filesTable.Col("basename").ILike(basename), ) ret, err := qb.findBySubquery(ctx, sq) @@ -891,7 +891,7 @@ func (qb *SceneStore) Wall(ctx context.Context, q *string) ([]*models.Scene, err } table := qb.table() - qq := qb.selectDataset().Prepared(true).Where(table.Col("details").Like("%" + s + "%")).Order(goqu.L("RANDOM()").Asc()).Limit(80) + qq := qb.selectDataset().Prepared(true).Where(table.Col("details").ILike("%" + s + "%")).Order(goqu.L("RANDOM()").Asc()).Limit(80) return qb.getMany(ctx, qq) } @@ -922,11 +922,11 @@ func (qb *SceneStore) makeQuery(ctx context.Context, sceneFilter *models.SceneFi query.addJoins( join{ table: scenesFilesTable, - onClause: "scenes_files.scene_id = scenes.id", + onClause: "scenes_files.scene_id = scenes.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, join{ table: fileTable, - onClause: "scenes_files.file_id = files.id", + onClause: "scenes_files.file_id = files.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, join{ table: folderTable, @@ -999,30 +999,32 @@ func (qb *SceneStore) queryGroupedFields(ctx context.Context, options models.Sce query.addJoins( join{ table: scenesFilesTable, - onClause: "scenes_files.scene_id = scenes.id", + onClause: "scenes_files.scene_id = scenes.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, join{ table: videoFileTable, - onClause: "scenes_files.file_id = video_files.file_id", + onClause: "scenes_files.file_id = video_files.file_id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, ) query.addColumn("COALESCE(video_files.duration, 0) as duration") aggregateQuery.addColumn("SUM(temp.duration) as duration") + query.addGroupBy("video_files.duration") } if options.TotalSize { query.addJoins( join{ table: scenesFilesTable, - onClause: "scenes_files.scene_id = scenes.id", + onClause: "scenes_files.scene_id = scenes.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, join{ table: fileTable, - onClause: "scenes_files.file_id = files.id", + onClause: "scenes_files.file_id = files.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, ) query.addColumn("COALESCE(files.size, 0) as size") aggregateQuery.addColumn("SUM(temp.size) as size") + query.addGroupBy("files.size") } const includeSortPagination = false @@ -1085,9 +1087,7 @@ var sceneSortOptions = sortOptions{ } func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindFilterType) error { - if findFilter == nil || findFilter.Sort == nil || *findFilter.Sort == "" { - return nil - } + models.EnsureFindFilterSorted(findFilter) sort := findFilter.GetSort("title") // CVE-2024-32231 - ensure sort is in the list of allowed sorts @@ -1099,11 +1099,11 @@ func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindF query.addJoins( join{ table: scenesFilesTable, - onClause: "scenes_files.scene_id = scenes.id", + onClause: "scenes_files.scene_id = scenes.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, join{ table: fileTable, - onClause: "scenes_files.file_id = files.id", + onClause: "scenes_files.file_id = files.id AND scenes_files.\"primary\" = " + getDBBoolean(true), }, ) } @@ -1131,10 +1131,14 @@ func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindF switch sort { case "movie_scene_number": query.join(groupsScenesTable, "", "scenes.id = groups_scenes.scene_id") - query.sortAndPagination += getSort("scene_index", direction, groupsScenesTable) + add, group := getSort("scene_index", direction, groupsScenesTable) + query.sortAndPagination += add + query.addGroupBy(group...) case "group_scene_number": query.join(groupsScenesTable, "scene_group", "scenes.id = scene_group.scene_id") - query.sortAndPagination += getSort("scene_index", direction, "scene_group") + add, group := getSort("scene_index", direction, "scene_group") + query.sortAndPagination += add + query.addGroupBy(group...) case "tag_count": query.sortAndPagination += getCountSort(sceneTable, scenesTagsTable, sceneIDColumn, direction) case "performer_count": @@ -1146,6 +1150,7 @@ func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindF addFileTable() addFolderTable() query.sortAndPagination += fmt.Sprintf(" ORDER BY COALESCE(folders.path, '') || COALESCE(files.basename, '') COLLATE NATURAL_CI %s", direction) + query.addGroupBy("folders.path", "files.basename") case "perceptual_similarity": // special handling for phash addFileTable() @@ -1153,36 +1158,50 @@ func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindF join{ table: fingerprintTable, as: "fingerprints_phash", - onClause: "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'", + onClause: "scenes_files.file_id = fingerprints_phash.file_id AND scenes_files.\"primary\" = " + getDBBoolean(true) + " AND fingerprints_phash.type = 'phash'", }, ) query.sortAndPagination += " ORDER BY fingerprints_phash.fingerprint " + direction + ", files.size DESC" + query.addGroupBy("fingerprints_phash.fingerprint", "files.size") case "bitrate": sort = "bit_rate" addVideoFileTable() - query.sortAndPagination += getSort(sort, direction, videoFileTable) + add, group := getSort(sort, direction, videoFileTable) + query.sortAndPagination += add + query.addGroupBy(group...) case "file_mod_time": sort = "mod_time" addFileTable() - query.sortAndPagination += getSort(sort, direction, fileTable) + add, agg := getSort(sort, direction, fileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "framerate": sort = "frame_rate" addVideoFileTable() - query.sortAndPagination += getSort(sort, direction, videoFileTable) + add, agg := getSort(sort, direction, videoFileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "filesize": addFileTable() - query.sortAndPagination += getSort(sort, direction, fileTable) + add, agg := getSort(sort, direction, fileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "duration": addVideoFileTable() - query.sortAndPagination += getSort(sort, direction, videoFileTable) + add, agg := getSort(sort, direction, videoFileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "interactive", "interactive_speed": addVideoFileTable() - query.sortAndPagination += getSort(sort, direction, videoFileTable) + add, agg := getSort(sort, direction, videoFileTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "title": addFileTable() addFolderTable() query.sortAndPagination += " ORDER BY COALESCE(scenes.title, files.basename) COLLATE NATURAL_CI " + direction + ", folders.path COLLATE NATURAL_CI " + direction + query.addGroupBy("scenes.title", "files.basename", "folders.path") case "play_count": query.sortAndPagination += getCountSort(sceneTable, scenesViewDatesTable, sceneIDColumn, direction) case "last_played_at": @@ -1192,11 +1211,14 @@ func (qb *SceneStore) setSceneSort(query *queryBuilder, findFilter *models.FindF case "o_counter": query.sortAndPagination += getCountSort(sceneTable, scenesODatesTable, sceneIDColumn, direction) default: - query.sortAndPagination += getSort(sort, direction, "scenes") + add, agg := getSort(sort, direction, "scenes") + query.sortAndPagination += add + query.addGroupBy(agg...) } // Whatever the sorting, always use title/id as a final sort - query.sortAndPagination += ", COALESCE(scenes.title, scenes.id) COLLATE NATURAL_CI ASC" + query.sortAndPagination += ", COALESCE(scenes.title, CAST(scenes.id as text)) COLLATE NATURAL_CI ASC" + query.addGroupBy("scenes.title", "scenes.id") return nil } @@ -1332,7 +1354,9 @@ func (qb *SceneStore) FindDuplicates(ctx context.Context, distance int, duration var dupeIds [][]int if distance == 0 { var ids []string - if err := dbWrapper.Select(ctx, &ids, findExactDuplicateQuery, durationDiff); err != nil { + + dbfix_findExactDuplicateQuery := fmt.Sprintf(findExactDuplicateQuery, getDBGroupConcat("DISTINCT scene_id")) + if err := dbWrapper.Select(ctx, &ids, dbfix_findExactDuplicateQuery, durationDiff); err != nil { return nil, err } diff --git a/pkg/sqlite/scene_filter.go b/pkg/sqlite/scene_filter.go index 2e63dad975f..f2fdf6f3821 100644 --- a/pkg/sqlite/scene_filter.go +++ b/pkg/sqlite/scene_filter.go @@ -65,7 +65,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if sceneFilter.Oshash != nil { qb.addSceneFilesTable(f) - f.addLeftJoin(fingerprintTable, "fingerprints_oshash", "scenes_files.file_id = fingerprints_oshash.file_id AND fingerprints_oshash.type = 'oshash'") + f.addLeftJoin(fingerprintTable, "fingerprints_oshash", "scenes_files.file_id = fingerprints_oshash.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)+" AND fingerprints_oshash.type = 'oshash'") } stringCriterionHandler(sceneFilter.Oshash, "fingerprints_oshash.fingerprint")(ctx, f) @@ -74,7 +74,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if sceneFilter.Checksum != nil { qb.addSceneFilesTable(f) - f.addLeftJoin(fingerprintTable, "fingerprints_md5", "scenes_files.file_id = fingerprints_md5.file_id AND fingerprints_md5.type = 'md5'") + f.addLeftJoin(fingerprintTable, "fingerprints_md5", "scenes_files.file_id = fingerprints_md5.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)+" AND fingerprints_md5.type = 'md5'") } stringCriterionHandler(sceneFilter.Checksum, "fingerprints_md5.fingerprint")(ctx, f) @@ -111,7 +111,7 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if sceneFilter.StashID != nil { sceneRepository.stashIDs.join(f, "scene_stash_ids", "scenes.id") - stringCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(ctx, f) + uuidCriterionHandler(sceneFilter.StashID, "scene_stash_ids.stash_id")(ctx, f) } }), @@ -214,12 +214,12 @@ func (qb *sceneFilterHandler) criterionHandler() criterionHandler { } func (qb *sceneFilterHandler) addSceneFilesTable(f *filterBuilder) { - f.addLeftJoin(scenesFilesTable, "", "scenes_files.scene_id = scenes.id") + f.addLeftJoin(scenesFilesTable, "", "scenes_files.scene_id = scenes.id AND scenes_files.\"primary\" = "+getDBBoolean(true)) } func (qb *sceneFilterHandler) addFilesTable(f *filterBuilder) { qb.addSceneFilesTable(f) - f.addLeftJoin(fileTable, "", "scenes_files.file_id = files.id") + f.addLeftJoin(fileTable, "", "scenes_files.file_id = files.id AND scenes_files.\"primary\" = "+getDBBoolean(true)) } func (qb *sceneFilterHandler) addFoldersTable(f *filterBuilder) { @@ -229,7 +229,7 @@ func (qb *sceneFilterHandler) addFoldersTable(f *filterBuilder) { func (qb *sceneFilterHandler) addVideoFilesTable(f *filterBuilder) { qb.addSceneFilesTable(f) - f.addLeftJoin(videoFileTable, "", "video_files.file_id = scenes_files.file_id") + f.addLeftJoin(videoFileTable, "", "video_files.file_id = scenes_files.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)) } func (qb *sceneFilterHandler) playCountCriterionHandler(count *models.IntCriterionInput) criterionHandlerFunc { @@ -277,7 +277,7 @@ func (qb *sceneFilterHandler) phashDuplicatedCriterionHandler(duplicatedFilter * v = "=" } - f.addInnerJoin("(SELECT file_id FROM files_fingerprints INNER JOIN (SELECT fingerprint FROM files_fingerprints WHERE type = 'phash' GROUP BY fingerprint HAVING COUNT (fingerprint) "+v+" 1) dupes on files_fingerprints.fingerprint = dupes.fingerprint)", "scph", "scenes_files.file_id = scph.file_id") + f.addInnerJoin("(SELECT file_id FROM files_fingerprints INNER JOIN (SELECT fingerprint FROM files_fingerprints WHERE type = 'phash' GROUP BY fingerprint HAVING COUNT (fingerprint) "+v+" 1) dupes on files_fingerprints.fingerprint = dupes.fingerprint)", "scph", "scenes_files.file_id = scph.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)) } } } @@ -326,7 +326,11 @@ func (qb *sceneFilterHandler) isMissingCriterionHandler(isMissing *string) crite sceneRepository.performers.join(f, "performers_join", "scenes.id") f.addWhere("performers_join.scene_id IS NULL") case "date": - f.addWhere(`scenes.date IS NULL OR scenes.date IS ""`) + q := "scenes.date IS NULL" + if dbWrapper.dbType == SqliteBackend { + q += ` OR scenes.date IS ""` + } + f.addWhere(q) case "tags": sceneRepository.tags.join(f, "tags_join", "scenes.id") f.addWhere("tags_join.scene_id IS NULL") @@ -335,12 +339,12 @@ func (qb *sceneFilterHandler) isMissingCriterionHandler(isMissing *string) crite f.addWhere("scene_stash_ids.scene_id IS NULL") case "phash": qb.addSceneFilesTable(f) - f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'") + f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)+" AND fingerprints_phash.type = 'phash'") f.addWhere("fingerprints_phash.fingerprint IS NULL") case "cover": f.addWhere("scenes.cover_blob IS NULL") default: - f.addWhere("(scenes." + *isMissing + " IS NULL OR TRIM(scenes." + *isMissing + ") = '')") + f.addWhere("(scenes." + *isMissing + " IS NULL OR TRIM(CAST(scenes." + *isMissing + " AS TEXT)) = '')") } } } @@ -385,7 +389,7 @@ func (qb *sceneFilterHandler) captionCriterionHandler(captions *models.StringCri excludeClause := `scenes.id NOT IN ( SELECT scenes_files.scene_id from scenes_files INNER JOIN video_captions on video_captions.file_id = scenes_files.file_id - WHERE video_captions.language_code LIKE ? + WHERE LOWER(video_captions.language_code) LIKE LOWER(?) )` f.addWhere(excludeClause, criterion.Value) @@ -455,12 +459,12 @@ func (qb *sceneFilterHandler) performerFavoriteCriterionHandler(performerfavorit if *performerfavorite { // contains at least one favorite f.addLeftJoin("performers", "", "performers.id = performers_scenes.performer_id") - f.addWhere("performers.favorite = 1") + f.addWhere("performers.favorite = " + getDBBoolean(true)) } else { // contains zero favorites f.addLeftJoin(`(SELECT performers_scenes.scene_id as id FROM performers_scenes JOIN performers ON performers.id = performers_scenes.performer_id -GROUP BY performers_scenes.scene_id HAVING SUM(performers.favorite) = 0)`, "nofaves", "scenes.id = nofaves.id") +GROUP BY performers_scenes.scene_id HAVING SUM(performers.favorite) = `+getDBBoolean(false)+")", "nofaves", "scenes.id = nofaves.id") f.addWhere("performers_scenes.scene_id IS NULL OR nofaves.id IS NOT NULL") } } @@ -476,7 +480,15 @@ func (qb *sceneFilterHandler) performerAgeCriterionHandler(performerAge *models. f.addWhere("scenes.date != '' AND performers.birthdate != ''") f.addWhere("scenes.date IS NOT NULL AND performers.birthdate IS NOT NULL") - ageCalc := "cast(strftime('%Y.%m%d', scenes.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + var ageCalc string + + switch dbWrapper.dbType { + case PostgresBackend: + ageCalc = "EXTRACT(YEAR FROM AGE(scenes.date, performers.birthdate))" + case SqliteBackend: + ageCalc = "cast(strftime('%Y.%m%d', scenes.date) - strftime('%Y.%m%d', performers.birthdate) as int)" + } + whereClause, args := getIntWhereClause(ageCalc, performerAge.Modifier, performerAge.Value, performerAge.Value2) f.addWhere(whereClause, args...) } @@ -532,7 +544,7 @@ func (qb *sceneFilterHandler) phashDistanceCriterionHandler(phashDistance *model return func(ctx context.Context, f *filterBuilder) { if phashDistance != nil { qb.addSceneFilesTable(f) - f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND fingerprints_phash.type = 'phash'") + f.addLeftJoin(fingerprintTable, "fingerprints_phash", "scenes_files.file_id = fingerprints_phash.file_id AND scenes_files.\"primary\" = "+getDBBoolean(true)+" AND fingerprints_phash.type = 'phash'") value, _ := utils.StringToPhash(phashDistance.Value) distance := 0 @@ -545,23 +557,19 @@ func (qb *sceneFilterHandler) phashDistanceCriterionHandler(phashDistance *model intCriterionHandler(&models.IntCriterionInput{ Value: int(value), Modifier: phashDistance.Modifier, - }, "fingerprints_phash.fingerprint", nil)(ctx, f) + }, "CAST(fingerprints_phash.fingerprint AS bigint)", nil)(ctx, f) } switch { case phashDistance.Modifier == models.CriterionModifierEquals && distance > 0: - // needed to avoid a type mismatch - f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'") - f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) < ?", value, distance) + f.addWhere("phash_distance(CAST(fingerprints_phash.fingerprint AS bigint), ?) < ?", value, distance) case phashDistance.Modifier == models.CriterionModifierNotEquals && distance > 0: - // needed to avoid a type mismatch - f.addWhere("typeof(fingerprints_phash.fingerprint) = 'integer'") - f.addWhere("phash_distance(fingerprints_phash.fingerprint, ?) > ?", value, distance) + f.addWhere("phash_distance(CAST(fingerprints_phash.fingerprint AS bigint), ?) > ?", value, distance) default: intCriterionHandler(&models.IntCriterionInput{ Value: int(value), Modifier: phashDistance.Modifier, - }, "fingerprints_phash.fingerprint", nil)(ctx, f) + }, "CAST(fingerprints_phash.fingerprint AS bigint)", nil)(ctx, f) } } } diff --git a/pkg/sqlite/scene_marker.go b/pkg/sqlite/scene_marker.go index 8b2306eab4b..2cdc3234f85 100644 --- a/pkg/sqlite/scene_marker.go +++ b/pkg/sqlite/scene_marker.go @@ -273,7 +273,7 @@ func (qb *SceneMarkerStore) CountByTagID(ctx context.Context, tagID int) (int, e func (qb *SceneMarkerStore) GetMarkerStrings(ctx context.Context, q *string, sort *string) ([]*models.MarkerStringsResultType, error) { query := "SELECT count(*) as `count`, scene_markers.id as id, scene_markers.title as title FROM scene_markers" if q != nil { - query += " WHERE title LIKE '%" + *q + "%'" + query += " WHERE LOWER(title) LIKE LOWER('%" + *q + "%')" } query += " GROUP BY title" if sort != nil && *sort == "count" { @@ -292,7 +292,7 @@ func (qb *SceneMarkerStore) Wall(ctx context.Context, q *string) ([]*models.Scen } table := qb.table() - qq := qb.selectDataset().Prepared(true).Where(table.Col("title").Like("%" + s + "%")).Order(goqu.L("RANDOM()").Asc()).Limit(80) + qq := qb.selectDataset().Prepared(true).Where(table.Col("title").ILike("%" + s + "%")).Order(goqu.L("RANDOM()").Asc()).Limit(80) return qb.getMany(ctx, qq) } @@ -382,12 +382,17 @@ func (qb *SceneMarkerStore) setSceneMarkerSort(query *queryBuilder, findFilter * case "scenes_updated_at": sort = "updated_at" query.join(sceneTable, "", "scenes.id = scene_markers.scene_id") - query.sortAndPagination += getSort(sort, direction, sceneTable) + add, agg := getSort(sort, direction, sceneTable) + query.sortAndPagination += add + query.addGroupBy(agg...) case "title": query.join(tagTable, "", "scene_markers.primary_tag_id = tags.id") query.sortAndPagination += " ORDER BY COALESCE(NULLIF(scene_markers.title,''), tags.name) COLLATE NATURAL_CI " + direction + query.addGroupBy("scene_markers.title", "tags.name") default: - query.sortAndPagination += getSort(sort, direction, sceneMarkerTable) + add, agg := getSort(sort, direction, sceneMarkerTable) + query.sortAndPagination += add + query.addGroupBy(agg...) } query.sortAndPagination += ", scene_markers.scene_id ASC, scene_markers.seconds ASC" diff --git a/pkg/sqlite/scene_marker_filter.go b/pkg/sqlite/scene_marker_filter.go index d5e044e85a7..1bee6b83489 100644 --- a/pkg/sqlite/scene_marker_filter.go +++ b/pkg/sqlite/scene_marker_filter.go @@ -95,7 +95,7 @@ func (qb *sceneMarkerFilterHandler) tagsCriterionHandler(criterion *models.Hiera } if len(tags.Value) > 0 { - valuesClause, err := getHierarchicalValues(ctx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth) + valuesClause, err := getHierarchicalValues(ctx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth, false) if err != nil { f.setError(err) return @@ -116,7 +116,7 @@ func (qb *sceneMarkerFilterHandler) tagsCriterionHandler(criterion *models.Hiera // includes only the provided ids f.addWhere("marker_tags.root_tag_id IS NOT NULL") tagsLen := len(tags.Value) - f.addHaving(fmt.Sprintf("count(distinct marker_tags.root_tag_id) IS %d", tagsLen)) + f.addHaving(fmt.Sprintf("count(distinct marker_tags.root_tag_id) = %d", tagsLen)) // decrement by one to account for primary tag id f.addWhere("(SELECT COUNT(*) FROM scene_markers_tags s WHERE s.scene_marker_id = scene_markers.id) = ?", tagsLen-1) case models.CriterionModifierNotEquals: @@ -127,16 +127,16 @@ func (qb *sceneMarkerFilterHandler) tagsCriterionHandler(criterion *models.Hiera } if len(criterion.Excludes) > 0 { - valuesClause, err := getHierarchicalValues(ctx, tags.Excludes, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth) + valuesClause, err := getHierarchicalValues(ctx, tags.Excludes, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth, true) if err != nil { f.setError(err) return } - clause := "scene_markers.id NOT IN (SELECT scene_markers_tags.scene_marker_id FROM scene_markers_tags WHERE scene_markers_tags.tag_id IN (SELECT column2 FROM (%s)))" + clause := "scene_markers.id NOT IN (SELECT scene_markers_tags.scene_marker_id FROM scene_markers_tags WHERE scene_markers_tags.tag_id IN (SELECT column2 FROM %s))" f.addWhere(fmt.Sprintf(clause, valuesClause)) - f.addWhere(fmt.Sprintf("scene_markers.primary_tag_id NOT IN (SELECT column2 FROM (%s))", valuesClause)) + f.addWhere(fmt.Sprintf("scene_markers.primary_tag_id NOT IN (SELECT column2 FROM %s)", valuesClause)) } } } diff --git a/pkg/sqlite/scene_marker_test.go b/pkg/sqlite/scene_marker_test.go index ce8f4d3ad6b..0b0e3adfc81 100644 --- a/pkg/sqlite/scene_marker_test.go +++ b/pkg/sqlite/scene_marker_test.go @@ -16,7 +16,7 @@ import ( func TestMarkerFindBySceneID(t *testing.T) { withTxn(func(ctx context.Context) error { - mqb := db.SceneMarker + mqb := db.GetRepo().SceneMarker sceneID := sceneIDs[sceneIdxWithMarkers] markers, err := mqb.FindBySceneID(ctx, sceneID) @@ -44,7 +44,7 @@ func TestMarkerFindBySceneID(t *testing.T) { func TestMarkerCountByTagID(t *testing.T) { withTxn(func(ctx context.Context) error { - mqb := db.SceneMarker + mqb := db.GetRepo().SceneMarker markerCount, err := mqb.CountByTagID(ctx, tagIDs[tagIdxWithPrimaryMarkers]) @@ -77,7 +77,7 @@ func TestMarkerCountByTagID(t *testing.T) { func TestMarkerQueryQ(t *testing.T) { withTxn(func(ctx context.Context) error { q := getSceneTitle(sceneIdxWithMarkers) - m, _, err := db.SceneMarker.Query(ctx, nil, &models.FindFilterType{ + m, _, err := db.GetRepo().SceneMarker.Query(ctx, nil, &models.FindFilterType{ Q: &q, }) @@ -98,7 +98,7 @@ func TestMarkerQueryQ(t *testing.T) { func TestMarkerQuerySortBySceneUpdated(t *testing.T) { withTxn(func(ctx context.Context) error { sort := "scenes_updated_at" - _, _, err := db.SceneMarker.Query(ctx, nil, &models.FindFilterType{ + _, _, err := db.GetRepo().SceneMarker.Query(ctx, nil, &models.FindFilterType{ Sort: &sort, }) @@ -153,7 +153,7 @@ func TestMarkerQueryTags(t *testing.T) { withTxn(func(ctx context.Context) error { testTags := func(t *testing.T, m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { - tagIDs, err := db.SceneMarker.GetTagIDs(ctx, m.ID) + tagIDs, err := db.GetRepo().SceneMarker.GetTagIDs(ctx, m.ID) if err != nil { t.Errorf("error getting marker tag ids: %v", err) } @@ -255,7 +255,7 @@ func TestMarkerQueryTags(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - markers := queryMarkers(ctx, t, db.SceneMarker, tc.markerFilter, tc.findFilter) + markers := queryMarkers(ctx, t, db.GetRepo().SceneMarker, tc.markerFilter, tc.findFilter) assert.Greater(t, len(markers), 0) for _, m := range markers { testTags(t, m, tc.markerFilter) @@ -276,13 +276,13 @@ func TestMarkerQuerySceneTags(t *testing.T) { withTxn(func(ctx context.Context) error { testTags := func(t *testing.T, m *models.SceneMarker, markerFilter *models.SceneMarkerFilterType) { - s, err := db.Scene.Find(ctx, m.SceneID) + s, err := db.GetRepo().Scene.Find(ctx, m.SceneID) if err != nil { t.Errorf("error getting marker tag ids: %v", err) return } - if err := s.LoadTagIDs(ctx, db.Scene); err != nil { + if err := s.LoadTagIDs(ctx, db.GetRepo().Scene); err != nil { t.Errorf("error getting marker tag ids: %v", err) return } @@ -379,7 +379,7 @@ func TestMarkerQuerySceneTags(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - markers := queryMarkers(ctx, t, db.SceneMarker, tc.markerFilter, tc.findFilter) + markers := queryMarkers(ctx, t, db.GetRepo().SceneMarker, tc.markerFilter, tc.findFilter) assert.Greater(t, len(markers), 0) for _, m := range markers { testTags(t, m, tc.markerFilter) diff --git a/pkg/sqlite/scene_test.go b/pkg/sqlite/scene_test.go index a3174d7278d..c316c32870a 100644 --- a/pkg/sqlite/scene_test.go +++ b/pkg/sqlite/scene_test.go @@ -22,38 +22,38 @@ import ( func loadSceneRelationships(ctx context.Context, expected models.Scene, actual *models.Scene) error { if expected.URLs.Loaded() { - if err := actual.LoadURLs(ctx, db.Scene); err != nil { + if err := actual.LoadURLs(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.GalleryIDs.Loaded() { - if err := actual.LoadGalleryIDs(ctx, db.Scene); err != nil { + if err := actual.LoadGalleryIDs(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.TagIDs.Loaded() { - if err := actual.LoadTagIDs(ctx, db.Scene); err != nil { + if err := actual.LoadTagIDs(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.PerformerIDs.Loaded() { - if err := actual.LoadPerformerIDs(ctx, db.Scene); err != nil { + if err := actual.LoadPerformerIDs(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.Groups.Loaded() { - if err := actual.LoadGroups(ctx, db.Scene); err != nil { + if err := actual.LoadGroups(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.StashIDs.Loaded() { - if err := actual.LoadStashIDs(ctx, db.Scene); err != nil { + if err := actual.LoadStashIDs(ctx, db.GetRepo().Scene); err != nil { return err } } if expected.Files.Loaded() { - if err := actual.LoadFiles(ctx, db.Scene); err != nil { + if err := actual.LoadFiles(ctx, db.GetRepo().Scene); err != nil { return err } } @@ -75,6 +75,13 @@ func loadSceneRelationships(ctx context.Context, expected models.Scene, actual * return nil } +func sortScene(copy *models.Scene) { + // Ordering is not ensured + copy.GalleryIDs.Sort() + copy.TagIDs.Sort() + copy.PerformerIDs.Sort() +} + func Test_sceneQueryBuilder_Create(t *testing.T) { var ( title = "title" @@ -91,8 +98,8 @@ func Test_sceneQueryBuilder_Create(t *testing.T) { sceneIndex2 = 234 endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") date, _ = models.ParseDate("2003-02-01") @@ -233,7 +240,7 @@ func Test_sceneQueryBuilder_Create(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -267,6 +274,8 @@ func Test_sceneQueryBuilder_Create(t *testing.T) { return } + sortScene(©) + sortScene(&s) assert.Equal(copy, s) // ensure can find the scene @@ -284,6 +293,7 @@ func Test_sceneQueryBuilder_Create(t *testing.T) { t.Errorf("loadSceneRelationships() error = %v", err) return } + sortScene(found) assert.Equal(copy, *found) return @@ -321,8 +331,8 @@ func Test_sceneQueryBuilder_Update(t *testing.T) { sceneIndex2 = 234 endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") date, _ = models.ParseDate("2003-02-01") ) @@ -466,7 +476,7 @@ func Test_sceneQueryBuilder_Update(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -492,6 +502,8 @@ func Test_sceneQueryBuilder_Update(t *testing.T) { return } + sortScene(©) + sortScene(s) assert.Equal(copy, *s) }) } @@ -531,8 +543,8 @@ func Test_sceneQueryBuilder_UpdatePartial(t *testing.T) { sceneIndex2 = 234 endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") date, _ = models.ParseDate("2003-02-01") ) @@ -675,7 +687,7 @@ func Test_sceneQueryBuilder_UpdatePartial(t *testing.T) { }, } for _, tt := range tests { - qb := db.Scene + qb := db.GetRepo().Scene runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -699,6 +711,8 @@ func Test_sceneQueryBuilder_UpdatePartial(t *testing.T) { // ignore file ids clearSceneFileIDs(got) + sortScene(&tt.want) + sortScene(got) assert.Equal(tt.want, *got) s, err := qb.Find(ctx, tt.id) @@ -714,6 +728,7 @@ func Test_sceneQueryBuilder_UpdatePartial(t *testing.T) { // ignore file ids clearSceneFileIDs(s) + sortScene(s) assert.Equal(tt.want, *s) }) } @@ -725,8 +740,8 @@ func Test_sceneQueryBuilder_UpdatePartialRelationships(t *testing.T) { sceneIndex2 = 234 endpoint1 = "endpoint1" endpoint2 = "endpoint2" - stashID1 = "stashid1" - stashID2 = "stashid2" + stashID1 = getUUID("stashid1") + stashID2 = getUUID("stashid2") groupScenes = []models.GroupsScenes{ { @@ -1215,7 +1230,7 @@ func Test_sceneQueryBuilder_UpdatePartialRelationships(t *testing.T) { } for _, tt := range tests { - qb := db.Scene + qb := db.GetRepo().Scene runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) @@ -1291,7 +1306,7 @@ func Test_sceneQueryBuilder_AddO(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1328,7 +1343,7 @@ func Test_sceneQueryBuilder_DeleteO(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1365,7 +1380,7 @@ func Test_sceneQueryBuilder_ResetO(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1403,7 +1418,7 @@ func Test_sceneQueryBuilder_Destroy(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1475,7 +1490,7 @@ func Test_sceneQueryBuilder_Find(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1546,7 +1561,7 @@ func Test_sceneQueryBuilder_FindMany(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1616,7 +1631,7 @@ func Test_sceneQueryBuilder_FindByChecksum(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1686,7 +1701,7 @@ func Test_sceneQueryBuilder_FindByOSHash(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1757,7 +1772,7 @@ func Test_sceneQueryBuilder_FindByPath(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1799,7 +1814,7 @@ func Test_sceneQueryBuilder_FindByGalleryID(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1823,7 +1838,7 @@ func Test_sceneQueryBuilder_FindByGalleryID(t *testing.T) { func TestSceneCountByPerformerID(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene count, err := sqb.CountByPerformerID(ctx, performerIDs[performerIdxWithScene]) if err != nil { @@ -1874,7 +1889,7 @@ func Test_sceneStore_FindByFileID(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1920,7 +1935,7 @@ func Test_sceneStore_CountByFileID(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1947,7 +1962,7 @@ func Test_sceneStore_CountMissingChecksum(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1974,7 +1989,7 @@ func Test_sceneStore_CountMissingOshash(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -1992,7 +2007,7 @@ func Test_sceneStore_CountMissingOshash(t *testing.T) { func TestSceneWall(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene const sceneIdx = 2 wallQuery := getSceneStringValue(sceneIdx, "Details") @@ -2029,7 +2044,7 @@ func TestSceneQueryQ(t *testing.T) { q := getSceneStringValue(sceneIdx, titleField) withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneQueryQ(ctx, t, sqb, q, sceneIdx) @@ -2211,7 +2226,7 @@ func TestSceneQuery(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + results, err := db.GetRepo().Scene.Query(ctx, models.SceneQueryOptions{ SceneFilter: tt.filter, QueryOptions: models.QueryOptions{ FindFilter: tt.findFilter, @@ -2324,7 +2339,7 @@ func TestSceneQueryPath(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -2423,7 +2438,7 @@ func TestSceneQueryPathOr(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) @@ -2458,7 +2473,7 @@ func TestSceneQueryPathAndRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) @@ -2497,7 +2512,7 @@ func TestSceneQueryPathNotRating(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) @@ -2530,7 +2545,7 @@ func TestSceneIllegalQuery(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene queryOptions := models.SceneQueryOptions{ SceneFilter: sceneFilter, @@ -2557,7 +2572,7 @@ func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func t.Helper() withTxn(func(ctx context.Context) error { t.Helper() - sqb := db.Scene + sqb := db.GetRepo().Scene scenes := queryScene(ctx, t, sqb, &filter, nil) @@ -2580,7 +2595,7 @@ func verifySceneQuery(t *testing.T, filter models.SceneFilterType, verifyFn func func verifyScenesPath(t *testing.T, pathCriterion models.StringCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ Path: &pathCriterion, } @@ -2674,7 +2689,7 @@ func TestSceneQueryRating100(t *testing.T) { func verifyScenesRating100(t *testing.T, ratingCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ Rating100: &ratingCriterion, } @@ -2733,7 +2748,7 @@ func TestSceneQueryOCounter(t *testing.T) { func verifyScenesOCounter(t *testing.T, oCounterCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ OCounter: &oCounterCriterion, } @@ -2798,7 +2813,7 @@ func TestSceneQueryDuration(t *testing.T) { func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ Duration: &durationCriterion, } @@ -2806,7 +2821,7 @@ func verifyScenesDuration(t *testing.T, durationCriterion models.IntCriterionInp scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { - if err := scene.LoadPrimaryFile(ctx, db.File); err != nil { + if err := scene.LoadPrimaryFile(ctx, db.GetRepo().File); err != nil { t.Errorf("Error querying scene files: %v", err) return nil } @@ -2870,7 +2885,7 @@ func TestSceneQueryResolution(t *testing.T) { func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ Resolution: &models.ResolutionCriterionInput{ Value: resolution, @@ -2881,7 +2896,7 @@ func verifyScenesResolution(t *testing.T, resolution models.ResolutionEnum) { scenes := queryScene(ctx, t, sqb, &sceneFilter, nil) for _, scene := range scenes { - if err := scene.LoadPrimaryFile(ctx, db.File); err != nil { + if err := scene.LoadPrimaryFile(ctx, db.GetRepo().File); err != nil { t.Errorf("Error querying scene files: %v", err) return nil } @@ -2933,7 +2948,7 @@ func TestAllResolutionsHaveResolutionRange(t *testing.T) { func TestSceneQueryResolutionModifiers(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene sceneNoResolution, _ := createScene(ctx, 0, 0) firstScene540P, _ := createScene(ctx, 960, 540) secondScene540P, _ := createScene(ctx, 1280, 719) @@ -2994,13 +3009,13 @@ func createScene(ctx context.Context, width int, height int) (*models.Scene, err Height: height, } - if err := db.File.Create(ctx, sceneFile); err != nil { + if err := db.GetRepo().File.Create(ctx, sceneFile); err != nil { return nil, err } scene := &models.Scene{} - if err := db.Scene.Create(ctx, scene, []models.FileID{sceneFile.ID}); err != nil { + if err := db.GetRepo().Scene.Create(ctx, scene, []models.FileID{sceneFile.ID}); err != nil { return nil, err } @@ -3009,7 +3024,7 @@ func createScene(ctx context.Context, width int, height int) (*models.Scene, err func TestSceneQueryHasMarkers(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene hasMarkers := "true" sceneFilter := models.SceneFilterType{ HasMarkers: &hasMarkers, @@ -3045,7 +3060,7 @@ func TestSceneQueryHasMarkers(t *testing.T) { func TestSceneQueryIsMissingGallery(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "galleries" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3074,7 +3089,7 @@ func TestSceneQueryIsMissingGallery(t *testing.T) { func TestSceneQueryIsMissingStudio(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "studio" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3103,7 +3118,7 @@ func TestSceneQueryIsMissingStudio(t *testing.T) { func TestSceneQueryIsMissingMovies(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "movie" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3132,7 +3147,7 @@ func TestSceneQueryIsMissingMovies(t *testing.T) { func TestSceneQueryIsMissingPerformers(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "performers" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3163,7 +3178,7 @@ func TestSceneQueryIsMissingPerformers(t *testing.T) { func TestSceneQueryIsMissingDate(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "date" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3185,7 +3200,7 @@ func TestSceneQueryIsMissingDate(t *testing.T) { func TestSceneQueryIsMissingTags(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "tags" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3211,7 +3226,7 @@ func TestSceneQueryIsMissingTags(t *testing.T) { func TestSceneQueryIsMissingRating(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "rating" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3232,7 +3247,7 @@ func TestSceneQueryIsMissingRating(t *testing.T) { func TestSceneQueryIsMissingPhash(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene isMissing := "phash" sceneFilter := models.SceneFilterType{ IsMissing: &isMissing, @@ -3363,7 +3378,7 @@ func TestSceneQueryPerformers(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + results, err := db.GetRepo().Scene.Query(ctx, models.SceneQueryOptions{ SceneFilter: &models.SceneFilterType{ Performers: &tt.filter, }, @@ -3499,7 +3514,7 @@ func TestSceneQueryTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + results, err := db.GetRepo().Scene.Query(ctx, models.SceneQueryOptions{ SceneFilter: &models.SceneFilterType{ Tags: &tt.filter, }, @@ -3696,7 +3711,7 @@ func TestSceneQueryPerformerTags(t *testing.T) { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { assert := assert.New(t) - results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + results, err := db.GetRepo().Scene.Query(ctx, models.SceneQueryOptions{ SceneFilter: tt.filter, QueryOptions: models.QueryOptions{ FindFilter: tt.findFilter, @@ -3790,7 +3805,7 @@ func TestSceneQueryStudio(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -3816,7 +3831,7 @@ func TestSceneQueryStudio(t *testing.T) { func TestSceneQueryStudioDepth(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene depth := 2 studioCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ @@ -3945,7 +3960,7 @@ func TestSceneGroups(t *testing.T) { findFilter.Q = &tt.q } - results, err := db.Scene.Query(ctx, models.SceneQueryOptions{ + results, err := db.GetRepo().Scene.Query(ctx, models.SceneQueryOptions{ SceneFilter: sceneFilter, QueryOptions: models.QueryOptions{ FindFilter: findFilter, @@ -3970,7 +3985,7 @@ func TestSceneGroups(t *testing.T) { func TestSceneQueryMovies(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene movieCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(groupIDs[groupIdxWithScene]), @@ -4010,7 +4025,7 @@ func TestSceneQueryMovies(t *testing.T) { func TestSceneQueryPhashDuplicated(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene duplicated := true phashCriterion := models.PHashDuplicationCriterionInput{ Duplicated: &duplicated, @@ -4121,7 +4136,7 @@ func TestSceneQuerySorting(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) { @@ -4173,7 +4188,7 @@ func TestSceneQueryPagination(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes := queryScene(ctx, t, sqb, nil, &findFilter) assert.Len(t, scenes, 1) @@ -4221,7 +4236,7 @@ func TestSceneQueryTagCount(t *testing.T) { func verifyScenesTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ TagCount: &tagCountCriterion, } @@ -4262,7 +4277,7 @@ func TestSceneQueryPerformerCount(t *testing.T) { func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene sceneFilter := models.SceneFilterType{ PerformerCount: &performerCountCriterion, } @@ -4285,7 +4300,7 @@ func verifyScenesPerformerCount(t *testing.T, performerCountCriterion models.Int func TestFindByMovieID(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes, err := sqb.FindByGroupID(ctx, groupIDs[groupIdxWithScene]) @@ -4310,7 +4325,7 @@ func TestFindByMovieID(t *testing.T) { func TestFindByPerformerID(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Scene + sqb := db.GetRepo().Scene scenes, err := sqb.FindByPerformerID(ctx, performerIDs[performerIdxWithScene]) @@ -4335,7 +4350,7 @@ func TestFindByPerformerID(t *testing.T) { func TestSceneUpdateSceneCover(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene sceneID := sceneIDs[sceneIdxWithGallery] @@ -4347,7 +4362,7 @@ func TestSceneUpdateSceneCover(t *testing.T) { func TestSceneStashIDs(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene // create scene to test against const name = "TestSceneStashIDs" @@ -4374,14 +4389,14 @@ func testSceneStashIDs(ctx context.Context, t *testing.T, s *models.Scene) { assert.Len(t, s.StashIDs.List(), 0) // add stash ids - const stashIDStr = "stashID" + var stashIDStr = getUUID("stashID") const endpoint = "endpoint" stashID := models.StashID{ StashID: stashIDStr, Endpoint: endpoint, } - qb := db.Scene + qb := db.GetRepo().Scene // update stash ids and ensure was updated var err error @@ -4423,7 +4438,7 @@ func testSceneStashIDs(ctx context.Context, t *testing.T, s *models.Scene) { func TestSceneQueryQTrim(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene expectedID := sceneIDs[sceneIdxWithSpacedName] @@ -4465,7 +4480,7 @@ func TestSceneQueryQTrim(t *testing.T) { } func TestSceneStore_All(t *testing.T) { - qb := db.Scene + qb := db.GetRepo().Scene withRollbackTxn(func(ctx context.Context) error { got, err := qb.All(ctx) @@ -4482,7 +4497,7 @@ func TestSceneStore_All(t *testing.T) { } func TestSceneStore_FindDuplicates(t *testing.T) { - qb := db.Scene + qb := db.GetRepo().Scene withRollbackTxn(func(ctx context.Context) error { distance := 0 @@ -4536,7 +4551,7 @@ func TestSceneStore_AssignFiles(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -4572,7 +4587,7 @@ func TestSceneStore_AddView(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -4663,7 +4678,7 @@ func TestSceneStore_SaveActivity(t *testing.T) { }, } - qb := db.Scene + qb := db.GetRepo().Scene for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -4715,7 +4730,7 @@ func TestSceneStore_SaveActivity(t *testing.T) { // TODO - this should be in history_test and generalised func TestSceneStore_CountAllViews(t *testing.T) { withRollbackTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene sceneID := sceneIDs[sceneIdx1WithPerformer] @@ -4748,7 +4763,7 @@ func TestSceneStore_CountAllViews(t *testing.T) { func TestSceneStore_CountUniqueViews(t *testing.T) { withRollbackTxn(func(ctx context.Context) error { - qb := db.Scene + qb := db.GetRepo().Scene sceneID := sceneIDs[sceneIdx1WithPerformer] diff --git a/pkg/sqlite/setup_test.go b/pkg/sqlite/setup_test.go index 1c3f914d3ba..56d51be1300 100644 --- a/pkg/sqlite/setup_test.go +++ b/pkg/sqlite/setup_test.go @@ -584,7 +584,7 @@ func indexFromID(ids []int, id int) int { return -1 } -var db *sqlite.Database +var db sqlite.DBInterface func TestMain(m *testing.M) { // initialise empty config - needed by some migrations @@ -630,6 +630,19 @@ func testTeardown(databaseFile string) { } } +func getNewDB(databaseFile string) sqlite.DBInterface { + dbUrl, valid := os.LookupEnv("PGSQL_TEST") + if valid { + fmt.Printf("Postgres backend for tests detected\n") + db = sqlite.NewPostgresDatabase(sqlite.NewDatabase(), dbUrl, true) + } else { + fmt.Printf("SQLite backend for tests detected\n") + db = sqlite.NewSQLiteDatabase(sqlite.NewDatabase(), databaseFile, true) + } + + return db +} + func runTests(m *testing.M) int { // create the database file f, err := os.CreateTemp("", "*.sqlite") @@ -639,15 +652,17 @@ func runTests(m *testing.M) int { f.Close() databaseFile := f.Name() - db = sqlite.NewDatabase() + db = getNewDB(databaseFile) + db.SetBlobStoreOptions(sqlite.BlobStoreOptions{ UseDatabase: true, // don't use filesystem }) - if err := db.Open(databaseFile); err != nil { + if err := db.Open(); err != nil { panic(fmt.Sprintf("Could not initialize database: %s", err.Error())) } + db.TestMode() // defer close and delete the database defer testTeardown(databaseFile) @@ -673,11 +688,11 @@ func populateDB() error { // TODO - link folders to zip files - if err := createTags(ctx, db.Tag, tagsNameCase, tagsNameNoCase); err != nil { + if err := createTags(ctx, db.GetRepo().Tag, tagsNameCase, tagsNameNoCase); err != nil { return fmt.Errorf("error creating tags: %s", err.Error()) } - if err := createGroups(ctx, db.Group, groupsNameCase, groupsNameNoCase); err != nil { + if err := createGroups(ctx, db.GetRepo().Group, groupsNameCase, groupsNameNoCase); err != nil { return fmt.Errorf("error creating groups: %s", err.Error()) } @@ -701,15 +716,15 @@ func populateDB() error { return fmt.Errorf("error creating images: %s", err.Error()) } - if err := addTagImage(ctx, db.Tag, tagIdxWithCoverImage); err != nil { + if err := addTagImage(ctx, db.GetRepo().Tag, tagIdxWithCoverImage); err != nil { return fmt.Errorf("error adding tag image: %s", err.Error()) } - if err := createSavedFilters(ctx, db.SavedFilter, totalSavedFilters); err != nil { + if err := createSavedFilters(ctx, db.GetRepo().SavedFilter, totalSavedFilters); err != nil { return fmt.Errorf("error creating saved filters: %s", err.Error()) } - if err := linkGroupStudios(ctx, db.Group); err != nil { + if err := linkGroupStudios(ctx, db.GetRepo().Group); err != nil { return fmt.Errorf("error linking group studios: %s", err.Error()) } @@ -717,21 +732,21 @@ func populateDB() error { return fmt.Errorf("error linking studios parent: %s", err.Error()) } - if err := linkTagsParent(ctx, db.Tag); err != nil { + if err := linkTagsParent(ctx, db.GetRepo().Tag); err != nil { return fmt.Errorf("error linking tags parent: %s", err.Error()) } - if err := linkGroupsParent(ctx, db.Group); err != nil { + if err := linkGroupsParent(ctx, db.GetRepo().Group); err != nil { return fmt.Errorf("error linking tags parent: %s", err.Error()) } for _, ms := range markerSpecs { - if err := createMarker(ctx, db.SceneMarker, ms); err != nil { + if err := createMarker(ctx, db.GetRepo().SceneMarker, ms); err != nil { return fmt.Errorf("error creating scene marker: %s", err.Error()) } } for _, cs := range chapterSpecs { - if err := createChapter(ctx, db.GalleryChapter, cs); err != nil { + if err := createChapter(ctx, db.GetRepo().GalleryChapter, cs); err != nil { return fmt.Errorf("error creating gallery chapter: %s", err.Error()) } } @@ -778,7 +793,7 @@ func makeFolder(i int) models.Folder { } func createFolders(ctx context.Context) error { - qb := db.Folder + qb := db.GetRepo().Folder for i := 0; i < totalFolders; i++ { folder := makeFolder(i) @@ -881,7 +896,7 @@ func makeFile(i int) models.File { } func createFiles(ctx context.Context) error { - qb := db.File + qb := db.GetRepo().File for i := 0; i < totalFiles; i++ { file := makeFile(i) @@ -1026,7 +1041,7 @@ func getObjectDate(index int) *models.Date { func sceneStashID(i int) models.StashID { return models.StashID{ - StashID: getSceneStringValue(i, "stashid"), + StashID: getUUID("stashid"), Endpoint: getSceneStringValue(i, "endpoint"), } } @@ -1130,8 +1145,8 @@ func makeScene(i int) *models.Scene { } func createScenes(ctx context.Context, n int) error { - sqb := db.Scene - fqb := db.File + sqb := db.GetRepo().Scene + fqb := db.GetRepo().File for i := 0; i < n; i++ { f := makeSceneFile(i) @@ -1219,8 +1234,8 @@ func makeImage(i int) *models.Image { } func createImages(ctx context.Context, n int) error { - qb := db.Image - fqb := db.File + qb := db.GetRepo().Image + fqb := db.GetRepo().File for i := 0; i < n; i++ { f := makeImageFile(i) @@ -1316,8 +1331,8 @@ func makeGallery(i int, includeScenes bool) *models.Gallery { } func createGalleries(ctx context.Context, n int) error { - gqb := db.Gallery - fqb := db.File + gqb := db.GetRepo().Gallery + fqb := db.GetRepo().File for i := 0; i < n; i++ { var fileIDs []models.FileID @@ -1491,9 +1506,14 @@ func getIgnoreAutoTag(index int) bool { return index%5 == 0 } +func getUUID(_ string) string { + // TODO: Encode input string + return "00000000-0000-0000-0000-000000000000" +} + func performerStashID(i int) models.StashID { return models.StashID{ - StashID: getPerformerStringValue(i, "stashid"), + StashID: getUUID("stashid"), Endpoint: getPerformerStringValue(i, "endpoint"), } } @@ -1508,7 +1528,7 @@ func performerAliases(i int) []string { // createPerformers creates n performers with plain Name and o performers with camel cased NaMe included func createPerformers(ctx context.Context, n int, o int) error { - pqb := db.Performer + pqb := db.GetRepo().Performer const namePlain = "Name" const nameNoCase = "NaMe" @@ -1713,7 +1733,7 @@ func getStudioBoolValue(index int) bool { // createStudios creates n studios with plain Name and o studios with camel cased NaMe included func createStudios(ctx context.Context, n int, o int) error { - sqb := db.Studio + sqb := db.GetRepo().Studio const namePlain = "Name" const nameNoCase = "NaMe" @@ -1882,7 +1902,7 @@ func linkGroupStudios(ctx context.Context, mqb models.GroupWriter) error { } func linkStudiosParent(ctx context.Context) error { - qb := db.Studio + qb := db.GetRepo().Studio return doLinks(studioParentLinks, func(parentIndex, childIndex int) error { input := &models.StudioPartial{ ID: studioIDs[childIndex], diff --git a/pkg/sqlite/sql.go b/pkg/sqlite/sql.go index 780d2e9881b..4e58c22403a 100644 --- a/pkg/sqlite/sql.go +++ b/pkg/sqlite/sql.go @@ -17,12 +17,15 @@ func selectAll(tableName string) string { } func distinctIDs(qb *queryBuilder, tableName string) { - qb.addColumn("DISTINCT " + getColumn(tableName, "id")) + columnId := getColumn(tableName, "id") + qb.addColumn(columnId) + qb.addGroupBy(columnId) qb.from = tableName } func selectIDs(qb *queryBuilder, tableName string) { - qb.addColumn(getColumn(tableName, "id")) + columnId := getColumn(tableName, "id") + qb.addColumn(columnId) qb.from = tableName } @@ -78,17 +81,22 @@ func getSortDirection(direction string) string { return direction } } -func getSort(sort string, direction string, tableName string) string { +func getSort(sort string, direction string, tableName string) (string, []string) { direction = getSortDirection(direction) + nullsfix := "" + if dbWrapper.dbType == PostgresBackend { + nullsfix = " NULLS LAST" + } + switch { case strings.HasSuffix(sort, "_count"): var relationTableName = strings.TrimSuffix(sort, "_count") // TODO: pluralize? colName := getColumn(relationTableName, "id") - return " ORDER BY COUNT(distinct " + colName + ") " + direction + return " ORDER BY COUNT(distinct " + colName + ") " + direction, nil case strings.Compare(sort, "filesize") == 0: colName := getColumn(tableName, "size") - return " ORDER BY " + colName + " " + direction + return " ORDER BY " + colName + " " + direction + nullsfix, []string{colName} case strings.HasPrefix(sort, randomSeedPrefix): // seed as a parameter from the UI seedStr := sort[len(randomSeedPrefix):] @@ -97,22 +105,22 @@ func getSort(sort string, direction string, tableName string) string { // fallback to a random seed seed = rand.Uint64() } - return getRandomSort(tableName, direction, seed) + return getRandomSort(tableName, direction, seed), nil case strings.Compare(sort, "random") == 0: - return getRandomSort(tableName, direction, rand.Uint64()) + return getRandomSort(tableName, direction, rand.Uint64()), nil default: colName := getColumn(tableName, sort) if strings.Contains(sort, ".") { colName = sort } if strings.Compare(sort, "name") == 0 { - return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction + return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction + nullsfix, []string{colName} } if strings.Compare(sort, "title") == 0 { - return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction + return " ORDER BY " + colName + " COLLATE NATURAL_CI " + direction + nullsfix, []string{colName} } - return " ORDER BY " + colName + " " + direction + return " ORDER BY " + colName + " " + direction + nullsfix, []string{colName} } } @@ -130,6 +138,11 @@ func getRandomSort(tableName string, direction string, seed uint64) string { // ORDER BY ((n+seed)*(n+seed)*p1 + (n+seed)*p2) % p3 // since sqlite converts overflowing numbers to reals, a custom db function that uses uints with overflow should be faster, // however in practice the overhead of calling a custom function vastly outweighs the benefits + + if dbWrapper.dbType == PostgresBackend { + colName = "CAST(" + colName + " AS DECIMAL)" + } + return fmt.Sprintf(" ORDER BY mod((%[1]s + %[2]d) * (%[1]s + %[2]d) * 52959209 + (%[1]s + %[2]d) * 1047483763, 2147483647) %[3]s", colName, seed, direction) } @@ -156,14 +169,14 @@ func getStringSearchClause(columns []string, q string, not bool) sqlClause { // Search for any word for _, word := range queryWords { for _, column := range columns { - likeClauses = append(likeClauses, column+notStr+" LIKE ?") + likeClauses = append(likeClauses, "LOWER("+column+")"+notStr+" LIKE LOWER(?)") args = append(args, "%"+word+"%") } } } else { // Search the exact query for _, column := range columns { - likeClauses = append(likeClauses, column+notStr+" LIKE ?") + likeClauses = append(likeClauses, "LOWER("+column+")"+notStr+" LIKE LOWER(?)") args = append(args, "%"+trimmedQuery+"%") } } @@ -333,10 +346,10 @@ func getMultiCriterionClause(primaryTable, foreignTable, joinTable, primaryFK, f // includes all of the provided ids if joinTable != "" { whereClause = joinTable + "." + foreignFK + " IN " + getInBinding(len(criterion.Value)) - havingClause = "count(distinct " + joinTable + "." + foreignFK + ") IS " + strconv.Itoa(len(criterion.Value)) + havingClause = "count(distinct " + joinTable + "." + foreignFK + ") = " + strconv.Itoa(len(criterion.Value)) } else { whereClause = foreignTable + ".id IN " + getInBinding(len(criterion.Value)) - havingClause = "count(distinct " + foreignTable + ".id) IS " + strconv.Itoa(len(criterion.Value)) + havingClause = "count(distinct " + foreignTable + ".id) = " + strconv.Itoa(len(criterion.Value)) } case models.CriterionModifierExcludes: // excludes all of the provided ids diff --git a/pkg/sqlite/stash_id_test.go b/pkg/sqlite/stash_id_test.go index 10949b47533..bf5e9b30c41 100644 --- a/pkg/sqlite/stash_id_test.go +++ b/pkg/sqlite/stash_id_test.go @@ -24,7 +24,7 @@ func testStashIDReaderWriter(ctx context.Context, t *testing.T, r stashIDReaderW testNoStashIDs(ctx, t, r, -1) // add stash ids - const stashIDStr = "stashID" + var stashIDStr = getUUID("stashID") const endpoint = "endpoint" stashID := models.StashID{ StashID: stashIDStr, diff --git a/pkg/sqlite/studio.go b/pkg/sqlite/studio.go index 6b81109b1fa..9e82bb532bf 100644 --- a/pkg/sqlite/studio.go +++ b/pkg/sqlite/studio.go @@ -495,13 +495,13 @@ func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]* var whereClauses []exp.Expression for _, w := range words { - whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%")) - whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%")) + whereClauses = append(whereClauses, table.Col(studioNameColumn).ILike(w+"%")) + whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").ILike(w+"%")) } sq = sq.Where( goqu.Or(whereClauses...), - table.Col("ignore_auto_tag").Eq(0), + table.Col("ignore_auto_tag").IsFalse(), ) ret, err := qb.findBySubquery(ctx, sq) @@ -539,11 +539,13 @@ func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.Studi } var err error - query.sortAndPagination, err = qb.getStudioSort(findFilter) + var group []string + query.sortAndPagination, group, err = qb.getStudioSort(findFilter) if err != nil { return nil, err } query.sortAndPagination += getPagination(findFilter) + query.addGroupBy(group...) return &query, nil } @@ -589,7 +591,7 @@ var studioSortOptions = sortOptions{ "updated_at", } -func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string, error) { +func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string, []string, error) { var sort string var direction string if findFilter == nil { @@ -602,9 +604,10 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string, // CVE-2024-32231 - ensure sort is in the list of allowed sorts if err := studioSortOptions.validateSort(sort); err != nil { - return "", err + return "", nil, err } + group := []string{} sortQuery := "" switch sort { case "tag_count": @@ -618,12 +621,15 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) (string, case "child_count": sortQuery += getCountSort(studioTable, studioTable, studioParentIDColumn, direction) default: - sortQuery += getSort(sort, direction, "studios") + var add string + add, group = getSort(sort, direction, "studios") + sortQuery += add } // Whatever the sorting, always use name/id as a final sort - sortQuery += ", COALESCE(studios.name, studios.id) COLLATE NATURAL_CI ASC" - return sortQuery, nil + sortQuery += ", COALESCE(studios.name, CAST(studios.id as text)) COLLATE NATURAL_CI ASC" + group = append(group, "studios.name", "studios.id") + return sortQuery, group, nil } func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) { diff --git a/pkg/sqlite/studio_filter.go b/pkg/sqlite/studio_filter.go index c514364c4ff..fdfe7563f09 100644 --- a/pkg/sqlite/studio_filter.go +++ b/pkg/sqlite/studio_filter.go @@ -63,7 +63,7 @@ func (qb *studioFilterHandler) criterionHandler() criterionHandler { criterionHandlerFunc(func(ctx context.Context, f *filterBuilder) { if studioFilter.StashID != nil { studioRepository.stashIDs.join(f, "studio_stash_ids", "studios.id") - stringCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(ctx, f) + uuidCriterionHandler(studioFilter.StashID, "studio_stash_ids.stash_id")(ctx, f) } }), &stashIDCriterionHandler{ @@ -124,7 +124,7 @@ func (qb *studioFilterHandler) isMissingCriterionHandler(isMissing *string) crit studioRepository.stashIDs.join(f, "studio_stash_ids", "studios.id") f.addWhere("studio_stash_ids.studio_id IS NULL") default: - f.addWhere("(studios." + *isMissing + " IS NULL OR TRIM(studios." + *isMissing + ") = '')") + f.addWhere("(studios." + *isMissing + " IS NULL OR TRIM(CAST(studios." + *isMissing + " AS TEXT)) = '')") } } } diff --git a/pkg/sqlite/studio_test.go b/pkg/sqlite/studio_test.go index a61dadc245f..5ddba93ff16 100644 --- a/pkg/sqlite/studio_test.go +++ b/pkg/sqlite/studio_test.go @@ -18,7 +18,7 @@ import ( func TestStudioFindByName(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio name := studioNames[studioIdxWithScene] // find a studio by name @@ -70,7 +70,7 @@ func TestStudioQueryNameOr(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studios := queryStudio(ctx, t, sqb, &studioFilter, nil) @@ -103,7 +103,7 @@ func TestStudioQueryNameAndUrl(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studios := queryStudio(ctx, t, sqb, &studioFilter, nil) @@ -140,7 +140,7 @@ func TestStudioQueryNameNotUrl(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studios := queryStudio(ctx, t, sqb, &studioFilter, nil) @@ -173,7 +173,7 @@ func TestStudioIllegalQuery(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio _, _, err := sqb.Query(ctx, studioFilter, nil) assert.NotNil(err) @@ -199,7 +199,7 @@ func TestStudioQueryIgnoreAutoTag(t *testing.T) { IgnoreAutoTag: &ignoreAutoTag, } - sqb := db.Studio + sqb := db.GetRepo().Studio studios := queryStudio(ctx, t, sqb, &studioFilter, nil) @@ -214,7 +214,7 @@ func TestStudioQueryIgnoreAutoTag(t *testing.T) { func TestStudioQueryForAutoTag(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Studio + tqb := db.GetRepo().Studio name := studioNames[studioIdxWithGroup] // find a studio by name @@ -242,7 +242,7 @@ func TestStudioQueryForAutoTag(t *testing.T) { func TestStudioQueryParent(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioCriterion := models.MultiCriterionInput{ Value: []string{ strconv.Itoa(studioIDs[studioIdxWithChildStudio]), @@ -292,18 +292,18 @@ func TestStudioDestroyParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.GetRepo().Studio, parentName, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.GetRepo().Studio, childName, &parentID) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } - sqb := db.Studio + sqb := db.GetRepo().Studio // destroy the parent err = sqb.Destroy(ctx, createdParent.ID) @@ -325,7 +325,7 @@ func TestStudioDestroyParent(t *testing.T) { func TestStudioFindChildren(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studios, err := sqb.FindChildren(ctx, studioIDs[studioIdxWithChildStudio]) @@ -354,18 +354,18 @@ func TestStudioUpdateClearParent(t *testing.T) { // create parent and child studios if err := withTxn(func(ctx context.Context) error { - createdParent, err := createStudio(ctx, db.Studio, parentName, nil) + createdParent, err := createStudio(ctx, db.GetRepo().Studio, parentName, nil) if err != nil { return fmt.Errorf("Error creating parent studio: %s", err.Error()) } parentID := createdParent.ID - createdChild, err := createStudio(ctx, db.Studio, childName, &parentID) + createdChild, err := createStudio(ctx, db.GetRepo().Studio, childName, &parentID) if err != nil { return fmt.Errorf("Error creating child studio: %s", err.Error()) } - sqb := db.Studio + sqb := db.GetRepo().Studio // clear the parent id from the child input := models.StudioPartial{ @@ -391,11 +391,11 @@ func TestStudioUpdateClearParent(t *testing.T) { func TestStudioUpdateStudioImage(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Studio + qb := db.GetRepo().Studio // create studio to test against const name = "TestStudioUpdateStudioImage" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.GetRepo().Studio, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -427,7 +427,7 @@ func TestStudioQuerySceneCount(t *testing.T) { func verifyStudiosSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioFilter := models.StudioFilterType{ SceneCount: &sceneCountCriterion, } @@ -436,7 +436,7 @@ func verifyStudiosSceneCount(t *testing.T, sceneCountCriterion models.IntCriteri assert.Greater(t, len(studios), 0) for _, studio := range studios { - sceneCount, err := db.Scene.CountByStudioID(ctx, studio.ID) + sceneCount, err := db.GetRepo().Scene.CountByStudioID(ctx, studio.ID) if err != nil { return err } @@ -468,7 +468,7 @@ func TestStudioQueryImageCount(t *testing.T) { func verifyStudiosImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioFilter := models.StudioFilterType{ ImageCount: &imageCountCriterion, } @@ -479,7 +479,7 @@ func verifyStudiosImageCount(t *testing.T, imageCountCriterion models.IntCriteri for _, studio := range studios { pp := 0 - result, err := db.Image.Query(ctx, models.ImageQueryOptions{ + result, err := db.GetRepo().Image.Query(ctx, models.ImageQueryOptions{ QueryOptions: models.QueryOptions{ FindFilter: &models.FindFilterType{ PerPage: &pp, @@ -524,7 +524,7 @@ func TestStudioQueryGalleryCount(t *testing.T) { func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioFilter := models.StudioFilterType{ GalleryCount: &galleryCountCriterion, } @@ -535,7 +535,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri for _, studio := range studios { pp := 0 - _, count, err := db.Gallery.Query(ctx, &models.GalleryFilterType{ + _, count, err := db.GetRepo().Gallery.Query(ctx, &models.GalleryFilterType{ Studios: &models.HierarchicalMultiCriterionInput{ Value: []string{strconv.Itoa(studio.ID)}, Modifier: models.CriterionModifierIncludes, @@ -555,11 +555,11 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri func TestStudioStashIDs(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Studio + qb := db.GetRepo().Studio // create studio to test against const name = "TestStudioStashIDs" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.GetRepo().Studio, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -581,7 +581,7 @@ func TestStudioStashIDs(t *testing.T) { } func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) { - qb := db.Studio + qb := db.GetRepo().Studio if err := s.LoadStashIDs(ctx, qb); err != nil { t.Error(err.Error()) @@ -592,7 +592,7 @@ func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) { assert.Len(t, s.StashIDs.List(), 0) // add stash ids - const stashIDStr = "stashID" + var stashIDStr = getUUID("stashID") const endpoint = "endpoint" stashID := models.StashID{ StashID: stashIDStr, @@ -706,7 +706,7 @@ func TestStudioQueryRating(t *testing.T) { func queryStudios(ctx context.Context, t *testing.T, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) []*models.Studio { t.Helper() - studios, _, err := db.Studio.Query(ctx, studioFilter, findFilter) + studios, _, err := db.GetRepo().Studio.Query(ctx, studioFilter, findFilter) if err != nil { t.Errorf("Error querying studio: %s", err.Error()) } @@ -788,7 +788,7 @@ func TestStudioQueryTagCount(t *testing.T) { func verifyStudiosTagCount(t *testing.T, tagCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioFilter := models.StudioFilterType{ TagCount: &tagCountCriterion, } @@ -811,7 +811,7 @@ func verifyStudiosTagCount(t *testing.T, tagCountCriterion models.IntCriterionIn func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn func(ctx context.Context, s *models.Studio)) { withTxn(func(ctx context.Context) error { t.Helper() - sqb := db.Studio + sqb := db.GetRepo().Studio studios := queryStudio(ctx, t, sqb, &filter, nil) @@ -828,7 +828,7 @@ func verifyStudioQuery(t *testing.T, filter models.StudioFilterType, verifyFn fu func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio studioFilter := models.StudioFilterType{ Rating100: &ratingCriterion, } @@ -849,7 +849,7 @@ func verifyStudiosRating(t *testing.T, ratingCriterion models.IntCriterionInput) func TestStudioQueryIsMissingRating(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio isMissing := "rating" studioFilter := models.StudioFilterType{ IsMissing: &isMissing, @@ -925,7 +925,7 @@ func TestStudioQueryAlias(t *testing.T) { verifyFn := func(ctx context.Context, studio *models.Studio) { t.Helper() - aliases, err := db.Studio.GetAliases(ctx, studio.ID) + aliases, err := db.GetRepo().Studio.GetAliases(ctx, studio.ID) if err != nil { t.Errorf("Error querying studios: %s", err.Error()) } @@ -960,11 +960,11 @@ func TestStudioQueryAlias(t *testing.T) { func TestStudioAlias(t *testing.T) { if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Studio + qb := db.GetRepo().Studio // create studio to test against const name = "TestStudioAlias" - created, err := createStudio(ctx, db.Studio, name, nil) + created, err := createStudio(ctx, db.GetRepo().Studio, name, nil) if err != nil { return fmt.Errorf("Error creating studio: %s", err.Error()) } @@ -986,7 +986,7 @@ func TestStudioAlias(t *testing.T) { } func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) { - qb := db.Studio + qb := db.GetRepo().Studio if err := s.LoadAliases(ctx, qb); err != nil { t.Error(err.Error()) return @@ -1044,13 +1044,14 @@ func TestStudioQueryFast(t *testing.T) { tsString := "test" tsInt := 1 + tsId := "1" testStringCriterion := models.StringCriterionInput{ Value: tsString, Modifier: models.CriterionModifierEquals, } - testIncludesMultiCriterion := models.MultiCriterionInput{ - Value: []string{tsString}, + testIncludesMultiCriterionId := models.MultiCriterionInput{ + Value: []string{tsId}, Modifier: models.CriterionModifierIncludes, } testIntCriterion := models.IntCriterionInput{ @@ -1080,7 +1081,7 @@ func TestStudioQueryFast(t *testing.T) { SceneCount: &testIntCriterion, } parentsFilter := models.StudioFilterType{ - Parents: &testIncludesMultiCriterion, + Parents: &testIncludesMultiCriterionId, } filters := []models.StudioFilterType{nameFilter, aliasesFilter, stashIDFilter, urlFilter, ratingFilter, sceneCountFilter, imageCountFilter, parentsFilter} @@ -1108,7 +1109,7 @@ func TestStudioQueryFast(t *testing.T) { } withTxn(func(ctx context.Context) error { - sqb := db.Studio + sqb := db.GetRepo().Studio for _, f := range filters { for _, ff := range findFilters { _, _, err := sqb.Query(ctx, &f, &ff) diff --git a/pkg/sqlite/table.go b/pkg/sqlite/table.go index e374f0790e3..0a80ee84a8d 100644 --- a/pkg/sqlite/table.go +++ b/pkg/sqlite/table.go @@ -42,17 +42,13 @@ func (t *table) insert(ctx context.Context, o interface{}) (sql.Result, error) { } func (t *table) insertID(ctx context.Context, o interface{}) (int, error) { - result, err := t.insert(ctx, o) + q := dialect.Insert(t.table).Prepared(true).Rows(o).Returning(goqu.I("id")) + val, err := execID(ctx, q) if err != nil { - return 0, err - } - - ret, err := result.LastInsertId() - if err != nil { - return 0, err + return -1, fmt.Errorf("inserting into %s: %w", t.table.GetTable(), err) } - return int(ret), nil + return int(*val), nil } func (t *table) updateByID(ctx context.Context, id interface{}, o interface{}) error { @@ -743,7 +739,7 @@ func (t *imageGalleriesTable) resetCover(ctx context.Context, galleryID int) err "cover": false, }).Where( table.Col(galleryIDColumn).Eq(galleryID), - table.Col("cover").Eq(true), + table.Col("cover").IsTrue(), ) if _, err := exec(ctx, q); err != nil { @@ -832,7 +828,7 @@ func (t *relatedFilesTable) setPrimary(ctx context.Context, id int, fileID model table := t.table.table q := dialect.Update(table).Prepared(true).Set(goqu.Record{ - "primary": 0, + "primary": false, }).Where(t.idColumn.Eq(id), table.Col(fileIDColumn).Neq(fileID)) if _, err := exec(ctx, q); err != nil { @@ -840,7 +836,7 @@ func (t *relatedFilesTable) setPrimary(ctx context.Context, id int, fileID model } q = dialect.Update(table).Prepared(true).Set(goqu.Record{ - "primary": 1, + "primary": true, }).Where(t.idColumn.Eq(id), table.Col(fileIDColumn).Eq(fileID)) if _, err := exec(ctx, q); err != nil { @@ -1079,21 +1075,23 @@ func (t *viewHistoryTable) deleteDates(ctx context.Context, id int, dates []time dates = []time.Time{time.Now()} } + rowIdColumn := getDBRowId() + for _, date := range dates { var subquery *goqu.SelectDataset if mostRecent { // delete the most recent - subquery = dialect.Select("rowid").From(table).Where( + subquery = dialect.Select(rowIdColumn).From(table).Where( t.idColumn.Eq(id), ).Order(t.dateColumn.Desc()).Limit(1) } else { - subquery = dialect.Select("rowid").From(table).Where( + subquery = dialect.Select(rowIdColumn).From(table).Where( t.idColumn.Eq(id), t.dateColumn.Eq(UTCTimestamp{Timestamp{date}}), ).Limit(1) } - q := dialect.Delete(table).Where(goqu.I("rowid").Eq(subquery)) + q := dialect.Delete(table).Where(goqu.I(rowIdColumn).Eq(subquery)) if _, err := exec(ctx, q); err != nil { return nil, fmt.Errorf("deleting from %s: %w", table.GetTable(), err) @@ -1138,6 +1136,28 @@ func exec(ctx context.Context, stmt sqler) (sql.Result, error) { return ret, nil } +// Execute, but returns an ID +func execID(ctx context.Context, stmt sqler) (*int64, error) { + tx, err := getTx(ctx) + if err != nil { + return nil, err + } + + sql, args, err := stmt.ToSQL() + if err != nil { + return nil, fmt.Errorf("generating sql: %w", err) + } + + logger.Tracef("SQL: %s [%v]", sql, args) + var id int64 + err = tx.QueryRowContext(ctx, sql, args...).Scan(&id) + if err != nil { + return nil, fmt.Errorf("executing `%s` [%v]: %w", sql, args, err) + } + + return &id, nil +} + func count(ctx context.Context, q *goqu.SelectDataset) (int, error) { var count int if err := querySimple(ctx, q, &count); err != nil { diff --git a/pkg/sqlite/tag.go b/pkg/sqlite/tag.go index 919ef4847df..c2c5c8d5b71 100644 --- a/pkg/sqlite/tag.go +++ b/pkg/sqlite/tag.go @@ -383,7 +383,8 @@ func (qb *TagStore) FindBySceneID(ctx context.Context, sceneID int) ([]*models.T WHERE scenes_join.scene_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{sceneID} return qb.queryTags(ctx, query, args) } @@ -395,7 +396,8 @@ func (qb *TagStore) FindByPerformerID(ctx context.Context, performerID int) ([]* WHERE performers_join.performer_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{performerID} return qb.queryTags(ctx, query, args) } @@ -407,7 +409,8 @@ func (qb *TagStore) FindByImageID(ctx context.Context, imageID int) ([]*models.T WHERE images_join.image_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{imageID} return qb.queryTags(ctx, query, args) } @@ -419,7 +422,8 @@ func (qb *TagStore) FindByGalleryID(ctx context.Context, galleryID int) ([]*mode WHERE galleries_join.gallery_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{galleryID} return qb.queryTags(ctx, query, args) } @@ -431,7 +435,8 @@ func (qb *TagStore) FindByGroupID(ctx context.Context, groupID int) ([]*models.T WHERE groups_join.group_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{groupID} return qb.queryTags(ctx, query, args) } @@ -443,7 +448,8 @@ func (qb *TagStore) FindBySceneMarkerID(ctx context.Context, sceneMarkerID int) WHERE scene_markers_join.scene_marker_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{sceneMarkerID} return qb.queryTags(ctx, query, args) } @@ -455,7 +461,8 @@ func (qb *TagStore) FindByStudioID(ctx context.Context, studioID int) ([]*models WHERE studios_join.studio_id = ? GROUP BY tags.id ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{studioID} return qb.queryTags(ctx, query, args) } @@ -519,7 +526,8 @@ func (qb *TagStore) FindByParentTagID(ctx context.Context, parentID int) ([]*mod INNER JOIN tags_relations ON tags_relations.child_id = tags.id WHERE tags_relations.parent_id = ? ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{parentID} return qb.queryTags(ctx, query, args) } @@ -530,7 +538,8 @@ func (qb *TagStore) FindByChildTagID(ctx context.Context, parentID int) ([]*mode INNER JOIN tags_relations ON tags_relations.parent_id = tags.id WHERE tags_relations.child_id = ? ` - query += qb.getDefaultTagSort() + add, _ := qb.getDefaultTagSort() + query += add args := []interface{}{parentID} return qb.queryTags(ctx, query, args) } @@ -574,17 +583,17 @@ func (qb *TagStore) QueryForAutoTag(ctx context.Context, words []string) ([]*mod for _, w := range words { ww := w + "%" - whereClauses = append(whereClauses, "tags.name like ?") + whereClauses = append(whereClauses, "LOWER(tags.name) LIKE LOWER(?)") args = append(args, ww) // include aliases - whereClauses = append(whereClauses, "tag_aliases.alias like ?") + whereClauses = append(whereClauses, "LOWER(tag_aliases.alias) LIKE LOWER(?)") args = append(args, ww) } whereOr := "(" + strings.Join(whereClauses, " OR ") + ")" where := strings.Join([]string{ - "tags.ignore_auto_tag = 0", + "tags.ignore_auto_tag = " + getDBBoolean(false), whereOr, }, " AND ") return qb.queryTags(ctx, query+" WHERE "+where, args) @@ -616,11 +625,13 @@ func (qb *TagStore) Query(ctx context.Context, tagFilter *models.TagFilterType, } var err error - query.sortAndPagination, err = qb.getTagSort(&query, findFilter) + var group []string + query.sortAndPagination, group, err = qb.getTagSort(&query, findFilter) if err != nil { return nil, 0, err } query.sortAndPagination += getPagination(findFilter) + query.addGroupBy(group...) idsResult, countResult, err := query.executeFind(ctx) if err != nil { return nil, 0, err @@ -650,11 +661,11 @@ var tagSortOptions = sortOptions{ "updated_at", } -func (qb *TagStore) getDefaultTagSort() string { +func (qb *TagStore) getDefaultTagSort() (string, []string) { return getSort("name", "ASC", "tags") } -func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilterType) (string, error) { +func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilterType) (string, []string, error) { var sort string var direction string if findFilter == nil { @@ -667,9 +678,10 @@ func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilte // CVE-2024-32231 - ensure sort is in the list of allowed sorts if err := tagSortOptions.validateSort(sort); err != nil { - return "", err + return "", nil, err } + group := []string{} sortQuery := "" switch sort { case "scenes_count": @@ -687,12 +699,15 @@ func (qb *TagStore) getTagSort(query *queryBuilder, findFilter *models.FindFilte case "movies_count", "groups_count": sortQuery += getCountSort(tagTable, groupsTagsTable, tagIDColumn, direction) default: - sortQuery += getSort(sort, direction, "tags") + var add string + add, group = getSort(sort, direction, "tags") + sortQuery += add } // Whatever the sorting, always use name/id as a final sort - sortQuery += ", COALESCE(tags.name, tags.id) COLLATE NATURAL_CI ASC" - return sortQuery, nil + sortQuery += ", COALESCE(tags.name, CAST(tags.id as text)) COLLATE NATURAL_CI ASC" + group = append(group, "tags.name", "tags.id") + return sortQuery, group, nil } func (qb *TagStore) queryTags(ctx context.Context, query string, args []interface{}) ([]*models.Tag, error) { @@ -786,16 +801,21 @@ func (qb *TagStore) Merge(ctx context.Context, source []int, destination int) er "studios_tags": "studio_id", } - args = append(args, destination) for table, idColumn := range tagTables { - _, err := dbWrapper.Exec(ctx, `UPDATE OR IGNORE `+table+` -SET tag_id = ? -WHERE tag_id IN `+inBinding+` -AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idColumn+` AND o.tag_id = ?)`, - args..., - ) - if err != nil { - return err + for _, to_migrate_id := range srcArgs { + err := withSavepoint(ctx, func(ctx context.Context) error { + _, err := dbWrapper.Exec(ctx, `UPDATE `+table+` + SET tag_id = $1 + WHERE tag_id = $2 + AND NOT EXISTS(SELECT 1 FROM `+table+` o WHERE o.`+idColumn+` = `+table+`.`+idColumn+` AND o.tag_id = $1)`, + destination, to_migrate_id, + ) + return err + }) + + if err != nil && !isConstraintError(err) { + return err + } } // delete source tag ids from the table where they couldn't be set @@ -887,12 +907,10 @@ parents AS ( SELECT t.*, p.path FROM tags t INNER JOIN parents p ON t.id = p.parent_id ` - excludeArgs := []interface{}{tagID} + args := []interface{}{tagID, tagID} for _, excludeID := range excludeIDs { - excludeArgs = append(excludeArgs, excludeID) + args = append(args, excludeID) } - args := []interface{}{tagID} - args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) return qb.queryTagPaths(ctx, query, args) } @@ -911,12 +929,10 @@ children AS ( SELECT t.*, c.path FROM tags t INNER JOIN children c ON t.id = c.child_id ` - excludeArgs := []interface{}{tagID} + args := []interface{}{tagID, tagID} for _, excludeID := range excludeIDs { - excludeArgs = append(excludeArgs, excludeID) + args = append(args, excludeID) } - args := []interface{}{tagID} - args = append(args, append(append(excludeArgs, excludeArgs...), excludeArgs...)...) return qb.queryTagPaths(ctx, query, args) } diff --git a/pkg/sqlite/tag_filter.go b/pkg/sqlite/tag_filter.go index ba9e9bb08ec..638fca5e853 100644 --- a/pkg/sqlite/tag_filter.go +++ b/pkg/sqlite/tag_filter.go @@ -136,7 +136,7 @@ func (qb *tagFilterHandler) isMissingCriterionHandler(isMissing *string) criteri case "image": f.addWhere("tags.image_blob IS NULL") default: - f.addWhere("(tags." + *isMissing + " IS NULL OR TRIM(tags." + *isMissing + ") = '')") + f.addWhere("(tags." + *isMissing + " IS NULL OR TRIM(CAST(tags." + *isMissing + " AS TEXT)) = '')") } } } diff --git a/pkg/sqlite/tag_test.go b/pkg/sqlite/tag_test.go index 5359be78517..e9efcd0aa07 100644 --- a/pkg/sqlite/tag_test.go +++ b/pkg/sqlite/tag_test.go @@ -17,7 +17,7 @@ import ( func TestMarkerFindBySceneMarkerID(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Tag + tqb := db.GetRepo().Tag markerID := markerIDs[markerIdxWithTag] @@ -44,7 +44,7 @@ func TestMarkerFindBySceneMarkerID(t *testing.T) { func TestTagFindByGroupID(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Tag + tqb := db.GetRepo().Tag groupID := groupIDs[groupIdxWithTag] @@ -71,7 +71,7 @@ func TestTagFindByGroupID(t *testing.T) { func TestTagFindByName(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Tag + tqb := db.GetRepo().Tag name := tagNames[tagIdxWithScene] // find a tag by name @@ -107,7 +107,7 @@ func TestTagQueryIgnoreAutoTag(t *testing.T) { IgnoreAutoTag: &ignoreAutoTag, } - sqb := db.Tag + sqb := db.GetRepo().Tag tags := queryTags(ctx, t, sqb, &tagFilter, nil) @@ -122,7 +122,7 @@ func TestTagQueryIgnoreAutoTag(t *testing.T) { func TestTagQueryForAutoTag(t *testing.T) { withTxn(func(ctx context.Context) error { - tqb := db.Tag + tqb := db.GetRepo().Tag name := tagNames[tagIdx1WithScene] // find a tag by name @@ -156,7 +156,7 @@ func TestTagFindByNames(t *testing.T) { var names []string withTxn(func(ctx context.Context) error { - tqb := db.Tag + tqb := db.GetRepo().Tag names = append(names, tagNames[tagIdxWithScene]) // find tags by names @@ -201,7 +201,7 @@ func TestTagFindByNames(t *testing.T) { func TestTagQuerySort(t *testing.T) { withTxn(func(ctx context.Context) error { - sqb := db.Tag + sqb := db.GetRepo().Tag sortBy := "scenes_count" dir := models.SortDirectionEnumDesc @@ -286,7 +286,7 @@ func TestTagQueryAlias(t *testing.T) { } verifyFn := func(ctx context.Context, tag *models.Tag) { - aliases, err := db.Tag.GetAliases(ctx, tag.ID) + aliases, err := db.GetRepo().Tag.GetAliases(ctx, tag.ID) if err != nil { t.Errorf("Error querying tags: %s", err.Error()) } @@ -321,7 +321,7 @@ func TestTagQueryAlias(t *testing.T) { func verifyTagQuery(t *testing.T, tagFilter *models.TagFilterType, findFilter *models.FindFilterType, verifyFn func(ctx context.Context, t *models.Tag)) { withTxn(func(ctx context.Context) error { - sqb := db.Tag + sqb := db.GetRepo().Tag tags := queryTags(ctx, t, sqb, tagFilter, findFilter) @@ -345,7 +345,7 @@ func queryTags(ctx context.Context, t *testing.T, qb models.TagReader, tagFilter func TestTagQueryIsMissingImage(t *testing.T) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag isMissing := "image" tagFilter := models.TagFilterType{ IsMissing: &isMissing, @@ -399,7 +399,7 @@ func TestTagQuerySceneCount(t *testing.T) { func verifyTagSceneCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ SceneCount: &sceneCountCriterion, } @@ -438,7 +438,7 @@ func TestTagQueryMarkerCount(t *testing.T) { func verifyTagMarkerCount(t *testing.T, markerCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ MarkerCount: &markerCountCriterion, } @@ -477,7 +477,7 @@ func TestTagQueryImageCount(t *testing.T) { func verifyTagImageCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ ImageCount: &imageCountCriterion, } @@ -516,7 +516,7 @@ func TestTagQueryGalleryCount(t *testing.T) { func verifyTagGalleryCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ GalleryCount: &imageCountCriterion, } @@ -555,7 +555,7 @@ func TestTagQueryPerformerCount(t *testing.T) { func verifyTagPerformerCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ PerformerCount: &imageCountCriterion, } @@ -594,7 +594,7 @@ func TestTagQueryStudioCount(t *testing.T) { func verifyTagStudioCount(t *testing.T, imageCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ StudioCount: &imageCountCriterion, } @@ -633,7 +633,7 @@ func TestTagQueryParentCount(t *testing.T) { func verifyTagParentCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ ParentCount: &sceneCountCriterion, } @@ -673,7 +673,7 @@ func TestTagQueryChildCount(t *testing.T) { func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionInput) { withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag tagFilter := models.TagFilterType{ ChildCount: &sceneCountCriterion, } @@ -695,7 +695,7 @@ func verifyTagChildCount(t *testing.T, sceneCountCriterion models.IntCriterionIn func TestTagQueryParent(t *testing.T) { withTxn(func(ctx context.Context) error { const nameField = "Name" - sqb := db.Tag + sqb := db.GetRepo().Tag tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithChildTag]), @@ -773,7 +773,7 @@ func TestTagQueryChild(t *testing.T) { withTxn(func(ctx context.Context) error { const nameField = "Name" - sqb := db.Tag + sqb := db.GetRepo().Tag tagCriterion := models.HierarchicalMultiCriterionInput{ Value: []string{ strconv.Itoa(tagIDs[tagIdxWithParentTag]), @@ -849,7 +849,7 @@ func TestTagQueryChild(t *testing.T) { func TestTagUpdateTagImage(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag // create tag to test against const name = "TestTagUpdateTagImage" @@ -869,7 +869,7 @@ func TestTagUpdateTagImage(t *testing.T) { func TestTagUpdateAlias(t *testing.T) { if err := withTxn(func(ctx context.Context) error { - qb := db.Tag + qb := db.GetRepo().Tag // create tag to test against const name = "TestTagUpdateAlias" @@ -905,8 +905,8 @@ func TestTagMerge(t *testing.T) { // merge tests - perform these in a transaction that we'll rollback if err := withRollbackTxn(func(ctx context.Context) error { - qb := db.Tag - mqb := db.SceneMarker + qb := db.GetRepo().Tag + mqb := db.GetRepo().SceneMarker // try merging into same tag err := qb.Merge(ctx, []int{tagIDs[tagIdx1WithScene]}, tagIDs[tagIdx1WithScene]) @@ -962,11 +962,11 @@ func TestTagMerge(t *testing.T) { } // ensure scene points to new tag - s, err := db.Scene.Find(ctx, sceneIDs[sceneIdxWithTwoTags]) + s, err := db.GetRepo().Scene.Find(ctx, sceneIDs[sceneIdxWithTwoTags]) if err != nil { return err } - if err := s.LoadTagIDs(ctx, db.Scene); err != nil { + if err := s.LoadTagIDs(ctx, db.GetRepo().Scene); err != nil { return err } sceneTagIDs := s.TagIDs.List() @@ -989,19 +989,19 @@ func TestTagMerge(t *testing.T) { assert.Contains(markerTagIDs, destID) // ensure image points to new tag - imageTagIDs, err := db.Image.GetTagIDs(ctx, imageIDs[imageIdxWithTwoTags]) + imageTagIDs, err := db.GetRepo().Image.GetTagIDs(ctx, imageIDs[imageIdxWithTwoTags]) if err != nil { return err } assert.Contains(imageTagIDs, destID) - g, err := db.Gallery.Find(ctx, galleryIDs[galleryIdxWithTwoTags]) + g, err := db.GetRepo().Gallery.Find(ctx, galleryIDs[galleryIdxWithTwoTags]) if err != nil { return err } - if err := g.LoadTagIDs(ctx, db.Gallery); err != nil { + if err := g.LoadTagIDs(ctx, db.GetRepo().Gallery); err != nil { return err } @@ -1009,7 +1009,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(g.TagIDs.List(), destID) // ensure performer points to new tag - performerTagIDs, err := db.Performer.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags]) + performerTagIDs, err := db.GetRepo().Performer.GetTagIDs(ctx, performerIDs[performerIdxWithTwoTags]) if err != nil { return err } @@ -1017,7 +1017,7 @@ func TestTagMerge(t *testing.T) { assert.Contains(performerTagIDs, destID) // ensure studio points to new tag - studioTagIDs, err := db.Studio.GetTagIDs(ctx, studioIDs[studioIdxWithTwoTags]) + studioTagIDs, err := db.GetRepo().Studio.GetTagIDs(ctx, studioIDs[studioIdxWithTwoTags]) if err != nil { return err } diff --git a/pkg/sqlite/transaction.go b/pkg/sqlite/transaction.go index fb86723bdff..17b927cfff9 100644 --- a/pkg/sqlite/transaction.go +++ b/pkg/sqlite/transaction.go @@ -6,6 +6,7 @@ import ( "fmt" "runtime/debug" + "github.com/jackc/pgx/v5/pgconn" "github.com/jmoiron/sqlx" "github.com/mattn/go-sqlite3" "github.com/stashapp/stash/pkg/logger" @@ -108,9 +109,18 @@ func getDBReader(ctx context.Context) (dbReader, error) { } func (db *Database) IsLocked(err error) bool { - var sqliteError sqlite3.Error - if errors.As(err, &sqliteError) { - return sqliteError.Code == sqlite3.ErrBusy + switch dbWrapper.dbType { + case PostgresBackend: + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + // Class 53 — Insufficient Resources + return pgErr.Code[:2] == "53" + } + case SqliteBackend: + var sqliteError sqlite3.Error + if errors.As(err, &sqliteError) { + return sqliteError.Code == sqlite3.ErrBusy + } } return false } diff --git a/pkg/sqlite/transaction_test.go b/pkg/sqlite/transaction_test.go index 513a60a2065..8a1e809779c 100644 --- a/pkg/sqlite/transaction_test.go +++ b/pkg/sqlite/transaction_test.go @@ -36,11 +36,11 @@ import ( // Title: "test", // } -// if err := db.Scene.Create(ctx, scene, nil); err != nil { +// if err := db.GetRepo().Scene.Create(ctx, scene, nil); err != nil { // return err // } -// if err := db.Scene.Destroy(ctx, scene.ID); err != nil { +// if err := db.GetRepo().Scene.Destroy(ctx, scene.ID); err != nil { // return err // } // } @@ -94,7 +94,7 @@ func waitForOtherThread(c chan struct{}) error { // Title: "test", // } -// if err := db.Scene.Create(ctx, scene, nil); err != nil { +// if err := db.GetRepo().Scene.Create(ctx, scene, nil); err != nil { // return err // } @@ -106,7 +106,7 @@ func waitForOtherThread(c chan struct{}) error { // return err // } -// if err := db.Scene.Destroy(ctx, scene.ID); err != nil { +// if err := db.GetRepo().Scene.Destroy(ctx, scene.ID); err != nil { // return err // } @@ -139,7 +139,7 @@ func waitForOtherThread(c chan struct{}) error { // // expect error when we try to do this, as the other thread has already // // modified this table // // this takes time to fail, so we need to wait for it -// if err := db.Scene.Create(ctx, scene, nil); err != nil { +// if err := db.GetRepo().Scene.Create(ctx, scene, nil); err != nil { // if !db.IsLocked(err) { // t.Errorf("unexpected error: %v", err) // } @@ -169,7 +169,7 @@ func TestConcurrentExclusiveAndReadTxn(t *testing.T) { Title: "test", } - if err := db.Scene.Create(ctx, scene, nil); err != nil { + if err := db.GetRepo().Scene.Create(ctx, scene, nil); err != nil { return err } @@ -181,7 +181,7 @@ func TestConcurrentExclusiveAndReadTxn(t *testing.T) { return err } - if err := db.Scene.Destroy(ctx, scene.ID); err != nil { + if err := db.GetRepo().Scene.Destroy(ctx, scene.ID); err != nil { return err } @@ -207,7 +207,7 @@ func TestConcurrentExclusiveAndReadTxn(t *testing.T) { } }() - if _, err := db.Scene.Find(ctx, sceneIDs[sceneIdx1WithPerformer]); err != nil { + if _, err := db.GetRepo().Scene.Find(ctx, sceneIDs[sceneIdx1WithPerformer]); err != nil { t.Errorf("unexpected error: %v", err) return err } @@ -241,11 +241,11 @@ func TestConcurrentExclusiveAndReadTxn(t *testing.T) { // Title: "test", // } -// if err := db.Scene.Create(ctx, scene, nil); err != nil { +// if err := db.GetRepo().Scene.Create(ctx, scene, nil); err != nil { // return err // } -// if err := db.Scene.Destroy(ctx, scene.ID); err != nil { +// if err := db.GetRepo().Scene.Destroy(ctx, scene.ID); err != nil { // return err // } // } @@ -267,7 +267,7 @@ func TestConcurrentExclusiveAndReadTxn(t *testing.T) { // for l := 0; l < loops; l++ { // if err := txn.WithReadTxn(ctx, db, func(ctx context.Context) error { // for ll := 0; ll < innerLoops; ll++ { -// if _, err := db.Scene.Find(ctx, sceneIDs[ll%totalScenes]); err != nil { +// if _, err := db.GetRepo().Scene.Find(ctx, sceneIDs[ll%totalScenes]); err != nil { // return err // } // } diff --git a/pkg/sqlite/tx.go b/pkg/sqlite/tx.go index a2e272aa9f3..4733a9ed8e4 100644 --- a/pkg/sqlite/tx.go +++ b/pkg/sqlite/tx.go @@ -7,6 +7,7 @@ import ( "time" "github.com/jmoiron/sqlx" + "github.com/stashapp/stash/pkg/hash" "github.com/stashapp/stash/pkg/logger" ) @@ -35,7 +36,9 @@ func logSQL(start time.Time, query string, args ...interface{}) { } } -type dbWrapperType struct{} +type dbWrapperType struct { + dbType DatabaseType +} var dbWrapper = dbWrapperType{} @@ -47,7 +50,17 @@ func sqlError(err error, sql string, args ...interface{}) error { return fmt.Errorf("error executing `%s` [%v]: %w", sql, args, err) } -func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (db *dbWrapperType) Rebind(query string) string { + switch db.dbType { + case PostgresBackend: + return sqlx.Rebind(sqlx.DOLLAR, query) + default: + return query + } +} + +func (db *dbWrapperType) Get(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + query = db.Rebind(query) tx, err := getDBReader(ctx) if err != nil { return sqlError(err, query, args...) @@ -60,7 +73,8 @@ func (*dbWrapperType) Get(ctx context.Context, dest interface{}, query string, a return sqlError(err, query, args...) } -func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error { +func (db *dbWrapperType) Select(ctx context.Context, dest interface{}, query string, args ...interface{}) error { + query = db.Rebind(query) tx, err := getDBReader(ctx) if err != nil { return sqlError(err, query, args...) @@ -73,7 +87,8 @@ func (*dbWrapperType) Select(ctx context.Context, dest interface{}, query string return sqlError(err, query, args...) } -func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { +func (db *dbWrapperType) Queryx(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + query = db.Rebind(query) tx, err := getDBReader(ctx) if err != nil { return nil, sqlError(err, query, args...) @@ -86,7 +101,8 @@ func (*dbWrapperType) Queryx(ctx context.Context, query string, args ...interfac return ret, sqlError(err, query, args...) } -func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { +func (db *dbWrapperType) QueryxContext(ctx context.Context, query string, args ...interface{}) (*sqlx.Rows, error) { + query = db.Rebind(query) tx, err := getDBReader(ctx) if err != nil { return nil, sqlError(err, query, args...) @@ -99,7 +115,8 @@ func (*dbWrapperType) QueryxContext(ctx context.Context, query string, args ...i return ret, sqlError(err, query, args...) } -func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { +func (db *dbWrapperType) NamedExec(ctx context.Context, query string, arg interface{}) (sql.Result, error) { + query = db.Rebind(query) tx, err := getTx(ctx) if err != nil { return nil, sqlError(err, query, arg) @@ -112,7 +129,8 @@ func (*dbWrapperType) NamedExec(ctx context.Context, query string, arg interface return ret, sqlError(err, query, arg) } -func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { +func (db *dbWrapperType) Exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + query = db.Rebind(query) tx, err := getTx(ctx) if err != nil { return nil, sqlError(err, query, args...) @@ -126,7 +144,8 @@ func (*dbWrapperType) Exec(ctx context.Context, query string, args ...interface{ } // Prepare creates a prepared statement. -func (*dbWrapperType) Prepare(ctx context.Context, query string, args ...interface{}) (*stmt, error) { +func (db *dbWrapperType) Prepare(ctx context.Context, query string, args ...interface{}) (*stmt, error) { + query = db.Rebind(query) tx, err := getTx(ctx) if err != nil { return nil, sqlError(err, query, args...) @@ -144,7 +163,7 @@ func (*dbWrapperType) Prepare(ctx context.Context, query string, args ...interfa }, nil } -func (*dbWrapperType) ExecStmt(ctx context.Context, stmt *stmt, args ...interface{}) (sql.Result, error) { +func (db *dbWrapperType) ExecStmt(ctx context.Context, stmt *stmt, args ...interface{}) (sql.Result, error) { _, err := getTx(ctx) if err != nil { return nil, sqlError(err, stmt.query, args...) @@ -156,3 +175,47 @@ func (*dbWrapperType) ExecStmt(ctx context.Context, stmt *stmt, args ...interfac return ret, sqlError(err, stmt.query, args...) } + +type SavepointAction func(ctx context.Context) error + +// Encapsulates an action in a savepoint +// Its mostly used to rollback if an error occurred in postgres, as errors in postgres cancel the transaction. +func withSavepoint(ctx context.Context, action SavepointAction) error { + tx, err := getTx(ctx) + if err != nil { + return err + } + + // Generate savepoint + rnd, err := hash.GenerateRandomKey(64) + if err != nil { + return err + } + + // Sqlite needs some letters infront of the identifier + rnd = "savepoint_" + rnd + + // Create a savepoint + _, err = tx.Exec("SAVEPOINT " + rnd) + if err != nil { + return fmt.Errorf("failed to create savepoint: %w", err) + } + + // Execute the action + err = action(ctx) + if err != nil { + // Rollback to savepoint on error + if _, rbErr := tx.Exec("ROLLBACK TO SAVEPOINT " + rnd); rbErr != nil { + return fmt.Errorf("action failed and rollback to savepoint failed: %w", rbErr) + } + return fmt.Errorf("action failed: %w", err) + } + + // Release the savepoint on success + _, err = tx.Exec("RELEASE SAVEPOINT " + rnd) + if err != nil { + return fmt.Errorf("failed to release savepoint: %w", err) + } + + return nil +}