Skip to content

Commit

Permalink
pgsql fix more tests
Browse files Browse the repository at this point in the history
i fix bugs, therefore i am
  • Loading branch information
NodudeWasTaken committed Oct 18, 2024
1 parent 8fa2b38 commit aa4f257
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 41 deletions.
39 changes: 21 additions & 18 deletions pkg/sqlite/criterion_handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ func resolutionCriterionHandler(resolution *models.ResolutionCriterionInput, hei
min := resolution.Value.GetMinResolution()
max := resolution.Value.GetMaxResolution()

widthHeight := fmt.Sprintf("MIN(%s, %s)", widthColumn, heightColumn)
widthHeight := fmt.Sprintf("%s(%s, %s)", getDBMinFunc(), widthColumn, heightColumn)

switch resolution.Modifier {
case models.CriterionModifierEquals:
Expand Down Expand Up @@ -596,7 +596,7 @@ type hierarchicalMultiCriterionHandlerBuilder struct {
relationsTable string
}

func getHierarchicalValues(ctx context.Context, values []string, table, relationsTable, parentFK string, childFK string, depth *int) (string, error) {
func getHierarchicalValues(ctx context.Context, values []string, table, relationsTable, parentFK string, childFK string, depth *int, parenthesis bool) (string, error) {
var args []interface{}

if parentFK == "" {
Expand Down Expand Up @@ -627,7 +627,11 @@ func getHierarchicalValues(ctx context.Context, values []string, table, relation
}

if valid {
return "VALUES" + strings.Join(valuesClauses, ","), nil
values := "VALUES" + strings.Join(valuesClauses, ",")
if parenthesis {
values = "(" + values + ")" + getDBValuesFix()
}
return values, nil
}
}

Expand Down Expand Up @@ -690,6 +694,10 @@ WHERE id in {inBinding}
valuesClause.String = "VALUES" + strings.Join(values, ",")
}

if parenthesis {
valuesClause.String = "(" + valuesClause.String + ")" + getDBValuesFix()
}

return valuesClause.String, nil
}

Expand Down Expand Up @@ -742,35 +750,30 @@ func (m *hierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hierarchica
criterion.Value = nil
}

var pgsql_fix string
if dbWrapper.dbType == PostgresBackend {
pgsql_fix = " AS v(column1, column2)"
}

if len(criterion.Value) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, true)
if err != nil {
f.setError(err)
return
}

switch criterion.Modifier {
case models.CriterionModifierIncludes:
f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM (%s)%s)", m.primaryTable, m.foreignFK, valuesClause, pgsql_fix))
f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM %s)", m.primaryTable, m.foreignFK, valuesClause))
case models.CriterionModifierIncludesAll:
f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM (%s)%s)", m.primaryTable, m.foreignFK, valuesClause, pgsql_fix))
f.addWhere(fmt.Sprintf("%s.%s IN (SELECT column2 FROM %s)", m.primaryTable, m.foreignFK, valuesClause))
f.addHaving(fmt.Sprintf("count(distinct %s.%s) = %d", m.primaryTable, m.foreignFK, len(criterion.Value)))
}
}

if len(criterion.Excludes) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, true)
if err != nil {
f.setError(err)
return
}

f.addWhere(fmt.Sprintf("%s.%s NOT IN (SELECT column2 FROM (%s)%s) OR %[1]s.%[2]s IS NULL", m.primaryTable, m.foreignFK, valuesClause, pgsql_fix))
f.addWhere(fmt.Sprintf("%s.%s NOT IN (SELECT column2 FROM %s) OR %[1]s.%[2]s IS NULL", m.primaryTable, m.foreignFK, valuesClause))
}
}
}
Expand Down Expand Up @@ -859,7 +862,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hiera
}

if len(criterion.Value) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, false)
if err != nil {
f.setError(err)
return
Expand All @@ -881,7 +884,7 @@ func (m *joinedHierarchicalMultiCriterionHandlerBuilder) handler(c *models.Hiera
}

if len(criterion.Excludes) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, m.foreignTable, m.relationsTable, m.parentFK, m.childFK, criterion.Depth, false)
if err != nil {
f.setError(err)
return
Expand Down Expand Up @@ -959,7 +962,7 @@ func (h *joinedPerformerTagsHandler) handle(ctx context.Context, f *filterBuilde
}

if len(criterion.Value) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, tagTable, "tags_relations", "", "", criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Value, tagTable, "tags_relations", "", "", criterion.Depth, false)
if err != nil {
f.setError(err)
return
Expand All @@ -977,13 +980,13 @@ INNER JOIN (`+valuesClause+`) t ON t.column2 = pt.tag_id
}

if len(criterion.Excludes) > 0 {
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, tagTable, "tags_relations", "", "", criterion.Depth)
valuesClause, err := getHierarchicalValues(ctx, criterion.Excludes, tagTable, "tags_relations", "", "", criterion.Depth, true)
if err != nil {
f.setError(err)
return
}

clause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{joinPrimaryKey} FROM {joinTable} INNER JOIN performers_tags ON {joinTable}.performer_id = performers_tags.performer_id WHERE performers_tags.tag_id IN (SELECT column2 FROM (%s)))", strFormatMap)
clause := utils.StrFormat("{primaryTable}.id NOT IN (SELECT {joinTable}.{joinPrimaryKey} FROM {joinTable} INNER JOIN performers_tags ON {joinTable}.performer_id = performers_tags.performer_id WHERE performers_tags.tag_id IN (SELECT column2 FROM %s))", strFormatMap)
f.addWhere(fmt.Sprintf(clause, valuesClause))
}
}
Expand Down
17 changes: 17 additions & 0 deletions pkg/sqlite/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,23 @@ func getDBBoolean(val bool) string {
}
}

func getDBValuesFix() (val string) {
if dbWrapper.dbType == PostgresBackend {
val = " AS v(column1, column2)"
}

return val
}

func getDBMinFunc() string {
switch dbWrapper.dbType {
case PostgresBackend:
return "LEAST"
default:
return "MIN"
}
}

func (db *Database) SetSchemaVersion(version uint) {
db.schemaVersion = version
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/sqlite/gallery_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func (qb *galleryFilterHandler) averageResolutionCriterionHandler(resolution *mo
min := resolution.Value.GetMinResolution()
max := resolution.Value.GetMaxResolution()

const widthHeight = "avg(MIN(image_files.width, image_files.height))"
var widthHeight = "avg(" + getDBMinFunc() + "(image_files.width, image_files.height))"

switch resolution.Modifier {
case models.CriterionModifierEquals:
Expand Down
4 changes: 2 additions & 2 deletions pkg/sqlite/performer_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,12 @@ func (qb *performerFilterHandler) studiosCriterionHandler(studios *models.Hierar
}

const derivedPerformerStudioTable = "performer_studio"
valuesClause, err := getHierarchicalValues(ctx, studios.Value, studioTable, "", "parent_id", "child_id", studios.Depth)
valuesClause, err := getHierarchicalValues(ctx, studios.Value, studioTable, "", "parent_id", "child_id", studios.Depth, true)
if err != nil {
f.setError(err)
return
}
f.addWith("studio(root_id, item_id) AS (" + valuesClause + ")")
f.addWith("studio(root_id, item_id) AS " + valuesClause)

templStr := `SELECT performer_id FROM {primaryTable}
INNER JOIN {joinTable} ON {primaryTable}.id = {joinTable}.{primaryFK}
Expand Down
12 changes: 6 additions & 6 deletions pkg/sqlite/performer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ func Test_PerformerStore_Create(t *testing.T) {
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)

Expand Down Expand Up @@ -217,8 +217,8 @@ func Test_PerformerStore_Update(t *testing.T) {
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)

Expand Down Expand Up @@ -398,8 +398,8 @@ func Test_PerformerStore_UpdatePartial(t *testing.T) {
favorite = true
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")
createdAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)
updatedAt = time.Date(2001, 1, 1, 0, 0, 0, 0, time.UTC)

Expand Down
12 changes: 6 additions & 6 deletions pkg/sqlite/scene_marker_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,18 @@ func (qb *sceneMarkerFilterHandler) tagsCriterionHandler(criterion *models.Hiera
}

if len(tags.Value) > 0 {
valuesClause, err := getHierarchicalValues(ctx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth)
valuesClause, err := getHierarchicalValues(ctx, tags.Value, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth, true)
if err != nil {
f.setError(err)
return
}

f.addWith(`marker_tags AS (
SELECT mt.scene_marker_id, t.column1 AS root_tag_id FROM scene_markers_tags mt
INNER JOIN (` + valuesClause + `) t ON t.column2 = mt.tag_id
INNER JOIN ` + valuesClause + ` t ON t.column2 = mt.tag_id
UNION
SELECT m.id, t.column1 FROM scene_markers m
INNER JOIN (` + valuesClause + `) t ON t.column2 = m.primary_tag_id
INNER JOIN ` + valuesClause + ` t ON t.column2 = m.primary_tag_id
)`)

f.addLeftJoin("marker_tags", "", "marker_tags.scene_marker_id = scene_markers.id")
Expand All @@ -127,16 +127,16 @@ func (qb *sceneMarkerFilterHandler) tagsCriterionHandler(criterion *models.Hiera
}

if len(criterion.Excludes) > 0 {
valuesClause, err := getHierarchicalValues(ctx, tags.Excludes, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth)
valuesClause, err := getHierarchicalValues(ctx, tags.Excludes, tagTable, "tags_relations", "parent_id", "child_id", tags.Depth, true)
if err != nil {
f.setError(err)
return
}

clause := "scene_markers.id NOT IN (SELECT scene_markers_tags.scene_marker_id FROM scene_markers_tags WHERE scene_markers_tags.tag_id IN (SELECT column2 FROM (%s)))"
clause := "scene_markers.id NOT IN (SELECT scene_markers_tags.scene_marker_id FROM scene_markers_tags WHERE scene_markers_tags.tag_id IN (SELECT column2 FROM %s))"
f.addWhere(fmt.Sprintf(clause, valuesClause))

f.addWhere(fmt.Sprintf("scene_markers.primary_tag_id NOT IN (SELECT column2 FROM (%s))", valuesClause))
f.addWhere(fmt.Sprintf("scene_markers.primary_tag_id NOT IN (SELECT column2 FROM %s)", valuesClause))
}
}
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/sqlite/scene_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ func Test_sceneQueryBuilder_Create(t *testing.T) {
sceneIndex2 = 234
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")

date, _ = models.ParseDate("2003-02-01")

Expand Down Expand Up @@ -321,8 +321,8 @@ func Test_sceneQueryBuilder_Update(t *testing.T) {
sceneIndex2 = 234
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")

date, _ = models.ParseDate("2003-02-01")
)
Expand Down Expand Up @@ -531,8 +531,8 @@ func Test_sceneQueryBuilder_UpdatePartial(t *testing.T) {
sceneIndex2 = 234
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")

date, _ = models.ParseDate("2003-02-01")
)
Expand Down Expand Up @@ -725,8 +725,8 @@ func Test_sceneQueryBuilder_UpdatePartialRelationships(t *testing.T) {
sceneIndex2 = 234
endpoint1 = "endpoint1"
endpoint2 = "endpoint2"
stashID1 = "stashid1"
stashID2 = "stashid2"
stashID1 = getUUID("stashid1")
stashID2 = getUUID("stashid2")

groupScenes = []models.GroupsScenes{
{
Expand Down

0 comments on commit aa4f257

Please sign in to comment.