diff --git a/ent.resolvers.go b/ent.resolvers.go index 582ea7cc..9a377f2a 100644 --- a/ent.resolvers.go +++ b/ent.resolvers.go @@ -20,6 +20,7 @@ package archivista import ( "context" + "entgo.io/contrib/entgql" "github.com/testifysec/archivista/ent" ) @@ -31,11 +32,11 @@ func (r *queryResolver) Nodes(ctx context.Context, ids []int) ([]ent.Noder, erro return r.client.Noders(ctx, ids) } -func (r *queryResolver) Dsses(ctx context.Context, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.DsseWhereInput) (*ent.DsseConnection, error) { +func (r *queryResolver) Dsses(ctx context.Context, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.DsseWhereInput) (*ent.DsseConnection, error) { return r.client.Dsse.Query().Paginate(ctx, after, first, before, last, ent.WithDsseFilter(where.Filter)) } -func (r *queryResolver) Subjects(ctx context.Context, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.SubjectWhereInput) (*ent.SubjectConnection, error) { +func (r *queryResolver) Subjects(ctx context.Context, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.SubjectWhereInput) (*ent.SubjectConnection, error) { return r.client.Subject.Query().Paginate(ctx, after, first, before, last, ent.WithSubjectFilter(where.Filter)) } diff --git a/ent/attestation.go b/ent/attestation.go index 4e20bc11..62ba8e50 100644 --- a/ent/attestation.go +++ b/ent/attestation.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/attestation" "github.com/testifysec/archivista/ent/attestationcollection" @@ -22,6 +23,7 @@ type Attestation struct { // The values are being populated by the AttestationQuery when eager-loading is set. Edges AttestationEdges `json:"edges"` attestation_collection_attestations *int + selectValues sql.SelectValues } // AttestationEdges holds the relations/edges for other nodes in the graph. @@ -32,7 +34,7 @@ type AttestationEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [1]bool // totalCount holds the count of the edges above. - totalCount [1]*int + totalCount [1]map[string]int } // AttestationCollectionOrErr returns the AttestationCollection value or an error if the edge @@ -60,7 +62,7 @@ func (*Attestation) scanValues(columns []string) ([]any, error) { case attestation.ForeignKeys[0]: // attestation_collection_attestations values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Attestation", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -93,21 +95,29 @@ func (a *Attestation) assignValues(columns []string, values []any) error { a.attestation_collection_attestations = new(int) *a.attestation_collection_attestations = int(value.Int64) } + default: + a.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Attestation. +// This includes values selected through modifiers, order, etc. +func (a *Attestation) Value(name string) (ent.Value, error) { + return a.selectValues.Get(name) +} + // QueryAttestationCollection queries the "attestation_collection" edge of the Attestation entity. func (a *Attestation) QueryAttestationCollection() *AttestationCollectionQuery { - return (&AttestationClient{config: a.config}).QueryAttestationCollection(a) + return NewAttestationClient(a.config).QueryAttestationCollection(a) } // Update returns a builder for updating this Attestation. // Note that you need to call Attestation.Unwrap() before calling this method if this Attestation // was returned from a transaction, and the transaction was committed or rolled back. func (a *Attestation) Update() *AttestationUpdateOne { - return (&AttestationClient{config: a.config}).UpdateOne(a) + return NewAttestationClient(a.config).UpdateOne(a) } // Unwrap unwraps the Attestation entity that was returned from a transaction after it was closed, @@ -134,9 +144,3 @@ func (a *Attestation) String() string { // Attestations is a parsable slice of Attestation. type Attestations []*Attestation - -func (a Attestations) config(cfg config) { - for _i := range a { - a[_i].config = cfg - } -} diff --git a/ent/attestation/attestation.go b/ent/attestation/attestation.go index 1f77e187..ad3226ba 100644 --- a/ent/attestation/attestation.go +++ b/ent/attestation/attestation.go @@ -2,6 +2,11 @@ package attestation +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the attestation type in the database. Label = "attestation" @@ -53,3 +58,30 @@ var ( // TypeValidator is a validator for the "type" field. It is called by the builders before save. TypeValidator func(string) error ) + +// OrderOption defines the ordering options for the Attestation queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByAttestationCollectionField orders the results by attestation_collection field. +func ByAttestationCollectionField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAttestationCollectionStep(), sql.OrderByField(field, opts...)) + } +} +func newAttestationCollectionStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AttestationCollectionInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, AttestationCollectionTable, AttestationCollectionColumn), + ) +} diff --git a/ent/attestation/where.go b/ent/attestation/where.go index 4dee743c..48e980a8 100644 --- a/ent/attestation/where.go +++ b/ent/attestation/where.go @@ -10,179 +10,117 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Attestation(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Attestation(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Attestation(sql.FieldLTE(FieldID, id)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldEQ(FieldType, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Attestation { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Attestation(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Attestation { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Attestation(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldHasSuffix(FieldType, v)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Attestation(sql.FieldContainsFold(FieldType, v)) } // HasAttestationCollection applies the HasEdge predicate on the "attestation_collection" edge. @@ -190,7 +128,6 @@ func HasAttestationCollection() predicate.Attestation { return predicate.Attestation(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationCollectionTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, AttestationCollectionTable, AttestationCollectionColumn), ) sqlgraph.HasNeighbors(s, step) @@ -200,11 +137,7 @@ func HasAttestationCollection() predicate.Attestation { // HasAttestationCollectionWith applies the HasEdge predicate on the "attestation_collection" edge with a given conditions (other predicates). func HasAttestationCollectionWith(preds ...predicate.AttestationCollection) predicate.Attestation { return predicate.Attestation(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationCollectionInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, AttestationCollectionTable, AttestationCollectionColumn), - ) + step := newAttestationCollectionStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -215,32 +148,15 @@ func HasAttestationCollectionWith(preds ...predicate.AttestationCollection) pred // And groups predicates with the AND operator between them. func And(predicates ...predicate.Attestation) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Attestation(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Attestation) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Attestation(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Attestation) predicate.Attestation { - return predicate.Attestation(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Attestation(sql.NotPredicates(p)) } diff --git a/ent/attestation_create.go b/ent/attestation_create.go index 7f11223c..5926e323 100644 --- a/ent/attestation_create.go +++ b/ent/attestation_create.go @@ -44,49 +44,7 @@ func (ac *AttestationCreate) Mutation() *AttestationMutation { // Save creates the Attestation in the database. func (ac *AttestationCreate) Save(ctx context.Context) (*Attestation, error) { - var ( - err error - node *Attestation - ) - if len(ac.hooks) == 0 { - if err = ac.check(); err != nil { - return nil, err - } - node, err = ac.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = ac.check(); err != nil { - return nil, err - } - ac.mutation = mutation - if node, err = ac.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(ac.hooks) - 1; i >= 0; i-- { - if ac.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ac.hooks[i](mut) - } - v, err := mut.Mutate(ctx, ac.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Attestation) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AttestationMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, ac.sqlSave, ac.mutation, ac.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -128,6 +86,9 @@ func (ac *AttestationCreate) check() error { } func (ac *AttestationCreate) sqlSave(ctx context.Context) (*Attestation, error) { + if err := ac.check(); err != nil { + return nil, err + } _node, _spec := ac.createSpec() if err := sqlgraph.CreateNode(ctx, ac.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -137,26 +98,18 @@ func (ac *AttestationCreate) sqlSave(ctx context.Context) (*Attestation, error) } id := _spec.ID.Value.(int64) _node.ID = int(id) + ac.mutation.id = &_node.ID + ac.mutation.done = true return _node, nil } func (ac *AttestationCreate) createSpec() (*Attestation, *sqlgraph.CreateSpec) { var ( _node = &Attestation{config: ac.config} - _spec = &sqlgraph.CreateSpec{ - Table: attestation.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(attestation.Table, sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt)) ) if value, ok := ac.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestation.FieldType, - }) + _spec.SetField(attestation.FieldType, field.TypeString, value) _node.Type = value } if nodes := ac.mutation.AttestationCollectionIDs(); len(nodes) > 0 { @@ -167,10 +120,7 @@ func (ac *AttestationCreate) createSpec() (*Attestation, *sqlgraph.CreateSpec) { Columns: []string{attestation.AttestationCollectionColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -185,11 +135,15 @@ func (ac *AttestationCreate) createSpec() (*Attestation, *sqlgraph.CreateSpec) { // AttestationCreateBulk is the builder for creating many Attestation entities in bulk. type AttestationCreateBulk struct { config + err error builders []*AttestationCreate } // Save creates the Attestation entities in the database. func (acb *AttestationCreateBulk) Save(ctx context.Context) ([]*Attestation, error) { + if acb.err != nil { + return nil, acb.err + } specs := make([]*sqlgraph.CreateSpec, len(acb.builders)) nodes := make([]*Attestation, len(acb.builders)) mutators := make([]Mutator, len(acb.builders)) @@ -205,8 +159,8 @@ func (acb *AttestationCreateBulk) Save(ctx context.Context) ([]*Attestation, err return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, acb.builders[i+1].mutation) } else { diff --git a/ent/attestation_delete.go b/ent/attestation_delete.go index 392287bb..e2024b15 100644 --- a/ent/attestation_delete.go +++ b/ent/attestation_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (ad *AttestationDelete) Where(ps ...predicate.Attestation) *AttestationDele // Exec executes the deletion query and returns how many vertices were deleted. func (ad *AttestationDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(ad.hooks) == 0 { - affected, err = ad.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - ad.mutation = mutation - affected, err = ad.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(ad.hooks) - 1; i >= 0; i-- { - if ad.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = ad.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, ad.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, ad.sqlExec, ad.mutation, ad.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (ad *AttestationDelete) ExecX(ctx context.Context) int { } func (ad *AttestationDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestation.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(attestation.Table, sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt)) if ps := ad.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (ad *AttestationDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + ad.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type AttestationDeleteOne struct { ad *AttestationDelete } +// Where appends a list predicates to the AttestationDelete builder. +func (ado *AttestationDeleteOne) Where(ps ...predicate.Attestation) *AttestationDeleteOne { + ado.ad.mutation.Where(ps...) + return ado +} + // Exec executes the deletion query. func (ado *AttestationDeleteOne) Exec(ctx context.Context) error { n, err := ado.ad.Exec(ctx) @@ -111,5 +82,7 @@ func (ado *AttestationDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ado *AttestationDeleteOne) ExecX(ctx context.Context) { - ado.ad.ExecX(ctx) + if err := ado.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/attestation_query.go b/ent/attestation_query.go index 406aeff6..47edcef7 100644 --- a/ent/attestation_query.go +++ b/ent/attestation_query.go @@ -18,11 +18,9 @@ import ( // AttestationQuery is the builder for querying Attestation entities. type AttestationQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []attestation.OrderOption + inters []Interceptor predicates []predicate.Attestation withAttestationCollection *AttestationCollectionQuery withFKs bool @@ -39,34 +37,34 @@ func (aq *AttestationQuery) Where(ps ...predicate.Attestation) *AttestationQuery return aq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (aq *AttestationQuery) Limit(limit int) *AttestationQuery { - aq.limit = &limit + aq.ctx.Limit = &limit return aq } -// Offset adds an offset step to the query. +// Offset to start from. func (aq *AttestationQuery) Offset(offset int) *AttestationQuery { - aq.offset = &offset + aq.ctx.Offset = &offset return aq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (aq *AttestationQuery) Unique(unique bool) *AttestationQuery { - aq.unique = &unique + aq.ctx.Unique = &unique return aq } -// Order adds an order step to the query. -func (aq *AttestationQuery) Order(o ...OrderFunc) *AttestationQuery { +// Order specifies how the records should be ordered. +func (aq *AttestationQuery) Order(o ...attestation.OrderOption) *AttestationQuery { aq.order = append(aq.order, o...) return aq } // QueryAttestationCollection chains the current query on the "attestation_collection" edge. func (aq *AttestationQuery) QueryAttestationCollection() *AttestationCollectionQuery { - query := &AttestationCollectionQuery{config: aq.config} + query := (&AttestationCollectionClient{config: aq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := aq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (aq *AttestationQuery) QueryAttestationCollection() *AttestationCollectionQ // First returns the first Attestation entity from the query. // Returns a *NotFoundError when no Attestation was found. func (aq *AttestationQuery) First(ctx context.Context) (*Attestation, error) { - nodes, err := aq.Limit(1).All(ctx) + nodes, err := aq.Limit(1).All(setContextOp(ctx, aq.ctx, "First")) if err != nil { return nil, err } @@ -112,7 +110,7 @@ func (aq *AttestationQuery) FirstX(ctx context.Context) *Attestation { // Returns a *NotFoundError when no Attestation ID was found. func (aq *AttestationQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(1).IDs(ctx); err != nil { + if ids, err = aq.Limit(1).IDs(setContextOp(ctx, aq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -135,7 +133,7 @@ func (aq *AttestationQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Attestation entity is found. // Returns a *NotFoundError when no Attestation entities are found. func (aq *AttestationQuery) Only(ctx context.Context) (*Attestation, error) { - nodes, err := aq.Limit(2).All(ctx) + nodes, err := aq.Limit(2).All(setContextOp(ctx, aq.ctx, "Only")) if err != nil { return nil, err } @@ -163,7 +161,7 @@ func (aq *AttestationQuery) OnlyX(ctx context.Context) *Attestation { // Returns a *NotFoundError when no entities are found. func (aq *AttestationQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = aq.Limit(2).IDs(ctx); err != nil { + if ids, err = aq.Limit(2).IDs(setContextOp(ctx, aq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -188,10 +186,12 @@ func (aq *AttestationQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Attestations. func (aq *AttestationQuery) All(ctx context.Context) ([]*Attestation, error) { + ctx = setContextOp(ctx, aq.ctx, "All") if err := aq.prepareQuery(ctx); err != nil { return nil, err } - return aq.sqlAll(ctx) + qr := querierAll[[]*Attestation, *AttestationQuery]() + return withInterceptors[[]*Attestation](ctx, aq, qr, aq.inters) } // AllX is like All, but panics if an error occurs. @@ -204,9 +204,12 @@ func (aq *AttestationQuery) AllX(ctx context.Context) []*Attestation { } // IDs executes the query and returns a list of Attestation IDs. -func (aq *AttestationQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := aq.Select(attestation.FieldID).Scan(ctx, &ids); err != nil { +func (aq *AttestationQuery) IDs(ctx context.Context) (ids []int, err error) { + if aq.ctx.Unique == nil && aq.path != nil { + aq.Unique(true) + } + ctx = setContextOp(ctx, aq.ctx, "IDs") + if err = aq.Select(attestation.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -223,10 +226,11 @@ func (aq *AttestationQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (aq *AttestationQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, aq.ctx, "Count") if err := aq.prepareQuery(ctx); err != nil { return 0, err } - return aq.sqlCount(ctx) + return withInterceptors[int](ctx, aq, querierCount[*AttestationQuery](), aq.inters) } // CountX is like Count, but panics if an error occurs. @@ -240,10 +244,15 @@ func (aq *AttestationQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (aq *AttestationQuery) Exist(ctx context.Context) (bool, error) { - if err := aq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, aq.ctx, "Exist") + switch _, err := aq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return aq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -263,22 +272,21 @@ func (aq *AttestationQuery) Clone() *AttestationQuery { } return &AttestationQuery{ config: aq.config, - limit: aq.limit, - offset: aq.offset, - order: append([]OrderFunc{}, aq.order...), + ctx: aq.ctx.Clone(), + order: append([]attestation.OrderOption{}, aq.order...), + inters: append([]Interceptor{}, aq.inters...), predicates: append([]predicate.Attestation{}, aq.predicates...), withAttestationCollection: aq.withAttestationCollection.Clone(), // clone intermediate query. - sql: aq.sql.Clone(), - path: aq.path, - unique: aq.unique, + sql: aq.sql.Clone(), + path: aq.path, } } // WithAttestationCollection tells the query-builder to eager-load the nodes that are connected to // the "attestation_collection" edge. The optional arguments are used to configure the query builder of the edge. func (aq *AttestationQuery) WithAttestationCollection(opts ...func(*AttestationCollectionQuery)) *AttestationQuery { - query := &AttestationCollectionQuery{config: aq.config} + query := (&AttestationCollectionClient{config: aq.config}).Query() for _, opt := range opts { opt(query) } @@ -301,16 +309,11 @@ func (aq *AttestationQuery) WithAttestationCollection(opts ...func(*AttestationC // Aggregate(ent.Count()). // Scan(ctx, &v) func (aq *AttestationQuery) GroupBy(field string, fields ...string) *AttestationGroupBy { - grbuild := &AttestationGroupBy{config: aq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := aq.prepareQuery(ctx); err != nil { - return nil, err - } - return aq.sqlQuery(ctx), nil - } + aq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AttestationGroupBy{build: aq} + grbuild.flds = &aq.ctx.Fields grbuild.label = attestation.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -327,15 +330,30 @@ func (aq *AttestationQuery) GroupBy(field string, fields ...string) *Attestation // Select(attestation.FieldType). // Scan(ctx, &v) func (aq *AttestationQuery) Select(fields ...string) *AttestationSelect { - aq.fields = append(aq.fields, fields...) - selbuild := &AttestationSelect{AttestationQuery: aq} - selbuild.label = attestation.Label - selbuild.flds, selbuild.scan = &aq.fields, selbuild.Scan - return selbuild + aq.ctx.Fields = append(aq.ctx.Fields, fields...) + sbuild := &AttestationSelect{AttestationQuery: aq} + sbuild.label = attestation.Label + sbuild.flds, sbuild.scan = &aq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AttestationSelect configured with the given aggregations. +func (aq *AttestationQuery) Aggregate(fns ...AggregateFunc) *AttestationSelect { + return aq.Select().Aggregate(fns...) } func (aq *AttestationQuery) prepareQuery(ctx context.Context) error { - for _, f := range aq.fields { + for _, inter := range aq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, aq); err != nil { + return err + } + } + } + for _, f := range aq.ctx.Fields { if !attestation.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -413,6 +431,9 @@ func (aq *AttestationQuery) loadAttestationCollection(ctx context.Context, query } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(attestationcollection.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -435,41 +456,22 @@ func (aq *AttestationQuery) sqlCount(ctx context.Context) (int, error) { if len(aq.modifiers) > 0 { _spec.Modifiers = aq.modifiers } - _spec.Node.Columns = aq.fields - if len(aq.fields) > 0 { - _spec.Unique = aq.unique != nil && *aq.unique + _spec.Node.Columns = aq.ctx.Fields + if len(aq.ctx.Fields) > 0 { + _spec.Unique = aq.ctx.Unique != nil && *aq.ctx.Unique } return sqlgraph.CountNodes(ctx, aq.driver, _spec) } -func (aq *AttestationQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := aq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (aq *AttestationQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestation.Table, - Columns: attestation.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, - }, - From: aq.sql, - Unique: true, - } - if unique := aq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(attestation.Table, attestation.Columns, sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt)) + _spec.From = aq.sql + if unique := aq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if aq.path != nil { + _spec.Unique = true } - if fields := aq.fields; len(fields) > 0 { + if fields := aq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, attestation.FieldID) for i := range fields { @@ -485,10 +487,10 @@ func (aq *AttestationQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := aq.order; len(ps) > 0 { @@ -504,7 +506,7 @@ func (aq *AttestationQuery) querySpec() *sqlgraph.QuerySpec { func (aq *AttestationQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(aq.driver.Dialect()) t1 := builder.Table(attestation.Table) - columns := aq.fields + columns := aq.ctx.Fields if len(columns) == 0 { columns = attestation.Columns } @@ -513,7 +515,7 @@ func (aq *AttestationQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = aq.sql selector.Select(selector.Columns(columns...)...) } - if aq.unique != nil && *aq.unique { + if aq.ctx.Unique != nil && *aq.ctx.Unique { selector.Distinct() } for _, p := range aq.predicates { @@ -522,12 +524,12 @@ func (aq *AttestationQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range aq.order { p(selector) } - if offset := aq.offset; offset != nil { + if offset := aq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := aq.limit; limit != nil { + if limit := aq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -535,13 +537,8 @@ func (aq *AttestationQuery) sqlQuery(ctx context.Context) *sql.Selector { // AttestationGroupBy is the group-by builder for Attestation entities. type AttestationGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *AttestationQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -550,74 +547,77 @@ func (agb *AttestationGroupBy) Aggregate(fns ...AggregateFunc) *AttestationGroup return agb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (agb *AttestationGroupBy) Scan(ctx context.Context, v any) error { - query, err := agb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, agb.build.ctx, "GroupBy") + if err := agb.build.prepareQuery(ctx); err != nil { return err } - agb.sql = query - return agb.sqlScan(ctx, v) + return scanWithInterceptors[*AttestationQuery, *AttestationGroupBy](ctx, agb.build, agb, agb.build.inters, v) } -func (agb *AttestationGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range agb.fields { - if !attestation.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (agb *AttestationGroupBy) sqlScan(ctx context.Context, root *AttestationQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(agb.fns)) + for _, fn := range agb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*agb.flds)+len(agb.fns)) + for _, f := range *agb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := agb.sqlQuery() + selector.GroupBy(selector.Columns(*agb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := agb.driver.Query(ctx, query, args, rows); err != nil { + if err := agb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (agb *AttestationGroupBy) sqlQuery() *sql.Selector { - selector := agb.sql.Select() - aggregation := make([]string, 0, len(agb.fns)) - for _, fn := range agb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(agb.fields)+len(agb.fns)) - for _, f := range agb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(agb.fields...)...) -} - // AttestationSelect is the builder for selecting fields of Attestation entities. type AttestationSelect struct { *AttestationQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (as *AttestationSelect) Aggregate(fns ...AggregateFunc) *AttestationSelect { + as.fns = append(as.fns, fns...) + return as } // Scan applies the selector query and scans the result into the given value. func (as *AttestationSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, as.ctx, "Select") if err := as.prepareQuery(ctx); err != nil { return err } - as.sql = as.AttestationQuery.sqlQuery(ctx) - return as.sqlScan(ctx, v) + return scanWithInterceptors[*AttestationQuery, *AttestationSelect](ctx, as.AttestationQuery, as, as.inters, v) } -func (as *AttestationSelect) sqlScan(ctx context.Context, v any) error { +func (as *AttestationSelect) sqlScan(ctx context.Context, root *AttestationQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(as.fns)) + for _, fn := range as.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*as.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := as.sql.Query() + query, args := selector.Query() if err := as.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/attestation_update.go b/ent/attestation_update.go index 82c45cd2..7f6b1d49 100644 --- a/ent/attestation_update.go +++ b/ent/attestation_update.go @@ -58,40 +58,7 @@ func (au *AttestationUpdate) ClearAttestationCollection() *AttestationUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (au *AttestationUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(au.hooks) == 0 { - if err = au.check(); err != nil { - return 0, err - } - affected, err = au.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = au.check(); err != nil { - return 0, err - } - au.mutation = mutation - affected, err = au.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(au.hooks) - 1; i >= 0; i-- { - if au.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = au.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, au.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, au.sqlSave, au.mutation, au.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -130,16 +97,10 @@ func (au *AttestationUpdate) check() error { } func (au *AttestationUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestation.Table, - Columns: attestation.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, - }, + if err := au.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(attestation.Table, attestation.Columns, sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt)) if ps := au.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -148,11 +109,7 @@ func (au *AttestationUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := au.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestation.FieldType, - }) + _spec.SetField(attestation.FieldType, field.TypeString, value) } if au.mutation.AttestationCollectionCleared() { edge := &sqlgraph.EdgeSpec{ @@ -162,10 +119,7 @@ func (au *AttestationUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{attestation.AttestationCollectionColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -178,10 +132,7 @@ func (au *AttestationUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{attestation.AttestationCollectionColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -197,6 +148,7 @@ func (au *AttestationUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + au.mutation.done = true return n, nil } @@ -236,6 +188,12 @@ func (auo *AttestationUpdateOne) ClearAttestationCollection() *AttestationUpdate return auo } +// Where appends a list predicates to the AttestationUpdate builder. +func (auo *AttestationUpdateOne) Where(ps ...predicate.Attestation) *AttestationUpdateOne { + auo.mutation.Where(ps...) + return auo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (auo *AttestationUpdateOne) Select(field string, fields ...string) *AttestationUpdateOne { @@ -245,46 +203,7 @@ func (auo *AttestationUpdateOne) Select(field string, fields ...string) *Attesta // Save executes the query and returns the updated Attestation entity. func (auo *AttestationUpdateOne) Save(ctx context.Context) (*Attestation, error) { - var ( - err error - node *Attestation - ) - if len(auo.hooks) == 0 { - if err = auo.check(); err != nil { - return nil, err - } - node, err = auo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = auo.check(); err != nil { - return nil, err - } - auo.mutation = mutation - node, err = auo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(auo.hooks) - 1; i >= 0; i-- { - if auo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = auo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, auo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Attestation) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AttestationMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, auo.sqlSave, auo.mutation, auo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -323,16 +242,10 @@ func (auo *AttestationUpdateOne) check() error { } func (auo *AttestationUpdateOne) sqlSave(ctx context.Context) (_node *Attestation, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestation.Table, - Columns: attestation.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, - }, + if err := auo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(attestation.Table, attestation.Columns, sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt)) id, ok := auo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Attestation.id" for update`)} @@ -358,11 +271,7 @@ func (auo *AttestationUpdateOne) sqlSave(ctx context.Context) (_node *Attestatio } } if value, ok := auo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestation.FieldType, - }) + _spec.SetField(attestation.FieldType, field.TypeString, value) } if auo.mutation.AttestationCollectionCleared() { edge := &sqlgraph.EdgeSpec{ @@ -372,10 +281,7 @@ func (auo *AttestationUpdateOne) sqlSave(ctx context.Context) (_node *Attestatio Columns: []string{attestation.AttestationCollectionColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -388,10 +294,7 @@ func (auo *AttestationUpdateOne) sqlSave(ctx context.Context) (_node *Attestatio Columns: []string{attestation.AttestationCollectionColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -410,5 +313,6 @@ func (auo *AttestationUpdateOne) sqlSave(ctx context.Context) (_node *Attestatio } return nil, err } + auo.mutation.done = true return _node, nil } diff --git a/ent/attestationcollection.go b/ent/attestationcollection.go index 489e954f..38ac47e7 100644 --- a/ent/attestationcollection.go +++ b/ent/attestationcollection.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/attestationcollection" "github.com/testifysec/archivista/ent/statement" @@ -22,6 +23,7 @@ type AttestationCollection struct { // The values are being populated by the AttestationCollectionQuery when eager-loading is set. Edges AttestationCollectionEdges `json:"edges"` statement_attestation_collections *int + selectValues sql.SelectValues } // AttestationCollectionEdges holds the relations/edges for other nodes in the graph. @@ -34,7 +36,9 @@ type AttestationCollectionEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [2]bool // totalCount holds the count of the edges above. - totalCount [2]*int + totalCount [2]map[string]int + + namedAttestations map[string][]*Attestation } // AttestationsOrErr returns the Attestations value or an error if the edge @@ -71,7 +75,7 @@ func (*AttestationCollection) scanValues(columns []string) ([]any, error) { case attestationcollection.ForeignKeys[0]: // statement_attestation_collections values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type AttestationCollection", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -104,26 +108,34 @@ func (ac *AttestationCollection) assignValues(columns []string, values []any) er ac.statement_attestation_collections = new(int) *ac.statement_attestation_collections = int(value.Int64) } + default: + ac.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the AttestationCollection. +// This includes values selected through modifiers, order, etc. +func (ac *AttestationCollection) Value(name string) (ent.Value, error) { + return ac.selectValues.Get(name) +} + // QueryAttestations queries the "attestations" edge of the AttestationCollection entity. func (ac *AttestationCollection) QueryAttestations() *AttestationQuery { - return (&AttestationCollectionClient{config: ac.config}).QueryAttestations(ac) + return NewAttestationCollectionClient(ac.config).QueryAttestations(ac) } // QueryStatement queries the "statement" edge of the AttestationCollection entity. func (ac *AttestationCollection) QueryStatement() *StatementQuery { - return (&AttestationCollectionClient{config: ac.config}).QueryStatement(ac) + return NewAttestationCollectionClient(ac.config).QueryStatement(ac) } // Update returns a builder for updating this AttestationCollection. // Note that you need to call AttestationCollection.Unwrap() before calling this method if this AttestationCollection // was returned from a transaction, and the transaction was committed or rolled back. func (ac *AttestationCollection) Update() *AttestationCollectionUpdateOne { - return (&AttestationCollectionClient{config: ac.config}).UpdateOne(ac) + return NewAttestationCollectionClient(ac.config).UpdateOne(ac) } // Unwrap unwraps the AttestationCollection entity that was returned from a transaction after it was closed, @@ -148,11 +160,29 @@ func (ac *AttestationCollection) String() string { return builder.String() } -// AttestationCollections is a parsable slice of AttestationCollection. -type AttestationCollections []*AttestationCollection +// NamedAttestations returns the Attestations named value or an error if the edge was not +// loaded in eager-loading with this name. +func (ac *AttestationCollection) NamedAttestations(name string) ([]*Attestation, error) { + if ac.Edges.namedAttestations == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := ac.Edges.namedAttestations[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} -func (ac AttestationCollections) config(cfg config) { - for _i := range ac { - ac[_i].config = cfg +func (ac *AttestationCollection) appendNamedAttestations(name string, edges ...*Attestation) { + if ac.Edges.namedAttestations == nil { + ac.Edges.namedAttestations = make(map[string][]*Attestation) + } + if len(edges) == 0 { + ac.Edges.namedAttestations[name] = []*Attestation{} + } else { + ac.Edges.namedAttestations[name] = append(ac.Edges.namedAttestations[name], edges...) } } + +// AttestationCollections is a parsable slice of AttestationCollection. +type AttestationCollections []*AttestationCollection diff --git a/ent/attestationcollection/attestationcollection.go b/ent/attestationcollection/attestationcollection.go index 4bebfb86..8b509f8b 100644 --- a/ent/attestationcollection/attestationcollection.go +++ b/ent/attestationcollection/attestationcollection.go @@ -2,6 +2,11 @@ package attestationcollection +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the attestationcollection type in the database. Label = "attestation_collection" @@ -62,3 +67,51 @@ var ( // NameValidator is a validator for the "name" field. It is called by the builders before save. NameValidator func(string) error ) + +// OrderOption defines the ordering options for the AttestationCollection queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByAttestationsCount orders the results by attestations count. +func ByAttestationsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newAttestationsStep(), opts...) + } +} + +// ByAttestations orders the results by attestations terms. +func ByAttestations(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAttestationsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByStatementField orders the results by statement field. +func ByStatementField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStatementStep(), sql.OrderByField(field, opts...)) + } +} +func newAttestationsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AttestationsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, AttestationsTable, AttestationsColumn), + ) +} +func newStatementStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StatementInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, true, StatementTable, StatementColumn), + ) +} diff --git a/ent/attestationcollection/where.go b/ent/attestationcollection/where.go index b986c42a..d903ea97 100644 --- a/ent/attestationcollection/where.go +++ b/ent/attestationcollection/where.go @@ -10,179 +10,117 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.AttestationCollection(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.AttestationCollection(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.AttestationCollection(sql.FieldLTE(FieldID, id)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldEQ(FieldName, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.AttestationCollection { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.AttestationCollection(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.AttestationCollection { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.AttestationCollection(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.AttestationCollection(sql.FieldContainsFold(FieldName, v)) } // HasAttestations applies the HasEdge predicate on the "attestations" edge. @@ -190,7 +128,6 @@ func HasAttestations() predicate.AttestationCollection { return predicate.AttestationCollection(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, AttestationsTable, AttestationsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -200,11 +137,7 @@ func HasAttestations() predicate.AttestationCollection { // HasAttestationsWith applies the HasEdge predicate on the "attestations" edge with a given conditions (other predicates). func HasAttestationsWith(preds ...predicate.Attestation) predicate.AttestationCollection { return predicate.AttestationCollection(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, AttestationsTable, AttestationsColumn), - ) + step := newAttestationsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -218,7 +151,6 @@ func HasStatement() predicate.AttestationCollection { return predicate.AttestationCollection(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementTable, FieldID), sqlgraph.Edge(sqlgraph.O2O, true, StatementTable, StatementColumn), ) sqlgraph.HasNeighbors(s, step) @@ -228,11 +160,7 @@ func HasStatement() predicate.AttestationCollection { // HasStatementWith applies the HasEdge predicate on the "statement" edge with a given conditions (other predicates). func HasStatementWith(preds ...predicate.Statement) predicate.AttestationCollection { return predicate.AttestationCollection(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2O, true, StatementTable, StatementColumn), - ) + step := newStatementStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -243,32 +171,15 @@ func HasStatementWith(preds ...predicate.Statement) predicate.AttestationCollect // And groups predicates with the AND operator between them. func And(predicates ...predicate.AttestationCollection) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.AttestationCollection(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.AttestationCollection) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.AttestationCollection(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.AttestationCollection) predicate.AttestationCollection { - return predicate.AttestationCollection(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.AttestationCollection(sql.NotPredicates(p)) } diff --git a/ent/attestationcollection_create.go b/ent/attestationcollection_create.go index 2bc08908..d7b77d90 100644 --- a/ent/attestationcollection_create.go +++ b/ent/attestationcollection_create.go @@ -60,49 +60,7 @@ func (acc *AttestationCollectionCreate) Mutation() *AttestationCollectionMutatio // Save creates the AttestationCollection in the database. func (acc *AttestationCollectionCreate) Save(ctx context.Context) (*AttestationCollection, error) { - var ( - err error - node *AttestationCollection - ) - if len(acc.hooks) == 0 { - if err = acc.check(); err != nil { - return nil, err - } - node, err = acc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationCollectionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = acc.check(); err != nil { - return nil, err - } - acc.mutation = mutation - if node, err = acc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(acc.hooks) - 1; i >= 0; i-- { - if acc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = acc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, acc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*AttestationCollection) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AttestationCollectionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, acc.sqlSave, acc.mutation, acc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -144,6 +102,9 @@ func (acc *AttestationCollectionCreate) check() error { } func (acc *AttestationCollectionCreate) sqlSave(ctx context.Context) (*AttestationCollection, error) { + if err := acc.check(); err != nil { + return nil, err + } _node, _spec := acc.createSpec() if err := sqlgraph.CreateNode(ctx, acc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -153,26 +114,18 @@ func (acc *AttestationCollectionCreate) sqlSave(ctx context.Context) (*Attestati } id := _spec.ID.Value.(int64) _node.ID = int(id) + acc.mutation.id = &_node.ID + acc.mutation.done = true return _node, nil } func (acc *AttestationCollectionCreate) createSpec() (*AttestationCollection, *sqlgraph.CreateSpec) { var ( _node = &AttestationCollection{config: acc.config} - _spec = &sqlgraph.CreateSpec{ - Table: attestationcollection.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(attestationcollection.Table, sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt)) ) if value, ok := acc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestationcollection.FieldName, - }) + _spec.SetField(attestationcollection.FieldName, field.TypeString, value) _node.Name = value } if nodes := acc.mutation.AttestationsIDs(); len(nodes) > 0 { @@ -183,10 +136,7 @@ func (acc *AttestationCollectionCreate) createSpec() (*AttestationCollection, *s Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -202,10 +152,7 @@ func (acc *AttestationCollectionCreate) createSpec() (*AttestationCollection, *s Columns: []string{attestationcollection.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -220,11 +167,15 @@ func (acc *AttestationCollectionCreate) createSpec() (*AttestationCollection, *s // AttestationCollectionCreateBulk is the builder for creating many AttestationCollection entities in bulk. type AttestationCollectionCreateBulk struct { config + err error builders []*AttestationCollectionCreate } // Save creates the AttestationCollection entities in the database. func (accb *AttestationCollectionCreateBulk) Save(ctx context.Context) ([]*AttestationCollection, error) { + if accb.err != nil { + return nil, accb.err + } specs := make([]*sqlgraph.CreateSpec, len(accb.builders)) nodes := make([]*AttestationCollection, len(accb.builders)) mutators := make([]Mutator, len(accb.builders)) @@ -240,8 +191,8 @@ func (accb *AttestationCollectionCreateBulk) Save(ctx context.Context) ([]*Attes return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, accb.builders[i+1].mutation) } else { diff --git a/ent/attestationcollection_delete.go b/ent/attestationcollection_delete.go index 3be55ee7..1be981b9 100644 --- a/ent/attestationcollection_delete.go +++ b/ent/attestationcollection_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (acd *AttestationCollectionDelete) Where(ps ...predicate.AttestationCollect // Exec executes the deletion query and returns how many vertices were deleted. func (acd *AttestationCollectionDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(acd.hooks) == 0 { - affected, err = acd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationCollectionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - acd.mutation = mutation - affected, err = acd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(acd.hooks) - 1; i >= 0; i-- { - if acd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = acd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, acd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, acd.sqlExec, acd.mutation, acd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (acd *AttestationCollectionDelete) ExecX(ctx context.Context) int { } func (acd *AttestationCollectionDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestationcollection.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(attestationcollection.Table, sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt)) if ps := acd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (acd *AttestationCollectionDelete) sqlExec(ctx context.Context) (int, error if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + acd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type AttestationCollectionDeleteOne struct { acd *AttestationCollectionDelete } +// Where appends a list predicates to the AttestationCollectionDelete builder. +func (acdo *AttestationCollectionDeleteOne) Where(ps ...predicate.AttestationCollection) *AttestationCollectionDeleteOne { + acdo.acd.mutation.Where(ps...) + return acdo +} + // Exec executes the deletion query. func (acdo *AttestationCollectionDeleteOne) Exec(ctx context.Context) error { n, err := acdo.acd.Exec(ctx) @@ -111,5 +82,7 @@ func (acdo *AttestationCollectionDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (acdo *AttestationCollectionDeleteOne) ExecX(ctx context.Context) { - acdo.acd.ExecX(ctx) + if err := acdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/attestationcollection_query.go b/ent/attestationcollection_query.go index f0567a7c..cc8cd7be 100644 --- a/ent/attestationcollection_query.go +++ b/ent/attestationcollection_query.go @@ -20,17 +20,16 @@ import ( // AttestationCollectionQuery is the builder for querying AttestationCollection entities. type AttestationCollectionQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string - predicates []predicate.AttestationCollection - withAttestations *AttestationQuery - withStatement *StatementQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*AttestationCollection) error + ctx *QueryContext + order []attestationcollection.OrderOption + inters []Interceptor + predicates []predicate.AttestationCollection + withAttestations *AttestationQuery + withStatement *StatementQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*AttestationCollection) error + withNamedAttestations map[string]*AttestationQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -42,34 +41,34 @@ func (acq *AttestationCollectionQuery) Where(ps ...predicate.AttestationCollecti return acq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (acq *AttestationCollectionQuery) Limit(limit int) *AttestationCollectionQuery { - acq.limit = &limit + acq.ctx.Limit = &limit return acq } -// Offset adds an offset step to the query. +// Offset to start from. func (acq *AttestationCollectionQuery) Offset(offset int) *AttestationCollectionQuery { - acq.offset = &offset + acq.ctx.Offset = &offset return acq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (acq *AttestationCollectionQuery) Unique(unique bool) *AttestationCollectionQuery { - acq.unique = &unique + acq.ctx.Unique = &unique return acq } -// Order adds an order step to the query. -func (acq *AttestationCollectionQuery) Order(o ...OrderFunc) *AttestationCollectionQuery { +// Order specifies how the records should be ordered. +func (acq *AttestationCollectionQuery) Order(o ...attestationcollection.OrderOption) *AttestationCollectionQuery { acq.order = append(acq.order, o...) return acq } // QueryAttestations chains the current query on the "attestations" edge. func (acq *AttestationCollectionQuery) QueryAttestations() *AttestationQuery { - query := &AttestationQuery{config: acq.config} + query := (&AttestationClient{config: acq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := acq.prepareQuery(ctx); err != nil { return nil, err @@ -91,7 +90,7 @@ func (acq *AttestationCollectionQuery) QueryAttestations() *AttestationQuery { // QueryStatement chains the current query on the "statement" edge. func (acq *AttestationCollectionQuery) QueryStatement() *StatementQuery { - query := &StatementQuery{config: acq.config} + query := (&StatementClient{config: acq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := acq.prepareQuery(ctx); err != nil { return nil, err @@ -114,7 +113,7 @@ func (acq *AttestationCollectionQuery) QueryStatement() *StatementQuery { // First returns the first AttestationCollection entity from the query. // Returns a *NotFoundError when no AttestationCollection was found. func (acq *AttestationCollectionQuery) First(ctx context.Context) (*AttestationCollection, error) { - nodes, err := acq.Limit(1).All(ctx) + nodes, err := acq.Limit(1).All(setContextOp(ctx, acq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +136,7 @@ func (acq *AttestationCollectionQuery) FirstX(ctx context.Context) *AttestationC // Returns a *NotFoundError when no AttestationCollection ID was found. func (acq *AttestationCollectionQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = acq.Limit(1).IDs(ctx); err != nil { + if ids, err = acq.Limit(1).IDs(setContextOp(ctx, acq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -160,7 +159,7 @@ func (acq *AttestationCollectionQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one AttestationCollection entity is found. // Returns a *NotFoundError when no AttestationCollection entities are found. func (acq *AttestationCollectionQuery) Only(ctx context.Context) (*AttestationCollection, error) { - nodes, err := acq.Limit(2).All(ctx) + nodes, err := acq.Limit(2).All(setContextOp(ctx, acq.ctx, "Only")) if err != nil { return nil, err } @@ -188,7 +187,7 @@ func (acq *AttestationCollectionQuery) OnlyX(ctx context.Context) *AttestationCo // Returns a *NotFoundError when no entities are found. func (acq *AttestationCollectionQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = acq.Limit(2).IDs(ctx); err != nil { + if ids, err = acq.Limit(2).IDs(setContextOp(ctx, acq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -213,10 +212,12 @@ func (acq *AttestationCollectionQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of AttestationCollections. func (acq *AttestationCollectionQuery) All(ctx context.Context) ([]*AttestationCollection, error) { + ctx = setContextOp(ctx, acq.ctx, "All") if err := acq.prepareQuery(ctx); err != nil { return nil, err } - return acq.sqlAll(ctx) + qr := querierAll[[]*AttestationCollection, *AttestationCollectionQuery]() + return withInterceptors[[]*AttestationCollection](ctx, acq, qr, acq.inters) } // AllX is like All, but panics if an error occurs. @@ -229,9 +230,12 @@ func (acq *AttestationCollectionQuery) AllX(ctx context.Context) []*AttestationC } // IDs executes the query and returns a list of AttestationCollection IDs. -func (acq *AttestationCollectionQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := acq.Select(attestationcollection.FieldID).Scan(ctx, &ids); err != nil { +func (acq *AttestationCollectionQuery) IDs(ctx context.Context) (ids []int, err error) { + if acq.ctx.Unique == nil && acq.path != nil { + acq.Unique(true) + } + ctx = setContextOp(ctx, acq.ctx, "IDs") + if err = acq.Select(attestationcollection.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -248,10 +252,11 @@ func (acq *AttestationCollectionQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (acq *AttestationCollectionQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, acq.ctx, "Count") if err := acq.prepareQuery(ctx); err != nil { return 0, err } - return acq.sqlCount(ctx) + return withInterceptors[int](ctx, acq, querierCount[*AttestationCollectionQuery](), acq.inters) } // CountX is like Count, but panics if an error occurs. @@ -265,10 +270,15 @@ func (acq *AttestationCollectionQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (acq *AttestationCollectionQuery) Exist(ctx context.Context) (bool, error) { - if err := acq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, acq.ctx, "Exist") + switch _, err := acq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return acq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -288,23 +298,22 @@ func (acq *AttestationCollectionQuery) Clone() *AttestationCollectionQuery { } return &AttestationCollectionQuery{ config: acq.config, - limit: acq.limit, - offset: acq.offset, - order: append([]OrderFunc{}, acq.order...), + ctx: acq.ctx.Clone(), + order: append([]attestationcollection.OrderOption{}, acq.order...), + inters: append([]Interceptor{}, acq.inters...), predicates: append([]predicate.AttestationCollection{}, acq.predicates...), withAttestations: acq.withAttestations.Clone(), withStatement: acq.withStatement.Clone(), // clone intermediate query. - sql: acq.sql.Clone(), - path: acq.path, - unique: acq.unique, + sql: acq.sql.Clone(), + path: acq.path, } } // WithAttestations tells the query-builder to eager-load the nodes that are connected to // the "attestations" edge. The optional arguments are used to configure the query builder of the edge. func (acq *AttestationCollectionQuery) WithAttestations(opts ...func(*AttestationQuery)) *AttestationCollectionQuery { - query := &AttestationQuery{config: acq.config} + query := (&AttestationClient{config: acq.config}).Query() for _, opt := range opts { opt(query) } @@ -315,7 +324,7 @@ func (acq *AttestationCollectionQuery) WithAttestations(opts ...func(*Attestatio // WithStatement tells the query-builder to eager-load the nodes that are connected to // the "statement" edge. The optional arguments are used to configure the query builder of the edge. func (acq *AttestationCollectionQuery) WithStatement(opts ...func(*StatementQuery)) *AttestationCollectionQuery { - query := &StatementQuery{config: acq.config} + query := (&StatementClient{config: acq.config}).Query() for _, opt := range opts { opt(query) } @@ -338,16 +347,11 @@ func (acq *AttestationCollectionQuery) WithStatement(opts ...func(*StatementQuer // Aggregate(ent.Count()). // Scan(ctx, &v) func (acq *AttestationCollectionQuery) GroupBy(field string, fields ...string) *AttestationCollectionGroupBy { - grbuild := &AttestationCollectionGroupBy{config: acq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := acq.prepareQuery(ctx); err != nil { - return nil, err - } - return acq.sqlQuery(ctx), nil - } + acq.ctx.Fields = append([]string{field}, fields...) + grbuild := &AttestationCollectionGroupBy{build: acq} + grbuild.flds = &acq.ctx.Fields grbuild.label = attestationcollection.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -364,15 +368,30 @@ func (acq *AttestationCollectionQuery) GroupBy(field string, fields ...string) * // Select(attestationcollection.FieldName). // Scan(ctx, &v) func (acq *AttestationCollectionQuery) Select(fields ...string) *AttestationCollectionSelect { - acq.fields = append(acq.fields, fields...) - selbuild := &AttestationCollectionSelect{AttestationCollectionQuery: acq} - selbuild.label = attestationcollection.Label - selbuild.flds, selbuild.scan = &acq.fields, selbuild.Scan - return selbuild + acq.ctx.Fields = append(acq.ctx.Fields, fields...) + sbuild := &AttestationCollectionSelect{AttestationCollectionQuery: acq} + sbuild.label = attestationcollection.Label + sbuild.flds, sbuild.scan = &acq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a AttestationCollectionSelect configured with the given aggregations. +func (acq *AttestationCollectionQuery) Aggregate(fns ...AggregateFunc) *AttestationCollectionSelect { + return acq.Select().Aggregate(fns...) } func (acq *AttestationCollectionQuery) prepareQuery(ctx context.Context) error { - for _, f := range acq.fields { + for _, inter := range acq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, acq); err != nil { + return err + } + } + } + for _, f := range acq.ctx.Fields { if !attestationcollection.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -437,6 +456,13 @@ func (acq *AttestationCollectionQuery) sqlAll(ctx context.Context, hooks ...quer return nil, err } } + for name, query := range acq.withNamedAttestations { + if err := acq.loadAttestations(ctx, query, nodes, + func(n *AttestationCollection) { n.appendNamedAttestations(name) }, + func(n *AttestationCollection, e *Attestation) { n.appendNamedAttestations(name, e) }); err != nil { + return nil, err + } + } for i := range acq.loadTotal { if err := acq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -457,7 +483,7 @@ func (acq *AttestationCollectionQuery) loadAttestations(ctx context.Context, que } query.withFKs = true query.Where(predicate.Attestation(func(s *sql.Selector) { - s.Where(sql.InValues(attestationcollection.AttestationsColumn, fks...)) + s.Where(sql.InValues(s.C(attestationcollection.AttestationsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -470,7 +496,7 @@ func (acq *AttestationCollectionQuery) loadAttestations(ctx context.Context, que } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "attestation_collection_attestations" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "attestation_collection_attestations" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -489,6 +515,9 @@ func (acq *AttestationCollectionQuery) loadStatement(ctx context.Context, query } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(statement.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -511,41 +540,22 @@ func (acq *AttestationCollectionQuery) sqlCount(ctx context.Context) (int, error if len(acq.modifiers) > 0 { _spec.Modifiers = acq.modifiers } - _spec.Node.Columns = acq.fields - if len(acq.fields) > 0 { - _spec.Unique = acq.unique != nil && *acq.unique + _spec.Node.Columns = acq.ctx.Fields + if len(acq.ctx.Fields) > 0 { + _spec.Unique = acq.ctx.Unique != nil && *acq.ctx.Unique } return sqlgraph.CountNodes(ctx, acq.driver, _spec) } -func (acq *AttestationCollectionQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := acq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (acq *AttestationCollectionQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestationcollection.Table, - Columns: attestationcollection.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, - }, - From: acq.sql, - Unique: true, - } - if unique := acq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(attestationcollection.Table, attestationcollection.Columns, sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt)) + _spec.From = acq.sql + if unique := acq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if acq.path != nil { + _spec.Unique = true } - if fields := acq.fields; len(fields) > 0 { + if fields := acq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, attestationcollection.FieldID) for i := range fields { @@ -561,10 +571,10 @@ func (acq *AttestationCollectionQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := acq.limit; limit != nil { + if limit := acq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := acq.offset; offset != nil { + if offset := acq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := acq.order; len(ps) > 0 { @@ -580,7 +590,7 @@ func (acq *AttestationCollectionQuery) querySpec() *sqlgraph.QuerySpec { func (acq *AttestationCollectionQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(acq.driver.Dialect()) t1 := builder.Table(attestationcollection.Table) - columns := acq.fields + columns := acq.ctx.Fields if len(columns) == 0 { columns = attestationcollection.Columns } @@ -589,7 +599,7 @@ func (acq *AttestationCollectionQuery) sqlQuery(ctx context.Context) *sql.Select selector = acq.sql selector.Select(selector.Columns(columns...)...) } - if acq.unique != nil && *acq.unique { + if acq.ctx.Unique != nil && *acq.ctx.Unique { selector.Distinct() } for _, p := range acq.predicates { @@ -598,26 +608,35 @@ func (acq *AttestationCollectionQuery) sqlQuery(ctx context.Context) *sql.Select for _, p := range acq.order { p(selector) } - if offset := acq.offset; offset != nil { + if offset := acq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := acq.limit; limit != nil { + if limit := acq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } +// WithNamedAttestations tells the query-builder to eager-load the nodes that are connected to the "attestations" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (acq *AttestationCollectionQuery) WithNamedAttestations(name string, opts ...func(*AttestationQuery)) *AttestationCollectionQuery { + query := (&AttestationClient{config: acq.config}).Query() + for _, opt := range opts { + opt(query) + } + if acq.withNamedAttestations == nil { + acq.withNamedAttestations = make(map[string]*AttestationQuery) + } + acq.withNamedAttestations[name] = query + return acq +} + // AttestationCollectionGroupBy is the group-by builder for AttestationCollection entities. type AttestationCollectionGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *AttestationCollectionQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -626,74 +645,77 @@ func (acgb *AttestationCollectionGroupBy) Aggregate(fns ...AggregateFunc) *Attes return acgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (acgb *AttestationCollectionGroupBy) Scan(ctx context.Context, v any) error { - query, err := acgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, acgb.build.ctx, "GroupBy") + if err := acgb.build.prepareQuery(ctx); err != nil { return err } - acgb.sql = query - return acgb.sqlScan(ctx, v) + return scanWithInterceptors[*AttestationCollectionQuery, *AttestationCollectionGroupBy](ctx, acgb.build, acgb, acgb.build.inters, v) } -func (acgb *AttestationCollectionGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range acgb.fields { - if !attestationcollection.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (acgb *AttestationCollectionGroupBy) sqlScan(ctx context.Context, root *AttestationCollectionQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(acgb.fns)) + for _, fn := range acgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*acgb.flds)+len(acgb.fns)) + for _, f := range *acgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := acgb.sqlQuery() + selector.GroupBy(selector.Columns(*acgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := acgb.driver.Query(ctx, query, args, rows); err != nil { + if err := acgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (acgb *AttestationCollectionGroupBy) sqlQuery() *sql.Selector { - selector := acgb.sql.Select() - aggregation := make([]string, 0, len(acgb.fns)) - for _, fn := range acgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(acgb.fields)+len(acgb.fns)) - for _, f := range acgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(acgb.fields...)...) -} - // AttestationCollectionSelect is the builder for selecting fields of AttestationCollection entities. type AttestationCollectionSelect struct { *AttestationCollectionQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (acs *AttestationCollectionSelect) Aggregate(fns ...AggregateFunc) *AttestationCollectionSelect { + acs.fns = append(acs.fns, fns...) + return acs } // Scan applies the selector query and scans the result into the given value. func (acs *AttestationCollectionSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, acs.ctx, "Select") if err := acs.prepareQuery(ctx); err != nil { return err } - acs.sql = acs.AttestationCollectionQuery.sqlQuery(ctx) - return acs.sqlScan(ctx, v) + return scanWithInterceptors[*AttestationCollectionQuery, *AttestationCollectionSelect](ctx, acs.AttestationCollectionQuery, acs, acs.inters, v) } -func (acs *AttestationCollectionSelect) sqlScan(ctx context.Context, v any) error { +func (acs *AttestationCollectionSelect) sqlScan(ctx context.Context, root *AttestationCollectionQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(acs.fns)) + for _, fn := range acs.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*acs.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := acs.sql.Query() + query, args := selector.Query() if err := acs.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/attestationcollection_update.go b/ent/attestationcollection_update.go index 84a6a3cb..63959913 100644 --- a/ent/attestationcollection_update.go +++ b/ent/attestationcollection_update.go @@ -95,40 +95,7 @@ func (acu *AttestationCollectionUpdate) ClearStatement() *AttestationCollectionU // Save executes the query and returns the number of nodes affected by the update operation. func (acu *AttestationCollectionUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(acu.hooks) == 0 { - if err = acu.check(); err != nil { - return 0, err - } - affected, err = acu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationCollectionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = acu.check(); err != nil { - return 0, err - } - acu.mutation = mutation - affected, err = acu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(acu.hooks) - 1; i >= 0; i-- { - if acu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = acu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, acu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, acu.sqlSave, acu.mutation, acu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -167,16 +134,10 @@ func (acu *AttestationCollectionUpdate) check() error { } func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestationcollection.Table, - Columns: attestationcollection.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, - }, + if err := acu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(attestationcollection.Table, attestationcollection.Columns, sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt)) if ps := acu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -185,11 +146,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err } } if value, ok := acu.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestationcollection.FieldName, - }) + _spec.SetField(attestationcollection.FieldName, field.TypeString, value) } if acu.mutation.AttestationsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -199,10 +156,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -215,10 +169,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -234,10 +185,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -253,10 +201,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err Columns: []string{attestationcollection.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -269,10 +214,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err Columns: []string{attestationcollection.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -288,6 +230,7 @@ func (acu *AttestationCollectionUpdate) sqlSave(ctx context.Context) (n int, err } return 0, err } + acu.mutation.done = true return n, nil } @@ -363,6 +306,12 @@ func (acuo *AttestationCollectionUpdateOne) ClearStatement() *AttestationCollect return acuo } +// Where appends a list predicates to the AttestationCollectionUpdate builder. +func (acuo *AttestationCollectionUpdateOne) Where(ps ...predicate.AttestationCollection) *AttestationCollectionUpdateOne { + acuo.mutation.Where(ps...) + return acuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (acuo *AttestationCollectionUpdateOne) Select(field string, fields ...string) *AttestationCollectionUpdateOne { @@ -372,46 +321,7 @@ func (acuo *AttestationCollectionUpdateOne) Select(field string, fields ...strin // Save executes the query and returns the updated AttestationCollection entity. func (acuo *AttestationCollectionUpdateOne) Save(ctx context.Context) (*AttestationCollection, error) { - var ( - err error - node *AttestationCollection - ) - if len(acuo.hooks) == 0 { - if err = acuo.check(); err != nil { - return nil, err - } - node, err = acuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*AttestationCollectionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = acuo.check(); err != nil { - return nil, err - } - acuo.mutation = mutation - node, err = acuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(acuo.hooks) - 1; i >= 0; i-- { - if acuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = acuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, acuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*AttestationCollection) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from AttestationCollectionMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, acuo.sqlSave, acuo.mutation, acuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -450,16 +360,10 @@ func (acuo *AttestationCollectionUpdateOne) check() error { } func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node *AttestationCollection, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: attestationcollection.Table, - Columns: attestationcollection.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, - }, + if err := acuo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(attestationcollection.Table, attestationcollection.Columns, sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt)) id, ok := acuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "AttestationCollection.id" for update`)} @@ -485,11 +389,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node } } if value, ok := acuo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: attestationcollection.FieldName, - }) + _spec.SetField(attestationcollection.FieldName, field.TypeString, value) } if acuo.mutation.AttestationsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -499,10 +399,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -515,10 +412,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -534,10 +428,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node Columns: []string{attestationcollection.AttestationsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestation.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestation.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -553,10 +444,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node Columns: []string{attestationcollection.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -569,10 +457,7 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node Columns: []string{attestationcollection.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -591,5 +476,6 @@ func (acuo *AttestationCollectionUpdateOne) sqlSave(ctx context.Context) (_node } return nil, err } + acuo.mutation.done = true return _node, nil } diff --git a/ent/client.go b/ent/client.go index 9b2a3e76..bca59122 100644 --- a/ent/client.go +++ b/ent/client.go @@ -7,9 +7,14 @@ import ( "errors" "fmt" "log" + "reflect" "github.com/testifysec/archivista/ent/migrate" + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" "github.com/testifysec/archivista/ent/attestation" "github.com/testifysec/archivista/ent/attestationcollection" "github.com/testifysec/archivista/ent/dsse" @@ -19,10 +24,6 @@ import ( "github.com/testifysec/archivista/ent/subject" "github.com/testifysec/archivista/ent/subjectdigest" "github.com/testifysec/archivista/ent/timestamp" - - "entgo.io/ent/dialect" - "entgo.io/ent/dialect/sql" - "entgo.io/ent/dialect/sql/sqlgraph" ) // Client is the client that holds all ent builders. @@ -54,7 +55,7 @@ type Client struct { // NewClient creates a new client configured with the given options. func NewClient(opts ...Option) *Client { - cfg := config{log: log.Println, hooks: &hooks{}} + cfg := config{log: log.Println, hooks: &hooks{}, inters: &inters{}} cfg.options(opts...) client := &Client{config: cfg} client.init() @@ -74,6 +75,55 @@ func (c *Client) init() { c.Timestamp = NewTimestampClient(c.config) } +type ( + // config is the configuration for the client and its builder. + config struct { + // driver used for executing database requests. + driver dialect.Driver + // debug enable a debug logging. + debug bool + // log used for logging on debug mode. + log func(...any) + // hooks to execute on mutations. + hooks *hooks + // interceptors to execute on queries. + inters *inters + } + // Option function to configure the client. + Option func(*config) +) + +// options applies the options on the config object. +func (c *config) options(opts ...Option) { + for _, opt := range opts { + opt(c) + } + if c.debug { + c.driver = dialect.Debug(c.driver, c.log) + } +} + +// Debug enables debug logging on the ent.Driver. +func Debug() Option { + return func(c *config) { + c.debug = true + } +} + +// Log sets the logging function for debug mode. +func Log(fn func(...any)) Option { + return func(c *config) { + c.log = fn + } +} + +// Driver configures the client driver. +func Driver(driver dialect.Driver) Option { + return func(c *config) { + c.driver = driver + } +} + // Open opens a database/sql.DB specified by the driver name and // the data source name, and returns a new client attached to it. // Optional parameters can be added for configuring the client. @@ -90,11 +140,14 @@ func Open(driverName, dataSourceName string, options ...Option) (*Client, error) } } +// ErrTxStarted is returned when trying to start a new transaction from a transactional client. +var ErrTxStarted = errors.New("ent: cannot start a transaction within a transaction") + // Tx returns a new transactional client. The provided context // is used until the transaction is committed or rolled back. func (c *Client) Tx(ctx context.Context) (*Tx, error) { if _, ok := c.driver.(*txDriver); ok { - return nil, errors.New("ent: cannot start a transaction within a transaction") + return nil, ErrTxStarted } tx, err := newTx(ctx, c.driver) if err != nil { @@ -170,15 +223,49 @@ func (c *Client) Close() error { // Use adds the mutation hooks to all the entity clients. // In order to add hooks to a specific client, call: `client.Node.Use(...)`. func (c *Client) Use(hooks ...Hook) { - c.Attestation.Use(hooks...) - c.AttestationCollection.Use(hooks...) - c.Dsse.Use(hooks...) - c.PayloadDigest.Use(hooks...) - c.Signature.Use(hooks...) - c.Statement.Use(hooks...) - c.Subject.Use(hooks...) - c.SubjectDigest.Use(hooks...) - c.Timestamp.Use(hooks...) + for _, n := range []interface{ Use(...Hook) }{ + c.Attestation, c.AttestationCollection, c.Dsse, c.PayloadDigest, c.Signature, + c.Statement, c.Subject, c.SubjectDigest, c.Timestamp, + } { + n.Use(hooks...) + } +} + +// Intercept adds the query interceptors to all the entity clients. +// In order to add interceptors to a specific client, call: `client.Node.Intercept(...)`. +func (c *Client) Intercept(interceptors ...Interceptor) { + for _, n := range []interface{ Intercept(...Interceptor) }{ + c.Attestation, c.AttestationCollection, c.Dsse, c.PayloadDigest, c.Signature, + c.Statement, c.Subject, c.SubjectDigest, c.Timestamp, + } { + n.Intercept(interceptors...) + } +} + +// Mutate implements the ent.Mutator interface. +func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { + switch m := m.(type) { + case *AttestationMutation: + return c.Attestation.mutate(ctx, m) + case *AttestationCollectionMutation: + return c.AttestationCollection.mutate(ctx, m) + case *DsseMutation: + return c.Dsse.mutate(ctx, m) + case *PayloadDigestMutation: + return c.PayloadDigest.mutate(ctx, m) + case *SignatureMutation: + return c.Signature.mutate(ctx, m) + case *StatementMutation: + return c.Statement.mutate(ctx, m) + case *SubjectMutation: + return c.Subject.mutate(ctx, m) + case *SubjectDigestMutation: + return c.SubjectDigest.mutate(ctx, m) + case *TimestampMutation: + return c.Timestamp.mutate(ctx, m) + default: + return nil, fmt.Errorf("ent: unknown mutation type %T", m) + } } // AttestationClient is a client for the Attestation schema. @@ -197,6 +284,12 @@ func (c *AttestationClient) Use(hooks ...Hook) { c.hooks.Attestation = append(c.hooks.Attestation, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `attestation.Intercept(f(g(h())))`. +func (c *AttestationClient) Intercept(interceptors ...Interceptor) { + c.inters.Attestation = append(c.inters.Attestation, interceptors...) +} + // Create returns a builder for creating a Attestation entity. func (c *AttestationClient) Create() *AttestationCreate { mutation := newAttestationMutation(c.config, OpCreate) @@ -208,6 +301,21 @@ func (c *AttestationClient) CreateBulk(builders ...*AttestationCreate) *Attestat return &AttestationCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AttestationClient) MapCreateBulk(slice any, setFunc func(*AttestationCreate, int)) *AttestationCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AttestationCreateBulk{err: fmt.Errorf("calling to AttestationClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AttestationCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AttestationCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Attestation. func (c *AttestationClient) Update() *AttestationUpdate { mutation := newAttestationMutation(c.config, OpUpdate) @@ -237,7 +345,7 @@ func (c *AttestationClient) DeleteOne(a *Attestation) *AttestationDeleteOne { return c.DeleteOneID(a.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *AttestationClient) DeleteOneID(id int) *AttestationDeleteOne { builder := c.Delete().Where(attestation.ID(id)) builder.mutation.id = &id @@ -249,6 +357,8 @@ func (c *AttestationClient) DeleteOneID(id int) *AttestationDeleteOne { func (c *AttestationClient) Query() *AttestationQuery { return &AttestationQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAttestation}, + inters: c.Interceptors(), } } @@ -268,8 +378,8 @@ func (c *AttestationClient) GetX(ctx context.Context, id int) *Attestation { // QueryAttestationCollection queries the attestation_collection edge of a Attestation. func (c *AttestationClient) QueryAttestationCollection(a *Attestation) *AttestationCollectionQuery { - query := &AttestationCollectionQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AttestationCollectionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := a.ID step := sqlgraph.NewStep( sqlgraph.From(attestation.Table, attestation.FieldID, id), @@ -287,6 +397,26 @@ func (c *AttestationClient) Hooks() []Hook { return c.hooks.Attestation } +// Interceptors returns the client interceptors. +func (c *AttestationClient) Interceptors() []Interceptor { + return c.inters.Attestation +} + +func (c *AttestationClient) mutate(ctx context.Context, m *AttestationMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AttestationCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AttestationUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AttestationUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AttestationDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Attestation mutation op: %q", m.Op()) + } +} + // AttestationCollectionClient is a client for the AttestationCollection schema. type AttestationCollectionClient struct { config @@ -303,6 +433,12 @@ func (c *AttestationCollectionClient) Use(hooks ...Hook) { c.hooks.AttestationCollection = append(c.hooks.AttestationCollection, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `attestationcollection.Intercept(f(g(h())))`. +func (c *AttestationCollectionClient) Intercept(interceptors ...Interceptor) { + c.inters.AttestationCollection = append(c.inters.AttestationCollection, interceptors...) +} + // Create returns a builder for creating a AttestationCollection entity. func (c *AttestationCollectionClient) Create() *AttestationCollectionCreate { mutation := newAttestationCollectionMutation(c.config, OpCreate) @@ -314,6 +450,21 @@ func (c *AttestationCollectionClient) CreateBulk(builders ...*AttestationCollect return &AttestationCollectionCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *AttestationCollectionClient) MapCreateBulk(slice any, setFunc func(*AttestationCollectionCreate, int)) *AttestationCollectionCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &AttestationCollectionCreateBulk{err: fmt.Errorf("calling to AttestationCollectionClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*AttestationCollectionCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &AttestationCollectionCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for AttestationCollection. func (c *AttestationCollectionClient) Update() *AttestationCollectionUpdate { mutation := newAttestationCollectionMutation(c.config, OpUpdate) @@ -343,7 +494,7 @@ func (c *AttestationCollectionClient) DeleteOne(ac *AttestationCollection) *Atte return c.DeleteOneID(ac.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *AttestationCollectionClient) DeleteOneID(id int) *AttestationCollectionDeleteOne { builder := c.Delete().Where(attestationcollection.ID(id)) builder.mutation.id = &id @@ -355,6 +506,8 @@ func (c *AttestationCollectionClient) DeleteOneID(id int) *AttestationCollection func (c *AttestationCollectionClient) Query() *AttestationCollectionQuery { return &AttestationCollectionQuery{ config: c.config, + ctx: &QueryContext{Type: TypeAttestationCollection}, + inters: c.Interceptors(), } } @@ -374,8 +527,8 @@ func (c *AttestationCollectionClient) GetX(ctx context.Context, id int) *Attesta // QueryAttestations queries the attestations edge of a AttestationCollection. func (c *AttestationCollectionClient) QueryAttestations(ac *AttestationCollection) *AttestationQuery { - query := &AttestationQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AttestationClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := ac.ID step := sqlgraph.NewStep( sqlgraph.From(attestationcollection.Table, attestationcollection.FieldID, id), @@ -390,8 +543,8 @@ func (c *AttestationCollectionClient) QueryAttestations(ac *AttestationCollectio // QueryStatement queries the statement edge of a AttestationCollection. func (c *AttestationCollectionClient) QueryStatement(ac *AttestationCollection) *StatementQuery { - query := &StatementQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&StatementClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := ac.ID step := sqlgraph.NewStep( sqlgraph.From(attestationcollection.Table, attestationcollection.FieldID, id), @@ -409,6 +562,26 @@ func (c *AttestationCollectionClient) Hooks() []Hook { return c.hooks.AttestationCollection } +// Interceptors returns the client interceptors. +func (c *AttestationCollectionClient) Interceptors() []Interceptor { + return c.inters.AttestationCollection +} + +func (c *AttestationCollectionClient) mutate(ctx context.Context, m *AttestationCollectionMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&AttestationCollectionCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&AttestationCollectionUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&AttestationCollectionUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&AttestationCollectionDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown AttestationCollection mutation op: %q", m.Op()) + } +} + // DsseClient is a client for the Dsse schema. type DsseClient struct { config @@ -425,6 +598,12 @@ func (c *DsseClient) Use(hooks ...Hook) { c.hooks.Dsse = append(c.hooks.Dsse, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `dsse.Intercept(f(g(h())))`. +func (c *DsseClient) Intercept(interceptors ...Interceptor) { + c.inters.Dsse = append(c.inters.Dsse, interceptors...) +} + // Create returns a builder for creating a Dsse entity. func (c *DsseClient) Create() *DsseCreate { mutation := newDsseMutation(c.config, OpCreate) @@ -436,6 +615,21 @@ func (c *DsseClient) CreateBulk(builders ...*DsseCreate) *DsseCreateBulk { return &DsseCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *DsseClient) MapCreateBulk(slice any, setFunc func(*DsseCreate, int)) *DsseCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &DsseCreateBulk{err: fmt.Errorf("calling to DsseClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*DsseCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &DsseCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Dsse. func (c *DsseClient) Update() *DsseUpdate { mutation := newDsseMutation(c.config, OpUpdate) @@ -465,7 +659,7 @@ func (c *DsseClient) DeleteOne(d *Dsse) *DsseDeleteOne { return c.DeleteOneID(d.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *DsseClient) DeleteOneID(id int) *DsseDeleteOne { builder := c.Delete().Where(dsse.ID(id)) builder.mutation.id = &id @@ -477,6 +671,8 @@ func (c *DsseClient) DeleteOneID(id int) *DsseDeleteOne { func (c *DsseClient) Query() *DsseQuery { return &DsseQuery{ config: c.config, + ctx: &QueryContext{Type: TypeDsse}, + inters: c.Interceptors(), } } @@ -496,8 +692,8 @@ func (c *DsseClient) GetX(ctx context.Context, id int) *Dsse { // QueryStatement queries the statement edge of a Dsse. func (c *DsseClient) QueryStatement(d *Dsse) *StatementQuery { - query := &StatementQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&StatementClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(dsse.Table, dsse.FieldID, id), @@ -512,8 +708,8 @@ func (c *DsseClient) QueryStatement(d *Dsse) *StatementQuery { // QuerySignatures queries the signatures edge of a Dsse. func (c *DsseClient) QuerySignatures(d *Dsse) *SignatureQuery { - query := &SignatureQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&SignatureClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(dsse.Table, dsse.FieldID, id), @@ -528,8 +724,8 @@ func (c *DsseClient) QuerySignatures(d *Dsse) *SignatureQuery { // QueryPayloadDigests queries the payload_digests edge of a Dsse. func (c *DsseClient) QueryPayloadDigests(d *Dsse) *PayloadDigestQuery { - query := &PayloadDigestQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&PayloadDigestClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := d.ID step := sqlgraph.NewStep( sqlgraph.From(dsse.Table, dsse.FieldID, id), @@ -547,6 +743,26 @@ func (c *DsseClient) Hooks() []Hook { return c.hooks.Dsse } +// Interceptors returns the client interceptors. +func (c *DsseClient) Interceptors() []Interceptor { + return c.inters.Dsse +} + +func (c *DsseClient) mutate(ctx context.Context, m *DsseMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&DsseCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&DsseUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&DsseUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&DsseDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Dsse mutation op: %q", m.Op()) + } +} + // PayloadDigestClient is a client for the PayloadDigest schema. type PayloadDigestClient struct { config @@ -563,6 +779,12 @@ func (c *PayloadDigestClient) Use(hooks ...Hook) { c.hooks.PayloadDigest = append(c.hooks.PayloadDigest, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `payloaddigest.Intercept(f(g(h())))`. +func (c *PayloadDigestClient) Intercept(interceptors ...Interceptor) { + c.inters.PayloadDigest = append(c.inters.PayloadDigest, interceptors...) +} + // Create returns a builder for creating a PayloadDigest entity. func (c *PayloadDigestClient) Create() *PayloadDigestCreate { mutation := newPayloadDigestMutation(c.config, OpCreate) @@ -574,6 +796,21 @@ func (c *PayloadDigestClient) CreateBulk(builders ...*PayloadDigestCreate) *Payl return &PayloadDigestCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *PayloadDigestClient) MapCreateBulk(slice any, setFunc func(*PayloadDigestCreate, int)) *PayloadDigestCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &PayloadDigestCreateBulk{err: fmt.Errorf("calling to PayloadDigestClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*PayloadDigestCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &PayloadDigestCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for PayloadDigest. func (c *PayloadDigestClient) Update() *PayloadDigestUpdate { mutation := newPayloadDigestMutation(c.config, OpUpdate) @@ -603,7 +840,7 @@ func (c *PayloadDigestClient) DeleteOne(pd *PayloadDigest) *PayloadDigestDeleteO return c.DeleteOneID(pd.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *PayloadDigestClient) DeleteOneID(id int) *PayloadDigestDeleteOne { builder := c.Delete().Where(payloaddigest.ID(id)) builder.mutation.id = &id @@ -615,6 +852,8 @@ func (c *PayloadDigestClient) DeleteOneID(id int) *PayloadDigestDeleteOne { func (c *PayloadDigestClient) Query() *PayloadDigestQuery { return &PayloadDigestQuery{ config: c.config, + ctx: &QueryContext{Type: TypePayloadDigest}, + inters: c.Interceptors(), } } @@ -634,8 +873,8 @@ func (c *PayloadDigestClient) GetX(ctx context.Context, id int) *PayloadDigest { // QueryDsse queries the dsse edge of a PayloadDigest. func (c *PayloadDigestClient) QueryDsse(pd *PayloadDigest) *DsseQuery { - query := &DsseQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DsseClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := pd.ID step := sqlgraph.NewStep( sqlgraph.From(payloaddigest.Table, payloaddigest.FieldID, id), @@ -653,6 +892,26 @@ func (c *PayloadDigestClient) Hooks() []Hook { return c.hooks.PayloadDigest } +// Interceptors returns the client interceptors. +func (c *PayloadDigestClient) Interceptors() []Interceptor { + return c.inters.PayloadDigest +} + +func (c *PayloadDigestClient) mutate(ctx context.Context, m *PayloadDigestMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&PayloadDigestCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&PayloadDigestUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&PayloadDigestUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&PayloadDigestDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown PayloadDigest mutation op: %q", m.Op()) + } +} + // SignatureClient is a client for the Signature schema. type SignatureClient struct { config @@ -669,6 +928,12 @@ func (c *SignatureClient) Use(hooks ...Hook) { c.hooks.Signature = append(c.hooks.Signature, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `signature.Intercept(f(g(h())))`. +func (c *SignatureClient) Intercept(interceptors ...Interceptor) { + c.inters.Signature = append(c.inters.Signature, interceptors...) +} + // Create returns a builder for creating a Signature entity. func (c *SignatureClient) Create() *SignatureCreate { mutation := newSignatureMutation(c.config, OpCreate) @@ -680,6 +945,21 @@ func (c *SignatureClient) CreateBulk(builders ...*SignatureCreate) *SignatureCre return &SignatureCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SignatureClient) MapCreateBulk(slice any, setFunc func(*SignatureCreate, int)) *SignatureCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SignatureCreateBulk{err: fmt.Errorf("calling to SignatureClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SignatureCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SignatureCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Signature. func (c *SignatureClient) Update() *SignatureUpdate { mutation := newSignatureMutation(c.config, OpUpdate) @@ -709,7 +989,7 @@ func (c *SignatureClient) DeleteOne(s *Signature) *SignatureDeleteOne { return c.DeleteOneID(s.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *SignatureClient) DeleteOneID(id int) *SignatureDeleteOne { builder := c.Delete().Where(signature.ID(id)) builder.mutation.id = &id @@ -721,6 +1001,8 @@ func (c *SignatureClient) DeleteOneID(id int) *SignatureDeleteOne { func (c *SignatureClient) Query() *SignatureQuery { return &SignatureQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSignature}, + inters: c.Interceptors(), } } @@ -740,8 +1022,8 @@ func (c *SignatureClient) GetX(ctx context.Context, id int) *Signature { // QueryDsse queries the dsse edge of a Signature. func (c *SignatureClient) QueryDsse(s *Signature) *DsseQuery { - query := &DsseQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DsseClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(signature.Table, signature.FieldID, id), @@ -756,8 +1038,8 @@ func (c *SignatureClient) QueryDsse(s *Signature) *DsseQuery { // QueryTimestamps queries the timestamps edge of a Signature. func (c *SignatureClient) QueryTimestamps(s *Signature) *TimestampQuery { - query := &TimestampQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&TimestampClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(signature.Table, signature.FieldID, id), @@ -775,6 +1057,26 @@ func (c *SignatureClient) Hooks() []Hook { return c.hooks.Signature } +// Interceptors returns the client interceptors. +func (c *SignatureClient) Interceptors() []Interceptor { + return c.inters.Signature +} + +func (c *SignatureClient) mutate(ctx context.Context, m *SignatureMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SignatureCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SignatureUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SignatureUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SignatureDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Signature mutation op: %q", m.Op()) + } +} + // StatementClient is a client for the Statement schema. type StatementClient struct { config @@ -791,6 +1093,12 @@ func (c *StatementClient) Use(hooks ...Hook) { c.hooks.Statement = append(c.hooks.Statement, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `statement.Intercept(f(g(h())))`. +func (c *StatementClient) Intercept(interceptors ...Interceptor) { + c.inters.Statement = append(c.inters.Statement, interceptors...) +} + // Create returns a builder for creating a Statement entity. func (c *StatementClient) Create() *StatementCreate { mutation := newStatementMutation(c.config, OpCreate) @@ -802,6 +1110,21 @@ func (c *StatementClient) CreateBulk(builders ...*StatementCreate) *StatementCre return &StatementCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *StatementClient) MapCreateBulk(slice any, setFunc func(*StatementCreate, int)) *StatementCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &StatementCreateBulk{err: fmt.Errorf("calling to StatementClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*StatementCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &StatementCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Statement. func (c *StatementClient) Update() *StatementUpdate { mutation := newStatementMutation(c.config, OpUpdate) @@ -831,7 +1154,7 @@ func (c *StatementClient) DeleteOne(s *Statement) *StatementDeleteOne { return c.DeleteOneID(s.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *StatementClient) DeleteOneID(id int) *StatementDeleteOne { builder := c.Delete().Where(statement.ID(id)) builder.mutation.id = &id @@ -843,6 +1166,8 @@ func (c *StatementClient) DeleteOneID(id int) *StatementDeleteOne { func (c *StatementClient) Query() *StatementQuery { return &StatementQuery{ config: c.config, + ctx: &QueryContext{Type: TypeStatement}, + inters: c.Interceptors(), } } @@ -862,8 +1187,8 @@ func (c *StatementClient) GetX(ctx context.Context, id int) *Statement { // QuerySubjects queries the subjects edge of a Statement. func (c *StatementClient) QuerySubjects(s *Statement) *SubjectQuery { - query := &SubjectQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&SubjectClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(statement.Table, statement.FieldID, id), @@ -878,8 +1203,8 @@ func (c *StatementClient) QuerySubjects(s *Statement) *SubjectQuery { // QueryAttestationCollections queries the attestation_collections edge of a Statement. func (c *StatementClient) QueryAttestationCollections(s *Statement) *AttestationCollectionQuery { - query := &AttestationCollectionQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&AttestationCollectionClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(statement.Table, statement.FieldID, id), @@ -894,8 +1219,8 @@ func (c *StatementClient) QueryAttestationCollections(s *Statement) *Attestation // QueryDsse queries the dsse edge of a Statement. func (c *StatementClient) QueryDsse(s *Statement) *DsseQuery { - query := &DsseQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&DsseClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(statement.Table, statement.FieldID, id), @@ -913,6 +1238,26 @@ func (c *StatementClient) Hooks() []Hook { return c.hooks.Statement } +// Interceptors returns the client interceptors. +func (c *StatementClient) Interceptors() []Interceptor { + return c.inters.Statement +} + +func (c *StatementClient) mutate(ctx context.Context, m *StatementMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&StatementCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&StatementUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&StatementUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&StatementDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Statement mutation op: %q", m.Op()) + } +} + // SubjectClient is a client for the Subject schema. type SubjectClient struct { config @@ -929,6 +1274,12 @@ func (c *SubjectClient) Use(hooks ...Hook) { c.hooks.Subject = append(c.hooks.Subject, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `subject.Intercept(f(g(h())))`. +func (c *SubjectClient) Intercept(interceptors ...Interceptor) { + c.inters.Subject = append(c.inters.Subject, interceptors...) +} + // Create returns a builder for creating a Subject entity. func (c *SubjectClient) Create() *SubjectCreate { mutation := newSubjectMutation(c.config, OpCreate) @@ -940,6 +1291,21 @@ func (c *SubjectClient) CreateBulk(builders ...*SubjectCreate) *SubjectCreateBul return &SubjectCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SubjectClient) MapCreateBulk(slice any, setFunc func(*SubjectCreate, int)) *SubjectCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SubjectCreateBulk{err: fmt.Errorf("calling to SubjectClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SubjectCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SubjectCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Subject. func (c *SubjectClient) Update() *SubjectUpdate { mutation := newSubjectMutation(c.config, OpUpdate) @@ -969,7 +1335,7 @@ func (c *SubjectClient) DeleteOne(s *Subject) *SubjectDeleteOne { return c.DeleteOneID(s.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *SubjectClient) DeleteOneID(id int) *SubjectDeleteOne { builder := c.Delete().Where(subject.ID(id)) builder.mutation.id = &id @@ -981,6 +1347,8 @@ func (c *SubjectClient) DeleteOneID(id int) *SubjectDeleteOne { func (c *SubjectClient) Query() *SubjectQuery { return &SubjectQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSubject}, + inters: c.Interceptors(), } } @@ -1000,8 +1368,8 @@ func (c *SubjectClient) GetX(ctx context.Context, id int) *Subject { // QuerySubjectDigests queries the subject_digests edge of a Subject. func (c *SubjectClient) QuerySubjectDigests(s *Subject) *SubjectDigestQuery { - query := &SubjectDigestQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&SubjectDigestClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(subject.Table, subject.FieldID, id), @@ -1016,8 +1384,8 @@ func (c *SubjectClient) QuerySubjectDigests(s *Subject) *SubjectDigestQuery { // QueryStatement queries the statement edge of a Subject. func (c *SubjectClient) QueryStatement(s *Subject) *StatementQuery { - query := &StatementQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&StatementClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := s.ID step := sqlgraph.NewStep( sqlgraph.From(subject.Table, subject.FieldID, id), @@ -1035,6 +1403,26 @@ func (c *SubjectClient) Hooks() []Hook { return c.hooks.Subject } +// Interceptors returns the client interceptors. +func (c *SubjectClient) Interceptors() []Interceptor { + return c.inters.Subject +} + +func (c *SubjectClient) mutate(ctx context.Context, m *SubjectMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SubjectCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SubjectUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SubjectUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SubjectDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Subject mutation op: %q", m.Op()) + } +} + // SubjectDigestClient is a client for the SubjectDigest schema. type SubjectDigestClient struct { config @@ -1051,6 +1439,12 @@ func (c *SubjectDigestClient) Use(hooks ...Hook) { c.hooks.SubjectDigest = append(c.hooks.SubjectDigest, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `subjectdigest.Intercept(f(g(h())))`. +func (c *SubjectDigestClient) Intercept(interceptors ...Interceptor) { + c.inters.SubjectDigest = append(c.inters.SubjectDigest, interceptors...) +} + // Create returns a builder for creating a SubjectDigest entity. func (c *SubjectDigestClient) Create() *SubjectDigestCreate { mutation := newSubjectDigestMutation(c.config, OpCreate) @@ -1062,6 +1456,21 @@ func (c *SubjectDigestClient) CreateBulk(builders ...*SubjectDigestCreate) *Subj return &SubjectDigestCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SubjectDigestClient) MapCreateBulk(slice any, setFunc func(*SubjectDigestCreate, int)) *SubjectDigestCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SubjectDigestCreateBulk{err: fmt.Errorf("calling to SubjectDigestClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SubjectDigestCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SubjectDigestCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for SubjectDigest. func (c *SubjectDigestClient) Update() *SubjectDigestUpdate { mutation := newSubjectDigestMutation(c.config, OpUpdate) @@ -1091,7 +1500,7 @@ func (c *SubjectDigestClient) DeleteOne(sd *SubjectDigest) *SubjectDigestDeleteO return c.DeleteOneID(sd.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *SubjectDigestClient) DeleteOneID(id int) *SubjectDigestDeleteOne { builder := c.Delete().Where(subjectdigest.ID(id)) builder.mutation.id = &id @@ -1103,6 +1512,8 @@ func (c *SubjectDigestClient) DeleteOneID(id int) *SubjectDigestDeleteOne { func (c *SubjectDigestClient) Query() *SubjectDigestQuery { return &SubjectDigestQuery{ config: c.config, + ctx: &QueryContext{Type: TypeSubjectDigest}, + inters: c.Interceptors(), } } @@ -1122,8 +1533,8 @@ func (c *SubjectDigestClient) GetX(ctx context.Context, id int) *SubjectDigest { // QuerySubject queries the subject edge of a SubjectDigest. func (c *SubjectDigestClient) QuerySubject(sd *SubjectDigest) *SubjectQuery { - query := &SubjectQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&SubjectClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := sd.ID step := sqlgraph.NewStep( sqlgraph.From(subjectdigest.Table, subjectdigest.FieldID, id), @@ -1141,6 +1552,26 @@ func (c *SubjectDigestClient) Hooks() []Hook { return c.hooks.SubjectDigest } +// Interceptors returns the client interceptors. +func (c *SubjectDigestClient) Interceptors() []Interceptor { + return c.inters.SubjectDigest +} + +func (c *SubjectDigestClient) mutate(ctx context.Context, m *SubjectDigestMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SubjectDigestCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SubjectDigestUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SubjectDigestUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SubjectDigestDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SubjectDigest mutation op: %q", m.Op()) + } +} + // TimestampClient is a client for the Timestamp schema. type TimestampClient struct { config @@ -1157,6 +1588,12 @@ func (c *TimestampClient) Use(hooks ...Hook) { c.hooks.Timestamp = append(c.hooks.Timestamp, hooks...) } +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `timestamp.Intercept(f(g(h())))`. +func (c *TimestampClient) Intercept(interceptors ...Interceptor) { + c.inters.Timestamp = append(c.inters.Timestamp, interceptors...) +} + // Create returns a builder for creating a Timestamp entity. func (c *TimestampClient) Create() *TimestampCreate { mutation := newTimestampMutation(c.config, OpCreate) @@ -1168,6 +1605,21 @@ func (c *TimestampClient) CreateBulk(builders ...*TimestampCreate) *TimestampCre return &TimestampCreateBulk{config: c.config, builders: builders} } +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *TimestampClient) MapCreateBulk(slice any, setFunc func(*TimestampCreate, int)) *TimestampCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &TimestampCreateBulk{err: fmt.Errorf("calling to TimestampClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*TimestampCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &TimestampCreateBulk{config: c.config, builders: builders} +} + // Update returns an update builder for Timestamp. func (c *TimestampClient) Update() *TimestampUpdate { mutation := newTimestampMutation(c.config, OpUpdate) @@ -1197,7 +1649,7 @@ func (c *TimestampClient) DeleteOne(t *Timestamp) *TimestampDeleteOne { return c.DeleteOneID(t.ID) } -// DeleteOne returns a builder for deleting the given entity by its id. +// DeleteOneID returns a builder for deleting the given entity by its id. func (c *TimestampClient) DeleteOneID(id int) *TimestampDeleteOne { builder := c.Delete().Where(timestamp.ID(id)) builder.mutation.id = &id @@ -1209,6 +1661,8 @@ func (c *TimestampClient) DeleteOneID(id int) *TimestampDeleteOne { func (c *TimestampClient) Query() *TimestampQuery { return &TimestampQuery{ config: c.config, + ctx: &QueryContext{Type: TypeTimestamp}, + inters: c.Interceptors(), } } @@ -1228,8 +1682,8 @@ func (c *TimestampClient) GetX(ctx context.Context, id int) *Timestamp { // QuerySignature queries the signature edge of a Timestamp. func (c *TimestampClient) QuerySignature(t *Timestamp) *SignatureQuery { - query := &SignatureQuery{config: c.config} - query.path = func(ctx context.Context) (fromV *sql.Selector, _ error) { + query := (&SignatureClient{config: c.config}).Query() + query.path = func(context.Context) (fromV *sql.Selector, _ error) { id := t.ID step := sqlgraph.NewStep( sqlgraph.From(timestamp.Table, timestamp.FieldID, id), @@ -1246,3 +1700,35 @@ func (c *TimestampClient) QuerySignature(t *Timestamp) *SignatureQuery { func (c *TimestampClient) Hooks() []Hook { return c.hooks.Timestamp } + +// Interceptors returns the client interceptors. +func (c *TimestampClient) Interceptors() []Interceptor { + return c.inters.Timestamp +} + +func (c *TimestampClient) mutate(ctx context.Context, m *TimestampMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&TimestampCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&TimestampUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&TimestampUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&TimestampDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown Timestamp mutation op: %q", m.Op()) + } +} + +// hooks and interceptors per client, for fast access. +type ( + hooks struct { + Attestation, AttestationCollection, Dsse, PayloadDigest, Signature, Statement, + Subject, SubjectDigest, Timestamp []ent.Hook + } + inters struct { + Attestation, AttestationCollection, Dsse, PayloadDigest, Signature, Statement, + Subject, SubjectDigest, Timestamp []ent.Interceptor + } +) diff --git a/ent/config.go b/ent/config.go deleted file mode 100644 index f930b7b6..00000000 --- a/ent/config.go +++ /dev/null @@ -1,67 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "entgo.io/ent" - "entgo.io/ent/dialect" -) - -// Option function to configure the client. -type Option func(*config) - -// Config is the configuration for the client and its builder. -type config struct { - // driver used for executing database requests. - driver dialect.Driver - // debug enable a debug logging. - debug bool - // log used for logging on debug mode. - log func(...any) - // hooks to execute on mutations. - hooks *hooks -} - -// hooks per client, for fast access. -type hooks struct { - Attestation []ent.Hook - AttestationCollection []ent.Hook - Dsse []ent.Hook - PayloadDigest []ent.Hook - Signature []ent.Hook - Statement []ent.Hook - Subject []ent.Hook - SubjectDigest []ent.Hook - Timestamp []ent.Hook -} - -// Options applies the options on the config object. -func (c *config) options(opts ...Option) { - for _, opt := range opts { - opt(c) - } - if c.debug { - c.driver = dialect.Debug(c.driver, c.log) - } -} - -// Debug enables debug logging on the ent.Driver. -func Debug() Option { - return func(c *config) { - c.debug = true - } -} - -// Log sets the logging function for debug mode. -func Log(fn func(...any)) Option { - return func(c *config) { - c.log = fn - } -} - -// Driver configures the client driver. -func Driver(driver dialect.Driver) Option { - return func(c *config) { - c.driver = driver - } -} diff --git a/ent/context.go b/ent/context.go deleted file mode 100644 index 7811bfa2..00000000 --- a/ent/context.go +++ /dev/null @@ -1,33 +0,0 @@ -// Code generated by ent, DO NOT EDIT. - -package ent - -import ( - "context" -) - -type clientCtxKey struct{} - -// FromContext returns a Client stored inside a context, or nil if there isn't one. -func FromContext(ctx context.Context) *Client { - c, _ := ctx.Value(clientCtxKey{}).(*Client) - return c -} - -// NewContext returns a new context with the given Client attached. -func NewContext(parent context.Context, c *Client) context.Context { - return context.WithValue(parent, clientCtxKey{}, c) -} - -type txCtxKey struct{} - -// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. -func TxFromContext(ctx context.Context) *Tx { - tx, _ := ctx.Value(txCtxKey{}).(*Tx) - return tx -} - -// NewTxContext returns a new context with the given Tx attached. -func NewTxContext(parent context.Context, tx *Tx) context.Context { - return context.WithValue(parent, txCtxKey{}, tx) -} diff --git a/ent/dsse.go b/ent/dsse.go index aaaf602e..4ffc38c3 100644 --- a/ent/dsse.go +++ b/ent/dsse.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/dsse" "github.com/testifysec/archivista/ent/statement" @@ -24,6 +25,7 @@ type Dsse struct { // The values are being populated by the DsseQuery when eager-loading is set. Edges DsseEdges `json:"edges"` dsse_statement *int + selectValues sql.SelectValues } // DsseEdges holds the relations/edges for other nodes in the graph. @@ -38,7 +40,10 @@ type DsseEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [3]bool // totalCount holds the count of the edges above. - totalCount [3]*int + totalCount [3]map[string]int + + namedSignatures map[string][]*Signature + namedPayloadDigests map[string][]*PayloadDigest } // StatementOrErr returns the Statement value or an error if the edge @@ -84,7 +89,7 @@ func (*Dsse) scanValues(columns []string) ([]any, error) { case dsse.ForeignKeys[0]: // dsse_statement values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Dsse", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -123,31 +128,39 @@ func (d *Dsse) assignValues(columns []string, values []any) error { d.dsse_statement = new(int) *d.dsse_statement = int(value.Int64) } + default: + d.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Dsse. +// This includes values selected through modifiers, order, etc. +func (d *Dsse) Value(name string) (ent.Value, error) { + return d.selectValues.Get(name) +} + // QueryStatement queries the "statement" edge of the Dsse entity. func (d *Dsse) QueryStatement() *StatementQuery { - return (&DsseClient{config: d.config}).QueryStatement(d) + return NewDsseClient(d.config).QueryStatement(d) } // QuerySignatures queries the "signatures" edge of the Dsse entity. func (d *Dsse) QuerySignatures() *SignatureQuery { - return (&DsseClient{config: d.config}).QuerySignatures(d) + return NewDsseClient(d.config).QuerySignatures(d) } // QueryPayloadDigests queries the "payload_digests" edge of the Dsse entity. func (d *Dsse) QueryPayloadDigests() *PayloadDigestQuery { - return (&DsseClient{config: d.config}).QueryPayloadDigests(d) + return NewDsseClient(d.config).QueryPayloadDigests(d) } // Update returns a builder for updating this Dsse. // Note that you need to call Dsse.Unwrap() before calling this method if this Dsse // was returned from a transaction, and the transaction was committed or rolled back. func (d *Dsse) Update() *DsseUpdateOne { - return (&DsseClient{config: d.config}).UpdateOne(d) + return NewDsseClient(d.config).UpdateOne(d) } // Unwrap unwraps the Dsse entity that was returned from a transaction after it was closed, @@ -175,11 +188,53 @@ func (d *Dsse) String() string { return builder.String() } -// Dsses is a parsable slice of Dsse. -type Dsses []*Dsse +// NamedSignatures returns the Signatures named value or an error if the edge was not +// loaded in eager-loading with this name. +func (d *Dsse) NamedSignatures(name string) ([]*Signature, error) { + if d.Edges.namedSignatures == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := d.Edges.namedSignatures[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} -func (d Dsses) config(cfg config) { - for _i := range d { - d[_i].config = cfg +func (d *Dsse) appendNamedSignatures(name string, edges ...*Signature) { + if d.Edges.namedSignatures == nil { + d.Edges.namedSignatures = make(map[string][]*Signature) + } + if len(edges) == 0 { + d.Edges.namedSignatures[name] = []*Signature{} + } else { + d.Edges.namedSignatures[name] = append(d.Edges.namedSignatures[name], edges...) } } + +// NamedPayloadDigests returns the PayloadDigests named value or an error if the edge was not +// loaded in eager-loading with this name. +func (d *Dsse) NamedPayloadDigests(name string) ([]*PayloadDigest, error) { + if d.Edges.namedPayloadDigests == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := d.Edges.namedPayloadDigests[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} + +func (d *Dsse) appendNamedPayloadDigests(name string, edges ...*PayloadDigest) { + if d.Edges.namedPayloadDigests == nil { + d.Edges.namedPayloadDigests = make(map[string][]*PayloadDigest) + } + if len(edges) == 0 { + d.Edges.namedPayloadDigests[name] = []*PayloadDigest{} + } else { + d.Edges.namedPayloadDigests[name] = append(d.Edges.namedPayloadDigests[name], edges...) + } +} + +// Dsses is a parsable slice of Dsse. +type Dsses []*Dsse diff --git a/ent/dsse/dsse.go b/ent/dsse/dsse.go index cc047354..41c063ce 100644 --- a/ent/dsse/dsse.go +++ b/ent/dsse/dsse.go @@ -2,6 +2,11 @@ package dsse +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the dsse type in the database. Label = "dsse" @@ -76,3 +81,77 @@ var ( // PayloadTypeValidator is a validator for the "payload_type" field. It is called by the builders before save. PayloadTypeValidator func(string) error ) + +// OrderOption defines the ordering options for the Dsse queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByGitoidSha256 orders the results by the gitoid_sha256 field. +func ByGitoidSha256(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldGitoidSha256, opts...).ToFunc() +} + +// ByPayloadType orders the results by the payload_type field. +func ByPayloadType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPayloadType, opts...).ToFunc() +} + +// ByStatementField orders the results by statement field. +func ByStatementField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStatementStep(), sql.OrderByField(field, opts...)) + } +} + +// BySignaturesCount orders the results by signatures count. +func BySignaturesCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSignaturesStep(), opts...) + } +} + +// BySignatures orders the results by signatures terms. +func BySignatures(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSignaturesStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByPayloadDigestsCount orders the results by payload_digests count. +func ByPayloadDigestsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newPayloadDigestsStep(), opts...) + } +} + +// ByPayloadDigests orders the results by payload_digests terms. +func ByPayloadDigests(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newPayloadDigestsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newStatementStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StatementInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, false, StatementTable, StatementColumn), + ) +} +func newSignaturesStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SignaturesInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SignaturesTable, SignaturesColumn), + ) +} +func newPayloadDigestsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(PayloadDigestsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, PayloadDigestsTable, PayloadDigestsColumn), + ) +} diff --git a/ent/dsse/where.go b/ent/dsse/where.go index 6c9ec1a7..175987bc 100644 --- a/ent/dsse/where.go +++ b/ent/dsse/where.go @@ -10,285 +10,187 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Dsse(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Dsse(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Dsse(sql.FieldLTE(FieldID, id)) } // GitoidSha256 applies equality check predicate on the "gitoid_sha256" field. It's identical to GitoidSha256EQ. func GitoidSha256(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldEQ(FieldGitoidSha256, v)) } // PayloadType applies equality check predicate on the "payload_type" field. It's identical to PayloadTypeEQ. func PayloadType(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldEQ(FieldPayloadType, v)) } // GitoidSha256EQ applies the EQ predicate on the "gitoid_sha256" field. func GitoidSha256EQ(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldEQ(FieldGitoidSha256, v)) } // GitoidSha256NEQ applies the NEQ predicate on the "gitoid_sha256" field. func GitoidSha256NEQ(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldNEQ(FieldGitoidSha256, v)) } // GitoidSha256In applies the In predicate on the "gitoid_sha256" field. func GitoidSha256In(vs ...string) predicate.Dsse { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldGitoidSha256), v...)) - }) + return predicate.Dsse(sql.FieldIn(FieldGitoidSha256, vs...)) } // GitoidSha256NotIn applies the NotIn predicate on the "gitoid_sha256" field. func GitoidSha256NotIn(vs ...string) predicate.Dsse { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldGitoidSha256), v...)) - }) + return predicate.Dsse(sql.FieldNotIn(FieldGitoidSha256, vs...)) } // GitoidSha256GT applies the GT predicate on the "gitoid_sha256" field. func GitoidSha256GT(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldGT(FieldGitoidSha256, v)) } // GitoidSha256GTE applies the GTE predicate on the "gitoid_sha256" field. func GitoidSha256GTE(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldGTE(FieldGitoidSha256, v)) } // GitoidSha256LT applies the LT predicate on the "gitoid_sha256" field. func GitoidSha256LT(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldLT(FieldGitoidSha256, v)) } // GitoidSha256LTE applies the LTE predicate on the "gitoid_sha256" field. func GitoidSha256LTE(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldLTE(FieldGitoidSha256, v)) } // GitoidSha256Contains applies the Contains predicate on the "gitoid_sha256" field. func GitoidSha256Contains(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldContains(FieldGitoidSha256, v)) } // GitoidSha256HasPrefix applies the HasPrefix predicate on the "gitoid_sha256" field. func GitoidSha256HasPrefix(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldHasPrefix(FieldGitoidSha256, v)) } // GitoidSha256HasSuffix applies the HasSuffix predicate on the "gitoid_sha256" field. func GitoidSha256HasSuffix(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldHasSuffix(FieldGitoidSha256, v)) } // GitoidSha256EqualFold applies the EqualFold predicate on the "gitoid_sha256" field. func GitoidSha256EqualFold(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldEqualFold(FieldGitoidSha256, v)) } // GitoidSha256ContainsFold applies the ContainsFold predicate on the "gitoid_sha256" field. func GitoidSha256ContainsFold(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldGitoidSha256), v)) - }) + return predicate.Dsse(sql.FieldContainsFold(FieldGitoidSha256, v)) } // PayloadTypeEQ applies the EQ predicate on the "payload_type" field. func PayloadTypeEQ(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldEQ(FieldPayloadType, v)) } // PayloadTypeNEQ applies the NEQ predicate on the "payload_type" field. func PayloadTypeNEQ(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldNEQ(FieldPayloadType, v)) } // PayloadTypeIn applies the In predicate on the "payload_type" field. func PayloadTypeIn(vs ...string) predicate.Dsse { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldPayloadType), v...)) - }) + return predicate.Dsse(sql.FieldIn(FieldPayloadType, vs...)) } // PayloadTypeNotIn applies the NotIn predicate on the "payload_type" field. func PayloadTypeNotIn(vs ...string) predicate.Dsse { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldPayloadType), v...)) - }) + return predicate.Dsse(sql.FieldNotIn(FieldPayloadType, vs...)) } // PayloadTypeGT applies the GT predicate on the "payload_type" field. func PayloadTypeGT(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldGT(FieldPayloadType, v)) } // PayloadTypeGTE applies the GTE predicate on the "payload_type" field. func PayloadTypeGTE(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldGTE(FieldPayloadType, v)) } // PayloadTypeLT applies the LT predicate on the "payload_type" field. func PayloadTypeLT(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldLT(FieldPayloadType, v)) } // PayloadTypeLTE applies the LTE predicate on the "payload_type" field. func PayloadTypeLTE(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldLTE(FieldPayloadType, v)) } // PayloadTypeContains applies the Contains predicate on the "payload_type" field. func PayloadTypeContains(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldContains(FieldPayloadType, v)) } // PayloadTypeHasPrefix applies the HasPrefix predicate on the "payload_type" field. func PayloadTypeHasPrefix(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldHasPrefix(FieldPayloadType, v)) } // PayloadTypeHasSuffix applies the HasSuffix predicate on the "payload_type" field. func PayloadTypeHasSuffix(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldHasSuffix(FieldPayloadType, v)) } // PayloadTypeEqualFold applies the EqualFold predicate on the "payload_type" field. func PayloadTypeEqualFold(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldEqualFold(FieldPayloadType, v)) } // PayloadTypeContainsFold applies the ContainsFold predicate on the "payload_type" field. func PayloadTypeContainsFold(v string) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldPayloadType), v)) - }) + return predicate.Dsse(sql.FieldContainsFold(FieldPayloadType, v)) } // HasStatement applies the HasEdge predicate on the "statement" edge. @@ -296,7 +198,6 @@ func HasStatement() predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, false, StatementTable, StatementColumn), ) sqlgraph.HasNeighbors(s, step) @@ -306,11 +207,7 @@ func HasStatement() predicate.Dsse { // HasStatementWith applies the HasEdge predicate on the "statement" edge with a given conditions (other predicates). func HasStatementWith(preds ...predicate.Statement) predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, false, StatementTable, StatementColumn), - ) + step := newStatementStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -324,7 +221,6 @@ func HasSignatures() predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(SignaturesTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, SignaturesTable, SignaturesColumn), ) sqlgraph.HasNeighbors(s, step) @@ -334,11 +230,7 @@ func HasSignatures() predicate.Dsse { // HasSignaturesWith applies the HasEdge predicate on the "signatures" edge with a given conditions (other predicates). func HasSignaturesWith(preds ...predicate.Signature) predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(SignaturesInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, SignaturesTable, SignaturesColumn), - ) + step := newSignaturesStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -352,7 +244,6 @@ func HasPayloadDigests() predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(PayloadDigestsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, PayloadDigestsTable, PayloadDigestsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -362,11 +253,7 @@ func HasPayloadDigests() predicate.Dsse { // HasPayloadDigestsWith applies the HasEdge predicate on the "payload_digests" edge with a given conditions (other predicates). func HasPayloadDigestsWith(preds ...predicate.PayloadDigest) predicate.Dsse { return predicate.Dsse(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(PayloadDigestsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, PayloadDigestsTable, PayloadDigestsColumn), - ) + step := newPayloadDigestsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -377,32 +264,15 @@ func HasPayloadDigestsWith(preds ...predicate.PayloadDigest) predicate.Dsse { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Dsse) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Dsse(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Dsse) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Dsse(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Dsse) predicate.Dsse { - return predicate.Dsse(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Dsse(sql.NotPredicates(p)) } diff --git a/ent/dsse_create.go b/ent/dsse_create.go index 41af789b..6f1fee71 100644 --- a/ent/dsse_create.go +++ b/ent/dsse_create.go @@ -90,49 +90,7 @@ func (dc *DsseCreate) Mutation() *DsseMutation { // Save creates the Dsse in the database. func (dc *DsseCreate) Save(ctx context.Context) (*Dsse, error) { - var ( - err error - node *Dsse - ) - if len(dc.hooks) == 0 { - if err = dc.check(); err != nil { - return nil, err - } - node, err = dc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DsseMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = dc.check(); err != nil { - return nil, err - } - dc.mutation = mutation - if node, err = dc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(dc.hooks) - 1; i >= 0; i-- { - if dc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, dc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Dsse) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DsseMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, dc.sqlSave, dc.mutation, dc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -179,6 +137,9 @@ func (dc *DsseCreate) check() error { } func (dc *DsseCreate) sqlSave(ctx context.Context) (*Dsse, error) { + if err := dc.check(); err != nil { + return nil, err + } _node, _spec := dc.createSpec() if err := sqlgraph.CreateNode(ctx, dc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -188,34 +149,22 @@ func (dc *DsseCreate) sqlSave(ctx context.Context) (*Dsse, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + dc.mutation.id = &_node.ID + dc.mutation.done = true return _node, nil } func (dc *DsseCreate) createSpec() (*Dsse, *sqlgraph.CreateSpec) { var ( _node = &Dsse{config: dc.config} - _spec = &sqlgraph.CreateSpec{ - Table: dsse.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(dsse.Table, sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt)) ) if value, ok := dc.mutation.GitoidSha256(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldGitoidSha256, - }) + _spec.SetField(dsse.FieldGitoidSha256, field.TypeString, value) _node.GitoidSha256 = value } if value, ok := dc.mutation.PayloadType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldPayloadType, - }) + _spec.SetField(dsse.FieldPayloadType, field.TypeString, value) _node.PayloadType = value } if nodes := dc.mutation.StatementIDs(); len(nodes) > 0 { @@ -226,10 +175,7 @@ func (dc *DsseCreate) createSpec() (*Dsse, *sqlgraph.CreateSpec) { Columns: []string{dsse.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -246,10 +192,7 @@ func (dc *DsseCreate) createSpec() (*Dsse, *sqlgraph.CreateSpec) { Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -265,10 +208,7 @@ func (dc *DsseCreate) createSpec() (*Dsse, *sqlgraph.CreateSpec) { Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -282,11 +222,15 @@ func (dc *DsseCreate) createSpec() (*Dsse, *sqlgraph.CreateSpec) { // DsseCreateBulk is the builder for creating many Dsse entities in bulk. type DsseCreateBulk struct { config + err error builders []*DsseCreate } // Save creates the Dsse entities in the database. func (dcb *DsseCreateBulk) Save(ctx context.Context) ([]*Dsse, error) { + if dcb.err != nil { + return nil, dcb.err + } specs := make([]*sqlgraph.CreateSpec, len(dcb.builders)) nodes := make([]*Dsse, len(dcb.builders)) mutators := make([]Mutator, len(dcb.builders)) @@ -302,8 +246,8 @@ func (dcb *DsseCreateBulk) Save(ctx context.Context) ([]*Dsse, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, dcb.builders[i+1].mutation) } else { diff --git a/ent/dsse_delete.go b/ent/dsse_delete.go index ecbb011f..37dd5807 100644 --- a/ent/dsse_delete.go +++ b/ent/dsse_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (dd *DsseDelete) Where(ps ...predicate.Dsse) *DsseDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (dd *DsseDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(dd.hooks) == 0 { - affected, err = dd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DsseMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - dd.mutation = mutation - affected, err = dd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(dd.hooks) - 1; i >= 0; i-- { - if dd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = dd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, dd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, dd.sqlExec, dd.mutation, dd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (dd *DsseDelete) ExecX(ctx context.Context) int { } func (dd *DsseDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: dsse.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(dsse.Table, sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt)) if ps := dd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (dd *DsseDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + dd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type DsseDeleteOne struct { dd *DsseDelete } +// Where appends a list predicates to the DsseDelete builder. +func (ddo *DsseDeleteOne) Where(ps ...predicate.Dsse) *DsseDeleteOne { + ddo.dd.mutation.Where(ps...) + return ddo +} + // Exec executes the deletion query. func (ddo *DsseDeleteOne) Exec(ctx context.Context) error { n, err := ddo.dd.Exec(ctx) @@ -111,5 +82,7 @@ func (ddo *DsseDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (ddo *DsseDeleteOne) ExecX(ctx context.Context) { - ddo.dd.ExecX(ctx) + if err := ddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/dsse_query.go b/ent/dsse_query.go index 4f6a3705..ab4d9081 100644 --- a/ent/dsse_query.go +++ b/ent/dsse_query.go @@ -21,18 +21,18 @@ import ( // DsseQuery is the builder for querying Dsse entities. type DsseQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string - predicates []predicate.Dsse - withStatement *StatementQuery - withSignatures *SignatureQuery - withPayloadDigests *PayloadDigestQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*Dsse) error + ctx *QueryContext + order []dsse.OrderOption + inters []Interceptor + predicates []predicate.Dsse + withStatement *StatementQuery + withSignatures *SignatureQuery + withPayloadDigests *PayloadDigestQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*Dsse) error + withNamedSignatures map[string]*SignatureQuery + withNamedPayloadDigests map[string]*PayloadDigestQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -44,34 +44,34 @@ func (dq *DsseQuery) Where(ps ...predicate.Dsse) *DsseQuery { return dq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (dq *DsseQuery) Limit(limit int) *DsseQuery { - dq.limit = &limit + dq.ctx.Limit = &limit return dq } -// Offset adds an offset step to the query. +// Offset to start from. func (dq *DsseQuery) Offset(offset int) *DsseQuery { - dq.offset = &offset + dq.ctx.Offset = &offset return dq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (dq *DsseQuery) Unique(unique bool) *DsseQuery { - dq.unique = &unique + dq.ctx.Unique = &unique return dq } -// Order adds an order step to the query. -func (dq *DsseQuery) Order(o ...OrderFunc) *DsseQuery { +// Order specifies how the records should be ordered. +func (dq *DsseQuery) Order(o ...dsse.OrderOption) *DsseQuery { dq.order = append(dq.order, o...) return dq } // QueryStatement chains the current query on the "statement" edge. func (dq *DsseQuery) QueryStatement() *StatementQuery { - query := &StatementQuery{config: dq.config} + query := (&StatementClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -93,7 +93,7 @@ func (dq *DsseQuery) QueryStatement() *StatementQuery { // QuerySignatures chains the current query on the "signatures" edge. func (dq *DsseQuery) QuerySignatures() *SignatureQuery { - query := &SignatureQuery{config: dq.config} + query := (&SignatureClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -115,7 +115,7 @@ func (dq *DsseQuery) QuerySignatures() *SignatureQuery { // QueryPayloadDigests chains the current query on the "payload_digests" edge. func (dq *DsseQuery) QueryPayloadDigests() *PayloadDigestQuery { - query := &PayloadDigestQuery{config: dq.config} + query := (&PayloadDigestClient{config: dq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := dq.prepareQuery(ctx); err != nil { return nil, err @@ -138,7 +138,7 @@ func (dq *DsseQuery) QueryPayloadDigests() *PayloadDigestQuery { // First returns the first Dsse entity from the query. // Returns a *NotFoundError when no Dsse was found. func (dq *DsseQuery) First(ctx context.Context) (*Dsse, error) { - nodes, err := dq.Limit(1).All(ctx) + nodes, err := dq.Limit(1).All(setContextOp(ctx, dq.ctx, "First")) if err != nil { return nil, err } @@ -161,7 +161,7 @@ func (dq *DsseQuery) FirstX(ctx context.Context) *Dsse { // Returns a *NotFoundError when no Dsse ID was found. func (dq *DsseQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(1).IDs(ctx); err != nil { + if ids, err = dq.Limit(1).IDs(setContextOp(ctx, dq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -184,7 +184,7 @@ func (dq *DsseQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Dsse entity is found. // Returns a *NotFoundError when no Dsse entities are found. func (dq *DsseQuery) Only(ctx context.Context) (*Dsse, error) { - nodes, err := dq.Limit(2).All(ctx) + nodes, err := dq.Limit(2).All(setContextOp(ctx, dq.ctx, "Only")) if err != nil { return nil, err } @@ -212,7 +212,7 @@ func (dq *DsseQuery) OnlyX(ctx context.Context) *Dsse { // Returns a *NotFoundError when no entities are found. func (dq *DsseQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = dq.Limit(2).IDs(ctx); err != nil { + if ids, err = dq.Limit(2).IDs(setContextOp(ctx, dq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -237,10 +237,12 @@ func (dq *DsseQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Dsses. func (dq *DsseQuery) All(ctx context.Context) ([]*Dsse, error) { + ctx = setContextOp(ctx, dq.ctx, "All") if err := dq.prepareQuery(ctx); err != nil { return nil, err } - return dq.sqlAll(ctx) + qr := querierAll[[]*Dsse, *DsseQuery]() + return withInterceptors[[]*Dsse](ctx, dq, qr, dq.inters) } // AllX is like All, but panics if an error occurs. @@ -253,9 +255,12 @@ func (dq *DsseQuery) AllX(ctx context.Context) []*Dsse { } // IDs executes the query and returns a list of Dsse IDs. -func (dq *DsseQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := dq.Select(dsse.FieldID).Scan(ctx, &ids); err != nil { +func (dq *DsseQuery) IDs(ctx context.Context) (ids []int, err error) { + if dq.ctx.Unique == nil && dq.path != nil { + dq.Unique(true) + } + ctx = setContextOp(ctx, dq.ctx, "IDs") + if err = dq.Select(dsse.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -272,10 +277,11 @@ func (dq *DsseQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (dq *DsseQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, dq.ctx, "Count") if err := dq.prepareQuery(ctx); err != nil { return 0, err } - return dq.sqlCount(ctx) + return withInterceptors[int](ctx, dq, querierCount[*DsseQuery](), dq.inters) } // CountX is like Count, but panics if an error occurs. @@ -289,10 +295,15 @@ func (dq *DsseQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (dq *DsseQuery) Exist(ctx context.Context) (bool, error) { - if err := dq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, dq.ctx, "Exist") + switch _, err := dq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return dq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -312,24 +323,23 @@ func (dq *DsseQuery) Clone() *DsseQuery { } return &DsseQuery{ config: dq.config, - limit: dq.limit, - offset: dq.offset, - order: append([]OrderFunc{}, dq.order...), + ctx: dq.ctx.Clone(), + order: append([]dsse.OrderOption{}, dq.order...), + inters: append([]Interceptor{}, dq.inters...), predicates: append([]predicate.Dsse{}, dq.predicates...), withStatement: dq.withStatement.Clone(), withSignatures: dq.withSignatures.Clone(), withPayloadDigests: dq.withPayloadDigests.Clone(), // clone intermediate query. - sql: dq.sql.Clone(), - path: dq.path, - unique: dq.unique, + sql: dq.sql.Clone(), + path: dq.path, } } // WithStatement tells the query-builder to eager-load the nodes that are connected to // the "statement" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DsseQuery) WithStatement(opts ...func(*StatementQuery)) *DsseQuery { - query := &StatementQuery{config: dq.config} + query := (&StatementClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -340,7 +350,7 @@ func (dq *DsseQuery) WithStatement(opts ...func(*StatementQuery)) *DsseQuery { // WithSignatures tells the query-builder to eager-load the nodes that are connected to // the "signatures" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DsseQuery) WithSignatures(opts ...func(*SignatureQuery)) *DsseQuery { - query := &SignatureQuery{config: dq.config} + query := (&SignatureClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -351,7 +361,7 @@ func (dq *DsseQuery) WithSignatures(opts ...func(*SignatureQuery)) *DsseQuery { // WithPayloadDigests tells the query-builder to eager-load the nodes that are connected to // the "payload_digests" edge. The optional arguments are used to configure the query builder of the edge. func (dq *DsseQuery) WithPayloadDigests(opts ...func(*PayloadDigestQuery)) *DsseQuery { - query := &PayloadDigestQuery{config: dq.config} + query := (&PayloadDigestClient{config: dq.config}).Query() for _, opt := range opts { opt(query) } @@ -374,16 +384,11 @@ func (dq *DsseQuery) WithPayloadDigests(opts ...func(*PayloadDigestQuery)) *Dsse // Aggregate(ent.Count()). // Scan(ctx, &v) func (dq *DsseQuery) GroupBy(field string, fields ...string) *DsseGroupBy { - grbuild := &DsseGroupBy{config: dq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := dq.prepareQuery(ctx); err != nil { - return nil, err - } - return dq.sqlQuery(ctx), nil - } + dq.ctx.Fields = append([]string{field}, fields...) + grbuild := &DsseGroupBy{build: dq} + grbuild.flds = &dq.ctx.Fields grbuild.label = dsse.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -400,15 +405,30 @@ func (dq *DsseQuery) GroupBy(field string, fields ...string) *DsseGroupBy { // Select(dsse.FieldGitoidSha256). // Scan(ctx, &v) func (dq *DsseQuery) Select(fields ...string) *DsseSelect { - dq.fields = append(dq.fields, fields...) - selbuild := &DsseSelect{DsseQuery: dq} - selbuild.label = dsse.Label - selbuild.flds, selbuild.scan = &dq.fields, selbuild.Scan - return selbuild + dq.ctx.Fields = append(dq.ctx.Fields, fields...) + sbuild := &DsseSelect{DsseQuery: dq} + sbuild.label = dsse.Label + sbuild.flds, sbuild.scan = &dq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a DsseSelect configured with the given aggregations. +func (dq *DsseQuery) Aggregate(fns ...AggregateFunc) *DsseSelect { + return dq.Select().Aggregate(fns...) } func (dq *DsseQuery) prepareQuery(ctx context.Context) error { - for _, f := range dq.fields { + for _, inter := range dq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, dq); err != nil { + return err + } + } + } + for _, f := range dq.ctx.Fields { if !dsse.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -481,6 +501,20 @@ func (dq *DsseQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Dsse, e return nil, err } } + for name, query := range dq.withNamedSignatures { + if err := dq.loadSignatures(ctx, query, nodes, + func(n *Dsse) { n.appendNamedSignatures(name) }, + func(n *Dsse, e *Signature) { n.appendNamedSignatures(name, e) }); err != nil { + return nil, err + } + } + for name, query := range dq.withNamedPayloadDigests { + if err := dq.loadPayloadDigests(ctx, query, nodes, + func(n *Dsse) { n.appendNamedPayloadDigests(name) }, + func(n *Dsse, e *PayloadDigest) { n.appendNamedPayloadDigests(name, e) }); err != nil { + return nil, err + } + } for i := range dq.loadTotal { if err := dq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -502,6 +536,9 @@ func (dq *DsseQuery) loadStatement(ctx context.Context, query *StatementQuery, n } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(statement.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -530,7 +567,7 @@ func (dq *DsseQuery) loadSignatures(ctx context.Context, query *SignatureQuery, } query.withFKs = true query.Where(predicate.Signature(func(s *sql.Selector) { - s.Where(sql.InValues(dsse.SignaturesColumn, fks...)) + s.Where(sql.InValues(s.C(dsse.SignaturesColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -543,7 +580,7 @@ func (dq *DsseQuery) loadSignatures(ctx context.Context, query *SignatureQuery, } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "dsse_signatures" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "dsse_signatures" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -561,7 +598,7 @@ func (dq *DsseQuery) loadPayloadDigests(ctx context.Context, query *PayloadDiges } query.withFKs = true query.Where(predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.InValues(dsse.PayloadDigestsColumn, fks...)) + s.Where(sql.InValues(s.C(dsse.PayloadDigestsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -574,7 +611,7 @@ func (dq *DsseQuery) loadPayloadDigests(ctx context.Context, query *PayloadDiges } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "dsse_payload_digests" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "dsse_payload_digests" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -586,41 +623,22 @@ func (dq *DsseQuery) sqlCount(ctx context.Context) (int, error) { if len(dq.modifiers) > 0 { _spec.Modifiers = dq.modifiers } - _spec.Node.Columns = dq.fields - if len(dq.fields) > 0 { - _spec.Unique = dq.unique != nil && *dq.unique + _spec.Node.Columns = dq.ctx.Fields + if len(dq.ctx.Fields) > 0 { + _spec.Unique = dq.ctx.Unique != nil && *dq.ctx.Unique } return sqlgraph.CountNodes(ctx, dq.driver, _spec) } -func (dq *DsseQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := dq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (dq *DsseQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: dsse.Table, - Columns: dsse.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, - }, - From: dq.sql, - Unique: true, - } - if unique := dq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(dsse.Table, dsse.Columns, sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt)) + _spec.From = dq.sql + if unique := dq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if dq.path != nil { + _spec.Unique = true } - if fields := dq.fields; len(fields) > 0 { + if fields := dq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, dsse.FieldID) for i := range fields { @@ -636,10 +654,10 @@ func (dq *DsseQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := dq.order; len(ps) > 0 { @@ -655,7 +673,7 @@ func (dq *DsseQuery) querySpec() *sqlgraph.QuerySpec { func (dq *DsseQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(dq.driver.Dialect()) t1 := builder.Table(dsse.Table) - columns := dq.fields + columns := dq.ctx.Fields if len(columns) == 0 { columns = dsse.Columns } @@ -664,7 +682,7 @@ func (dq *DsseQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = dq.sql selector.Select(selector.Columns(columns...)...) } - if dq.unique != nil && *dq.unique { + if dq.ctx.Unique != nil && *dq.ctx.Unique { selector.Distinct() } for _, p := range dq.predicates { @@ -673,26 +691,49 @@ func (dq *DsseQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range dq.order { p(selector) } - if offset := dq.offset; offset != nil { + if offset := dq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := dq.limit; limit != nil { + if limit := dq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } +// WithNamedSignatures tells the query-builder to eager-load the nodes that are connected to the "signatures" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (dq *DsseQuery) WithNamedSignatures(name string, opts ...func(*SignatureQuery)) *DsseQuery { + query := (&SignatureClient{config: dq.config}).Query() + for _, opt := range opts { + opt(query) + } + if dq.withNamedSignatures == nil { + dq.withNamedSignatures = make(map[string]*SignatureQuery) + } + dq.withNamedSignatures[name] = query + return dq +} + +// WithNamedPayloadDigests tells the query-builder to eager-load the nodes that are connected to the "payload_digests" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (dq *DsseQuery) WithNamedPayloadDigests(name string, opts ...func(*PayloadDigestQuery)) *DsseQuery { + query := (&PayloadDigestClient{config: dq.config}).Query() + for _, opt := range opts { + opt(query) + } + if dq.withNamedPayloadDigests == nil { + dq.withNamedPayloadDigests = make(map[string]*PayloadDigestQuery) + } + dq.withNamedPayloadDigests[name] = query + return dq +} + // DsseGroupBy is the group-by builder for Dsse entities. type DsseGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *DsseQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -701,74 +742,77 @@ func (dgb *DsseGroupBy) Aggregate(fns ...AggregateFunc) *DsseGroupBy { return dgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (dgb *DsseGroupBy) Scan(ctx context.Context, v any) error { - query, err := dgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, dgb.build.ctx, "GroupBy") + if err := dgb.build.prepareQuery(ctx); err != nil { return err } - dgb.sql = query - return dgb.sqlScan(ctx, v) + return scanWithInterceptors[*DsseQuery, *DsseGroupBy](ctx, dgb.build, dgb, dgb.build.inters, v) } -func (dgb *DsseGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range dgb.fields { - if !dsse.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (dgb *DsseGroupBy) sqlScan(ctx context.Context, root *DsseQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(dgb.fns)) + for _, fn := range dgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*dgb.flds)+len(dgb.fns)) + for _, f := range *dgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := dgb.sqlQuery() + selector.GroupBy(selector.Columns(*dgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := dgb.driver.Query(ctx, query, args, rows); err != nil { + if err := dgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (dgb *DsseGroupBy) sqlQuery() *sql.Selector { - selector := dgb.sql.Select() - aggregation := make([]string, 0, len(dgb.fns)) - for _, fn := range dgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(dgb.fields)+len(dgb.fns)) - for _, f := range dgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(dgb.fields...)...) -} - // DsseSelect is the builder for selecting fields of Dsse entities. type DsseSelect struct { *DsseQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ds *DsseSelect) Aggregate(fns ...AggregateFunc) *DsseSelect { + ds.fns = append(ds.fns, fns...) + return ds } // Scan applies the selector query and scans the result into the given value. func (ds *DsseSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ds.ctx, "Select") if err := ds.prepareQuery(ctx); err != nil { return err } - ds.sql = ds.DsseQuery.sqlQuery(ctx) - return ds.sqlScan(ctx, v) + return scanWithInterceptors[*DsseQuery, *DsseSelect](ctx, ds.DsseQuery, ds, ds.inters, v) } -func (ds *DsseSelect) sqlScan(ctx context.Context, v any) error { +func (ds *DsseSelect) sqlScan(ctx context.Context, root *DsseQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ds.fns)) + for _, fn := range ds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ds.sql.Query() + query, args := selector.Query() if err := ds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/dsse_update.go b/ent/dsse_update.go index 53775c9d..9527e7bf 100644 --- a/ent/dsse_update.go +++ b/ent/dsse_update.go @@ -146,40 +146,7 @@ func (du *DsseUpdate) RemovePayloadDigests(p ...*PayloadDigest) *DsseUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (du *DsseUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(du.hooks) == 0 { - if err = du.check(); err != nil { - return 0, err - } - affected, err = du.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DsseMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = du.check(); err != nil { - return 0, err - } - du.mutation = mutation - affected, err = du.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(du.hooks) - 1; i >= 0; i-- { - if du.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = du.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, du.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, du.sqlSave, du.mutation, du.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -220,16 +187,10 @@ func (du *DsseUpdate) check() error { } func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: dsse.Table, - Columns: dsse.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, - }, + if err := du.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(dsse.Table, dsse.Columns, sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt)) if ps := du.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -238,18 +199,10 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := du.mutation.GitoidSha256(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldGitoidSha256, - }) + _spec.SetField(dsse.FieldGitoidSha256, field.TypeString, value) } if value, ok := du.mutation.PayloadType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldPayloadType, - }) + _spec.SetField(dsse.FieldPayloadType, field.TypeString, value) } if du.mutation.StatementCleared() { edge := &sqlgraph.EdgeSpec{ @@ -259,10 +212,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -275,10 +225,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -294,10 +241,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -310,10 +254,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -329,10 +270,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -348,10 +286,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -364,10 +299,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -383,10 +315,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -402,6 +331,7 @@ func (du *DsseUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + du.mutation.done = true return n, nil } @@ -527,6 +457,12 @@ func (duo *DsseUpdateOne) RemovePayloadDigests(p ...*PayloadDigest) *DsseUpdateO return duo.RemovePayloadDigestIDs(ids...) } +// Where appends a list predicates to the DsseUpdate builder. +func (duo *DsseUpdateOne) Where(ps ...predicate.Dsse) *DsseUpdateOne { + duo.mutation.Where(ps...) + return duo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (duo *DsseUpdateOne) Select(field string, fields ...string) *DsseUpdateOne { @@ -536,46 +472,7 @@ func (duo *DsseUpdateOne) Select(field string, fields ...string) *DsseUpdateOne // Save executes the query and returns the updated Dsse entity. func (duo *DsseUpdateOne) Save(ctx context.Context) (*Dsse, error) { - var ( - err error - node *Dsse - ) - if len(duo.hooks) == 0 { - if err = duo.check(); err != nil { - return nil, err - } - node, err = duo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*DsseMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = duo.check(); err != nil { - return nil, err - } - duo.mutation = mutation - node, err = duo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(duo.hooks) - 1; i >= 0; i-- { - if duo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = duo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, duo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Dsse) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from DsseMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, duo.sqlSave, duo.mutation, duo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -616,16 +513,10 @@ func (duo *DsseUpdateOne) check() error { } func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: dsse.Table, - Columns: dsse.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, - }, + if err := duo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(dsse.Table, dsse.Columns, sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt)) id, ok := duo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Dsse.id" for update`)} @@ -651,18 +542,10 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) } } if value, ok := duo.mutation.GitoidSha256(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldGitoidSha256, - }) + _spec.SetField(dsse.FieldGitoidSha256, field.TypeString, value) } if value, ok := duo.mutation.PayloadType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: dsse.FieldPayloadType, - }) + _spec.SetField(dsse.FieldPayloadType, field.TypeString, value) } if duo.mutation.StatementCleared() { edge := &sqlgraph.EdgeSpec{ @@ -672,10 +555,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -688,10 +568,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -707,10 +584,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -723,10 +597,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -742,10 +613,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.SignaturesColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -761,10 +629,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -777,10 +642,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -796,10 +658,7 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) Columns: []string{dsse.PayloadDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -818,5 +677,6 @@ func (duo *DsseUpdateOne) sqlSave(ctx context.Context) (_node *Dsse, err error) } return nil, err } + duo.mutation.done = true return _node, nil } diff --git a/ent/ent.go b/ent/ent.go index 72d3ec2f..a972cec1 100644 --- a/ent/ent.go +++ b/ent/ent.go @@ -6,6 +6,8 @@ import ( "context" "errors" "fmt" + "reflect" + "sync" "entgo.io/ent" "entgo.io/ent/dialect/sql" @@ -23,52 +25,81 @@ import ( // ent aliases to avoid import conflicts in user's code. type ( - Op = ent.Op - Hook = ent.Hook - Value = ent.Value - Query = ent.Query - Policy = ent.Policy - Mutator = ent.Mutator - Mutation = ent.Mutation - MutateFunc = ent.MutateFunc + Op = ent.Op + Hook = ent.Hook + Value = ent.Value + Query = ent.Query + QueryContext = ent.QueryContext + Querier = ent.Querier + QuerierFunc = ent.QuerierFunc + Interceptor = ent.Interceptor + InterceptFunc = ent.InterceptFunc + Traverser = ent.Traverser + TraverseFunc = ent.TraverseFunc + Policy = ent.Policy + Mutator = ent.Mutator + Mutation = ent.Mutation + MutateFunc = ent.MutateFunc ) +type clientCtxKey struct{} + +// FromContext returns a Client stored inside a context, or nil if there isn't one. +func FromContext(ctx context.Context) *Client { + c, _ := ctx.Value(clientCtxKey{}).(*Client) + return c +} + +// NewContext returns a new context with the given Client attached. +func NewContext(parent context.Context, c *Client) context.Context { + return context.WithValue(parent, clientCtxKey{}, c) +} + +type txCtxKey struct{} + +// TxFromContext returns a Tx stored inside a context, or nil if there isn't one. +func TxFromContext(ctx context.Context) *Tx { + tx, _ := ctx.Value(txCtxKey{}).(*Tx) + return tx +} + +// NewTxContext returns a new context with the given Tx attached. +func NewTxContext(parent context.Context, tx *Tx) context.Context { + return context.WithValue(parent, txCtxKey{}, tx) +} + // OrderFunc applies an ordering on the sql selector. +// Deprecated: Use Asc/Desc functions or the package builders instead. type OrderFunc func(*sql.Selector) -// columnChecker returns a function indicates if the column exists in the given column. -func columnChecker(table string) func(string) error { - checks := map[string]func(string) bool{ - attestation.Table: attestation.ValidColumn, - attestationcollection.Table: attestationcollection.ValidColumn, - dsse.Table: dsse.ValidColumn, - payloaddigest.Table: payloaddigest.ValidColumn, - signature.Table: signature.ValidColumn, - statement.Table: statement.ValidColumn, - subject.Table: subject.ValidColumn, - subjectdigest.Table: subjectdigest.ValidColumn, - timestamp.Table: timestamp.ValidColumn, - } - check, ok := checks[table] - if !ok { - return func(string) error { - return fmt.Errorf("unknown table %q", table) - } - } - return func(column string) error { - if !check(column) { - return fmt.Errorf("unknown column %q for table %q", column, table) - } - return nil - } +var ( + initCheck sync.Once + columnCheck sql.ColumnCheck +) + +// columnChecker checks if the column exists in the given table. +func checkColumn(table, column string) error { + initCheck.Do(func() { + columnCheck = sql.NewColumnCheck(map[string]func(string) bool{ + attestation.Table: attestation.ValidColumn, + attestationcollection.Table: attestationcollection.ValidColumn, + dsse.Table: dsse.ValidColumn, + payloaddigest.Table: payloaddigest.ValidColumn, + signature.Table: signature.ValidColumn, + statement.Table: statement.ValidColumn, + subject.Table: subject.ValidColumn, + subjectdigest.Table: subjectdigest.ValidColumn, + timestamp.Table: timestamp.ValidColumn, + }) + }) + return columnCheck(table, column) } // Asc applies the given fields in ASC order. -func Asc(fields ...string) OrderFunc { +func Asc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Asc(s.C(f))) @@ -77,11 +108,10 @@ func Asc(fields ...string) OrderFunc { } // Desc applies the given fields in DESC order. -func Desc(fields ...string) OrderFunc { +func Desc(fields ...string) func(*sql.Selector) { return func(s *sql.Selector) { - check := columnChecker(s.TableName()) for _, f := range fields { - if err := check(f); err != nil { + if err := checkColumn(s.TableName(), f); err != nil { s.AddError(&ValidationError{Name: f, err: fmt.Errorf("ent: %w", err)}) } s.OrderBy(sql.Desc(s.C(f))) @@ -113,8 +143,7 @@ func Count() AggregateFunc { // Max applies the "max" aggregation function on the given field of each group. func Max(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -125,8 +154,7 @@ func Max(field string) AggregateFunc { // Mean applies the "mean" aggregation function on the given field of each group. func Mean(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -137,8 +165,7 @@ func Mean(field string) AggregateFunc { // Min applies the "min" aggregation function on the given field of each group. func Min(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -149,8 +176,7 @@ func Min(field string) AggregateFunc { // Sum applies the "sum" aggregation function on the given field of each group. func Sum(field string) AggregateFunc { return func(s *sql.Selector) string { - check := columnChecker(s.TableName()) - if err := check(field); err != nil { + if err := checkColumn(s.TableName(), field); err != nil { s.AddError(&ValidationError{Name: field, err: fmt.Errorf("ent: %w", err)}) return "" } @@ -279,6 +305,7 @@ func IsConstraintError(err error) bool { type selector struct { label string flds *[]string + fns []AggregateFunc scan func(context.Context, any) error } @@ -477,5 +504,121 @@ func (s *selector) BoolX(ctx context.Context) bool { return v } +// withHooks invokes the builder operation with the given hooks, if any. +func withHooks[V Value, M any, PM interface { + *M + Mutation +}](ctx context.Context, exec func(context.Context) (V, error), mutation PM, hooks []Hook) (value V, err error) { + if len(hooks) == 0 { + return exec(ctx) + } + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutationT, ok := any(m).(PM) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + // Set the mutation to the builder. + *mutation = *mutationT + return exec(ctx) + }) + for i := len(hooks) - 1; i >= 0; i-- { + if hooks[i] == nil { + return value, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") + } + mut = hooks[i](mut) + } + v, err := mut.Mutate(ctx, mutation) + if err != nil { + return value, err + } + nv, ok := v.(V) + if !ok { + return value, fmt.Errorf("unexpected node type %T returned from %T", v, mutation) + } + return nv, nil +} + +// setContextOp returns a new context with the given QueryContext attached (including its op) in case it does not exist. +func setContextOp(ctx context.Context, qc *QueryContext, op string) context.Context { + if ent.QueryFromContext(ctx) == nil { + qc.Op = op + ctx = ent.NewQueryContext(ctx, qc) + } + return ctx +} + +func querierAll[V Value, Q interface { + sqlAll(context.Context, ...queryHook) (V, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlAll(ctx) + }) +} + +func querierCount[Q interface { + sqlCount(context.Context) (int, error) +}]() Querier { + return QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + return query.sqlCount(ctx) + }) +} + +func withInterceptors[V Value](ctx context.Context, q Query, qr Querier, inters []Interceptor) (v V, err error) { + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + rv, err := qr.Query(ctx, q) + if err != nil { + return v, err + } + vt, ok := rv.(V) + if !ok { + return v, fmt.Errorf("unexpected type %T returned from %T. expected type: %T", vt, q, v) + } + return vt, nil +} + +func scanWithInterceptors[Q1 ent.Query, Q2 interface { + sqlScan(context.Context, Q1, any) error +}](ctx context.Context, rootQuery Q1, selectOrGroup Q2, inters []Interceptor, v any) error { + rv := reflect.ValueOf(v) + var qr Querier = QuerierFunc(func(ctx context.Context, q Query) (Value, error) { + query, ok := q.(Q1) + if !ok { + return nil, fmt.Errorf("unexpected query type %T", q) + } + if err := selectOrGroup.sqlScan(ctx, query, v); err != nil { + return nil, err + } + if k := rv.Kind(); k == reflect.Pointer && rv.Elem().CanInterface() { + return rv.Elem().Interface(), nil + } + return v, nil + }) + for i := len(inters) - 1; i >= 0; i-- { + qr = inters[i].Intercept(qr) + } + vv, err := qr.Query(ctx, rootQuery) + if err != nil { + return err + } + switch rv2 := reflect.ValueOf(vv); { + case rv.IsNil(), rv2.IsNil(), rv.Kind() != reflect.Pointer: + case rv.Type() == rv2.Type(): + rv.Elem().Set(rv2.Elem()) + case rv.Elem().Type() == rv2.Type(): + rv.Elem().Set(rv2) + } + return nil +} + // queryHook describes an internal hook for the different sqlAll methods. type queryHook func(context.Context, *sqlgraph.QuerySpec) diff --git a/ent/gql_collection.go b/ent/gql_collection.go index cd802364..4233340c 100644 --- a/ent/gql_collection.go +++ b/ent/gql_collection.go @@ -9,7 +9,15 @@ import ( "entgo.io/ent/dialect/sql" "github.com/99designs/gqlgen/graphql" + "github.com/testifysec/archivista/ent/attestation" + "github.com/testifysec/archivista/ent/attestationcollection" + "github.com/testifysec/archivista/ent/dsse" + "github.com/testifysec/archivista/ent/payloaddigest" + "github.com/testifysec/archivista/ent/signature" "github.com/testifysec/archivista/ent/statement" + "github.com/testifysec/archivista/ent/subject" + "github.com/testifysec/archivista/ent/subjectdigest" + "github.com/testifysec/archivista/ent/timestamp" ) // CollectFields tells the query-builder to eagerly load connected nodes by resolver context. @@ -24,21 +32,39 @@ func (a *AttestationQuery) CollectFields(ctx context.Context, satisfies ...strin return a, nil } -func (a *AttestationQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (a *AttestationQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(attestation.Columns)) + selectedFields = []string{attestation.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { - case "attestationCollection", "attestation_collection": + case "attestationCollection": var ( - path = append(path, field.Name) - query = &AttestationCollectionQuery{config: a.config} + alias = field.Alias + path = append(path, alias) + query = (&AttestationCollectionClient{config: a.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } a.withAttestationCollection = query + case "type": + if _, ok := fieldSeen[attestation.FieldType]; !ok { + selectedFields = append(selectedFields, attestation.FieldType) + fieldSeen[attestation.FieldType] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + a.Select(selectedFields...) + } return nil } @@ -48,7 +74,7 @@ type attestationPaginateArgs struct { opts []AttestationPaginateOption } -func newAttestationPaginateArgs(rv map[string]interface{}) *attestationPaginateArgs { +func newAttestationPaginateArgs(rv map[string]any) *attestationPaginateArgs { args := &attestationPaginateArgs{} if rv == nil { return args @@ -83,30 +109,51 @@ func (ac *AttestationCollectionQuery) CollectFields(ctx context.Context, satisfi return ac, nil } -func (ac *AttestationCollectionQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (ac *AttestationCollectionQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(attestationcollection.Columns)) + selectedFields = []string{attestationcollection.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "attestations": var ( - path = append(path, field.Name) - query = &AttestationQuery{config: ac.config} + alias = field.Alias + path = append(path, alias) + query = (&AttestationClient{config: ac.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - ac.withAttestations = query + ac.WithNamedAttestations(alias, func(wq *AttestationQuery) { + *wq = *query + }) case "statement": var ( - path = append(path, field.Name) - query = &StatementQuery{config: ac.config} + alias = field.Alias + path = append(path, alias) + query = (&StatementClient{config: ac.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } ac.withStatement = query + case "name": + if _, ok := fieldSeen[attestationcollection.FieldName]; !ok { + selectedFields = append(selectedFields, attestationcollection.FieldName) + fieldSeen[attestationcollection.FieldName] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + ac.Select(selectedFields...) + } return nil } @@ -116,7 +163,7 @@ type attestationcollectionPaginateArgs struct { opts []AttestationCollectionPaginateOption } -func newAttestationCollectionPaginateArgs(rv map[string]interface{}) *attestationcollectionPaginateArgs { +func newAttestationCollectionPaginateArgs(rv map[string]any) *attestationcollectionPaginateArgs { args := &attestationcollectionPaginateArgs{} if rv == nil { return args @@ -151,39 +198,68 @@ func (d *DsseQuery) CollectFields(ctx context.Context, satisfies ...string) (*Ds return d, nil } -func (d *DsseQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (d *DsseQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(dsse.Columns)) + selectedFields = []string{dsse.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "statement": var ( - path = append(path, field.Name) - query = &StatementQuery{config: d.config} + alias = field.Alias + path = append(path, alias) + query = (&StatementClient{config: d.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } d.withStatement = query case "signatures": var ( - path = append(path, field.Name) - query = &SignatureQuery{config: d.config} + alias = field.Alias + path = append(path, alias) + query = (&SignatureClient{config: d.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - d.withSignatures = query - case "payloadDigests", "payload_digests": + d.WithNamedSignatures(alias, func(wq *SignatureQuery) { + *wq = *query + }) + case "payloadDigests": var ( - path = append(path, field.Name) - query = &PayloadDigestQuery{config: d.config} + alias = field.Alias + path = append(path, alias) + query = (&PayloadDigestClient{config: d.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - d.withPayloadDigests = query + d.WithNamedPayloadDigests(alias, func(wq *PayloadDigestQuery) { + *wq = *query + }) + case "gitoidSha256": + if _, ok := fieldSeen[dsse.FieldGitoidSha256]; !ok { + selectedFields = append(selectedFields, dsse.FieldGitoidSha256) + fieldSeen[dsse.FieldGitoidSha256] = struct{}{} + } + case "payloadType": + if _, ok := fieldSeen[dsse.FieldPayloadType]; !ok { + selectedFields = append(selectedFields, dsse.FieldPayloadType) + fieldSeen[dsse.FieldPayloadType] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + d.Select(selectedFields...) + } return nil } @@ -193,7 +269,7 @@ type dssePaginateArgs struct { opts []DssePaginateOption } -func newDssePaginateArgs(rv map[string]interface{}) *dssePaginateArgs { +func newDssePaginateArgs(rv map[string]any) *dssePaginateArgs { args := &dssePaginateArgs{} if rv == nil { return args @@ -228,21 +304,44 @@ func (pd *PayloadDigestQuery) CollectFields(ctx context.Context, satisfies ...st return pd, nil } -func (pd *PayloadDigestQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (pd *PayloadDigestQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(payloaddigest.Columns)) + selectedFields = []string{payloaddigest.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "dsse": var ( - path = append(path, field.Name) - query = &DsseQuery{config: pd.config} + alias = field.Alias + path = append(path, alias) + query = (&DsseClient{config: pd.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } pd.withDsse = query + case "algorithm": + if _, ok := fieldSeen[payloaddigest.FieldAlgorithm]; !ok { + selectedFields = append(selectedFields, payloaddigest.FieldAlgorithm) + fieldSeen[payloaddigest.FieldAlgorithm] = struct{}{} + } + case "value": + if _, ok := fieldSeen[payloaddigest.FieldValue]; !ok { + selectedFields = append(selectedFields, payloaddigest.FieldValue) + fieldSeen[payloaddigest.FieldValue] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + pd.Select(selectedFields...) + } return nil } @@ -252,7 +351,7 @@ type payloaddigestPaginateArgs struct { opts []PayloadDigestPaginateOption } -func newPayloadDigestPaginateArgs(rv map[string]interface{}) *payloaddigestPaginateArgs { +func newPayloadDigestPaginateArgs(rv map[string]any) *payloaddigestPaginateArgs { args := &payloaddigestPaginateArgs{} if rv == nil { return args @@ -287,30 +386,56 @@ func (s *SignatureQuery) CollectFields(ctx context.Context, satisfies ...string) return s, nil } -func (s *SignatureQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (s *SignatureQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(signature.Columns)) + selectedFields = []string{signature.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "dsse": var ( - path = append(path, field.Name) - query = &DsseQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&DsseClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } s.withDsse = query case "timestamps": var ( - path = append(path, field.Name) - query = &TimestampQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&TimestampClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - s.withTimestamps = query + s.WithNamedTimestamps(alias, func(wq *TimestampQuery) { + *wq = *query + }) + case "keyID": + if _, ok := fieldSeen[signature.FieldKeyID]; !ok { + selectedFields = append(selectedFields, signature.FieldKeyID) + fieldSeen[signature.FieldKeyID] = struct{}{} + } + case "signature": + if _, ok := fieldSeen[signature.FieldSignature]; !ok { + selectedFields = append(selectedFields, signature.FieldSignature) + fieldSeen[signature.FieldSignature] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + s.Select(selectedFields...) + } return nil } @@ -320,7 +445,7 @@ type signaturePaginateArgs struct { opts []SignaturePaginateOption } -func newSignaturePaginateArgs(rv map[string]interface{}) *signaturePaginateArgs { +func newSignaturePaginateArgs(rv map[string]any) *signaturePaginateArgs { args := &signaturePaginateArgs{} if rv == nil { return args @@ -355,28 +480,36 @@ func (s *StatementQuery) CollectFields(ctx context.Context, satisfies ...string) return s, nil } -func (s *StatementQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (s *StatementQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(statement.Columns)) + selectedFields = []string{statement.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "subjects": var ( - path = append(path, field.Name) - query = &SubjectQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&SubjectClient{config: s.config}).Query() ) args := newSubjectPaginateArgs(fieldArgs(ctx, new(SubjectWhereInput), path...)) if err := validateFirstLast(args.first, args.last); err != nil { return fmt.Errorf("validate first and last in path %q: %w", path, err) } - pager, err := newSubjectPager(args.opts) + pager, err := newSubjectPager(args.opts, args.last != nil) if err != nil { return fmt.Errorf("create new pager in path %q: %w", path, err) } if query, err = pager.applyFilter(query); err != nil { return err } - if !hasCollectedField(ctx, append(path, edgesField)...) || args.first != nil && *args.first == 0 || args.last != nil && *args.last == 0 { - if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + ignoredEdges := !hasCollectedField(ctx, append(path, edgesField)...) + if hasCollectedField(ctx, append(path, totalCountField)...) || hasCollectedField(ctx, append(path, pageInfoField)...) { + hasPagination := args.after != nil || args.first != nil || args.before != nil || args.last != nil + if hasPagination || ignoredEdges { query := query.Clone() s.loadTotal = append(s.loadTotal, func(ctx context.Context, nodes []*Statement) error { ids := make([]driver.Value, len(nodes)) @@ -388,7 +521,7 @@ func (s *StatementQuery) collectField(ctx context.Context, op *graphql.Operation Count int `sql:"count"` } query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(statement.SubjectsColumn, ids...)) + s.Where(sql.InValues(s.C(statement.SubjectsColumn), ids...)) }) if err := query.GroupBy(statement.SubjectsColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { return err @@ -399,83 +532,83 @@ func (s *StatementQuery) collectField(ctx context.Context, op *graphql.Operation } for i := range nodes { n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n + if nodes[i].Edges.totalCount[0] == nil { + nodes[i].Edges.totalCount[0] = make(map[string]int) + } + nodes[i].Edges.totalCount[0][alias] = n + } + return nil + }) + } else { + s.loadTotal = append(s.loadTotal, func(_ context.Context, nodes []*Statement) error { + for i := range nodes { + n := len(nodes[i].Edges.Subjects) + if nodes[i].Edges.totalCount[0] == nil { + nodes[i].Edges.totalCount[0] = make(map[string]int) + } + nodes[i].Edges.totalCount[0][alias] = n } return nil }) } + } + if ignoredEdges || (args.first != nil && *args.first == 0) || (args.last != nil && *args.last == 0) { continue } - if (args.after != nil || args.first != nil || args.before != nil || args.last != nil) && hasCollectedField(ctx, append(path, totalCountField)...) { - query := query.Clone() - s.loadTotal = append(s.loadTotal, func(ctx context.Context, nodes []*Statement) error { - ids := make([]driver.Value, len(nodes)) - for i := range nodes { - ids[i] = nodes[i].ID - } - var v []struct { - NodeID int `sql:"statement_subjects"` - Count int `sql:"count"` - } - query.Where(func(s *sql.Selector) { - s.Where(sql.InValues(statement.SubjectsColumn, ids...)) - }) - if err := query.GroupBy(statement.SubjectsColumn).Aggregate(Count()).Scan(ctx, &v); err != nil { - return err - } - m := make(map[int]int, len(v)) - for i := range v { - m[v[i].NodeID] = v[i].Count - } - for i := range nodes { - n := m[nodes[i].ID] - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } else { - s.loadTotal = append(s.loadTotal, func(_ context.Context, nodes []*Statement) error { - for i := range nodes { - n := len(nodes[i].Edges.Subjects) - nodes[i].Edges.totalCount[0] = &n - } - return nil - }) - } - query = pager.applyCursors(query, args.after, args.before) - if limit := paginateLimit(args.first, args.last); limit > 0 { - modify := limitRows(statement.SubjectsColumn, limit, pager.orderExpr(args.last != nil)) - query.modifiers = append(query.modifiers, modify) - } else { - query = pager.applyOrder(query, args.last != nil) + if query, err = pager.applyCursors(query, args.after, args.before); err != nil { + return err } path = append(path, edgesField, nodeField) if field := collectedField(ctx, path...); field != nil { - if err := query.collectField(ctx, op, *field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, *field, path, mayAddCondition(satisfies, "Subject")...); err != nil { return err } } - s.withSubjects = query - case "attestationCollections", "attestation_collections": + if limit := paginateLimit(args.first, args.last); limit > 0 { + modify := limitRows(statement.SubjectsColumn, limit, pager.orderExpr(query)) + query.modifiers = append(query.modifiers, modify) + } else { + query = pager.applyOrder(query) + } + s.WithNamedSubjects(alias, func(wq *SubjectQuery) { + *wq = *query + }) + case "attestationCollections": var ( - path = append(path, field.Name) - query = &AttestationCollectionQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&AttestationCollectionClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } s.withAttestationCollections = query case "dsse": var ( - path = append(path, field.Name) - query = &DsseQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&DsseClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - s.withDsse = query + s.WithNamedDsse(alias, func(wq *DsseQuery) { + *wq = *query + }) + case "predicate": + if _, ok := fieldSeen[statement.FieldPredicate]; !ok { + selectedFields = append(selectedFields, statement.FieldPredicate) + fieldSeen[statement.FieldPredicate] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + s.Select(selectedFields...) + } return nil } @@ -485,7 +618,7 @@ type statementPaginateArgs struct { opts []StatementPaginateOption } -func newStatementPaginateArgs(rv map[string]interface{}) *statementPaginateArgs { +func newStatementPaginateArgs(rv map[string]any) *statementPaginateArgs { args := &statementPaginateArgs{} if rv == nil { return args @@ -520,30 +653,51 @@ func (s *SubjectQuery) CollectFields(ctx context.Context, satisfies ...string) ( return s, nil } -func (s *SubjectQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (s *SubjectQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(subject.Columns)) + selectedFields = []string{subject.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { - case "subjectDigests", "subject_digests": + case "subjectDigests": var ( - path = append(path, field.Name) - query = &SubjectDigestQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&SubjectDigestClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } - s.withSubjectDigests = query + s.WithNamedSubjectDigests(alias, func(wq *SubjectDigestQuery) { + *wq = *query + }) case "statement": var ( - path = append(path, field.Name) - query = &StatementQuery{config: s.config} + alias = field.Alias + path = append(path, alias) + query = (&StatementClient{config: s.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } s.withStatement = query + case "name": + if _, ok := fieldSeen[subject.FieldName]; !ok { + selectedFields = append(selectedFields, subject.FieldName) + fieldSeen[subject.FieldName] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + s.Select(selectedFields...) + } return nil } @@ -553,7 +707,7 @@ type subjectPaginateArgs struct { opts []SubjectPaginateOption } -func newSubjectPaginateArgs(rv map[string]interface{}) *subjectPaginateArgs { +func newSubjectPaginateArgs(rv map[string]any) *subjectPaginateArgs { args := &subjectPaginateArgs{} if rv == nil { return args @@ -588,21 +742,44 @@ func (sd *SubjectDigestQuery) CollectFields(ctx context.Context, satisfies ...st return sd, nil } -func (sd *SubjectDigestQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (sd *SubjectDigestQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(subjectdigest.Columns)) + selectedFields = []string{subjectdigest.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "subject": var ( - path = append(path, field.Name) - query = &SubjectQuery{config: sd.config} + alias = field.Alias + path = append(path, alias) + query = (&SubjectClient{config: sd.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } sd.withSubject = query + case "algorithm": + if _, ok := fieldSeen[subjectdigest.FieldAlgorithm]; !ok { + selectedFields = append(selectedFields, subjectdigest.FieldAlgorithm) + fieldSeen[subjectdigest.FieldAlgorithm] = struct{}{} + } + case "value": + if _, ok := fieldSeen[subjectdigest.FieldValue]; !ok { + selectedFields = append(selectedFields, subjectdigest.FieldValue) + fieldSeen[subjectdigest.FieldValue] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + sd.Select(selectedFields...) + } return nil } @@ -612,7 +789,7 @@ type subjectdigestPaginateArgs struct { opts []SubjectDigestPaginateOption } -func newSubjectDigestPaginateArgs(rv map[string]interface{}) *subjectdigestPaginateArgs { +func newSubjectDigestPaginateArgs(rv map[string]any) *subjectdigestPaginateArgs { args := &subjectdigestPaginateArgs{} if rv == nil { return args @@ -647,21 +824,44 @@ func (t *TimestampQuery) CollectFields(ctx context.Context, satisfies ...string) return t, nil } -func (t *TimestampQuery) collectField(ctx context.Context, op *graphql.OperationContext, field graphql.CollectedField, path []string, satisfies ...string) error { +func (t *TimestampQuery) collectField(ctx context.Context, opCtx *graphql.OperationContext, collected graphql.CollectedField, path []string, satisfies ...string) error { path = append([]string(nil), path...) - for _, field := range graphql.CollectFields(op, field.Selections, satisfies) { + var ( + unknownSeen bool + fieldSeen = make(map[string]struct{}, len(timestamp.Columns)) + selectedFields = []string{timestamp.FieldID} + ) + for _, field := range graphql.CollectFields(opCtx, collected.Selections, satisfies) { switch field.Name { case "signature": var ( - path = append(path, field.Name) - query = &SignatureQuery{config: t.config} + alias = field.Alias + path = append(path, alias) + query = (&SignatureClient{config: t.config}).Query() ) - if err := query.collectField(ctx, op, field, path, satisfies...); err != nil { + if err := query.collectField(ctx, opCtx, field, path, satisfies...); err != nil { return err } t.withSignature = query + case "type": + if _, ok := fieldSeen[timestamp.FieldType]; !ok { + selectedFields = append(selectedFields, timestamp.FieldType) + fieldSeen[timestamp.FieldType] = struct{}{} + } + case "timestamp": + if _, ok := fieldSeen[timestamp.FieldTimestamp]; !ok { + selectedFields = append(selectedFields, timestamp.FieldTimestamp) + fieldSeen[timestamp.FieldTimestamp] = struct{}{} + } + case "id": + case "__typename": + default: + unknownSeen = true } } + if !unknownSeen { + t.Select(selectedFields...) + } return nil } @@ -671,7 +871,7 @@ type timestampPaginateArgs struct { opts []TimestampPaginateOption } -func newTimestampPaginateArgs(rv map[string]interface{}) *timestampPaginateArgs { +func newTimestampPaginateArgs(rv map[string]any) *timestampPaginateArgs { args := ×tampPaginateArgs{} if rv == nil { return args @@ -705,35 +905,18 @@ const ( whereField = "where" ) -func fieldArgs(ctx context.Context, whereInput interface{}, path ...string) map[string]interface{} { - fc := graphql.GetFieldContext(ctx) - if fc == nil { +func fieldArgs(ctx context.Context, whereInput any, path ...string) map[string]any { + field := collectedField(ctx, path...) + if field == nil || field.Arguments == nil { return nil } oc := graphql.GetOperationContext(ctx) - for _, name := range path { - var field *graphql.CollectedField - for _, f := range graphql.CollectFields(oc, fc.Field.Selections, nil) { - if f.Name == name { - field = &f - break - } - } - if field == nil { - return nil - } - cf, err := fc.Child(ctx, *field) - if err != nil { - args := field.ArgumentMap(oc.Variables) - return unmarshalArgs(ctx, whereInput, args) - } - fc = cf - } - return fc.Args + args := field.ArgumentMap(oc.Variables) + return unmarshalArgs(ctx, whereInput, args) } // unmarshalArgs allows extracting the field arguments from their raw representation. -func unmarshalArgs(ctx context.Context, whereInput interface{}, args map[string]interface{}) map[string]interface{} { +func unmarshalArgs(ctx context.Context, whereInput any, args map[string]any) map[string]any { for _, k := range []string{firstField, lastField} { v, ok := args[k] if !ok { @@ -751,7 +934,7 @@ func unmarshalArgs(ctx context.Context, whereInput interface{}, args map[string] } c := &Cursor{} if c.UnmarshalGQL(v) == nil { - args[k] = &c + args[k] = c } } if v, ok := args[whereField]; ok && whereInput != nil { @@ -785,3 +968,17 @@ func limitRows(partitionBy string, limit int, orderBy ...sql.Querier) func(s *sq Prefix(with) } } + +// mayAddCondition appends another type condition to the satisfies list +// if condition is enabled (Node/Nodes) and it does not exist in the list. +func mayAddCondition(satisfies []string, typeCond string) []string { + if len(satisfies) == 0 { + return satisfies + } + for _, s := range satisfies { + if typeCond == s { + return satisfies + } + } + return append(satisfies, typeCond) +} diff --git a/ent/gql_edge.go b/ent/gql_edge.go index fb569ab3..9497978d 100644 --- a/ent/gql_edge.go +++ b/ent/gql_edge.go @@ -16,8 +16,12 @@ func (a *Attestation) AttestationCollection(ctx context.Context) (*AttestationCo return result, err } -func (ac *AttestationCollection) Attestations(ctx context.Context) ([]*Attestation, error) { - result, err := ac.Edges.AttestationsOrErr() +func (ac *AttestationCollection) Attestations(ctx context.Context) (result []*Attestation, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = ac.NamedAttestations(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = ac.Edges.AttestationsOrErr() + } if IsNotLoaded(err) { result, err = ac.QueryAttestations().All(ctx) } @@ -40,16 +44,24 @@ func (d *Dsse) Statement(ctx context.Context) (*Statement, error) { return result, MaskNotFound(err) } -func (d *Dsse) Signatures(ctx context.Context) ([]*Signature, error) { - result, err := d.Edges.SignaturesOrErr() +func (d *Dsse) Signatures(ctx context.Context) (result []*Signature, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = d.NamedSignatures(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = d.Edges.SignaturesOrErr() + } if IsNotLoaded(err) { result, err = d.QuerySignatures().All(ctx) } return result, err } -func (d *Dsse) PayloadDigests(ctx context.Context) ([]*PayloadDigest, error) { - result, err := d.Edges.PayloadDigestsOrErr() +func (d *Dsse) PayloadDigests(ctx context.Context) (result []*PayloadDigest, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = d.NamedPayloadDigests(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = d.Edges.PayloadDigestsOrErr() + } if IsNotLoaded(err) { result, err = d.QueryPayloadDigests().All(ctx) } @@ -72,8 +84,12 @@ func (s *Signature) Dsse(ctx context.Context) (*Dsse, error) { return result, MaskNotFound(err) } -func (s *Signature) Timestamps(ctx context.Context) ([]*Timestamp, error) { - result, err := s.Edges.TimestampsOrErr() +func (s *Signature) Timestamps(ctx context.Context) (result []*Timestamp, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = s.NamedTimestamps(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = s.Edges.TimestampsOrErr() + } if IsNotLoaded(err) { result, err = s.QueryTimestamps().All(ctx) } @@ -86,69 +102,18 @@ func (s *Statement) Subjects( opts := []SubjectPaginateOption{ WithSubjectFilter(where.Filter), } - totalCount := s.Edges.totalCount[0] - if nodes, err := s.Edges.SubjectsOrErr(); err == nil || totalCount != nil { - conn := &SubjectConnection{Edges: []*SubjectEdge{}} - if totalCount != nil { - conn.TotalCount = *totalCount - } - pager, err := newSubjectPager(opts) + alias := graphql.GetFieldContext(ctx).Field.Alias + totalCount, hasTotalCount := s.Edges.totalCount[0][alias] + if nodes, err := s.NamedSubjects(alias); err == nil || hasTotalCount { + pager, err := newSubjectPager(opts, last != nil) if err != nil { return nil, err } + conn := &SubjectConnection{Edges: []*SubjectEdge{}, TotalCount: totalCount} conn.build(nodes, pager, after, first, before, last) return conn, nil } - query := s.QuerySubjects() - if err := validateFirstLast(first, last); err != nil { - return nil, err - } - pager, err := newSubjectPager(opts) - if err != nil { - return nil, err - } - if query, err = pager.applyFilter(query); err != nil { - return nil, err - } - conn := &SubjectConnection{Edges: []*SubjectEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if totalCount != nil { - conn.TotalCount = *totalCount - } else if conn.TotalCount, err = query.Count(ctx); err != nil { - return nil, err - } - conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 - conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 - } - return conn, nil - } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := query.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count - } - - query = pager.applyCursors(query, after, before) - query = pager.applyOrder(query, last != nil) - if limit := paginateLimit(first, last); limit != 0 { - query.Limit(limit) - } - if field := collectedField(ctx, edgesField, nodeField); field != nil { - if err := query.collectField(ctx, graphql.GetOperationContext(ctx), *field, []string{edgesField, nodeField}); err != nil { - return nil, err - } - } - - nodes, err := query.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err - } - conn.build(nodes, pager, after, first, before, last) - return conn, nil + return s.QuerySubjects().Paginate(ctx, after, first, before, last, opts...) } func (s *Statement) AttestationCollections(ctx context.Context) (*AttestationCollection, error) { @@ -159,16 +124,24 @@ func (s *Statement) AttestationCollections(ctx context.Context) (*AttestationCol return result, MaskNotFound(err) } -func (s *Statement) Dsse(ctx context.Context) ([]*Dsse, error) { - result, err := s.Edges.DsseOrErr() +func (s *Statement) Dsse(ctx context.Context) (result []*Dsse, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = s.NamedDsse(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = s.Edges.DsseOrErr() + } if IsNotLoaded(err) { result, err = s.QueryDsse().All(ctx) } return result, err } -func (s *Subject) SubjectDigests(ctx context.Context) ([]*SubjectDigest, error) { - result, err := s.Edges.SubjectDigestsOrErr() +func (s *Subject) SubjectDigests(ctx context.Context) (result []*SubjectDigest, err error) { + if fc := graphql.GetFieldContext(ctx); fc != nil && fc.Field.Alias != "" { + result, err = s.NamedSubjectDigests(graphql.GetFieldContext(ctx).Field.Alias) + } else { + result, err = s.Edges.SubjectDigestsOrErr() + } if IsNotLoaded(err) { result, err = s.QuerySubjectDigests().All(ctx) } diff --git a/ent/gql_node.go b/ent/gql_node.go index 97b5f97c..d48b0190 100644 --- a/ent/gql_node.go +++ b/ent/gql_node.go @@ -4,7 +4,6 @@ package ent import ( "context" - "encoding/json" "fmt" "sync" "sync/atomic" @@ -29,409 +28,35 @@ import ( // Noder wraps the basic Node method. type Noder interface { - Node(context.Context) (*Node, error) + IsNode() } -// Node in the graph. -type Node struct { - ID int `json:"id,omitempty"` // node id. - Type string `json:"type,omitempty"` // node type. - Fields []*Field `json:"fields,omitempty"` // node fields. - Edges []*Edge `json:"edges,omitempty"` // node edges. -} +// IsNode implements the Node interface check for GQLGen. +func (n *Attestation) IsNode() {} -// Field of a node. -type Field struct { - Type string `json:"type,omitempty"` // field type. - Name string `json:"name,omitempty"` // field name (as in struct). - Value string `json:"value,omitempty"` // stringified value. -} +// IsNode implements the Node interface check for GQLGen. +func (n *AttestationCollection) IsNode() {} -// Edges between two nodes. -type Edge struct { - Type string `json:"type,omitempty"` // edge type. - Name string `json:"name,omitempty"` // edge name. - IDs []int `json:"ids,omitempty"` // node ids (where this edge point to). -} +// IsNode implements the Node interface check for GQLGen. +func (n *Dsse) IsNode() {} -func (a *Attestation) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: a.ID, - Type: "Attestation", - Fields: make([]*Field, 1), - Edges: make([]*Edge, 1), - } - var buf []byte - if buf, err = json.Marshal(a.Type); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "type", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "AttestationCollection", - Name: "attestation_collection", - } - err = a.QueryAttestationCollection(). - Select(attestationcollection.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - return node, nil -} +// IsNode implements the Node interface check for GQLGen. +func (n *PayloadDigest) IsNode() {} -func (ac *AttestationCollection) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: ac.ID, - Type: "AttestationCollection", - Fields: make([]*Field, 1), - Edges: make([]*Edge, 2), - } - var buf []byte - if buf, err = json.Marshal(ac.Name); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "name", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Attestation", - Name: "attestations", - } - err = ac.QueryAttestations(). - Select(attestation.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - node.Edges[1] = &Edge{ - Type: "Statement", - Name: "statement", - } - err = ac.QueryStatement(). - Select(statement.FieldID). - Scan(ctx, &node.Edges[1].IDs) - if err != nil { - return nil, err - } - return node, nil -} +// IsNode implements the Node interface check for GQLGen. +func (n *Signature) IsNode() {} -func (d *Dsse) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: d.ID, - Type: "Dsse", - Fields: make([]*Field, 2), - Edges: make([]*Edge, 3), - } - var buf []byte - if buf, err = json.Marshal(d.GitoidSha256); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "gitoid_sha256", - Value: string(buf), - } - if buf, err = json.Marshal(d.PayloadType); err != nil { - return nil, err - } - node.Fields[1] = &Field{ - Type: "string", - Name: "payload_type", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Statement", - Name: "statement", - } - err = d.QueryStatement(). - Select(statement.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - node.Edges[1] = &Edge{ - Type: "Signature", - Name: "signatures", - } - err = d.QuerySignatures(). - Select(signature.FieldID). - Scan(ctx, &node.Edges[1].IDs) - if err != nil { - return nil, err - } - node.Edges[2] = &Edge{ - Type: "PayloadDigest", - Name: "payload_digests", - } - err = d.QueryPayloadDigests(). - Select(payloaddigest.FieldID). - Scan(ctx, &node.Edges[2].IDs) - if err != nil { - return nil, err - } - return node, nil -} +// IsNode implements the Node interface check for GQLGen. +func (n *Statement) IsNode() {} -func (pd *PayloadDigest) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: pd.ID, - Type: "PayloadDigest", - Fields: make([]*Field, 2), - Edges: make([]*Edge, 1), - } - var buf []byte - if buf, err = json.Marshal(pd.Algorithm); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "algorithm", - Value: string(buf), - } - if buf, err = json.Marshal(pd.Value); err != nil { - return nil, err - } - node.Fields[1] = &Field{ - Type: "string", - Name: "value", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Dsse", - Name: "dsse", - } - err = pd.QueryDsse(). - Select(dsse.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - return node, nil -} +// IsNode implements the Node interface check for GQLGen. +func (n *Subject) IsNode() {} -func (s *Signature) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: s.ID, - Type: "Signature", - Fields: make([]*Field, 2), - Edges: make([]*Edge, 2), - } - var buf []byte - if buf, err = json.Marshal(s.KeyID); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "key_id", - Value: string(buf), - } - if buf, err = json.Marshal(s.Signature); err != nil { - return nil, err - } - node.Fields[1] = &Field{ - Type: "string", - Name: "signature", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Dsse", - Name: "dsse", - } - err = s.QueryDsse(). - Select(dsse.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - node.Edges[1] = &Edge{ - Type: "Timestamp", - Name: "timestamps", - } - err = s.QueryTimestamps(). - Select(timestamp.FieldID). - Scan(ctx, &node.Edges[1].IDs) - if err != nil { - return nil, err - } - return node, nil -} - -func (s *Statement) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: s.ID, - Type: "Statement", - Fields: make([]*Field, 1), - Edges: make([]*Edge, 3), - } - var buf []byte - if buf, err = json.Marshal(s.Predicate); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "predicate", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Subject", - Name: "subjects", - } - err = s.QuerySubjects(). - Select(subject.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - node.Edges[1] = &Edge{ - Type: "AttestationCollection", - Name: "attestation_collections", - } - err = s.QueryAttestationCollections(). - Select(attestationcollection.FieldID). - Scan(ctx, &node.Edges[1].IDs) - if err != nil { - return nil, err - } - node.Edges[2] = &Edge{ - Type: "Dsse", - Name: "dsse", - } - err = s.QueryDsse(). - Select(dsse.FieldID). - Scan(ctx, &node.Edges[2].IDs) - if err != nil { - return nil, err - } - return node, nil -} +// IsNode implements the Node interface check for GQLGen. +func (n *SubjectDigest) IsNode() {} -func (s *Subject) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: s.ID, - Type: "Subject", - Fields: make([]*Field, 1), - Edges: make([]*Edge, 2), - } - var buf []byte - if buf, err = json.Marshal(s.Name); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "name", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "SubjectDigest", - Name: "subject_digests", - } - err = s.QuerySubjectDigests(). - Select(subjectdigest.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - node.Edges[1] = &Edge{ - Type: "Statement", - Name: "statement", - } - err = s.QueryStatement(). - Select(statement.FieldID). - Scan(ctx, &node.Edges[1].IDs) - if err != nil { - return nil, err - } - return node, nil -} - -func (sd *SubjectDigest) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: sd.ID, - Type: "SubjectDigest", - Fields: make([]*Field, 2), - Edges: make([]*Edge, 1), - } - var buf []byte - if buf, err = json.Marshal(sd.Algorithm); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "algorithm", - Value: string(buf), - } - if buf, err = json.Marshal(sd.Value); err != nil { - return nil, err - } - node.Fields[1] = &Field{ - Type: "string", - Name: "value", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Subject", - Name: "subject", - } - err = sd.QuerySubject(). - Select(subject.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - return node, nil -} - -func (t *Timestamp) Node(ctx context.Context) (node *Node, err error) { - node = &Node{ - ID: t.ID, - Type: "Timestamp", - Fields: make([]*Field, 2), - Edges: make([]*Edge, 1), - } - var buf []byte - if buf, err = json.Marshal(t.Type); err != nil { - return nil, err - } - node.Fields[0] = &Field{ - Type: "string", - Name: "type", - Value: string(buf), - } - if buf, err = json.Marshal(t.Timestamp); err != nil { - return nil, err - } - node.Fields[1] = &Field{ - Type: "time.Time", - Name: "timestamp", - Value: string(buf), - } - node.Edges[0] = &Edge{ - Type: "Signature", - Name: "signature", - } - err = t.QuerySignature(). - Select(signature.FieldID). - Scan(ctx, &node.Edges[0].IDs) - if err != nil { - return nil, err - } - return node, nil -} - -func (c *Client) Node(ctx context.Context, id int) (*Node, error) { - n, err := c.Noder(ctx, id) - if err != nil { - return nil, err - } - return n.Node(ctx) -} +// IsNode implements the Node interface check for GQLGen. +func (n *Timestamp) IsNode() {} var errNodeInvalidID = &NotFoundError{"node"} diff --git a/ent/gql_pagination.go b/ent/gql_pagination.go index 554c637a..187db132 100644 --- a/ent/gql_pagination.go +++ b/ent/gql_pagination.go @@ -4,13 +4,10 @@ package ent import ( "context" - "encoding/base64" "errors" - "fmt" - "io" - "strconv" - "strings" + "entgo.io/contrib/entgql" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/errcode" @@ -24,165 +21,22 @@ import ( "github.com/testifysec/archivista/ent/subjectdigest" "github.com/testifysec/archivista/ent/timestamp" "github.com/vektah/gqlparser/v2/gqlerror" - "github.com/vmihailenco/msgpack/v5" ) -// OrderDirection defines the directions in which to order a list of items. -type OrderDirection string - -const ( - // OrderDirectionAsc specifies an ascending order. - OrderDirectionAsc OrderDirection = "ASC" - // OrderDirectionDesc specifies a descending order. - OrderDirectionDesc OrderDirection = "DESC" +// Common entgql types. +type ( + Cursor = entgql.Cursor[int] + PageInfo = entgql.PageInfo[int] + OrderDirection = entgql.OrderDirection ) -// Validate the order direction value. -func (o OrderDirection) Validate() error { - if o != OrderDirectionAsc && o != OrderDirectionDesc { - return fmt.Errorf("%s is not a valid OrderDirection", o) - } - return nil -} - -// String implements fmt.Stringer interface. -func (o OrderDirection) String() string { - return string(o) -} - -// MarshalGQL implements graphql.Marshaler interface. -func (o OrderDirection) MarshalGQL(w io.Writer) { - io.WriteString(w, strconv.Quote(o.String())) -} - -// UnmarshalGQL implements graphql.Unmarshaler interface. -func (o *OrderDirection) UnmarshalGQL(val interface{}) error { - str, ok := val.(string) - if !ok { - return fmt.Errorf("order direction %T must be a string", val) - } - *o = OrderDirection(str) - return o.Validate() -} - -func (o OrderDirection) reverse() OrderDirection { - if o == OrderDirectionDesc { - return OrderDirectionAsc - } - return OrderDirectionDesc -} - -func (o OrderDirection) orderFunc(field string) OrderFunc { - if o == OrderDirectionDesc { +func orderFunc(o OrderDirection, field string) func(*sql.Selector) { + if o == entgql.OrderDirectionDesc { return Desc(field) } return Asc(field) } -func cursorsToPredicates(direction OrderDirection, after, before *Cursor, field, idField string) []func(s *sql.Selector) { - var predicates []func(s *sql.Selector) - if after != nil { - if after.Value != nil { - var predicate func([]string, ...interface{}) *sql.Predicate - if direction == OrderDirectionAsc { - predicate = sql.CompositeGT - } else { - predicate = sql.CompositeLT - } - predicates = append(predicates, func(s *sql.Selector) { - s.Where(predicate( - s.Columns(field, idField), - after.Value, after.ID, - )) - }) - } else { - var predicate func(string, interface{}) *sql.Predicate - if direction == OrderDirectionAsc { - predicate = sql.GT - } else { - predicate = sql.LT - } - predicates = append(predicates, func(s *sql.Selector) { - s.Where(predicate( - s.C(idField), - after.ID, - )) - }) - } - } - if before != nil { - if before.Value != nil { - var predicate func([]string, ...interface{}) *sql.Predicate - if direction == OrderDirectionAsc { - predicate = sql.CompositeLT - } else { - predicate = sql.CompositeGT - } - predicates = append(predicates, func(s *sql.Selector) { - s.Where(predicate( - s.Columns(field, idField), - before.Value, before.ID, - )) - }) - } else { - var predicate func(string, interface{}) *sql.Predicate - if direction == OrderDirectionAsc { - predicate = sql.LT - } else { - predicate = sql.GT - } - predicates = append(predicates, func(s *sql.Selector) { - s.Where(predicate( - s.C(idField), - before.ID, - )) - }) - } - } - return predicates -} - -// PageInfo of a connection type. -type PageInfo struct { - HasNextPage bool `json:"hasNextPage"` - HasPreviousPage bool `json:"hasPreviousPage"` - StartCursor *Cursor `json:"startCursor"` - EndCursor *Cursor `json:"endCursor"` -} - -// Cursor of an edge type. -type Cursor struct { - ID int `msgpack:"i"` - Value Value `msgpack:"v,omitempty"` -} - -// MarshalGQL implements graphql.Marshaler interface. -func (c Cursor) MarshalGQL(w io.Writer) { - quote := []byte{'"'} - w.Write(quote) - defer w.Write(quote) - wc := base64.NewEncoder(base64.RawStdEncoding, w) - defer wc.Close() - _ = msgpack.NewEncoder(wc).Encode(c) -} - -// UnmarshalGQL implements graphql.Unmarshaler interface. -func (c *Cursor) UnmarshalGQL(v interface{}) error { - s, ok := v.(string) - if !ok { - return fmt.Errorf("%T is not a string", v) - } - if err := msgpack.NewDecoder( - base64.NewDecoder( - base64.RawStdEncoding, - strings.NewReader(s), - ), - ).Decode(c); err != nil { - return fmt.Errorf("cannot decode cursor: %w", err) - } - return nil -} - const errInvalidPagination = "INVALID_PAGINATION" func validateFirstLast(first, last *int) (err *gqlerror.Error) { @@ -215,7 +69,7 @@ func collectedField(ctx context.Context, path ...string) *graphql.CollectedField walk: for _, name := range path { for _, f := range graphql.CollectFields(oc, field.Selections, nil) { - if f.Name == name { + if f.Alias == name { field = f continue walk } @@ -333,12 +187,13 @@ func WithAttestationFilter(filter func(*AttestationQuery) (*AttestationQuery, er } type attestationPager struct { - order *AttestationOrder - filter func(*AttestationQuery) (*AttestationQuery, error) + reverse bool + order *AttestationOrder + filter func(*AttestationQuery) (*AttestationQuery, error) } -func newAttestationPager(opts []AttestationPaginateOption) (*attestationPager, error) { - pager := &attestationPager{} +func newAttestationPager(opts []AttestationPaginateOption, reverse bool) (*attestationPager, error) { + pager := &attestationPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -361,37 +216,44 @@ func (p *attestationPager) toCursor(a *Attestation) Cursor { return p.order.Field.toCursor(a) } -func (p *attestationPager) applyCursors(query *AttestationQuery, after, before *Cursor) *AttestationQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultAttestationOrder.Field.field, - ) { +func (p *attestationPager) applyCursors(query *AttestationQuery, after, before *Cursor) (*AttestationQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultAttestationOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *attestationPager) applyOrder(query *AttestationQuery, reverse bool) *AttestationQuery { +func (p *attestationPager) applyOrder(query *AttestationQuery) *AttestationQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultAttestationOrder.Field { - query = query.Order(direction.orderFunc(DefaultAttestationOrder.Field.field)) + query = query.Order(DefaultAttestationOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *attestationPager) orderExpr(reverse bool) sql.Querier { +func (p *attestationPager) orderExpr(query *AttestationQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultAttestationOrder.Field { - b.Comma().Ident(DefaultAttestationOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultAttestationOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -404,7 +266,7 @@ func (a *AttestationQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newAttestationPager(opts) + pager, err := newAttestationPager(opts, last != nil) if err != nil { return nil, err } @@ -412,27 +274,23 @@ func (a *AttestationQuery) Paginate( return nil, err } conn := &AttestationConnection{Edges: []*AttestationEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = a.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = a.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := a.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if a, err = pager.applyCursors(a, after, before); err != nil { + return nil, err } - - a = pager.applyCursors(a, after, before) - a = pager.applyOrder(a, last != nil) if limit := paginateLimit(first, last); limit != 0 { a.Limit(limit) } @@ -441,10 +299,10 @@ func (a *AttestationQuery) Paginate( return nil, err } } - + a = pager.applyOrder(a) nodes, err := a.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -452,7 +310,10 @@ func (a *AttestationQuery) Paginate( // AttestationOrderField defines the ordering field of Attestation. type AttestationOrderField struct { - field string + // Value extracts the ordering value from the given Attestation. + Value func(*Attestation) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) attestation.OrderOption toCursor func(*Attestation) Cursor } @@ -464,9 +325,13 @@ type AttestationOrder struct { // DefaultAttestationOrder is the default ordering of Attestation. var DefaultAttestationOrder = &AttestationOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &AttestationOrderField{ - field: attestation.FieldID, + Value: func(a *Attestation) (ent.Value, error) { + return a.ID, nil + }, + column: attestation.FieldID, + toTerm: attestation.ByID, toCursor: func(a *Attestation) Cursor { return Cursor{ID: a.ID} }, @@ -568,12 +433,13 @@ func WithAttestationCollectionFilter(filter func(*AttestationCollectionQuery) (* } type attestationcollectionPager struct { - order *AttestationCollectionOrder - filter func(*AttestationCollectionQuery) (*AttestationCollectionQuery, error) + reverse bool + order *AttestationCollectionOrder + filter func(*AttestationCollectionQuery) (*AttestationCollectionQuery, error) } -func newAttestationCollectionPager(opts []AttestationCollectionPaginateOption) (*attestationcollectionPager, error) { - pager := &attestationcollectionPager{} +func newAttestationCollectionPager(opts []AttestationCollectionPaginateOption, reverse bool) (*attestationcollectionPager, error) { + pager := &attestationcollectionPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -596,37 +462,44 @@ func (p *attestationcollectionPager) toCursor(ac *AttestationCollection) Cursor return p.order.Field.toCursor(ac) } -func (p *attestationcollectionPager) applyCursors(query *AttestationCollectionQuery, after, before *Cursor) *AttestationCollectionQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultAttestationCollectionOrder.Field.field, - ) { +func (p *attestationcollectionPager) applyCursors(query *AttestationCollectionQuery, after, before *Cursor) (*AttestationCollectionQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultAttestationCollectionOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *attestationcollectionPager) applyOrder(query *AttestationCollectionQuery, reverse bool) *AttestationCollectionQuery { +func (p *attestationcollectionPager) applyOrder(query *AttestationCollectionQuery) *AttestationCollectionQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultAttestationCollectionOrder.Field { - query = query.Order(direction.orderFunc(DefaultAttestationCollectionOrder.Field.field)) + query = query.Order(DefaultAttestationCollectionOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *attestationcollectionPager) orderExpr(reverse bool) sql.Querier { +func (p *attestationcollectionPager) orderExpr(query *AttestationCollectionQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultAttestationCollectionOrder.Field { - b.Comma().Ident(DefaultAttestationCollectionOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultAttestationCollectionOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -639,7 +512,7 @@ func (ac *AttestationCollectionQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newAttestationCollectionPager(opts) + pager, err := newAttestationCollectionPager(opts, last != nil) if err != nil { return nil, err } @@ -647,27 +520,23 @@ func (ac *AttestationCollectionQuery) Paginate( return nil, err } conn := &AttestationCollectionConnection{Edges: []*AttestationCollectionEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = ac.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = ac.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := ac.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if ac, err = pager.applyCursors(ac, after, before); err != nil { + return nil, err } - - ac = pager.applyCursors(ac, after, before) - ac = pager.applyOrder(ac, last != nil) if limit := paginateLimit(first, last); limit != 0 { ac.Limit(limit) } @@ -676,10 +545,10 @@ func (ac *AttestationCollectionQuery) Paginate( return nil, err } } - + ac = pager.applyOrder(ac) nodes, err := ac.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -687,7 +556,10 @@ func (ac *AttestationCollectionQuery) Paginate( // AttestationCollectionOrderField defines the ordering field of AttestationCollection. type AttestationCollectionOrderField struct { - field string + // Value extracts the ordering value from the given AttestationCollection. + Value func(*AttestationCollection) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) attestationcollection.OrderOption toCursor func(*AttestationCollection) Cursor } @@ -699,9 +571,13 @@ type AttestationCollectionOrder struct { // DefaultAttestationCollectionOrder is the default ordering of AttestationCollection. var DefaultAttestationCollectionOrder = &AttestationCollectionOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &AttestationCollectionOrderField{ - field: attestationcollection.FieldID, + Value: func(ac *AttestationCollection) (ent.Value, error) { + return ac.ID, nil + }, + column: attestationcollection.FieldID, + toTerm: attestationcollection.ByID, toCursor: func(ac *AttestationCollection) Cursor { return Cursor{ID: ac.ID} }, @@ -803,12 +679,13 @@ func WithDsseFilter(filter func(*DsseQuery) (*DsseQuery, error)) DssePaginateOpt } type dssePager struct { - order *DsseOrder - filter func(*DsseQuery) (*DsseQuery, error) + reverse bool + order *DsseOrder + filter func(*DsseQuery) (*DsseQuery, error) } -func newDssePager(opts []DssePaginateOption) (*dssePager, error) { - pager := &dssePager{} +func newDssePager(opts []DssePaginateOption, reverse bool) (*dssePager, error) { + pager := &dssePager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -831,37 +708,44 @@ func (p *dssePager) toCursor(d *Dsse) Cursor { return p.order.Field.toCursor(d) } -func (p *dssePager) applyCursors(query *DsseQuery, after, before *Cursor) *DsseQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultDsseOrder.Field.field, - ) { +func (p *dssePager) applyCursors(query *DsseQuery, after, before *Cursor) (*DsseQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultDsseOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *dssePager) applyOrder(query *DsseQuery, reverse bool) *DsseQuery { +func (p *dssePager) applyOrder(query *DsseQuery) *DsseQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultDsseOrder.Field { - query = query.Order(direction.orderFunc(DefaultDsseOrder.Field.field)) + query = query.Order(DefaultDsseOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *dssePager) orderExpr(reverse bool) sql.Querier { +func (p *dssePager) orderExpr(query *DsseQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultDsseOrder.Field { - b.Comma().Ident(DefaultDsseOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultDsseOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -874,7 +758,7 @@ func (d *DsseQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newDssePager(opts) + pager, err := newDssePager(opts, last != nil) if err != nil { return nil, err } @@ -882,27 +766,23 @@ func (d *DsseQuery) Paginate( return nil, err } conn := &DsseConnection{Edges: []*DsseEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = d.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = d.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := d.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if d, err = pager.applyCursors(d, after, before); err != nil { + return nil, err } - - d = pager.applyCursors(d, after, before) - d = pager.applyOrder(d, last != nil) if limit := paginateLimit(first, last); limit != 0 { d.Limit(limit) } @@ -911,10 +791,10 @@ func (d *DsseQuery) Paginate( return nil, err } } - + d = pager.applyOrder(d) nodes, err := d.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -922,7 +802,10 @@ func (d *DsseQuery) Paginate( // DsseOrderField defines the ordering field of Dsse. type DsseOrderField struct { - field string + // Value extracts the ordering value from the given Dsse. + Value func(*Dsse) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) dsse.OrderOption toCursor func(*Dsse) Cursor } @@ -934,9 +817,13 @@ type DsseOrder struct { // DefaultDsseOrder is the default ordering of Dsse. var DefaultDsseOrder = &DsseOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &DsseOrderField{ - field: dsse.FieldID, + Value: func(d *Dsse) (ent.Value, error) { + return d.ID, nil + }, + column: dsse.FieldID, + toTerm: dsse.ByID, toCursor: func(d *Dsse) Cursor { return Cursor{ID: d.ID} }, @@ -1038,12 +925,13 @@ func WithPayloadDigestFilter(filter func(*PayloadDigestQuery) (*PayloadDigestQue } type payloaddigestPager struct { - order *PayloadDigestOrder - filter func(*PayloadDigestQuery) (*PayloadDigestQuery, error) + reverse bool + order *PayloadDigestOrder + filter func(*PayloadDigestQuery) (*PayloadDigestQuery, error) } -func newPayloadDigestPager(opts []PayloadDigestPaginateOption) (*payloaddigestPager, error) { - pager := &payloaddigestPager{} +func newPayloadDigestPager(opts []PayloadDigestPaginateOption, reverse bool) (*payloaddigestPager, error) { + pager := &payloaddigestPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -1066,37 +954,44 @@ func (p *payloaddigestPager) toCursor(pd *PayloadDigest) Cursor { return p.order.Field.toCursor(pd) } -func (p *payloaddigestPager) applyCursors(query *PayloadDigestQuery, after, before *Cursor) *PayloadDigestQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultPayloadDigestOrder.Field.field, - ) { +func (p *payloaddigestPager) applyCursors(query *PayloadDigestQuery, after, before *Cursor) (*PayloadDigestQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultPayloadDigestOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *payloaddigestPager) applyOrder(query *PayloadDigestQuery, reverse bool) *PayloadDigestQuery { +func (p *payloaddigestPager) applyOrder(query *PayloadDigestQuery) *PayloadDigestQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultPayloadDigestOrder.Field { - query = query.Order(direction.orderFunc(DefaultPayloadDigestOrder.Field.field)) + query = query.Order(DefaultPayloadDigestOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *payloaddigestPager) orderExpr(reverse bool) sql.Querier { +func (p *payloaddigestPager) orderExpr(query *PayloadDigestQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultPayloadDigestOrder.Field { - b.Comma().Ident(DefaultPayloadDigestOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultPayloadDigestOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -1109,7 +1004,7 @@ func (pd *PayloadDigestQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newPayloadDigestPager(opts) + pager, err := newPayloadDigestPager(opts, last != nil) if err != nil { return nil, err } @@ -1117,27 +1012,23 @@ func (pd *PayloadDigestQuery) Paginate( return nil, err } conn := &PayloadDigestConnection{Edges: []*PayloadDigestEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = pd.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = pd.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := pd.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if pd, err = pager.applyCursors(pd, after, before); err != nil { + return nil, err } - - pd = pager.applyCursors(pd, after, before) - pd = pager.applyOrder(pd, last != nil) if limit := paginateLimit(first, last); limit != 0 { pd.Limit(limit) } @@ -1146,10 +1037,10 @@ func (pd *PayloadDigestQuery) Paginate( return nil, err } } - + pd = pager.applyOrder(pd) nodes, err := pd.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -1157,7 +1048,10 @@ func (pd *PayloadDigestQuery) Paginate( // PayloadDigestOrderField defines the ordering field of PayloadDigest. type PayloadDigestOrderField struct { - field string + // Value extracts the ordering value from the given PayloadDigest. + Value func(*PayloadDigest) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) payloaddigest.OrderOption toCursor func(*PayloadDigest) Cursor } @@ -1169,9 +1063,13 @@ type PayloadDigestOrder struct { // DefaultPayloadDigestOrder is the default ordering of PayloadDigest. var DefaultPayloadDigestOrder = &PayloadDigestOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &PayloadDigestOrderField{ - field: payloaddigest.FieldID, + Value: func(pd *PayloadDigest) (ent.Value, error) { + return pd.ID, nil + }, + column: payloaddigest.FieldID, + toTerm: payloaddigest.ByID, toCursor: func(pd *PayloadDigest) Cursor { return Cursor{ID: pd.ID} }, @@ -1273,12 +1171,13 @@ func WithSignatureFilter(filter func(*SignatureQuery) (*SignatureQuery, error)) } type signaturePager struct { - order *SignatureOrder - filter func(*SignatureQuery) (*SignatureQuery, error) + reverse bool + order *SignatureOrder + filter func(*SignatureQuery) (*SignatureQuery, error) } -func newSignaturePager(opts []SignaturePaginateOption) (*signaturePager, error) { - pager := &signaturePager{} +func newSignaturePager(opts []SignaturePaginateOption, reverse bool) (*signaturePager, error) { + pager := &signaturePager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -1301,37 +1200,44 @@ func (p *signaturePager) toCursor(s *Signature) Cursor { return p.order.Field.toCursor(s) } -func (p *signaturePager) applyCursors(query *SignatureQuery, after, before *Cursor) *SignatureQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultSignatureOrder.Field.field, - ) { +func (p *signaturePager) applyCursors(query *SignatureQuery, after, before *Cursor) (*SignatureQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultSignatureOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *signaturePager) applyOrder(query *SignatureQuery, reverse bool) *SignatureQuery { +func (p *signaturePager) applyOrder(query *SignatureQuery) *SignatureQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultSignatureOrder.Field { - query = query.Order(direction.orderFunc(DefaultSignatureOrder.Field.field)) + query = query.Order(DefaultSignatureOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *signaturePager) orderExpr(reverse bool) sql.Querier { +func (p *signaturePager) orderExpr(query *SignatureQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultSignatureOrder.Field { - b.Comma().Ident(DefaultSignatureOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultSignatureOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -1344,7 +1250,7 @@ func (s *SignatureQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newSignaturePager(opts) + pager, err := newSignaturePager(opts, last != nil) if err != nil { return nil, err } @@ -1352,27 +1258,23 @@ func (s *SignatureQuery) Paginate( return nil, err } conn := &SignatureConnection{Edges: []*SignatureEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = s.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = s.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := s.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if s, err = pager.applyCursors(s, after, before); err != nil { + return nil, err } - - s = pager.applyCursors(s, after, before) - s = pager.applyOrder(s, last != nil) if limit := paginateLimit(first, last); limit != 0 { s.Limit(limit) } @@ -1381,10 +1283,10 @@ func (s *SignatureQuery) Paginate( return nil, err } } - + s = pager.applyOrder(s) nodes, err := s.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -1392,7 +1294,10 @@ func (s *SignatureQuery) Paginate( // SignatureOrderField defines the ordering field of Signature. type SignatureOrderField struct { - field string + // Value extracts the ordering value from the given Signature. + Value func(*Signature) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) signature.OrderOption toCursor func(*Signature) Cursor } @@ -1404,9 +1309,13 @@ type SignatureOrder struct { // DefaultSignatureOrder is the default ordering of Signature. var DefaultSignatureOrder = &SignatureOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &SignatureOrderField{ - field: signature.FieldID, + Value: func(s *Signature) (ent.Value, error) { + return s.ID, nil + }, + column: signature.FieldID, + toTerm: signature.ByID, toCursor: func(s *Signature) Cursor { return Cursor{ID: s.ID} }, @@ -1508,12 +1417,13 @@ func WithStatementFilter(filter func(*StatementQuery) (*StatementQuery, error)) } type statementPager struct { - order *StatementOrder - filter func(*StatementQuery) (*StatementQuery, error) + reverse bool + order *StatementOrder + filter func(*StatementQuery) (*StatementQuery, error) } -func newStatementPager(opts []StatementPaginateOption) (*statementPager, error) { - pager := &statementPager{} +func newStatementPager(opts []StatementPaginateOption, reverse bool) (*statementPager, error) { + pager := &statementPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -1536,37 +1446,44 @@ func (p *statementPager) toCursor(s *Statement) Cursor { return p.order.Field.toCursor(s) } -func (p *statementPager) applyCursors(query *StatementQuery, after, before *Cursor) *StatementQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultStatementOrder.Field.field, - ) { +func (p *statementPager) applyCursors(query *StatementQuery, after, before *Cursor) (*StatementQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultStatementOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *statementPager) applyOrder(query *StatementQuery, reverse bool) *StatementQuery { +func (p *statementPager) applyOrder(query *StatementQuery) *StatementQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultStatementOrder.Field { - query = query.Order(direction.orderFunc(DefaultStatementOrder.Field.field)) + query = query.Order(DefaultStatementOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *statementPager) orderExpr(reverse bool) sql.Querier { +func (p *statementPager) orderExpr(query *StatementQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultStatementOrder.Field { - b.Comma().Ident(DefaultStatementOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultStatementOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -1579,7 +1496,7 @@ func (s *StatementQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newStatementPager(opts) + pager, err := newStatementPager(opts, last != nil) if err != nil { return nil, err } @@ -1587,27 +1504,23 @@ func (s *StatementQuery) Paginate( return nil, err } conn := &StatementConnection{Edges: []*StatementEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = s.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = s.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := s.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if s, err = pager.applyCursors(s, after, before); err != nil { + return nil, err } - - s = pager.applyCursors(s, after, before) - s = pager.applyOrder(s, last != nil) if limit := paginateLimit(first, last); limit != 0 { s.Limit(limit) } @@ -1616,10 +1529,10 @@ func (s *StatementQuery) Paginate( return nil, err } } - + s = pager.applyOrder(s) nodes, err := s.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -1627,7 +1540,10 @@ func (s *StatementQuery) Paginate( // StatementOrderField defines the ordering field of Statement. type StatementOrderField struct { - field string + // Value extracts the ordering value from the given Statement. + Value func(*Statement) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) statement.OrderOption toCursor func(*Statement) Cursor } @@ -1639,9 +1555,13 @@ type StatementOrder struct { // DefaultStatementOrder is the default ordering of Statement. var DefaultStatementOrder = &StatementOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &StatementOrderField{ - field: statement.FieldID, + Value: func(s *Statement) (ent.Value, error) { + return s.ID, nil + }, + column: statement.FieldID, + toTerm: statement.ByID, toCursor: func(s *Statement) Cursor { return Cursor{ID: s.ID} }, @@ -1743,12 +1663,13 @@ func WithSubjectFilter(filter func(*SubjectQuery) (*SubjectQuery, error)) Subjec } type subjectPager struct { - order *SubjectOrder - filter func(*SubjectQuery) (*SubjectQuery, error) + reverse bool + order *SubjectOrder + filter func(*SubjectQuery) (*SubjectQuery, error) } -func newSubjectPager(opts []SubjectPaginateOption) (*subjectPager, error) { - pager := &subjectPager{} +func newSubjectPager(opts []SubjectPaginateOption, reverse bool) (*subjectPager, error) { + pager := &subjectPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -1771,37 +1692,44 @@ func (p *subjectPager) toCursor(s *Subject) Cursor { return p.order.Field.toCursor(s) } -func (p *subjectPager) applyCursors(query *SubjectQuery, after, before *Cursor) *SubjectQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultSubjectOrder.Field.field, - ) { +func (p *subjectPager) applyCursors(query *SubjectQuery, after, before *Cursor) (*SubjectQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultSubjectOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *subjectPager) applyOrder(query *SubjectQuery, reverse bool) *SubjectQuery { +func (p *subjectPager) applyOrder(query *SubjectQuery) *SubjectQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultSubjectOrder.Field { - query = query.Order(direction.orderFunc(DefaultSubjectOrder.Field.field)) + query = query.Order(DefaultSubjectOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *subjectPager) orderExpr(reverse bool) sql.Querier { +func (p *subjectPager) orderExpr(query *SubjectQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultSubjectOrder.Field { - b.Comma().Ident(DefaultSubjectOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultSubjectOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -1814,7 +1742,7 @@ func (s *SubjectQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newSubjectPager(opts) + pager, err := newSubjectPager(opts, last != nil) if err != nil { return nil, err } @@ -1822,27 +1750,23 @@ func (s *SubjectQuery) Paginate( return nil, err } conn := &SubjectConnection{Edges: []*SubjectEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = s.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = s.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := s.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if s, err = pager.applyCursors(s, after, before); err != nil { + return nil, err } - - s = pager.applyCursors(s, after, before) - s = pager.applyOrder(s, last != nil) if limit := paginateLimit(first, last); limit != 0 { s.Limit(limit) } @@ -1851,10 +1775,10 @@ func (s *SubjectQuery) Paginate( return nil, err } } - + s = pager.applyOrder(s) nodes, err := s.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -1862,7 +1786,10 @@ func (s *SubjectQuery) Paginate( // SubjectOrderField defines the ordering field of Subject. type SubjectOrderField struct { - field string + // Value extracts the ordering value from the given Subject. + Value func(*Subject) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) subject.OrderOption toCursor func(*Subject) Cursor } @@ -1874,9 +1801,13 @@ type SubjectOrder struct { // DefaultSubjectOrder is the default ordering of Subject. var DefaultSubjectOrder = &SubjectOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &SubjectOrderField{ - field: subject.FieldID, + Value: func(s *Subject) (ent.Value, error) { + return s.ID, nil + }, + column: subject.FieldID, + toTerm: subject.ByID, toCursor: func(s *Subject) Cursor { return Cursor{ID: s.ID} }, @@ -1978,12 +1909,13 @@ func WithSubjectDigestFilter(filter func(*SubjectDigestQuery) (*SubjectDigestQue } type subjectdigestPager struct { - order *SubjectDigestOrder - filter func(*SubjectDigestQuery) (*SubjectDigestQuery, error) + reverse bool + order *SubjectDigestOrder + filter func(*SubjectDigestQuery) (*SubjectDigestQuery, error) } -func newSubjectDigestPager(opts []SubjectDigestPaginateOption) (*subjectdigestPager, error) { - pager := &subjectdigestPager{} +func newSubjectDigestPager(opts []SubjectDigestPaginateOption, reverse bool) (*subjectdigestPager, error) { + pager := &subjectdigestPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -2006,37 +1938,44 @@ func (p *subjectdigestPager) toCursor(sd *SubjectDigest) Cursor { return p.order.Field.toCursor(sd) } -func (p *subjectdigestPager) applyCursors(query *SubjectDigestQuery, after, before *Cursor) *SubjectDigestQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultSubjectDigestOrder.Field.field, - ) { +func (p *subjectdigestPager) applyCursors(query *SubjectDigestQuery, after, before *Cursor) (*SubjectDigestQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultSubjectDigestOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *subjectdigestPager) applyOrder(query *SubjectDigestQuery, reverse bool) *SubjectDigestQuery { +func (p *subjectdigestPager) applyOrder(query *SubjectDigestQuery) *SubjectDigestQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultSubjectDigestOrder.Field { - query = query.Order(direction.orderFunc(DefaultSubjectDigestOrder.Field.field)) + query = query.Order(DefaultSubjectDigestOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *subjectdigestPager) orderExpr(reverse bool) sql.Querier { +func (p *subjectdigestPager) orderExpr(query *SubjectDigestQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultSubjectDigestOrder.Field { - b.Comma().Ident(DefaultSubjectDigestOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultSubjectDigestOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -2049,7 +1988,7 @@ func (sd *SubjectDigestQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newSubjectDigestPager(opts) + pager, err := newSubjectDigestPager(opts, last != nil) if err != nil { return nil, err } @@ -2057,27 +1996,23 @@ func (sd *SubjectDigestQuery) Paginate( return nil, err } conn := &SubjectDigestConnection{Edges: []*SubjectDigestEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = sd.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = sd.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := sd.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if sd, err = pager.applyCursors(sd, after, before); err != nil { + return nil, err } - - sd = pager.applyCursors(sd, after, before) - sd = pager.applyOrder(sd, last != nil) if limit := paginateLimit(first, last); limit != 0 { sd.Limit(limit) } @@ -2086,10 +2021,10 @@ func (sd *SubjectDigestQuery) Paginate( return nil, err } } - + sd = pager.applyOrder(sd) nodes, err := sd.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -2097,7 +2032,10 @@ func (sd *SubjectDigestQuery) Paginate( // SubjectDigestOrderField defines the ordering field of SubjectDigest. type SubjectDigestOrderField struct { - field string + // Value extracts the ordering value from the given SubjectDigest. + Value func(*SubjectDigest) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) subjectdigest.OrderOption toCursor func(*SubjectDigest) Cursor } @@ -2109,9 +2047,13 @@ type SubjectDigestOrder struct { // DefaultSubjectDigestOrder is the default ordering of SubjectDigest. var DefaultSubjectDigestOrder = &SubjectDigestOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &SubjectDigestOrderField{ - field: subjectdigest.FieldID, + Value: func(sd *SubjectDigest) (ent.Value, error) { + return sd.ID, nil + }, + column: subjectdigest.FieldID, + toTerm: subjectdigest.ByID, toCursor: func(sd *SubjectDigest) Cursor { return Cursor{ID: sd.ID} }, @@ -2213,12 +2155,13 @@ func WithTimestampFilter(filter func(*TimestampQuery) (*TimestampQuery, error)) } type timestampPager struct { - order *TimestampOrder - filter func(*TimestampQuery) (*TimestampQuery, error) + reverse bool + order *TimestampOrder + filter func(*TimestampQuery) (*TimestampQuery, error) } -func newTimestampPager(opts []TimestampPaginateOption) (*timestampPager, error) { - pager := ×tampPager{} +func newTimestampPager(opts []TimestampPaginateOption, reverse bool) (*timestampPager, error) { + pager := ×tampPager{reverse: reverse} for _, opt := range opts { if err := opt(pager); err != nil { return nil, err @@ -2241,37 +2184,44 @@ func (p *timestampPager) toCursor(t *Timestamp) Cursor { return p.order.Field.toCursor(t) } -func (p *timestampPager) applyCursors(query *TimestampQuery, after, before *Cursor) *TimestampQuery { - for _, predicate := range cursorsToPredicates( - p.order.Direction, after, before, - p.order.Field.field, DefaultTimestampOrder.Field.field, - ) { +func (p *timestampPager) applyCursors(query *TimestampQuery, after, before *Cursor) (*TimestampQuery, error) { + direction := p.order.Direction + if p.reverse { + direction = direction.Reverse() + } + for _, predicate := range entgql.CursorsPredicate(after, before, DefaultTimestampOrder.Field.column, p.order.Field.column, direction) { query = query.Where(predicate) } - return query + return query, nil } -func (p *timestampPager) applyOrder(query *TimestampQuery, reverse bool) *TimestampQuery { +func (p *timestampPager) applyOrder(query *TimestampQuery) *TimestampQuery { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() } - query = query.Order(direction.orderFunc(p.order.Field.field)) + query = query.Order(p.order.Field.toTerm(direction.OrderTermOption())) if p.order.Field != DefaultTimestampOrder.Field { - query = query.Order(direction.orderFunc(DefaultTimestampOrder.Field.field)) + query = query.Order(DefaultTimestampOrder.Field.toTerm(direction.OrderTermOption())) + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return query } -func (p *timestampPager) orderExpr(reverse bool) sql.Querier { +func (p *timestampPager) orderExpr(query *TimestampQuery) sql.Querier { direction := p.order.Direction - if reverse { - direction = direction.reverse() + if p.reverse { + direction = direction.Reverse() + } + if len(query.ctx.Fields) > 0 { + query.ctx.AppendFieldOnce(p.order.Field.column) } return sql.ExprFunc(func(b *sql.Builder) { - b.Ident(p.order.Field.field).Pad().WriteString(string(direction)) + b.Ident(p.order.Field.column).Pad().WriteString(string(direction)) if p.order.Field != DefaultTimestampOrder.Field { - b.Comma().Ident(DefaultTimestampOrder.Field.field).Pad().WriteString(string(direction)) + b.Comma().Ident(DefaultTimestampOrder.Field.column).Pad().WriteString(string(direction)) } }) } @@ -2284,7 +2234,7 @@ func (t *TimestampQuery) Paginate( if err := validateFirstLast(first, last); err != nil { return nil, err } - pager, err := newTimestampPager(opts) + pager, err := newTimestampPager(opts, last != nil) if err != nil { return nil, err } @@ -2292,27 +2242,23 @@ func (t *TimestampQuery) Paginate( return nil, err } conn := &TimestampConnection{Edges: []*TimestampEdge{}} - if !hasCollectedField(ctx, edgesField) || first != nil && *first == 0 || last != nil && *last == 0 { - if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { - if conn.TotalCount, err = t.Count(ctx); err != nil { + ignoredEdges := !hasCollectedField(ctx, edgesField) + if hasCollectedField(ctx, totalCountField) || hasCollectedField(ctx, pageInfoField) { + hasPagination := after != nil || first != nil || before != nil || last != nil + if hasPagination || ignoredEdges { + if conn.TotalCount, err = t.Clone().Count(ctx); err != nil { return nil, err } conn.PageInfo.HasNextPage = first != nil && conn.TotalCount > 0 conn.PageInfo.HasPreviousPage = last != nil && conn.TotalCount > 0 } + } + if ignoredEdges || (first != nil && *first == 0) || (last != nil && *last == 0) { return conn, nil } - - if (after != nil || first != nil || before != nil || last != nil) && hasCollectedField(ctx, totalCountField) { - count, err := t.Clone().Count(ctx) - if err != nil { - return nil, err - } - conn.TotalCount = count + if t, err = pager.applyCursors(t, after, before); err != nil { + return nil, err } - - t = pager.applyCursors(t, after, before) - t = pager.applyOrder(t, last != nil) if limit := paginateLimit(first, last); limit != 0 { t.Limit(limit) } @@ -2321,10 +2267,10 @@ func (t *TimestampQuery) Paginate( return nil, err } } - + t = pager.applyOrder(t) nodes, err := t.All(ctx) - if err != nil || len(nodes) == 0 { - return conn, err + if err != nil { + return nil, err } conn.build(nodes, pager, after, first, before, last) return conn, nil @@ -2332,7 +2278,10 @@ func (t *TimestampQuery) Paginate( // TimestampOrderField defines the ordering field of Timestamp. type TimestampOrderField struct { - field string + // Value extracts the ordering value from the given Timestamp. + Value func(*Timestamp) (ent.Value, error) + column string // field or computed. + toTerm func(...sql.OrderTermOption) timestamp.OrderOption toCursor func(*Timestamp) Cursor } @@ -2344,9 +2293,13 @@ type TimestampOrder struct { // DefaultTimestampOrder is the default ordering of Timestamp. var DefaultTimestampOrder = &TimestampOrder{ - Direction: OrderDirectionAsc, + Direction: entgql.OrderDirectionAsc, Field: &TimestampOrderField{ - field: timestamp.FieldID, + Value: func(t *Timestamp) (ent.Value, error) { + return t.ID, nil + }, + column: timestamp.FieldID, + toTerm: timestamp.ByID, toCursor: func(t *Timestamp) Cursor { return Cursor{ID: t.ID} }, diff --git a/ent/hook/hook.go b/ent/hook/hook.go index 8453d45a..43da56b3 100644 --- a/ent/hook/hook.go +++ b/ent/hook/hook.go @@ -15,11 +15,10 @@ type AttestationFunc func(context.Context, *ent.AttestationMutation) (ent.Value, // Mutate calls f(ctx, m). func (f AttestationFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.AttestationMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AttestationMutation", m) + if mv, ok := m.(*ent.AttestationMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AttestationMutation", m) } // The AttestationCollectionFunc type is an adapter to allow the use of ordinary @@ -28,11 +27,10 @@ type AttestationCollectionFunc func(context.Context, *ent.AttestationCollectionM // Mutate calls f(ctx, m). func (f AttestationCollectionFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.AttestationCollectionMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AttestationCollectionMutation", m) + if mv, ok := m.(*ent.AttestationCollectionMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AttestationCollectionMutation", m) } // The DsseFunc type is an adapter to allow the use of ordinary @@ -41,11 +39,10 @@ type DsseFunc func(context.Context, *ent.DsseMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f DsseFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.DsseMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DsseMutation", m) + if mv, ok := m.(*ent.DsseMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.DsseMutation", m) } // The PayloadDigestFunc type is an adapter to allow the use of ordinary @@ -54,11 +51,10 @@ type PayloadDigestFunc func(context.Context, *ent.PayloadDigestMutation) (ent.Va // Mutate calls f(ctx, m). func (f PayloadDigestFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.PayloadDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PayloadDigestMutation", m) + if mv, ok := m.(*ent.PayloadDigestMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.PayloadDigestMutation", m) } // The SignatureFunc type is an adapter to allow the use of ordinary @@ -67,11 +63,10 @@ type SignatureFunc func(context.Context, *ent.SignatureMutation) (ent.Value, err // Mutate calls f(ctx, m). func (f SignatureFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.SignatureMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SignatureMutation", m) + if mv, ok := m.(*ent.SignatureMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SignatureMutation", m) } // The StatementFunc type is an adapter to allow the use of ordinary @@ -80,11 +75,10 @@ type StatementFunc func(context.Context, *ent.StatementMutation) (ent.Value, err // Mutate calls f(ctx, m). func (f StatementFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.StatementMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.StatementMutation", m) + if mv, ok := m.(*ent.StatementMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.StatementMutation", m) } // The SubjectFunc type is an adapter to allow the use of ordinary @@ -93,11 +87,10 @@ type SubjectFunc func(context.Context, *ent.SubjectMutation) (ent.Value, error) // Mutate calls f(ctx, m). func (f SubjectFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.SubjectMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SubjectMutation", m) + if mv, ok := m.(*ent.SubjectMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SubjectMutation", m) } // The SubjectDigestFunc type is an adapter to allow the use of ordinary @@ -106,11 +99,10 @@ type SubjectDigestFunc func(context.Context, *ent.SubjectDigestMutation) (ent.Va // Mutate calls f(ctx, m). func (f SubjectDigestFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.SubjectDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SubjectDigestMutation", m) + if mv, ok := m.(*ent.SubjectDigestMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SubjectDigestMutation", m) } // The TimestampFunc type is an adapter to allow the use of ordinary @@ -119,11 +111,10 @@ type TimestampFunc func(context.Context, *ent.TimestampMutation) (ent.Value, err // Mutate calls f(ctx, m). func (f TimestampFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { - mv, ok := m.(*ent.TimestampMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TimestampMutation", m) + if mv, ok := m.(*ent.TimestampMutation); ok { + return f(ctx, mv) } - return f(ctx, mv) + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.TimestampMutation", m) } // Condition is a hook condition function. diff --git a/ent/mutation.go b/ent/mutation.go index 121af63b..1cff68c1 100644 --- a/ent/mutation.go +++ b/ent/mutation.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "entgo.io/ent" + "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/attestation" "github.com/testifysec/archivista/ent/attestationcollection" "github.com/testifysec/archivista/ent/dsse" @@ -19,8 +21,6 @@ import ( "github.com/testifysec/archivista/ent/subject" "github.com/testifysec/archivista/ent/subjectdigest" "github.com/testifysec/archivista/ent/timestamp" - - "entgo.io/ent" ) const ( @@ -236,11 +236,26 @@ func (m *AttestationMutation) Where(ps ...predicate.Attestation) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the AttestationMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AttestationMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Attestation, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *AttestationMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *AttestationMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Attestation). func (m *AttestationMutation) Type() string { return m.typ @@ -671,11 +686,26 @@ func (m *AttestationCollectionMutation) Where(ps ...predicate.AttestationCollect m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the AttestationCollectionMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *AttestationCollectionMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.AttestationCollection, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *AttestationCollectionMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *AttestationCollectionMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (AttestationCollection). func (m *AttestationCollectionMutation) Type() string { return m.typ @@ -1228,11 +1258,26 @@ func (m *DsseMutation) Where(ps ...predicate.Dsse) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the DsseMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *DsseMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Dsse, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *DsseMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *DsseMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Dsse). func (m *DsseMutation) Type() string { return m.typ @@ -1714,11 +1759,26 @@ func (m *PayloadDigestMutation) Where(ps ...predicate.PayloadDigest) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the PayloadDigestMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *PayloadDigestMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.PayloadDigest, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *PayloadDigestMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *PayloadDigestMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (PayloadDigest). func (m *PayloadDigestMutation) Type() string { return m.typ @@ -2203,11 +2263,26 @@ func (m *SignatureMutation) Where(ps ...predicate.Signature) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the SignatureMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SignatureMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Signature, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *SignatureMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *SignatureMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Signature). func (m *SignatureMutation) Type() string { return m.typ @@ -2740,11 +2815,26 @@ func (m *StatementMutation) Where(ps ...predicate.Statement) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the StatementMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *StatementMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Statement, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *StatementMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *StatementMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Statement). func (m *StatementMutation) Type() string { return m.typ @@ -3229,11 +3319,26 @@ func (m *SubjectMutation) Where(ps ...predicate.Subject) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the SubjectMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SubjectMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Subject, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *SubjectMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *SubjectMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Subject). func (m *SubjectMutation) Type() string { return m.typ @@ -3672,11 +3777,26 @@ func (m *SubjectDigestMutation) Where(ps ...predicate.SubjectDigest) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the SubjectDigestMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SubjectDigestMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SubjectDigest, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *SubjectDigestMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *SubjectDigestMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (SubjectDigest). func (m *SubjectDigestMutation) Type() string { return m.typ @@ -4104,11 +4224,26 @@ func (m *TimestampMutation) Where(ps ...predicate.Timestamp) { m.predicates = append(m.predicates, ps...) } +// WhereP appends storage-level predicates to the TimestampMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *TimestampMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.Timestamp, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + // Op returns the operation name. func (m *TimestampMutation) Op() Op { return m.op } +// SetOp allows setting the mutation operation. +func (m *TimestampMutation) SetOp(op Op) { + m.op = op +} + // Type returns the node type of this mutation (Timestamp). func (m *TimestampMutation) Type() string { return m.typ diff --git a/ent/payloaddigest.go b/ent/payloaddigest.go index 89b73fec..a4111824 100644 --- a/ent/payloaddigest.go +++ b/ent/payloaddigest.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/dsse" "github.com/testifysec/archivista/ent/payloaddigest" @@ -24,6 +25,7 @@ type PayloadDigest struct { // The values are being populated by the PayloadDigestQuery when eager-loading is set. Edges PayloadDigestEdges `json:"edges"` dsse_payload_digests *int + selectValues sql.SelectValues } // PayloadDigestEdges holds the relations/edges for other nodes in the graph. @@ -34,7 +36,7 @@ type PayloadDigestEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [1]bool // totalCount holds the count of the edges above. - totalCount [1]*int + totalCount [1]map[string]int } // DsseOrErr returns the Dsse value or an error if the edge @@ -62,7 +64,7 @@ func (*PayloadDigest) scanValues(columns []string) ([]any, error) { case payloaddigest.ForeignKeys[0]: // dsse_payload_digests values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type PayloadDigest", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -101,21 +103,29 @@ func (pd *PayloadDigest) assignValues(columns []string, values []any) error { pd.dsse_payload_digests = new(int) *pd.dsse_payload_digests = int(value.Int64) } + default: + pd.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the PayloadDigest. +// This includes values selected through modifiers, order, etc. +func (pd *PayloadDigest) GetValue(name string) (ent.Value, error) { + return pd.selectValues.Get(name) +} + // QueryDsse queries the "dsse" edge of the PayloadDigest entity. func (pd *PayloadDigest) QueryDsse() *DsseQuery { - return (&PayloadDigestClient{config: pd.config}).QueryDsse(pd) + return NewPayloadDigestClient(pd.config).QueryDsse(pd) } // Update returns a builder for updating this PayloadDigest. // Note that you need to call PayloadDigest.Unwrap() before calling this method if this PayloadDigest // was returned from a transaction, and the transaction was committed or rolled back. func (pd *PayloadDigest) Update() *PayloadDigestUpdateOne { - return (&PayloadDigestClient{config: pd.config}).UpdateOne(pd) + return NewPayloadDigestClient(pd.config).UpdateOne(pd) } // Unwrap unwraps the PayloadDigest entity that was returned from a transaction after it was closed, @@ -145,9 +155,3 @@ func (pd *PayloadDigest) String() string { // PayloadDigests is a parsable slice of PayloadDigest. type PayloadDigests []*PayloadDigest - -func (pd PayloadDigests) config(cfg config) { - for _i := range pd { - pd[_i].config = cfg - } -} diff --git a/ent/payloaddigest/payloaddigest.go b/ent/payloaddigest/payloaddigest.go index 85a78cd4..d2e31d7a 100644 --- a/ent/payloaddigest/payloaddigest.go +++ b/ent/payloaddigest/payloaddigest.go @@ -2,6 +2,11 @@ package payloaddigest +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the payloaddigest type in the database. Label = "payload_digest" @@ -58,3 +63,35 @@ var ( // ValueValidator is a validator for the "value" field. It is called by the builders before save. ValueValidator func(string) error ) + +// OrderOption defines the ordering options for the PayloadDigest queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByAlgorithm orders the results by the algorithm field. +func ByAlgorithm(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlgorithm, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// ByDsseField orders the results by dsse field. +func ByDsseField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDsseStep(), sql.OrderByField(field, opts...)) + } +} +func newDsseStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DsseInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), + ) +} diff --git a/ent/payloaddigest/where.go b/ent/payloaddigest/where.go index 3f8b6600..06f2bfbb 100644 --- a/ent/payloaddigest/where.go +++ b/ent/payloaddigest/where.go @@ -10,285 +10,187 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.PayloadDigest(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.PayloadDigest(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.PayloadDigest(sql.FieldLTE(FieldID, id)) } // Algorithm applies equality check predicate on the "algorithm" field. It's identical to AlgorithmEQ. func Algorithm(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldAlgorithm, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldValue, v)) } // AlgorithmEQ applies the EQ predicate on the "algorithm" field. func AlgorithmEQ(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldAlgorithm, v)) } // AlgorithmNEQ applies the NEQ predicate on the "algorithm" field. func AlgorithmNEQ(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldNEQ(FieldAlgorithm, v)) } // AlgorithmIn applies the In predicate on the "algorithm" field. func AlgorithmIn(vs ...string) predicate.PayloadDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlgorithm), v...)) - }) + return predicate.PayloadDigest(sql.FieldIn(FieldAlgorithm, vs...)) } // AlgorithmNotIn applies the NotIn predicate on the "algorithm" field. func AlgorithmNotIn(vs ...string) predicate.PayloadDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlgorithm), v...)) - }) + return predicate.PayloadDigest(sql.FieldNotIn(FieldAlgorithm, vs...)) } // AlgorithmGT applies the GT predicate on the "algorithm" field. func AlgorithmGT(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldGT(FieldAlgorithm, v)) } // AlgorithmGTE applies the GTE predicate on the "algorithm" field. func AlgorithmGTE(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldGTE(FieldAlgorithm, v)) } // AlgorithmLT applies the LT predicate on the "algorithm" field. func AlgorithmLT(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldLT(FieldAlgorithm, v)) } // AlgorithmLTE applies the LTE predicate on the "algorithm" field. func AlgorithmLTE(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldLTE(FieldAlgorithm, v)) } // AlgorithmContains applies the Contains predicate on the "algorithm" field. func AlgorithmContains(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldContains(FieldAlgorithm, v)) } // AlgorithmHasPrefix applies the HasPrefix predicate on the "algorithm" field. func AlgorithmHasPrefix(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldHasPrefix(FieldAlgorithm, v)) } // AlgorithmHasSuffix applies the HasSuffix predicate on the "algorithm" field. func AlgorithmHasSuffix(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldHasSuffix(FieldAlgorithm, v)) } // AlgorithmEqualFold applies the EqualFold predicate on the "algorithm" field. func AlgorithmEqualFold(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldEqualFold(FieldAlgorithm, v)) } // AlgorithmContainsFold applies the ContainsFold predicate on the "algorithm" field. func AlgorithmContainsFold(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAlgorithm), v)) - }) + return predicate.PayloadDigest(sql.FieldContainsFold(FieldAlgorithm, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.PayloadDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.PayloadDigest(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.PayloadDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.PayloadDigest(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.PayloadDigest(sql.FieldContainsFold(FieldValue, v)) } // HasDsse applies the HasEdge predicate on the "dsse" edge. @@ -296,7 +198,6 @@ func HasDsse() predicate.PayloadDigest { return predicate.PayloadDigest(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), ) sqlgraph.HasNeighbors(s, step) @@ -306,11 +207,7 @@ func HasDsse() predicate.PayloadDigest { // HasDsseWith applies the HasEdge predicate on the "dsse" edge with a given conditions (other predicates). func HasDsseWith(preds ...predicate.Dsse) predicate.PayloadDigest { return predicate.PayloadDigest(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), - ) + step := newDsseStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -321,32 +218,15 @@ func HasDsseWith(preds ...predicate.Dsse) predicate.PayloadDigest { // And groups predicates with the AND operator between them. func And(predicates ...predicate.PayloadDigest) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.PayloadDigest(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.PayloadDigest) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.PayloadDigest(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.PayloadDigest) predicate.PayloadDigest { - return predicate.PayloadDigest(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.PayloadDigest(sql.NotPredicates(p)) } diff --git a/ent/payloaddigest_create.go b/ent/payloaddigest_create.go index e4c09877..3d0cb992 100644 --- a/ent/payloaddigest_create.go +++ b/ent/payloaddigest_create.go @@ -58,49 +58,7 @@ func (pdc *PayloadDigestCreate) Mutation() *PayloadDigestMutation { // Save creates the PayloadDigest in the database. func (pdc *PayloadDigestCreate) Save(ctx context.Context) (*PayloadDigest, error) { - var ( - err error - node *PayloadDigest - ) - if len(pdc.hooks) == 0 { - if err = pdc.check(); err != nil { - return nil, err - } - node, err = pdc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*PayloadDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = pdc.check(); err != nil { - return nil, err - } - pdc.mutation = mutation - if node, err = pdc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(pdc.hooks) - 1; i >= 0; i-- { - if pdc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = pdc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, pdc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*PayloadDigest) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from PayloadDigestMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, pdc.sqlSave, pdc.mutation, pdc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -147,6 +105,9 @@ func (pdc *PayloadDigestCreate) check() error { } func (pdc *PayloadDigestCreate) sqlSave(ctx context.Context) (*PayloadDigest, error) { + if err := pdc.check(); err != nil { + return nil, err + } _node, _spec := pdc.createSpec() if err := sqlgraph.CreateNode(ctx, pdc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -156,34 +117,22 @@ func (pdc *PayloadDigestCreate) sqlSave(ctx context.Context) (*PayloadDigest, er } id := _spec.ID.Value.(int64) _node.ID = int(id) + pdc.mutation.id = &_node.ID + pdc.mutation.done = true return _node, nil } func (pdc *PayloadDigestCreate) createSpec() (*PayloadDigest, *sqlgraph.CreateSpec) { var ( _node = &PayloadDigest{config: pdc.config} - _spec = &sqlgraph.CreateSpec{ - Table: payloaddigest.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(payloaddigest.Table, sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt)) ) if value, ok := pdc.mutation.Algorithm(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldAlgorithm, - }) + _spec.SetField(payloaddigest.FieldAlgorithm, field.TypeString, value) _node.Algorithm = value } if value, ok := pdc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldValue, - }) + _spec.SetField(payloaddigest.FieldValue, field.TypeString, value) _node.Value = value } if nodes := pdc.mutation.DsseIDs(); len(nodes) > 0 { @@ -194,10 +143,7 @@ func (pdc *PayloadDigestCreate) createSpec() (*PayloadDigest, *sqlgraph.CreateSp Columns: []string{payloaddigest.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -212,11 +158,15 @@ func (pdc *PayloadDigestCreate) createSpec() (*PayloadDigest, *sqlgraph.CreateSp // PayloadDigestCreateBulk is the builder for creating many PayloadDigest entities in bulk. type PayloadDigestCreateBulk struct { config + err error builders []*PayloadDigestCreate } // Save creates the PayloadDigest entities in the database. func (pdcb *PayloadDigestCreateBulk) Save(ctx context.Context) ([]*PayloadDigest, error) { + if pdcb.err != nil { + return nil, pdcb.err + } specs := make([]*sqlgraph.CreateSpec, len(pdcb.builders)) nodes := make([]*PayloadDigest, len(pdcb.builders)) mutators := make([]Mutator, len(pdcb.builders)) @@ -232,8 +182,8 @@ func (pdcb *PayloadDigestCreateBulk) Save(ctx context.Context) ([]*PayloadDigest return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, pdcb.builders[i+1].mutation) } else { diff --git a/ent/payloaddigest_delete.go b/ent/payloaddigest_delete.go index a4724295..37ea6c97 100644 --- a/ent/payloaddigest_delete.go +++ b/ent/payloaddigest_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (pdd *PayloadDigestDelete) Where(ps ...predicate.PayloadDigest) *PayloadDig // Exec executes the deletion query and returns how many vertices were deleted. func (pdd *PayloadDigestDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(pdd.hooks) == 0 { - affected, err = pdd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*PayloadDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - pdd.mutation = mutation - affected, err = pdd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(pdd.hooks) - 1; i >= 0; i-- { - if pdd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = pdd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, pdd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, pdd.sqlExec, pdd.mutation, pdd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (pdd *PayloadDigestDelete) ExecX(ctx context.Context) int { } func (pdd *PayloadDigestDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: payloaddigest.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(payloaddigest.Table, sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt)) if ps := pdd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (pdd *PayloadDigestDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + pdd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type PayloadDigestDeleteOne struct { pdd *PayloadDigestDelete } +// Where appends a list predicates to the PayloadDigestDelete builder. +func (pddo *PayloadDigestDeleteOne) Where(ps ...predicate.PayloadDigest) *PayloadDigestDeleteOne { + pddo.pdd.mutation.Where(ps...) + return pddo +} + // Exec executes the deletion query. func (pddo *PayloadDigestDeleteOne) Exec(ctx context.Context) error { n, err := pddo.pdd.Exec(ctx) @@ -111,5 +82,7 @@ func (pddo *PayloadDigestDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (pddo *PayloadDigestDeleteOne) ExecX(ctx context.Context) { - pddo.pdd.ExecX(ctx) + if err := pddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/payloaddigest_query.go b/ent/payloaddigest_query.go index c02c8a94..7ad8faf6 100644 --- a/ent/payloaddigest_query.go +++ b/ent/payloaddigest_query.go @@ -18,11 +18,9 @@ import ( // PayloadDigestQuery is the builder for querying PayloadDigest entities. type PayloadDigestQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []payloaddigest.OrderOption + inters []Interceptor predicates []predicate.PayloadDigest withDsse *DsseQuery withFKs bool @@ -39,34 +37,34 @@ func (pdq *PayloadDigestQuery) Where(ps ...predicate.PayloadDigest) *PayloadDige return pdq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (pdq *PayloadDigestQuery) Limit(limit int) *PayloadDigestQuery { - pdq.limit = &limit + pdq.ctx.Limit = &limit return pdq } -// Offset adds an offset step to the query. +// Offset to start from. func (pdq *PayloadDigestQuery) Offset(offset int) *PayloadDigestQuery { - pdq.offset = &offset + pdq.ctx.Offset = &offset return pdq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (pdq *PayloadDigestQuery) Unique(unique bool) *PayloadDigestQuery { - pdq.unique = &unique + pdq.ctx.Unique = &unique return pdq } -// Order adds an order step to the query. -func (pdq *PayloadDigestQuery) Order(o ...OrderFunc) *PayloadDigestQuery { +// Order specifies how the records should be ordered. +func (pdq *PayloadDigestQuery) Order(o ...payloaddigest.OrderOption) *PayloadDigestQuery { pdq.order = append(pdq.order, o...) return pdq } // QueryDsse chains the current query on the "dsse" edge. func (pdq *PayloadDigestQuery) QueryDsse() *DsseQuery { - query := &DsseQuery{config: pdq.config} + query := (&DsseClient{config: pdq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := pdq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (pdq *PayloadDigestQuery) QueryDsse() *DsseQuery { // First returns the first PayloadDigest entity from the query. // Returns a *NotFoundError when no PayloadDigest was found. func (pdq *PayloadDigestQuery) First(ctx context.Context) (*PayloadDigest, error) { - nodes, err := pdq.Limit(1).All(ctx) + nodes, err := pdq.Limit(1).All(setContextOp(ctx, pdq.ctx, "First")) if err != nil { return nil, err } @@ -112,7 +110,7 @@ func (pdq *PayloadDigestQuery) FirstX(ctx context.Context) *PayloadDigest { // Returns a *NotFoundError when no PayloadDigest ID was found. func (pdq *PayloadDigestQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pdq.Limit(1).IDs(ctx); err != nil { + if ids, err = pdq.Limit(1).IDs(setContextOp(ctx, pdq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -135,7 +133,7 @@ func (pdq *PayloadDigestQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one PayloadDigest entity is found. // Returns a *NotFoundError when no PayloadDigest entities are found. func (pdq *PayloadDigestQuery) Only(ctx context.Context) (*PayloadDigest, error) { - nodes, err := pdq.Limit(2).All(ctx) + nodes, err := pdq.Limit(2).All(setContextOp(ctx, pdq.ctx, "Only")) if err != nil { return nil, err } @@ -163,7 +161,7 @@ func (pdq *PayloadDigestQuery) OnlyX(ctx context.Context) *PayloadDigest { // Returns a *NotFoundError when no entities are found. func (pdq *PayloadDigestQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = pdq.Limit(2).IDs(ctx); err != nil { + if ids, err = pdq.Limit(2).IDs(setContextOp(ctx, pdq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -188,10 +186,12 @@ func (pdq *PayloadDigestQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of PayloadDigests. func (pdq *PayloadDigestQuery) All(ctx context.Context) ([]*PayloadDigest, error) { + ctx = setContextOp(ctx, pdq.ctx, "All") if err := pdq.prepareQuery(ctx); err != nil { return nil, err } - return pdq.sqlAll(ctx) + qr := querierAll[[]*PayloadDigest, *PayloadDigestQuery]() + return withInterceptors[[]*PayloadDigest](ctx, pdq, qr, pdq.inters) } // AllX is like All, but panics if an error occurs. @@ -204,9 +204,12 @@ func (pdq *PayloadDigestQuery) AllX(ctx context.Context) []*PayloadDigest { } // IDs executes the query and returns a list of PayloadDigest IDs. -func (pdq *PayloadDigestQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := pdq.Select(payloaddigest.FieldID).Scan(ctx, &ids); err != nil { +func (pdq *PayloadDigestQuery) IDs(ctx context.Context) (ids []int, err error) { + if pdq.ctx.Unique == nil && pdq.path != nil { + pdq.Unique(true) + } + ctx = setContextOp(ctx, pdq.ctx, "IDs") + if err = pdq.Select(payloaddigest.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -223,10 +226,11 @@ func (pdq *PayloadDigestQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (pdq *PayloadDigestQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, pdq.ctx, "Count") if err := pdq.prepareQuery(ctx); err != nil { return 0, err } - return pdq.sqlCount(ctx) + return withInterceptors[int](ctx, pdq, querierCount[*PayloadDigestQuery](), pdq.inters) } // CountX is like Count, but panics if an error occurs. @@ -240,10 +244,15 @@ func (pdq *PayloadDigestQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (pdq *PayloadDigestQuery) Exist(ctx context.Context) (bool, error) { - if err := pdq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, pdq.ctx, "Exist") + switch _, err := pdq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return pdq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -263,22 +272,21 @@ func (pdq *PayloadDigestQuery) Clone() *PayloadDigestQuery { } return &PayloadDigestQuery{ config: pdq.config, - limit: pdq.limit, - offset: pdq.offset, - order: append([]OrderFunc{}, pdq.order...), + ctx: pdq.ctx.Clone(), + order: append([]payloaddigest.OrderOption{}, pdq.order...), + inters: append([]Interceptor{}, pdq.inters...), predicates: append([]predicate.PayloadDigest{}, pdq.predicates...), withDsse: pdq.withDsse.Clone(), // clone intermediate query. - sql: pdq.sql.Clone(), - path: pdq.path, - unique: pdq.unique, + sql: pdq.sql.Clone(), + path: pdq.path, } } // WithDsse tells the query-builder to eager-load the nodes that are connected to // the "dsse" edge. The optional arguments are used to configure the query builder of the edge. func (pdq *PayloadDigestQuery) WithDsse(opts ...func(*DsseQuery)) *PayloadDigestQuery { - query := &DsseQuery{config: pdq.config} + query := (&DsseClient{config: pdq.config}).Query() for _, opt := range opts { opt(query) } @@ -301,16 +309,11 @@ func (pdq *PayloadDigestQuery) WithDsse(opts ...func(*DsseQuery)) *PayloadDigest // Aggregate(ent.Count()). // Scan(ctx, &v) func (pdq *PayloadDigestQuery) GroupBy(field string, fields ...string) *PayloadDigestGroupBy { - grbuild := &PayloadDigestGroupBy{config: pdq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := pdq.prepareQuery(ctx); err != nil { - return nil, err - } - return pdq.sqlQuery(ctx), nil - } + pdq.ctx.Fields = append([]string{field}, fields...) + grbuild := &PayloadDigestGroupBy{build: pdq} + grbuild.flds = &pdq.ctx.Fields grbuild.label = payloaddigest.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -327,15 +330,30 @@ func (pdq *PayloadDigestQuery) GroupBy(field string, fields ...string) *PayloadD // Select(payloaddigest.FieldAlgorithm). // Scan(ctx, &v) func (pdq *PayloadDigestQuery) Select(fields ...string) *PayloadDigestSelect { - pdq.fields = append(pdq.fields, fields...) - selbuild := &PayloadDigestSelect{PayloadDigestQuery: pdq} - selbuild.label = payloaddigest.Label - selbuild.flds, selbuild.scan = &pdq.fields, selbuild.Scan - return selbuild + pdq.ctx.Fields = append(pdq.ctx.Fields, fields...) + sbuild := &PayloadDigestSelect{PayloadDigestQuery: pdq} + sbuild.label = payloaddigest.Label + sbuild.flds, sbuild.scan = &pdq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a PayloadDigestSelect configured with the given aggregations. +func (pdq *PayloadDigestQuery) Aggregate(fns ...AggregateFunc) *PayloadDigestSelect { + return pdq.Select().Aggregate(fns...) } func (pdq *PayloadDigestQuery) prepareQuery(ctx context.Context) error { - for _, f := range pdq.fields { + for _, inter := range pdq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, pdq); err != nil { + return err + } + } + } + for _, f := range pdq.ctx.Fields { if !payloaddigest.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -413,6 +431,9 @@ func (pdq *PayloadDigestQuery) loadDsse(ctx context.Context, query *DsseQuery, n } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(dsse.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -435,41 +456,22 @@ func (pdq *PayloadDigestQuery) sqlCount(ctx context.Context) (int, error) { if len(pdq.modifiers) > 0 { _spec.Modifiers = pdq.modifiers } - _spec.Node.Columns = pdq.fields - if len(pdq.fields) > 0 { - _spec.Unique = pdq.unique != nil && *pdq.unique + _spec.Node.Columns = pdq.ctx.Fields + if len(pdq.ctx.Fields) > 0 { + _spec.Unique = pdq.ctx.Unique != nil && *pdq.ctx.Unique } return sqlgraph.CountNodes(ctx, pdq.driver, _spec) } -func (pdq *PayloadDigestQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := pdq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (pdq *PayloadDigestQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: payloaddigest.Table, - Columns: payloaddigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, - }, - From: pdq.sql, - Unique: true, - } - if unique := pdq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(payloaddigest.Table, payloaddigest.Columns, sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt)) + _spec.From = pdq.sql + if unique := pdq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if pdq.path != nil { + _spec.Unique = true } - if fields := pdq.fields; len(fields) > 0 { + if fields := pdq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, payloaddigest.FieldID) for i := range fields { @@ -485,10 +487,10 @@ func (pdq *PayloadDigestQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := pdq.limit; limit != nil { + if limit := pdq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := pdq.offset; offset != nil { + if offset := pdq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := pdq.order; len(ps) > 0 { @@ -504,7 +506,7 @@ func (pdq *PayloadDigestQuery) querySpec() *sqlgraph.QuerySpec { func (pdq *PayloadDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(pdq.driver.Dialect()) t1 := builder.Table(payloaddigest.Table) - columns := pdq.fields + columns := pdq.ctx.Fields if len(columns) == 0 { columns = payloaddigest.Columns } @@ -513,7 +515,7 @@ func (pdq *PayloadDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = pdq.sql selector.Select(selector.Columns(columns...)...) } - if pdq.unique != nil && *pdq.unique { + if pdq.ctx.Unique != nil && *pdq.ctx.Unique { selector.Distinct() } for _, p := range pdq.predicates { @@ -522,12 +524,12 @@ func (pdq *PayloadDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range pdq.order { p(selector) } - if offset := pdq.offset; offset != nil { + if offset := pdq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := pdq.limit; limit != nil { + if limit := pdq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -535,13 +537,8 @@ func (pdq *PayloadDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { // PayloadDigestGroupBy is the group-by builder for PayloadDigest entities. type PayloadDigestGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *PayloadDigestQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -550,74 +547,77 @@ func (pdgb *PayloadDigestGroupBy) Aggregate(fns ...AggregateFunc) *PayloadDigest return pdgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (pdgb *PayloadDigestGroupBy) Scan(ctx context.Context, v any) error { - query, err := pdgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, pdgb.build.ctx, "GroupBy") + if err := pdgb.build.prepareQuery(ctx); err != nil { return err } - pdgb.sql = query - return pdgb.sqlScan(ctx, v) + return scanWithInterceptors[*PayloadDigestQuery, *PayloadDigestGroupBy](ctx, pdgb.build, pdgb, pdgb.build.inters, v) } -func (pdgb *PayloadDigestGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range pdgb.fields { - if !payloaddigest.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (pdgb *PayloadDigestGroupBy) sqlScan(ctx context.Context, root *PayloadDigestQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(pdgb.fns)) + for _, fn := range pdgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*pdgb.flds)+len(pdgb.fns)) + for _, f := range *pdgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := pdgb.sqlQuery() + selector.GroupBy(selector.Columns(*pdgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := pdgb.driver.Query(ctx, query, args, rows); err != nil { + if err := pdgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (pdgb *PayloadDigestGroupBy) sqlQuery() *sql.Selector { - selector := pdgb.sql.Select() - aggregation := make([]string, 0, len(pdgb.fns)) - for _, fn := range pdgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(pdgb.fields)+len(pdgb.fns)) - for _, f := range pdgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(pdgb.fields...)...) -} - // PayloadDigestSelect is the builder for selecting fields of PayloadDigest entities. type PayloadDigestSelect struct { *PayloadDigestQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (pds *PayloadDigestSelect) Aggregate(fns ...AggregateFunc) *PayloadDigestSelect { + pds.fns = append(pds.fns, fns...) + return pds } // Scan applies the selector query and scans the result into the given value. func (pds *PayloadDigestSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, pds.ctx, "Select") if err := pds.prepareQuery(ctx); err != nil { return err } - pds.sql = pds.PayloadDigestQuery.sqlQuery(ctx) - return pds.sqlScan(ctx, v) + return scanWithInterceptors[*PayloadDigestQuery, *PayloadDigestSelect](ctx, pds.PayloadDigestQuery, pds, pds.inters, v) } -func (pds *PayloadDigestSelect) sqlScan(ctx context.Context, v any) error { +func (pds *PayloadDigestSelect) sqlScan(ctx context.Context, root *PayloadDigestQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(pds.fns)) + for _, fn := range pds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*pds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := pds.sql.Query() + query, args := selector.Query() if err := pds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/payloaddigest_update.go b/ent/payloaddigest_update.go index 43e856cb..91da51ae 100644 --- a/ent/payloaddigest_update.go +++ b/ent/payloaddigest_update.go @@ -72,40 +72,7 @@ func (pdu *PayloadDigestUpdate) ClearDsse() *PayloadDigestUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (pdu *PayloadDigestUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(pdu.hooks) == 0 { - if err = pdu.check(); err != nil { - return 0, err - } - affected, err = pdu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*PayloadDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = pdu.check(); err != nil { - return 0, err - } - pdu.mutation = mutation - affected, err = pdu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(pdu.hooks) - 1; i >= 0; i-- { - if pdu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = pdu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, pdu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, pdu.sqlSave, pdu.mutation, pdu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -146,16 +113,10 @@ func (pdu *PayloadDigestUpdate) check() error { } func (pdu *PayloadDigestUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: payloaddigest.Table, - Columns: payloaddigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, - }, + if err := pdu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(payloaddigest.Table, payloaddigest.Columns, sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt)) if ps := pdu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -164,18 +125,10 @@ func (pdu *PayloadDigestUpdate) sqlSave(ctx context.Context) (n int, err error) } } if value, ok := pdu.mutation.Algorithm(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldAlgorithm, - }) + _spec.SetField(payloaddigest.FieldAlgorithm, field.TypeString, value) } if value, ok := pdu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldValue, - }) + _spec.SetField(payloaddigest.FieldValue, field.TypeString, value) } if pdu.mutation.DsseCleared() { edge := &sqlgraph.EdgeSpec{ @@ -185,10 +138,7 @@ func (pdu *PayloadDigestUpdate) sqlSave(ctx context.Context) (n int, err error) Columns: []string{payloaddigest.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -201,10 +151,7 @@ func (pdu *PayloadDigestUpdate) sqlSave(ctx context.Context) (n int, err error) Columns: []string{payloaddigest.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -220,6 +167,7 @@ func (pdu *PayloadDigestUpdate) sqlSave(ctx context.Context) (n int, err error) } return 0, err } + pdu.mutation.done = true return n, nil } @@ -273,6 +221,12 @@ func (pduo *PayloadDigestUpdateOne) ClearDsse() *PayloadDigestUpdateOne { return pduo } +// Where appends a list predicates to the PayloadDigestUpdate builder. +func (pduo *PayloadDigestUpdateOne) Where(ps ...predicate.PayloadDigest) *PayloadDigestUpdateOne { + pduo.mutation.Where(ps...) + return pduo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (pduo *PayloadDigestUpdateOne) Select(field string, fields ...string) *PayloadDigestUpdateOne { @@ -282,46 +236,7 @@ func (pduo *PayloadDigestUpdateOne) Select(field string, fields ...string) *Payl // Save executes the query and returns the updated PayloadDigest entity. func (pduo *PayloadDigestUpdateOne) Save(ctx context.Context) (*PayloadDigest, error) { - var ( - err error - node *PayloadDigest - ) - if len(pduo.hooks) == 0 { - if err = pduo.check(); err != nil { - return nil, err - } - node, err = pduo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*PayloadDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = pduo.check(); err != nil { - return nil, err - } - pduo.mutation = mutation - node, err = pduo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(pduo.hooks) - 1; i >= 0; i-- { - if pduo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = pduo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, pduo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*PayloadDigest) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from PayloadDigestMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, pduo.sqlSave, pduo.mutation, pduo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -362,16 +277,10 @@ func (pduo *PayloadDigestUpdateOne) check() error { } func (pduo *PayloadDigestUpdateOne) sqlSave(ctx context.Context) (_node *PayloadDigest, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: payloaddigest.Table, - Columns: payloaddigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: payloaddigest.FieldID, - }, - }, + if err := pduo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(payloaddigest.Table, payloaddigest.Columns, sqlgraph.NewFieldSpec(payloaddigest.FieldID, field.TypeInt)) id, ok := pduo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "PayloadDigest.id" for update`)} @@ -397,18 +306,10 @@ func (pduo *PayloadDigestUpdateOne) sqlSave(ctx context.Context) (_node *Payload } } if value, ok := pduo.mutation.Algorithm(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldAlgorithm, - }) + _spec.SetField(payloaddigest.FieldAlgorithm, field.TypeString, value) } if value, ok := pduo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: payloaddigest.FieldValue, - }) + _spec.SetField(payloaddigest.FieldValue, field.TypeString, value) } if pduo.mutation.DsseCleared() { edge := &sqlgraph.EdgeSpec{ @@ -418,10 +319,7 @@ func (pduo *PayloadDigestUpdateOne) sqlSave(ctx context.Context) (_node *Payload Columns: []string{payloaddigest.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -434,10 +332,7 @@ func (pduo *PayloadDigestUpdateOne) sqlSave(ctx context.Context) (_node *Payload Columns: []string{payloaddigest.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -456,5 +351,6 @@ func (pduo *PayloadDigestUpdateOne) sqlSave(ctx context.Context) (_node *Payload } return nil, err } + pduo.mutation.done = true return _node, nil } diff --git a/ent/runtime/runtime.go b/ent/runtime/runtime.go index e7b13943..e1a2fbd9 100644 --- a/ent/runtime/runtime.go +++ b/ent/runtime/runtime.go @@ -5,6 +5,6 @@ package runtime // The schema-stitching logic is generated in github.com/testifysec/archivista/ent/runtime.go const ( - Version = "v0.11.3" // Version of ent codegen. - Sum = "h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc=" // Sum of ent codegen. + Version = "v0.12.4" // Version of ent codegen. + Sum = "h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8=" // Sum of ent codegen. ) diff --git a/ent/signature.go b/ent/signature.go index 5be8dd9d..38738a3a 100644 --- a/ent/signature.go +++ b/ent/signature.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/dsse" "github.com/testifysec/archivista/ent/signature" @@ -24,6 +25,7 @@ type Signature struct { // The values are being populated by the SignatureQuery when eager-loading is set. Edges SignatureEdges `json:"edges"` dsse_signatures *int + selectValues sql.SelectValues } // SignatureEdges holds the relations/edges for other nodes in the graph. @@ -36,7 +38,9 @@ type SignatureEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [2]bool // totalCount holds the count of the edges above. - totalCount [2]*int + totalCount [2]map[string]int + + namedTimestamps map[string][]*Timestamp } // DsseOrErr returns the Dsse value or an error if the edge @@ -73,7 +77,7 @@ func (*Signature) scanValues(columns []string) ([]any, error) { case signature.ForeignKeys[0]: // dsse_signatures values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Signature", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -112,26 +116,34 @@ func (s *Signature) assignValues(columns []string, values []any) error { s.dsse_signatures = new(int) *s.dsse_signatures = int(value.Int64) } + default: + s.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Signature. +// This includes values selected through modifiers, order, etc. +func (s *Signature) Value(name string) (ent.Value, error) { + return s.selectValues.Get(name) +} + // QueryDsse queries the "dsse" edge of the Signature entity. func (s *Signature) QueryDsse() *DsseQuery { - return (&SignatureClient{config: s.config}).QueryDsse(s) + return NewSignatureClient(s.config).QueryDsse(s) } // QueryTimestamps queries the "timestamps" edge of the Signature entity. func (s *Signature) QueryTimestamps() *TimestampQuery { - return (&SignatureClient{config: s.config}).QueryTimestamps(s) + return NewSignatureClient(s.config).QueryTimestamps(s) } // Update returns a builder for updating this Signature. // Note that you need to call Signature.Unwrap() before calling this method if this Signature // was returned from a transaction, and the transaction was committed or rolled back. func (s *Signature) Update() *SignatureUpdateOne { - return (&SignatureClient{config: s.config}).UpdateOne(s) + return NewSignatureClient(s.config).UpdateOne(s) } // Unwrap unwraps the Signature entity that was returned from a transaction after it was closed, @@ -159,11 +171,29 @@ func (s *Signature) String() string { return builder.String() } -// Signatures is a parsable slice of Signature. -type Signatures []*Signature +// NamedTimestamps returns the Timestamps named value or an error if the edge was not +// loaded in eager-loading with this name. +func (s *Signature) NamedTimestamps(name string) ([]*Timestamp, error) { + if s.Edges.namedTimestamps == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := s.Edges.namedTimestamps[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} -func (s Signatures) config(cfg config) { - for _i := range s { - s[_i].config = cfg +func (s *Signature) appendNamedTimestamps(name string, edges ...*Timestamp) { + if s.Edges.namedTimestamps == nil { + s.Edges.namedTimestamps = make(map[string][]*Timestamp) + } + if len(edges) == 0 { + s.Edges.namedTimestamps[name] = []*Timestamp{} + } else { + s.Edges.namedTimestamps[name] = append(s.Edges.namedTimestamps[name], edges...) } } + +// Signatures is a parsable slice of Signature. +type Signatures []*Signature diff --git a/ent/signature/signature.go b/ent/signature/signature.go index dea255a2..8e7d97f6 100644 --- a/ent/signature/signature.go +++ b/ent/signature/signature.go @@ -2,6 +2,11 @@ package signature +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the signature type in the database. Label = "signature" @@ -67,3 +72,56 @@ var ( // SignatureValidator is a validator for the "signature" field. It is called by the builders before save. SignatureValidator func(string) error ) + +// OrderOption defines the ordering options for the Signature queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByKeyID orders the results by the key_id field. +func ByKeyID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKeyID, opts...).ToFunc() +} + +// BySignature orders the results by the signature field. +func BySignature(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSignature, opts...).ToFunc() +} + +// ByDsseField orders the results by dsse field. +func ByDsseField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDsseStep(), sql.OrderByField(field, opts...)) + } +} + +// ByTimestampsCount orders the results by timestamps count. +func ByTimestampsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newTimestampsStep(), opts...) + } +} + +// ByTimestamps orders the results by timestamps terms. +func ByTimestamps(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newTimestampsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newDsseStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DsseInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), + ) +} +func newTimestampsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(TimestampsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, TimestampsTable, TimestampsColumn), + ) +} diff --git a/ent/signature/where.go b/ent/signature/where.go index 5209ca94..451ed067 100644 --- a/ent/signature/where.go +++ b/ent/signature/where.go @@ -10,285 +10,187 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Signature(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Signature(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Signature(sql.FieldLTE(FieldID, id)) } // KeyID applies equality check predicate on the "key_id" field. It's identical to KeyIDEQ. func KeyID(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldEQ(FieldKeyID, v)) } // Signature applies equality check predicate on the "signature" field. It's identical to SignatureEQ. func Signature(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldEQ(FieldSignature, v)) } // KeyIDEQ applies the EQ predicate on the "key_id" field. func KeyIDEQ(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldEQ(FieldKeyID, v)) } // KeyIDNEQ applies the NEQ predicate on the "key_id" field. func KeyIDNEQ(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldNEQ(FieldKeyID, v)) } // KeyIDIn applies the In predicate on the "key_id" field. func KeyIDIn(vs ...string) predicate.Signature { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldKeyID), v...)) - }) + return predicate.Signature(sql.FieldIn(FieldKeyID, vs...)) } // KeyIDNotIn applies the NotIn predicate on the "key_id" field. func KeyIDNotIn(vs ...string) predicate.Signature { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldKeyID), v...)) - }) + return predicate.Signature(sql.FieldNotIn(FieldKeyID, vs...)) } // KeyIDGT applies the GT predicate on the "key_id" field. func KeyIDGT(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldGT(FieldKeyID, v)) } // KeyIDGTE applies the GTE predicate on the "key_id" field. func KeyIDGTE(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldGTE(FieldKeyID, v)) } // KeyIDLT applies the LT predicate on the "key_id" field. func KeyIDLT(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldLT(FieldKeyID, v)) } // KeyIDLTE applies the LTE predicate on the "key_id" field. func KeyIDLTE(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldLTE(FieldKeyID, v)) } // KeyIDContains applies the Contains predicate on the "key_id" field. func KeyIDContains(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldContains(FieldKeyID, v)) } // KeyIDHasPrefix applies the HasPrefix predicate on the "key_id" field. func KeyIDHasPrefix(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldHasPrefix(FieldKeyID, v)) } // KeyIDHasSuffix applies the HasSuffix predicate on the "key_id" field. func KeyIDHasSuffix(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldHasSuffix(FieldKeyID, v)) } // KeyIDEqualFold applies the EqualFold predicate on the "key_id" field. func KeyIDEqualFold(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldEqualFold(FieldKeyID, v)) } // KeyIDContainsFold applies the ContainsFold predicate on the "key_id" field. func KeyIDContainsFold(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldKeyID), v)) - }) + return predicate.Signature(sql.FieldContainsFold(FieldKeyID, v)) } // SignatureEQ applies the EQ predicate on the "signature" field. func SignatureEQ(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldEQ(FieldSignature, v)) } // SignatureNEQ applies the NEQ predicate on the "signature" field. func SignatureNEQ(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldNEQ(FieldSignature, v)) } // SignatureIn applies the In predicate on the "signature" field. func SignatureIn(vs ...string) predicate.Signature { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldSignature), v...)) - }) + return predicate.Signature(sql.FieldIn(FieldSignature, vs...)) } // SignatureNotIn applies the NotIn predicate on the "signature" field. func SignatureNotIn(vs ...string) predicate.Signature { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldSignature), v...)) - }) + return predicate.Signature(sql.FieldNotIn(FieldSignature, vs...)) } // SignatureGT applies the GT predicate on the "signature" field. func SignatureGT(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldGT(FieldSignature, v)) } // SignatureGTE applies the GTE predicate on the "signature" field. func SignatureGTE(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldGTE(FieldSignature, v)) } // SignatureLT applies the LT predicate on the "signature" field. func SignatureLT(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldLT(FieldSignature, v)) } // SignatureLTE applies the LTE predicate on the "signature" field. func SignatureLTE(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldLTE(FieldSignature, v)) } // SignatureContains applies the Contains predicate on the "signature" field. func SignatureContains(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldContains(FieldSignature, v)) } // SignatureHasPrefix applies the HasPrefix predicate on the "signature" field. func SignatureHasPrefix(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldHasPrefix(FieldSignature, v)) } // SignatureHasSuffix applies the HasSuffix predicate on the "signature" field. func SignatureHasSuffix(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldHasSuffix(FieldSignature, v)) } // SignatureEqualFold applies the EqualFold predicate on the "signature" field. func SignatureEqualFold(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldEqualFold(FieldSignature, v)) } // SignatureContainsFold applies the ContainsFold predicate on the "signature" field. func SignatureContainsFold(v string) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldSignature), v)) - }) + return predicate.Signature(sql.FieldContainsFold(FieldSignature, v)) } // HasDsse applies the HasEdge predicate on the "dsse" edge. @@ -296,7 +198,6 @@ func HasDsse() predicate.Signature { return predicate.Signature(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), ) sqlgraph.HasNeighbors(s, step) @@ -306,11 +207,7 @@ func HasDsse() predicate.Signature { // HasDsseWith applies the HasEdge predicate on the "dsse" edge with a given conditions (other predicates). func HasDsseWith(preds ...predicate.Dsse) predicate.Signature { return predicate.Signature(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, DsseTable, DsseColumn), - ) + step := newDsseStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -324,7 +221,6 @@ func HasTimestamps() predicate.Signature { return predicate.Signature(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(TimestampsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, TimestampsTable, TimestampsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -334,11 +230,7 @@ func HasTimestamps() predicate.Signature { // HasTimestampsWith applies the HasEdge predicate on the "timestamps" edge with a given conditions (other predicates). func HasTimestampsWith(preds ...predicate.Timestamp) predicate.Signature { return predicate.Signature(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(TimestampsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, TimestampsTable, TimestampsColumn), - ) + step := newTimestampsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -349,32 +241,15 @@ func HasTimestampsWith(preds ...predicate.Timestamp) predicate.Signature { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Signature) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Signature(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Signature) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Signature(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Signature) predicate.Signature { - return predicate.Signature(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Signature(sql.NotPredicates(p)) } diff --git a/ent/signature_create.go b/ent/signature_create.go index 565a7831..8ee21e81 100644 --- a/ent/signature_create.go +++ b/ent/signature_create.go @@ -74,49 +74,7 @@ func (sc *SignatureCreate) Mutation() *SignatureMutation { // Save creates the Signature in the database. func (sc *SignatureCreate) Save(ctx context.Context) (*Signature, error) { - var ( - err error - node *Signature - ) - if len(sc.hooks) == 0 { - if err = sc.check(); err != nil { - return nil, err - } - node, err = sc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SignatureMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sc.check(); err != nil { - return nil, err - } - sc.mutation = mutation - if node, err = sc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(sc.hooks) - 1; i >= 0; i-- { - if sc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, sc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Signature) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SignatureMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, sc.sqlSave, sc.mutation, sc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -163,6 +121,9 @@ func (sc *SignatureCreate) check() error { } func (sc *SignatureCreate) sqlSave(ctx context.Context) (*Signature, error) { + if err := sc.check(); err != nil { + return nil, err + } _node, _spec := sc.createSpec() if err := sqlgraph.CreateNode(ctx, sc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -172,34 +133,22 @@ func (sc *SignatureCreate) sqlSave(ctx context.Context) (*Signature, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + sc.mutation.id = &_node.ID + sc.mutation.done = true return _node, nil } func (sc *SignatureCreate) createSpec() (*Signature, *sqlgraph.CreateSpec) { var ( _node = &Signature{config: sc.config} - _spec = &sqlgraph.CreateSpec{ - Table: signature.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(signature.Table, sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt)) ) if value, ok := sc.mutation.KeyID(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldKeyID, - }) + _spec.SetField(signature.FieldKeyID, field.TypeString, value) _node.KeyID = value } if value, ok := sc.mutation.Signature(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldSignature, - }) + _spec.SetField(signature.FieldSignature, field.TypeString, value) _node.Signature = value } if nodes := sc.mutation.DsseIDs(); len(nodes) > 0 { @@ -210,10 +159,7 @@ func (sc *SignatureCreate) createSpec() (*Signature, *sqlgraph.CreateSpec) { Columns: []string{signature.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -230,10 +176,7 @@ func (sc *SignatureCreate) createSpec() (*Signature, *sqlgraph.CreateSpec) { Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -247,11 +190,15 @@ func (sc *SignatureCreate) createSpec() (*Signature, *sqlgraph.CreateSpec) { // SignatureCreateBulk is the builder for creating many Signature entities in bulk. type SignatureCreateBulk struct { config + err error builders []*SignatureCreate } // Save creates the Signature entities in the database. func (scb *SignatureCreateBulk) Save(ctx context.Context) ([]*Signature, error) { + if scb.err != nil { + return nil, scb.err + } specs := make([]*sqlgraph.CreateSpec, len(scb.builders)) nodes := make([]*Signature, len(scb.builders)) mutators := make([]Mutator, len(scb.builders)) @@ -267,8 +214,8 @@ func (scb *SignatureCreateBulk) Save(ctx context.Context) ([]*Signature, error) return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, scb.builders[i+1].mutation) } else { diff --git a/ent/signature_delete.go b/ent/signature_delete.go index 7be7002b..eb1a7100 100644 --- a/ent/signature_delete.go +++ b/ent/signature_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (sd *SignatureDelete) Where(ps ...predicate.Signature) *SignatureDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (sd *SignatureDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(sd.hooks) == 0 { - affected, err = sd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SignatureMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - sd.mutation = mutation - affected, err = sd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(sd.hooks) - 1; i >= 0; i-- { - if sd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, sd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, sd.sqlExec, sd.mutation, sd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (sd *SignatureDelete) ExecX(ctx context.Context) int { } func (sd *SignatureDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: signature.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(signature.Table, sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt)) if ps := sd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (sd *SignatureDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + sd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type SignatureDeleteOne struct { sd *SignatureDelete } +// Where appends a list predicates to the SignatureDelete builder. +func (sdo *SignatureDeleteOne) Where(ps ...predicate.Signature) *SignatureDeleteOne { + sdo.sd.mutation.Where(ps...) + return sdo +} + // Exec executes the deletion query. func (sdo *SignatureDeleteOne) Exec(ctx context.Context) error { n, err := sdo.sd.Exec(ctx) @@ -111,5 +82,7 @@ func (sdo *SignatureDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (sdo *SignatureDeleteOne) ExecX(ctx context.Context) { - sdo.sd.ExecX(ctx) + if err := sdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/signature_query.go b/ent/signature_query.go index f3707d5f..275de31e 100644 --- a/ent/signature_query.go +++ b/ent/signature_query.go @@ -20,17 +20,16 @@ import ( // SignatureQuery is the builder for querying Signature entities. type SignatureQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string - predicates []predicate.Signature - withDsse *DsseQuery - withTimestamps *TimestampQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*Signature) error + ctx *QueryContext + order []signature.OrderOption + inters []Interceptor + predicates []predicate.Signature + withDsse *DsseQuery + withTimestamps *TimestampQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*Signature) error + withNamedTimestamps map[string]*TimestampQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -42,34 +41,34 @@ func (sq *SignatureQuery) Where(ps ...predicate.Signature) *SignatureQuery { return sq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (sq *SignatureQuery) Limit(limit int) *SignatureQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } -// Offset adds an offset step to the query. +// Offset to start from. func (sq *SignatureQuery) Offset(offset int) *SignatureQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *SignatureQuery) Unique(unique bool) *SignatureQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } -// Order adds an order step to the query. -func (sq *SignatureQuery) Order(o ...OrderFunc) *SignatureQuery { +// Order specifies how the records should be ordered. +func (sq *SignatureQuery) Order(o ...signature.OrderOption) *SignatureQuery { sq.order = append(sq.order, o...) return sq } // QueryDsse chains the current query on the "dsse" edge. func (sq *SignatureQuery) QueryDsse() *DsseQuery { - query := &DsseQuery{config: sq.config} + query := (&DsseClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -91,7 +90,7 @@ func (sq *SignatureQuery) QueryDsse() *DsseQuery { // QueryTimestamps chains the current query on the "timestamps" edge. func (sq *SignatureQuery) QueryTimestamps() *TimestampQuery { - query := &TimestampQuery{config: sq.config} + query := (&TimestampClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -114,7 +113,7 @@ func (sq *SignatureQuery) QueryTimestamps() *TimestampQuery { // First returns the first Signature entity from the query. // Returns a *NotFoundError when no Signature was found. func (sq *SignatureQuery) First(ctx context.Context) (*Signature, error) { - nodes, err := sq.Limit(1).All(ctx) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +136,7 @@ func (sq *SignatureQuery) FirstX(ctx context.Context) *Signature { // Returns a *NotFoundError when no Signature ID was found. func (sq *SignatureQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(1).IDs(ctx); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -160,7 +159,7 @@ func (sq *SignatureQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Signature entity is found. // Returns a *NotFoundError when no Signature entities are found. func (sq *SignatureQuery) Only(ctx context.Context) (*Signature, error) { - nodes, err := sq.Limit(2).All(ctx) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -188,7 +187,7 @@ func (sq *SignatureQuery) OnlyX(ctx context.Context) *Signature { // Returns a *NotFoundError when no entities are found. func (sq *SignatureQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(2).IDs(ctx); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -213,10 +212,12 @@ func (sq *SignatureQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Signatures. func (sq *SignatureQuery) All(ctx context.Context) ([]*Signature, error) { + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } - return sq.sqlAll(ctx) + qr := querierAll[[]*Signature, *SignatureQuery]() + return withInterceptors[[]*Signature](ctx, sq, qr, sq.inters) } // AllX is like All, but panics if an error occurs. @@ -229,9 +230,12 @@ func (sq *SignatureQuery) AllX(ctx context.Context) []*Signature { } // IDs executes the query and returns a list of Signature IDs. -func (sq *SignatureQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := sq.Select(signature.FieldID).Scan(ctx, &ids); err != nil { +func (sq *SignatureQuery) IDs(ctx context.Context) (ids []int, err error) { + if sq.ctx.Unique == nil && sq.path != nil { + sq.Unique(true) + } + ctx = setContextOp(ctx, sq.ctx, "IDs") + if err = sq.Select(signature.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -248,10 +252,11 @@ func (sq *SignatureQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (sq *SignatureQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } - return sq.sqlCount(ctx) + return withInterceptors[int](ctx, sq, querierCount[*SignatureQuery](), sq.inters) } // CountX is like Count, but panics if an error occurs. @@ -265,10 +270,15 @@ func (sq *SignatureQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *SignatureQuery) Exist(ctx context.Context) (bool, error) { - if err := sq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, sq.ctx, "Exist") + switch _, err := sq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return sq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -288,23 +298,22 @@ func (sq *SignatureQuery) Clone() *SignatureQuery { } return &SignatureQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, - order: append([]OrderFunc{}, sq.order...), + ctx: sq.ctx.Clone(), + order: append([]signature.OrderOption{}, sq.order...), + inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Signature{}, sq.predicates...), withDsse: sq.withDsse.Clone(), withTimestamps: sq.withTimestamps.Clone(), // clone intermediate query. - sql: sq.sql.Clone(), - path: sq.path, - unique: sq.unique, + sql: sq.sql.Clone(), + path: sq.path, } } // WithDsse tells the query-builder to eager-load the nodes that are connected to // the "dsse" edge. The optional arguments are used to configure the query builder of the edge. func (sq *SignatureQuery) WithDsse(opts ...func(*DsseQuery)) *SignatureQuery { - query := &DsseQuery{config: sq.config} + query := (&DsseClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -315,7 +324,7 @@ func (sq *SignatureQuery) WithDsse(opts ...func(*DsseQuery)) *SignatureQuery { // WithTimestamps tells the query-builder to eager-load the nodes that are connected to // the "timestamps" edge. The optional arguments are used to configure the query builder of the edge. func (sq *SignatureQuery) WithTimestamps(opts ...func(*TimestampQuery)) *SignatureQuery { - query := &TimestampQuery{config: sq.config} + query := (&TimestampClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -338,16 +347,11 @@ func (sq *SignatureQuery) WithTimestamps(opts ...func(*TimestampQuery)) *Signatu // Aggregate(ent.Count()). // Scan(ctx, &v) func (sq *SignatureQuery) GroupBy(field string, fields ...string) *SignatureGroupBy { - grbuild := &SignatureGroupBy{config: sq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := sq.prepareQuery(ctx); err != nil { - return nil, err - } - return sq.sqlQuery(ctx), nil - } + sq.ctx.Fields = append([]string{field}, fields...) + grbuild := &SignatureGroupBy{build: sq} + grbuild.flds = &sq.ctx.Fields grbuild.label = signature.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -364,15 +368,30 @@ func (sq *SignatureQuery) GroupBy(field string, fields ...string) *SignatureGrou // Select(signature.FieldKeyID). // Scan(ctx, &v) func (sq *SignatureQuery) Select(fields ...string) *SignatureSelect { - sq.fields = append(sq.fields, fields...) - selbuild := &SignatureSelect{SignatureQuery: sq} - selbuild.label = signature.Label - selbuild.flds, selbuild.scan = &sq.fields, selbuild.Scan - return selbuild + sq.ctx.Fields = append(sq.ctx.Fields, fields...) + sbuild := &SignatureSelect{SignatureQuery: sq} + sbuild.label = signature.Label + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SignatureSelect configured with the given aggregations. +func (sq *SignatureQuery) Aggregate(fns ...AggregateFunc) *SignatureSelect { + return sq.Select().Aggregate(fns...) } func (sq *SignatureQuery) prepareQuery(ctx context.Context) error { - for _, f := range sq.fields { + for _, inter := range sq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sq); err != nil { + return err + } + } + } + for _, f := range sq.ctx.Fields { if !signature.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -437,6 +456,13 @@ func (sq *SignatureQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Si return nil, err } } + for name, query := range sq.withNamedTimestamps { + if err := sq.loadTimestamps(ctx, query, nodes, + func(n *Signature) { n.appendNamedTimestamps(name) }, + func(n *Signature, e *Timestamp) { n.appendNamedTimestamps(name, e) }); err != nil { + return nil, err + } + } for i := range sq.loadTotal { if err := sq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -458,6 +484,9 @@ func (sq *SignatureQuery) loadDsse(ctx context.Context, query *DsseQuery, nodes } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(dsse.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -486,7 +515,7 @@ func (sq *SignatureQuery) loadTimestamps(ctx context.Context, query *TimestampQu } query.withFKs = true query.Where(predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.InValues(signature.TimestampsColumn, fks...)) + s.Where(sql.InValues(s.C(signature.TimestampsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -499,7 +528,7 @@ func (sq *SignatureQuery) loadTimestamps(ctx context.Context, query *TimestampQu } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "signature_timestamps" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "signature_timestamps" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -511,41 +540,22 @@ func (sq *SignatureQuery) sqlCount(ctx context.Context) (int, error) { if len(sq.modifiers) > 0 { _spec.Modifiers = sq.modifiers } - _spec.Node.Columns = sq.fields - if len(sq.fields) > 0 { - _spec.Unique = sq.unique != nil && *sq.unique + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique } return sqlgraph.CountNodes(ctx, sq.driver, _spec) } -func (sq *SignatureQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := sq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (sq *SignatureQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: signature.Table, - Columns: signature.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, - }, - From: sq.sql, - Unique: true, - } - if unique := sq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(signature.Table, signature.Columns, sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt)) + _spec.From = sq.sql + if unique := sq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if sq.path != nil { + _spec.Unique = true } - if fields := sq.fields; len(fields) > 0 { + if fields := sq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, signature.FieldID) for i := range fields { @@ -561,10 +571,10 @@ func (sq *SignatureQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sq.order; len(ps) > 0 { @@ -580,7 +590,7 @@ func (sq *SignatureQuery) querySpec() *sqlgraph.QuerySpec { func (sq *SignatureQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(signature.Table) - columns := sq.fields + columns := sq.ctx.Fields if len(columns) == 0 { columns = signature.Columns } @@ -589,7 +599,7 @@ func (sq *SignatureQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sq.sql selector.Select(selector.Columns(columns...)...) } - if sq.unique != nil && *sq.unique { + if sq.ctx.Unique != nil && *sq.ctx.Unique { selector.Distinct() } for _, p := range sq.predicates { @@ -598,26 +608,35 @@ func (sq *SignatureQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sq.order { p(selector) } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } +// WithNamedTimestamps tells the query-builder to eager-load the nodes that are connected to the "timestamps" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (sq *SignatureQuery) WithNamedTimestamps(name string, opts ...func(*TimestampQuery)) *SignatureQuery { + query := (&TimestampClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + if sq.withNamedTimestamps == nil { + sq.withNamedTimestamps = make(map[string]*TimestampQuery) + } + sq.withNamedTimestamps[name] = query + return sq +} + // SignatureGroupBy is the group-by builder for Signature entities. type SignatureGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *SignatureQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -626,74 +645,77 @@ func (sgb *SignatureGroupBy) Aggregate(fns ...AggregateFunc) *SignatureGroupBy { return sgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (sgb *SignatureGroupBy) Scan(ctx context.Context, v any) error { - query, err := sgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") + if err := sgb.build.prepareQuery(ctx); err != nil { return err } - sgb.sql = query - return sgb.sqlScan(ctx, v) + return scanWithInterceptors[*SignatureQuery, *SignatureGroupBy](ctx, sgb.build, sgb, sgb.build.inters, v) } -func (sgb *SignatureGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range sgb.fields { - if !signature.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (sgb *SignatureGroupBy) sqlScan(ctx context.Context, root *SignatureQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sgb.fns)) + for _, fn := range sgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sgb.flds)+len(sgb.fns)) + for _, f := range *sgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := sgb.sqlQuery() + selector.GroupBy(selector.Columns(*sgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := sgb.driver.Query(ctx, query, args, rows); err != nil { + if err := sgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (sgb *SignatureGroupBy) sqlQuery() *sql.Selector { - selector := sgb.sql.Select() - aggregation := make([]string, 0, len(sgb.fns)) - for _, fn := range sgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(sgb.fields)+len(sgb.fns)) - for _, f := range sgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(sgb.fields...)...) -} - // SignatureSelect is the builder for selecting fields of Signature entities. type SignatureSelect struct { *SignatureQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ss *SignatureSelect) Aggregate(fns ...AggregateFunc) *SignatureSelect { + ss.fns = append(ss.fns, fns...) + return ss } // Scan applies the selector query and scans the result into the given value. func (ss *SignatureSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } - ss.sql = ss.SignatureQuery.sqlQuery(ctx) - return ss.sqlScan(ctx, v) + return scanWithInterceptors[*SignatureQuery, *SignatureSelect](ctx, ss.SignatureQuery, ss, ss.inters, v) } -func (ss *SignatureSelect) sqlScan(ctx context.Context, v any) error { +func (ss *SignatureSelect) sqlScan(ctx context.Context, root *SignatureQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ss.fns)) + for _, fn := range ss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ss.sql.Query() + query, args := selector.Query() if err := ss.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/signature_update.go b/ent/signature_update.go index 6cbc46ad..7776db94 100644 --- a/ent/signature_update.go +++ b/ent/signature_update.go @@ -109,40 +109,7 @@ func (su *SignatureUpdate) RemoveTimestamps(t ...*Timestamp) *SignatureUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (su *SignatureUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(su.hooks) == 0 { - if err = su.check(); err != nil { - return 0, err - } - affected, err = su.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SignatureMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = su.check(); err != nil { - return 0, err - } - su.mutation = mutation - affected, err = su.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(su.hooks) - 1; i >= 0; i-- { - if su.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = su.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, su.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, su.sqlSave, su.mutation, su.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -183,16 +150,10 @@ func (su *SignatureUpdate) check() error { } func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: signature.Table, - Columns: signature.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, - }, + if err := su.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(signature.Table, signature.Columns, sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt)) if ps := su.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -201,18 +162,10 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := su.mutation.KeyID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldKeyID, - }) + _spec.SetField(signature.FieldKeyID, field.TypeString, value) } if value, ok := su.mutation.Signature(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldSignature, - }) + _spec.SetField(signature.FieldSignature, field.TypeString, value) } if su.mutation.DsseCleared() { edge := &sqlgraph.EdgeSpec{ @@ -222,10 +175,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{signature.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -238,10 +188,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{signature.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -257,10 +204,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -273,10 +217,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -292,10 +233,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -311,6 +249,7 @@ func (su *SignatureUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + su.mutation.done = true return n, nil } @@ -400,6 +339,12 @@ func (suo *SignatureUpdateOne) RemoveTimestamps(t ...*Timestamp) *SignatureUpdat return suo.RemoveTimestampIDs(ids...) } +// Where appends a list predicates to the SignatureUpdate builder. +func (suo *SignatureUpdateOne) Where(ps ...predicate.Signature) *SignatureUpdateOne { + suo.mutation.Where(ps...) + return suo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (suo *SignatureUpdateOne) Select(field string, fields ...string) *SignatureUpdateOne { @@ -409,46 +354,7 @@ func (suo *SignatureUpdateOne) Select(field string, fields ...string) *Signature // Save executes the query and returns the updated Signature entity. func (suo *SignatureUpdateOne) Save(ctx context.Context) (*Signature, error) { - var ( - err error - node *Signature - ) - if len(suo.hooks) == 0 { - if err = suo.check(); err != nil { - return nil, err - } - node, err = suo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SignatureMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = suo.check(); err != nil { - return nil, err - } - suo.mutation = mutation - node, err = suo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(suo.hooks) - 1; i >= 0; i-- { - if suo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = suo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, suo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Signature) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SignatureMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, suo.sqlSave, suo.mutation, suo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -489,16 +395,10 @@ func (suo *SignatureUpdateOne) check() error { } func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: signature.Table, - Columns: signature.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, - }, + if err := suo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(signature.Table, signature.Columns, sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt)) id, ok := suo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Signature.id" for update`)} @@ -524,18 +424,10 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e } } if value, ok := suo.mutation.KeyID(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldKeyID, - }) + _spec.SetField(signature.FieldKeyID, field.TypeString, value) } if value, ok := suo.mutation.Signature(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: signature.FieldSignature, - }) + _spec.SetField(signature.FieldSignature, field.TypeString, value) } if suo.mutation.DsseCleared() { edge := &sqlgraph.EdgeSpec{ @@ -545,10 +437,7 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e Columns: []string{signature.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -561,10 +450,7 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e Columns: []string{signature.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -580,10 +466,7 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -596,10 +479,7 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -615,10 +495,7 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e Columns: []string{signature.TimestampsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -637,5 +514,6 @@ func (suo *SignatureUpdateOne) sqlSave(ctx context.Context) (_node *Signature, e } return nil, err } + suo.mutation.done = true return _node, nil } diff --git a/ent/statement.go b/ent/statement.go index e4ca31a2..79dedbb1 100644 --- a/ent/statement.go +++ b/ent/statement.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/attestationcollection" "github.com/testifysec/archivista/ent/statement" @@ -20,7 +21,8 @@ type Statement struct { Predicate string `json:"predicate,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the StatementQuery when eager-loading is set. - Edges StatementEdges `json:"edges"` + Edges StatementEdges `json:"edges"` + selectValues sql.SelectValues } // StatementEdges holds the relations/edges for other nodes in the graph. @@ -35,7 +37,10 @@ type StatementEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [3]bool // totalCount holds the count of the edges above. - totalCount [3]*int + totalCount [3]map[string]int + + namedSubjects map[string][]*Subject + namedDsse map[string][]*Dsse } // SubjectsOrErr returns the Subjects value or an error if the edge @@ -79,7 +84,7 @@ func (*Statement) scanValues(columns []string) ([]any, error) { case statement.FieldPredicate: values[i] = new(sql.NullString) default: - return nil, fmt.Errorf("unexpected column %q for type Statement", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -105,31 +110,39 @@ func (s *Statement) assignValues(columns []string, values []any) error { } else if value.Valid { s.Predicate = value.String } + default: + s.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Statement. +// This includes values selected through modifiers, order, etc. +func (s *Statement) Value(name string) (ent.Value, error) { + return s.selectValues.Get(name) +} + // QuerySubjects queries the "subjects" edge of the Statement entity. func (s *Statement) QuerySubjects() *SubjectQuery { - return (&StatementClient{config: s.config}).QuerySubjects(s) + return NewStatementClient(s.config).QuerySubjects(s) } // QueryAttestationCollections queries the "attestation_collections" edge of the Statement entity. func (s *Statement) QueryAttestationCollections() *AttestationCollectionQuery { - return (&StatementClient{config: s.config}).QueryAttestationCollections(s) + return NewStatementClient(s.config).QueryAttestationCollections(s) } // QueryDsse queries the "dsse" edge of the Statement entity. func (s *Statement) QueryDsse() *DsseQuery { - return (&StatementClient{config: s.config}).QueryDsse(s) + return NewStatementClient(s.config).QueryDsse(s) } // Update returns a builder for updating this Statement. // Note that you need to call Statement.Unwrap() before calling this method if this Statement // was returned from a transaction, and the transaction was committed or rolled back. func (s *Statement) Update() *StatementUpdateOne { - return (&StatementClient{config: s.config}).UpdateOne(s) + return NewStatementClient(s.config).UpdateOne(s) } // Unwrap unwraps the Statement entity that was returned from a transaction after it was closed, @@ -154,11 +167,53 @@ func (s *Statement) String() string { return builder.String() } -// Statements is a parsable slice of Statement. -type Statements []*Statement +// NamedSubjects returns the Subjects named value or an error if the edge was not +// loaded in eager-loading with this name. +func (s *Statement) NamedSubjects(name string) ([]*Subject, error) { + if s.Edges.namedSubjects == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := s.Edges.namedSubjects[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} -func (s Statements) config(cfg config) { - for _i := range s { - s[_i].config = cfg +func (s *Statement) appendNamedSubjects(name string, edges ...*Subject) { + if s.Edges.namedSubjects == nil { + s.Edges.namedSubjects = make(map[string][]*Subject) + } + if len(edges) == 0 { + s.Edges.namedSubjects[name] = []*Subject{} + } else { + s.Edges.namedSubjects[name] = append(s.Edges.namedSubjects[name], edges...) } } + +// NamedDsse returns the Dsse named value or an error if the edge was not +// loaded in eager-loading with this name. +func (s *Statement) NamedDsse(name string) ([]*Dsse, error) { + if s.Edges.namedDsse == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := s.Edges.namedDsse[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} + +func (s *Statement) appendNamedDsse(name string, edges ...*Dsse) { + if s.Edges.namedDsse == nil { + s.Edges.namedDsse = make(map[string][]*Dsse) + } + if len(edges) == 0 { + s.Edges.namedDsse[name] = []*Dsse{} + } else { + s.Edges.namedDsse[name] = append(s.Edges.namedDsse[name], edges...) + } +} + +// Statements is a parsable slice of Statement. +type Statements []*Statement diff --git a/ent/statement/statement.go b/ent/statement/statement.go index a76160c8..82e86df0 100644 --- a/ent/statement/statement.go +++ b/ent/statement/statement.go @@ -2,6 +2,11 @@ package statement +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the statement type in the database. Label = "statement" @@ -60,3 +65,72 @@ var ( // PredicateValidator is a validator for the "predicate" field. It is called by the builders before save. PredicateValidator func(string) error ) + +// OrderOption defines the ordering options for the Statement queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByPredicate orders the results by the predicate field. +func ByPredicate(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPredicate, opts...).ToFunc() +} + +// BySubjectsCount orders the results by subjects count. +func BySubjectsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSubjectsStep(), opts...) + } +} + +// BySubjects orders the results by subjects terms. +func BySubjects(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSubjectsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByAttestationCollectionsField orders the results by attestation_collections field. +func ByAttestationCollectionsField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newAttestationCollectionsStep(), sql.OrderByField(field, opts...)) + } +} + +// ByDsseCount orders the results by dsse count. +func ByDsseCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newDsseStep(), opts...) + } +} + +// ByDsse orders the results by dsse terms. +func ByDsse(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newDsseStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} +func newSubjectsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SubjectsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SubjectsTable, SubjectsColumn), + ) +} +func newAttestationCollectionsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(AttestationCollectionsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2O, false, AttestationCollectionsTable, AttestationCollectionsColumn), + ) +} +func newDsseStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(DsseInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, true, DsseTable, DsseColumn), + ) +} diff --git a/ent/statement/where.go b/ent/statement/where.go index 3f72cc59..2771d7ae 100644 --- a/ent/statement/where.go +++ b/ent/statement/where.go @@ -10,179 +10,117 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Statement(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Statement(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Statement(sql.FieldLTE(FieldID, id)) } // Predicate applies equality check predicate on the "predicate" field. It's identical to PredicateEQ. func Predicate(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldEQ(FieldPredicate, v)) } // PredicateEQ applies the EQ predicate on the "predicate" field. func PredicateEQ(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldEQ(FieldPredicate, v)) } // PredicateNEQ applies the NEQ predicate on the "predicate" field. func PredicateNEQ(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldNEQ(FieldPredicate, v)) } // PredicateIn applies the In predicate on the "predicate" field. func PredicateIn(vs ...string) predicate.Statement { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldPredicate), v...)) - }) + return predicate.Statement(sql.FieldIn(FieldPredicate, vs...)) } // PredicateNotIn applies the NotIn predicate on the "predicate" field. func PredicateNotIn(vs ...string) predicate.Statement { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldPredicate), v...)) - }) + return predicate.Statement(sql.FieldNotIn(FieldPredicate, vs...)) } // PredicateGT applies the GT predicate on the "predicate" field. func PredicateGT(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldGT(FieldPredicate, v)) } // PredicateGTE applies the GTE predicate on the "predicate" field. func PredicateGTE(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldGTE(FieldPredicate, v)) } // PredicateLT applies the LT predicate on the "predicate" field. func PredicateLT(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldLT(FieldPredicate, v)) } // PredicateLTE applies the LTE predicate on the "predicate" field. func PredicateLTE(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldLTE(FieldPredicate, v)) } // PredicateContains applies the Contains predicate on the "predicate" field. func PredicateContains(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldContains(FieldPredicate, v)) } // PredicateHasPrefix applies the HasPrefix predicate on the "predicate" field. func PredicateHasPrefix(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldHasPrefix(FieldPredicate, v)) } // PredicateHasSuffix applies the HasSuffix predicate on the "predicate" field. func PredicateHasSuffix(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldHasSuffix(FieldPredicate, v)) } // PredicateEqualFold applies the EqualFold predicate on the "predicate" field. func PredicateEqualFold(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldEqualFold(FieldPredicate, v)) } // PredicateContainsFold applies the ContainsFold predicate on the "predicate" field. func PredicateContainsFold(v string) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldPredicate), v)) - }) + return predicate.Statement(sql.FieldContainsFold(FieldPredicate, v)) } // HasSubjects applies the HasEdge predicate on the "subjects" edge. @@ -190,7 +128,6 @@ func HasSubjects() predicate.Statement { return predicate.Statement(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, SubjectsTable, SubjectsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -200,11 +137,7 @@ func HasSubjects() predicate.Statement { // HasSubjectsWith applies the HasEdge predicate on the "subjects" edge with a given conditions (other predicates). func HasSubjectsWith(preds ...predicate.Subject) predicate.Statement { return predicate.Statement(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, SubjectsTable, SubjectsColumn), - ) + step := newSubjectsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -218,7 +151,6 @@ func HasAttestationCollections() predicate.Statement { return predicate.Statement(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationCollectionsTable, FieldID), sqlgraph.Edge(sqlgraph.O2O, false, AttestationCollectionsTable, AttestationCollectionsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -228,11 +160,7 @@ func HasAttestationCollections() predicate.Statement { // HasAttestationCollectionsWith applies the HasEdge predicate on the "attestation_collections" edge with a given conditions (other predicates). func HasAttestationCollectionsWith(preds ...predicate.AttestationCollection) predicate.Statement { return predicate.Statement(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(AttestationCollectionsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2O, false, AttestationCollectionsTable, AttestationCollectionsColumn), - ) + step := newAttestationCollectionsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -246,7 +174,6 @@ func HasDsse() predicate.Statement { return predicate.Statement(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, true, DsseTable, DsseColumn), ) sqlgraph.HasNeighbors(s, step) @@ -256,11 +183,7 @@ func HasDsse() predicate.Statement { // HasDsseWith applies the HasEdge predicate on the "dsse" edge with a given conditions (other predicates). func HasDsseWith(preds ...predicate.Dsse) predicate.Statement { return predicate.Statement(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(DsseInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, true, DsseTable, DsseColumn), - ) + step := newDsseStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -271,32 +194,15 @@ func HasDsseWith(preds ...predicate.Dsse) predicate.Statement { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Statement) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Statement(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Statement) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Statement(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Statement) predicate.Statement { - return predicate.Statement(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Statement(sql.NotPredicates(p)) } diff --git a/ent/statement_create.go b/ent/statement_create.go index bfa02f7c..86cf43de 100644 --- a/ent/statement_create.go +++ b/ent/statement_create.go @@ -84,49 +84,7 @@ func (sc *StatementCreate) Mutation() *StatementMutation { // Save creates the Statement in the database. func (sc *StatementCreate) Save(ctx context.Context) (*Statement, error) { - var ( - err error - node *Statement - ) - if len(sc.hooks) == 0 { - if err = sc.check(); err != nil { - return nil, err - } - node, err = sc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*StatementMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sc.check(); err != nil { - return nil, err - } - sc.mutation = mutation - if node, err = sc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(sc.hooks) - 1; i >= 0; i-- { - if sc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, sc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Statement) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from StatementMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, sc.sqlSave, sc.mutation, sc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -165,6 +123,9 @@ func (sc *StatementCreate) check() error { } func (sc *StatementCreate) sqlSave(ctx context.Context) (*Statement, error) { + if err := sc.check(); err != nil { + return nil, err + } _node, _spec := sc.createSpec() if err := sqlgraph.CreateNode(ctx, sc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -174,26 +135,18 @@ func (sc *StatementCreate) sqlSave(ctx context.Context) (*Statement, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + sc.mutation.id = &_node.ID + sc.mutation.done = true return _node, nil } func (sc *StatementCreate) createSpec() (*Statement, *sqlgraph.CreateSpec) { var ( _node = &Statement{config: sc.config} - _spec = &sqlgraph.CreateSpec{ - Table: statement.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(statement.Table, sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt)) ) if value, ok := sc.mutation.Predicate(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: statement.FieldPredicate, - }) + _spec.SetField(statement.FieldPredicate, field.TypeString, value) _node.Predicate = value } if nodes := sc.mutation.SubjectsIDs(); len(nodes) > 0 { @@ -204,10 +157,7 @@ func (sc *StatementCreate) createSpec() (*Statement, *sqlgraph.CreateSpec) { Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -223,10 +173,7 @@ func (sc *StatementCreate) createSpec() (*Statement, *sqlgraph.CreateSpec) { Columns: []string{statement.AttestationCollectionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -242,10 +189,7 @@ func (sc *StatementCreate) createSpec() (*Statement, *sqlgraph.CreateSpec) { Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -259,11 +203,15 @@ func (sc *StatementCreate) createSpec() (*Statement, *sqlgraph.CreateSpec) { // StatementCreateBulk is the builder for creating many Statement entities in bulk. type StatementCreateBulk struct { config + err error builders []*StatementCreate } // Save creates the Statement entities in the database. func (scb *StatementCreateBulk) Save(ctx context.Context) ([]*Statement, error) { + if scb.err != nil { + return nil, scb.err + } specs := make([]*sqlgraph.CreateSpec, len(scb.builders)) nodes := make([]*Statement, len(scb.builders)) mutators := make([]Mutator, len(scb.builders)) @@ -279,8 +227,8 @@ func (scb *StatementCreateBulk) Save(ctx context.Context) ([]*Statement, error) return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, scb.builders[i+1].mutation) } else { diff --git a/ent/statement_delete.go b/ent/statement_delete.go index e46120e6..0669b47e 100644 --- a/ent/statement_delete.go +++ b/ent/statement_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (sd *StatementDelete) Where(ps ...predicate.Statement) *StatementDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (sd *StatementDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(sd.hooks) == 0 { - affected, err = sd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*StatementMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - sd.mutation = mutation - affected, err = sd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(sd.hooks) - 1; i >= 0; i-- { - if sd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, sd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, sd.sqlExec, sd.mutation, sd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (sd *StatementDelete) ExecX(ctx context.Context) int { } func (sd *StatementDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: statement.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(statement.Table, sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt)) if ps := sd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (sd *StatementDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + sd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type StatementDeleteOne struct { sd *StatementDelete } +// Where appends a list predicates to the StatementDelete builder. +func (sdo *StatementDeleteOne) Where(ps ...predicate.Statement) *StatementDeleteOne { + sdo.sd.mutation.Where(ps...) + return sdo +} + // Exec executes the deletion query. func (sdo *StatementDeleteOne) Exec(ctx context.Context) error { n, err := sdo.sd.Exec(ctx) @@ -111,5 +82,7 @@ func (sdo *StatementDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (sdo *StatementDeleteOne) ExecX(ctx context.Context) { - sdo.sd.ExecX(ctx) + if err := sdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/statement_query.go b/ent/statement_query.go index 1784986d..a9fd8f54 100644 --- a/ent/statement_query.go +++ b/ent/statement_query.go @@ -21,17 +21,17 @@ import ( // StatementQuery is the builder for querying Statement entities. type StatementQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []statement.OrderOption + inters []Interceptor predicates []predicate.Statement withSubjects *SubjectQuery withAttestationCollections *AttestationCollectionQuery withDsse *DsseQuery modifiers []func(*sql.Selector) loadTotal []func(context.Context, []*Statement) error + withNamedSubjects map[string]*SubjectQuery + withNamedDsse map[string]*DsseQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -43,34 +43,34 @@ func (sq *StatementQuery) Where(ps ...predicate.Statement) *StatementQuery { return sq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (sq *StatementQuery) Limit(limit int) *StatementQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } -// Offset adds an offset step to the query. +// Offset to start from. func (sq *StatementQuery) Offset(offset int) *StatementQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *StatementQuery) Unique(unique bool) *StatementQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } -// Order adds an order step to the query. -func (sq *StatementQuery) Order(o ...OrderFunc) *StatementQuery { +// Order specifies how the records should be ordered. +func (sq *StatementQuery) Order(o ...statement.OrderOption) *StatementQuery { sq.order = append(sq.order, o...) return sq } // QuerySubjects chains the current query on the "subjects" edge. func (sq *StatementQuery) QuerySubjects() *SubjectQuery { - query := &SubjectQuery{config: sq.config} + query := (&SubjectClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -92,7 +92,7 @@ func (sq *StatementQuery) QuerySubjects() *SubjectQuery { // QueryAttestationCollections chains the current query on the "attestation_collections" edge. func (sq *StatementQuery) QueryAttestationCollections() *AttestationCollectionQuery { - query := &AttestationCollectionQuery{config: sq.config} + query := (&AttestationCollectionClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -114,7 +114,7 @@ func (sq *StatementQuery) QueryAttestationCollections() *AttestationCollectionQu // QueryDsse chains the current query on the "dsse" edge. func (sq *StatementQuery) QueryDsse() *DsseQuery { - query := &DsseQuery{config: sq.config} + query := (&DsseClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -137,7 +137,7 @@ func (sq *StatementQuery) QueryDsse() *DsseQuery { // First returns the first Statement entity from the query. // Returns a *NotFoundError when no Statement was found. func (sq *StatementQuery) First(ctx context.Context) (*Statement, error) { - nodes, err := sq.Limit(1).All(ctx) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -160,7 +160,7 @@ func (sq *StatementQuery) FirstX(ctx context.Context) *Statement { // Returns a *NotFoundError when no Statement ID was found. func (sq *StatementQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(1).IDs(ctx); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -183,7 +183,7 @@ func (sq *StatementQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Statement entity is found. // Returns a *NotFoundError when no Statement entities are found. func (sq *StatementQuery) Only(ctx context.Context) (*Statement, error) { - nodes, err := sq.Limit(2).All(ctx) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -211,7 +211,7 @@ func (sq *StatementQuery) OnlyX(ctx context.Context) *Statement { // Returns a *NotFoundError when no entities are found. func (sq *StatementQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(2).IDs(ctx); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -236,10 +236,12 @@ func (sq *StatementQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Statements. func (sq *StatementQuery) All(ctx context.Context) ([]*Statement, error) { + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } - return sq.sqlAll(ctx) + qr := querierAll[[]*Statement, *StatementQuery]() + return withInterceptors[[]*Statement](ctx, sq, qr, sq.inters) } // AllX is like All, but panics if an error occurs. @@ -252,9 +254,12 @@ func (sq *StatementQuery) AllX(ctx context.Context) []*Statement { } // IDs executes the query and returns a list of Statement IDs. -func (sq *StatementQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := sq.Select(statement.FieldID).Scan(ctx, &ids); err != nil { +func (sq *StatementQuery) IDs(ctx context.Context) (ids []int, err error) { + if sq.ctx.Unique == nil && sq.path != nil { + sq.Unique(true) + } + ctx = setContextOp(ctx, sq.ctx, "IDs") + if err = sq.Select(statement.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -271,10 +276,11 @@ func (sq *StatementQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (sq *StatementQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } - return sq.sqlCount(ctx) + return withInterceptors[int](ctx, sq, querierCount[*StatementQuery](), sq.inters) } // CountX is like Count, but panics if an error occurs. @@ -288,10 +294,15 @@ func (sq *StatementQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *StatementQuery) Exist(ctx context.Context) (bool, error) { - if err := sq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, sq.ctx, "Exist") + switch _, err := sq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return sq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -311,24 +322,23 @@ func (sq *StatementQuery) Clone() *StatementQuery { } return &StatementQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, - order: append([]OrderFunc{}, sq.order...), + ctx: sq.ctx.Clone(), + order: append([]statement.OrderOption{}, sq.order...), + inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Statement{}, sq.predicates...), withSubjects: sq.withSubjects.Clone(), withAttestationCollections: sq.withAttestationCollections.Clone(), withDsse: sq.withDsse.Clone(), // clone intermediate query. - sql: sq.sql.Clone(), - path: sq.path, - unique: sq.unique, + sql: sq.sql.Clone(), + path: sq.path, } } // WithSubjects tells the query-builder to eager-load the nodes that are connected to // the "subjects" edge. The optional arguments are used to configure the query builder of the edge. func (sq *StatementQuery) WithSubjects(opts ...func(*SubjectQuery)) *StatementQuery { - query := &SubjectQuery{config: sq.config} + query := (&SubjectClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -339,7 +349,7 @@ func (sq *StatementQuery) WithSubjects(opts ...func(*SubjectQuery)) *StatementQu // WithAttestationCollections tells the query-builder to eager-load the nodes that are connected to // the "attestation_collections" edge. The optional arguments are used to configure the query builder of the edge. func (sq *StatementQuery) WithAttestationCollections(opts ...func(*AttestationCollectionQuery)) *StatementQuery { - query := &AttestationCollectionQuery{config: sq.config} + query := (&AttestationCollectionClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -350,7 +360,7 @@ func (sq *StatementQuery) WithAttestationCollections(opts ...func(*AttestationCo // WithDsse tells the query-builder to eager-load the nodes that are connected to // the "dsse" edge. The optional arguments are used to configure the query builder of the edge. func (sq *StatementQuery) WithDsse(opts ...func(*DsseQuery)) *StatementQuery { - query := &DsseQuery{config: sq.config} + query := (&DsseClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -373,16 +383,11 @@ func (sq *StatementQuery) WithDsse(opts ...func(*DsseQuery)) *StatementQuery { // Aggregate(ent.Count()). // Scan(ctx, &v) func (sq *StatementQuery) GroupBy(field string, fields ...string) *StatementGroupBy { - grbuild := &StatementGroupBy{config: sq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := sq.prepareQuery(ctx); err != nil { - return nil, err - } - return sq.sqlQuery(ctx), nil - } + sq.ctx.Fields = append([]string{field}, fields...) + grbuild := &StatementGroupBy{build: sq} + grbuild.flds = &sq.ctx.Fields grbuild.label = statement.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -399,15 +404,30 @@ func (sq *StatementQuery) GroupBy(field string, fields ...string) *StatementGrou // Select(statement.FieldPredicate). // Scan(ctx, &v) func (sq *StatementQuery) Select(fields ...string) *StatementSelect { - sq.fields = append(sq.fields, fields...) - selbuild := &StatementSelect{StatementQuery: sq} - selbuild.label = statement.Label - selbuild.flds, selbuild.scan = &sq.fields, selbuild.Scan - return selbuild + sq.ctx.Fields = append(sq.ctx.Fields, fields...) + sbuild := &StatementSelect{StatementQuery: sq} + sbuild.label = statement.Label + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a StatementSelect configured with the given aggregations. +func (sq *StatementQuery) Aggregate(fns ...AggregateFunc) *StatementSelect { + return sq.Select().Aggregate(fns...) } func (sq *StatementQuery) prepareQuery(ctx context.Context) error { - for _, f := range sq.fields { + for _, inter := range sq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sq); err != nil { + return err + } + } + } + for _, f := range sq.ctx.Fields { if !statement.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -473,6 +493,20 @@ func (sq *StatementQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*St return nil, err } } + for name, query := range sq.withNamedSubjects { + if err := sq.loadSubjects(ctx, query, nodes, + func(n *Statement) { n.appendNamedSubjects(name) }, + func(n *Statement, e *Subject) { n.appendNamedSubjects(name, e) }); err != nil { + return nil, err + } + } + for name, query := range sq.withNamedDsse { + if err := sq.loadDsse(ctx, query, nodes, + func(n *Statement) { n.appendNamedDsse(name) }, + func(n *Statement, e *Dsse) { n.appendNamedDsse(name, e) }); err != nil { + return nil, err + } + } for i := range sq.loadTotal { if err := sq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -493,7 +527,7 @@ func (sq *StatementQuery) loadSubjects(ctx context.Context, query *SubjectQuery, } query.withFKs = true query.Where(predicate.Subject(func(s *sql.Selector) { - s.Where(sql.InValues(statement.SubjectsColumn, fks...)) + s.Where(sql.InValues(s.C(statement.SubjectsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -506,7 +540,7 @@ func (sq *StatementQuery) loadSubjects(ctx context.Context, query *SubjectQuery, } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "statement_subjects" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "statement_subjects" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -521,7 +555,7 @@ func (sq *StatementQuery) loadAttestationCollections(ctx context.Context, query } query.withFKs = true query.Where(predicate.AttestationCollection(func(s *sql.Selector) { - s.Where(sql.InValues(statement.AttestationCollectionsColumn, fks...)) + s.Where(sql.InValues(s.C(statement.AttestationCollectionsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -534,7 +568,7 @@ func (sq *StatementQuery) loadAttestationCollections(ctx context.Context, query } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "statement_attestation_collections" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "statement_attestation_collections" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -552,7 +586,7 @@ func (sq *StatementQuery) loadDsse(ctx context.Context, query *DsseQuery, nodes } query.withFKs = true query.Where(predicate.Dsse(func(s *sql.Selector) { - s.Where(sql.InValues(statement.DsseColumn, fks...)) + s.Where(sql.InValues(s.C(statement.DsseColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -565,7 +599,7 @@ func (sq *StatementQuery) loadDsse(ctx context.Context, query *DsseQuery, nodes } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "dsse_statement" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "dsse_statement" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -577,41 +611,22 @@ func (sq *StatementQuery) sqlCount(ctx context.Context) (int, error) { if len(sq.modifiers) > 0 { _spec.Modifiers = sq.modifiers } - _spec.Node.Columns = sq.fields - if len(sq.fields) > 0 { - _spec.Unique = sq.unique != nil && *sq.unique + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique } return sqlgraph.CountNodes(ctx, sq.driver, _spec) } -func (sq *StatementQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := sq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (sq *StatementQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: statement.Table, - Columns: statement.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, - }, - From: sq.sql, - Unique: true, - } - if unique := sq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(statement.Table, statement.Columns, sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt)) + _spec.From = sq.sql + if unique := sq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if sq.path != nil { + _spec.Unique = true } - if fields := sq.fields; len(fields) > 0 { + if fields := sq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, statement.FieldID) for i := range fields { @@ -627,10 +642,10 @@ func (sq *StatementQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sq.order; len(ps) > 0 { @@ -646,7 +661,7 @@ func (sq *StatementQuery) querySpec() *sqlgraph.QuerySpec { func (sq *StatementQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(statement.Table) - columns := sq.fields + columns := sq.ctx.Fields if len(columns) == 0 { columns = statement.Columns } @@ -655,7 +670,7 @@ func (sq *StatementQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sq.sql selector.Select(selector.Columns(columns...)...) } - if sq.unique != nil && *sq.unique { + if sq.ctx.Unique != nil && *sq.ctx.Unique { selector.Distinct() } for _, p := range sq.predicates { @@ -664,26 +679,49 @@ func (sq *StatementQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sq.order { p(selector) } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } +// WithNamedSubjects tells the query-builder to eager-load the nodes that are connected to the "subjects" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (sq *StatementQuery) WithNamedSubjects(name string, opts ...func(*SubjectQuery)) *StatementQuery { + query := (&SubjectClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + if sq.withNamedSubjects == nil { + sq.withNamedSubjects = make(map[string]*SubjectQuery) + } + sq.withNamedSubjects[name] = query + return sq +} + +// WithNamedDsse tells the query-builder to eager-load the nodes that are connected to the "dsse" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (sq *StatementQuery) WithNamedDsse(name string, opts ...func(*DsseQuery)) *StatementQuery { + query := (&DsseClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + if sq.withNamedDsse == nil { + sq.withNamedDsse = make(map[string]*DsseQuery) + } + sq.withNamedDsse[name] = query + return sq +} + // StatementGroupBy is the group-by builder for Statement entities. type StatementGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *StatementQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -692,74 +730,77 @@ func (sgb *StatementGroupBy) Aggregate(fns ...AggregateFunc) *StatementGroupBy { return sgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (sgb *StatementGroupBy) Scan(ctx context.Context, v any) error { - query, err := sgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") + if err := sgb.build.prepareQuery(ctx); err != nil { return err } - sgb.sql = query - return sgb.sqlScan(ctx, v) + return scanWithInterceptors[*StatementQuery, *StatementGroupBy](ctx, sgb.build, sgb, sgb.build.inters, v) } -func (sgb *StatementGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range sgb.fields { - if !statement.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (sgb *StatementGroupBy) sqlScan(ctx context.Context, root *StatementQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sgb.fns)) + for _, fn := range sgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sgb.flds)+len(sgb.fns)) + for _, f := range *sgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := sgb.sqlQuery() + selector.GroupBy(selector.Columns(*sgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := sgb.driver.Query(ctx, query, args, rows); err != nil { + if err := sgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (sgb *StatementGroupBy) sqlQuery() *sql.Selector { - selector := sgb.sql.Select() - aggregation := make([]string, 0, len(sgb.fns)) - for _, fn := range sgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(sgb.fields)+len(sgb.fns)) - for _, f := range sgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(sgb.fields...)...) -} - // StatementSelect is the builder for selecting fields of Statement entities. type StatementSelect struct { *StatementQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ss *StatementSelect) Aggregate(fns ...AggregateFunc) *StatementSelect { + ss.fns = append(ss.fns, fns...) + return ss } // Scan applies the selector query and scans the result into the given value. func (ss *StatementSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } - ss.sql = ss.StatementQuery.sqlQuery(ctx) - return ss.sqlScan(ctx, v) + return scanWithInterceptors[*StatementQuery, *StatementSelect](ctx, ss.StatementQuery, ss, ss.inters, v) } -func (ss *StatementSelect) sqlScan(ctx context.Context, v any) error { +func (ss *StatementSelect) sqlScan(ctx context.Context, root *StatementQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ss.fns)) + for _, fn := range ss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ss.sql.Query() + query, args := selector.Query() if err := ss.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/statement_update.go b/ent/statement_update.go index 5d4b9c19..0c3704e8 100644 --- a/ent/statement_update.go +++ b/ent/statement_update.go @@ -140,40 +140,7 @@ func (su *StatementUpdate) RemoveDsse(d ...*Dsse) *StatementUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (su *StatementUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(su.hooks) == 0 { - if err = su.check(); err != nil { - return 0, err - } - affected, err = su.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*StatementMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = su.check(); err != nil { - return 0, err - } - su.mutation = mutation - affected, err = su.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(su.hooks) - 1; i >= 0; i-- { - if su.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = su.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, su.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, su.sqlSave, su.mutation, su.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -209,16 +176,10 @@ func (su *StatementUpdate) check() error { } func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: statement.Table, - Columns: statement.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, - }, + if err := su.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(statement.Table, statement.Columns, sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt)) if ps := su.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -227,11 +188,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := su.mutation.Predicate(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: statement.FieldPredicate, - }) + _spec.SetField(statement.FieldPredicate, field.TypeString, value) } if su.mutation.SubjectsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -241,10 +198,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -257,10 +211,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -276,10 +227,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -295,10 +243,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.AttestationCollectionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -311,10 +256,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.AttestationCollectionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -330,10 +272,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -346,10 +285,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -365,10 +301,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -384,6 +317,7 @@ func (su *StatementUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + su.mutation.done = true return n, nil } @@ -503,6 +437,12 @@ func (suo *StatementUpdateOne) RemoveDsse(d ...*Dsse) *StatementUpdateOne { return suo.RemoveDsseIDs(ids...) } +// Where appends a list predicates to the StatementUpdate builder. +func (suo *StatementUpdateOne) Where(ps ...predicate.Statement) *StatementUpdateOne { + suo.mutation.Where(ps...) + return suo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (suo *StatementUpdateOne) Select(field string, fields ...string) *StatementUpdateOne { @@ -512,46 +452,7 @@ func (suo *StatementUpdateOne) Select(field string, fields ...string) *Statement // Save executes the query and returns the updated Statement entity. func (suo *StatementUpdateOne) Save(ctx context.Context) (*Statement, error) { - var ( - err error - node *Statement - ) - if len(suo.hooks) == 0 { - if err = suo.check(); err != nil { - return nil, err - } - node, err = suo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*StatementMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = suo.check(); err != nil { - return nil, err - } - suo.mutation = mutation - node, err = suo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(suo.hooks) - 1; i >= 0; i-- { - if suo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = suo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, suo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Statement) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from StatementMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, suo.sqlSave, suo.mutation, suo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -587,16 +488,10 @@ func (suo *StatementUpdateOne) check() error { } func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: statement.Table, - Columns: statement.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, - }, + if err := suo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(statement.Table, statement.Columns, sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt)) id, ok := suo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Statement.id" for update`)} @@ -622,11 +517,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e } } if value, ok := suo.mutation.Predicate(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: statement.FieldPredicate, - }) + _spec.SetField(statement.FieldPredicate, field.TypeString, value) } if suo.mutation.SubjectsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -636,10 +527,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -652,10 +540,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -671,10 +556,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.SubjectsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -690,10 +572,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.AttestationCollectionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -706,10 +585,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.AttestationCollectionsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: attestationcollection.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(attestationcollection.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -725,10 +601,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -741,10 +614,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -760,10 +630,7 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e Columns: []string{statement.DsseColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: dsse.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(dsse.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -782,5 +649,6 @@ func (suo *StatementUpdateOne) sqlSave(ctx context.Context) (_node *Statement, e } return nil, err } + suo.mutation.done = true return _node, nil } diff --git a/ent/subject.go b/ent/subject.go index 748aafa3..42500b1d 100644 --- a/ent/subject.go +++ b/ent/subject.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/statement" "github.com/testifysec/archivista/ent/subject" @@ -22,6 +23,7 @@ type Subject struct { // The values are being populated by the SubjectQuery when eager-loading is set. Edges SubjectEdges `json:"edges"` statement_subjects *int + selectValues sql.SelectValues } // SubjectEdges holds the relations/edges for other nodes in the graph. @@ -34,7 +36,9 @@ type SubjectEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [2]bool // totalCount holds the count of the edges above. - totalCount [2]*int + totalCount [2]map[string]int + + namedSubjectDigests map[string][]*SubjectDigest } // SubjectDigestsOrErr returns the SubjectDigests value or an error if the edge @@ -71,7 +75,7 @@ func (*Subject) scanValues(columns []string) ([]any, error) { case subject.ForeignKeys[0]: // statement_subjects values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Subject", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -104,26 +108,34 @@ func (s *Subject) assignValues(columns []string, values []any) error { s.statement_subjects = new(int) *s.statement_subjects = int(value.Int64) } + default: + s.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Subject. +// This includes values selected through modifiers, order, etc. +func (s *Subject) Value(name string) (ent.Value, error) { + return s.selectValues.Get(name) +} + // QuerySubjectDigests queries the "subject_digests" edge of the Subject entity. func (s *Subject) QuerySubjectDigests() *SubjectDigestQuery { - return (&SubjectClient{config: s.config}).QuerySubjectDigests(s) + return NewSubjectClient(s.config).QuerySubjectDigests(s) } // QueryStatement queries the "statement" edge of the Subject entity. func (s *Subject) QueryStatement() *StatementQuery { - return (&SubjectClient{config: s.config}).QueryStatement(s) + return NewSubjectClient(s.config).QueryStatement(s) } // Update returns a builder for updating this Subject. // Note that you need to call Subject.Unwrap() before calling this method if this Subject // was returned from a transaction, and the transaction was committed or rolled back. func (s *Subject) Update() *SubjectUpdateOne { - return (&SubjectClient{config: s.config}).UpdateOne(s) + return NewSubjectClient(s.config).UpdateOne(s) } // Unwrap unwraps the Subject entity that was returned from a transaction after it was closed, @@ -148,11 +160,29 @@ func (s *Subject) String() string { return builder.String() } -// Subjects is a parsable slice of Subject. -type Subjects []*Subject +// NamedSubjectDigests returns the SubjectDigests named value or an error if the edge was not +// loaded in eager-loading with this name. +func (s *Subject) NamedSubjectDigests(name string) ([]*SubjectDigest, error) { + if s.Edges.namedSubjectDigests == nil { + return nil, &NotLoadedError{edge: name} + } + nodes, ok := s.Edges.namedSubjectDigests[name] + if !ok { + return nil, &NotLoadedError{edge: name} + } + return nodes, nil +} -func (s Subjects) config(cfg config) { - for _i := range s { - s[_i].config = cfg +func (s *Subject) appendNamedSubjectDigests(name string, edges ...*SubjectDigest) { + if s.Edges.namedSubjectDigests == nil { + s.Edges.namedSubjectDigests = make(map[string][]*SubjectDigest) + } + if len(edges) == 0 { + s.Edges.namedSubjectDigests[name] = []*SubjectDigest{} + } else { + s.Edges.namedSubjectDigests[name] = append(s.Edges.namedSubjectDigests[name], edges...) } } + +// Subjects is a parsable slice of Subject. +type Subjects []*Subject diff --git a/ent/subject/subject.go b/ent/subject/subject.go index 02a613ca..ec6aa416 100644 --- a/ent/subject/subject.go +++ b/ent/subject/subject.go @@ -2,6 +2,11 @@ package subject +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the subject type in the database. Label = "subject" @@ -62,3 +67,51 @@ var ( // NameValidator is a validator for the "name" field. It is called by the builders before save. NameValidator func(string) error ) + +// OrderOption defines the ordering options for the Subject queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// BySubjectDigestsCount orders the results by subject_digests count. +func BySubjectDigestsCount(opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborsCount(s, newSubjectDigestsStep(), opts...) + } +} + +// BySubjectDigests orders the results by subject_digests terms. +func BySubjectDigests(term sql.OrderTerm, terms ...sql.OrderTerm) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSubjectDigestsStep(), append([]sql.OrderTerm{term}, terms...)...) + } +} + +// ByStatementField orders the results by statement field. +func ByStatementField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newStatementStep(), sql.OrderByField(field, opts...)) + } +} +func newSubjectDigestsStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SubjectDigestsInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.O2M, false, SubjectDigestsTable, SubjectDigestsColumn), + ) +} +func newStatementStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(StatementInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, StatementTable, StatementColumn), + ) +} diff --git a/ent/subject/where.go b/ent/subject/where.go index c808fc12..cfc49197 100644 --- a/ent/subject/where.go +++ b/ent/subject/where.go @@ -10,179 +10,117 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Subject(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Subject(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Subject(sql.FieldLTE(FieldID, id)) } // Name applies equality check predicate on the "name" field. It's identical to NameEQ. func Name(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldEQ(FieldName, v)) } // NameEQ applies the EQ predicate on the "name" field. func NameEQ(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldEQ(FieldName, v)) } // NameNEQ applies the NEQ predicate on the "name" field. func NameNEQ(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldNEQ(FieldName, v)) } // NameIn applies the In predicate on the "name" field. func NameIn(vs ...string) predicate.Subject { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldName), v...)) - }) + return predicate.Subject(sql.FieldIn(FieldName, vs...)) } // NameNotIn applies the NotIn predicate on the "name" field. func NameNotIn(vs ...string) predicate.Subject { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldName), v...)) - }) + return predicate.Subject(sql.FieldNotIn(FieldName, vs...)) } // NameGT applies the GT predicate on the "name" field. func NameGT(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldGT(FieldName, v)) } // NameGTE applies the GTE predicate on the "name" field. func NameGTE(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldGTE(FieldName, v)) } // NameLT applies the LT predicate on the "name" field. func NameLT(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldLT(FieldName, v)) } // NameLTE applies the LTE predicate on the "name" field. func NameLTE(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldLTE(FieldName, v)) } // NameContains applies the Contains predicate on the "name" field. func NameContains(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldContains(FieldName, v)) } // NameHasPrefix applies the HasPrefix predicate on the "name" field. func NameHasPrefix(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldHasPrefix(FieldName, v)) } // NameHasSuffix applies the HasSuffix predicate on the "name" field. func NameHasSuffix(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldHasSuffix(FieldName, v)) } // NameEqualFold applies the EqualFold predicate on the "name" field. func NameEqualFold(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldEqualFold(FieldName, v)) } // NameContainsFold applies the ContainsFold predicate on the "name" field. func NameContainsFold(v string) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldName), v)) - }) + return predicate.Subject(sql.FieldContainsFold(FieldName, v)) } // HasSubjectDigests applies the HasEdge predicate on the "subject_digests" edge. @@ -190,7 +128,6 @@ func HasSubjectDigests() predicate.Subject { return predicate.Subject(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectDigestsTable, FieldID), sqlgraph.Edge(sqlgraph.O2M, false, SubjectDigestsTable, SubjectDigestsColumn), ) sqlgraph.HasNeighbors(s, step) @@ -200,11 +137,7 @@ func HasSubjectDigests() predicate.Subject { // HasSubjectDigestsWith applies the HasEdge predicate on the "subject_digests" edge with a given conditions (other predicates). func HasSubjectDigestsWith(preds ...predicate.SubjectDigest) predicate.Subject { return predicate.Subject(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectDigestsInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.O2M, false, SubjectDigestsTable, SubjectDigestsColumn), - ) + step := newSubjectDigestsStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -218,7 +151,6 @@ func HasStatement() predicate.Subject { return predicate.Subject(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, StatementTable, StatementColumn), ) sqlgraph.HasNeighbors(s, step) @@ -228,11 +160,7 @@ func HasStatement() predicate.Subject { // HasStatementWith applies the HasEdge predicate on the "statement" edge with a given conditions (other predicates). func HasStatementWith(preds ...predicate.Statement) predicate.Subject { return predicate.Subject(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(StatementInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, StatementTable, StatementColumn), - ) + step := newStatementStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -243,32 +171,15 @@ func HasStatementWith(preds ...predicate.Statement) predicate.Subject { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Subject) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Subject(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Subject) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Subject(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Subject) predicate.Subject { - return predicate.Subject(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Subject(sql.NotPredicates(p)) } diff --git a/ent/subject_create.go b/ent/subject_create.go index 3d3abcca..f25b56b8 100644 --- a/ent/subject_create.go +++ b/ent/subject_create.go @@ -68,49 +68,7 @@ func (sc *SubjectCreate) Mutation() *SubjectMutation { // Save creates the Subject in the database. func (sc *SubjectCreate) Save(ctx context.Context) (*Subject, error) { - var ( - err error - node *Subject - ) - if len(sc.hooks) == 0 { - if err = sc.check(); err != nil { - return nil, err - } - node, err = sc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sc.check(); err != nil { - return nil, err - } - sc.mutation = mutation - if node, err = sc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(sc.hooks) - 1; i >= 0; i-- { - if sc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, sc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Subject) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SubjectMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, sc.sqlSave, sc.mutation, sc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -149,6 +107,9 @@ func (sc *SubjectCreate) check() error { } func (sc *SubjectCreate) sqlSave(ctx context.Context) (*Subject, error) { + if err := sc.check(); err != nil { + return nil, err + } _node, _spec := sc.createSpec() if err := sqlgraph.CreateNode(ctx, sc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -158,26 +119,18 @@ func (sc *SubjectCreate) sqlSave(ctx context.Context) (*Subject, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + sc.mutation.id = &_node.ID + sc.mutation.done = true return _node, nil } func (sc *SubjectCreate) createSpec() (*Subject, *sqlgraph.CreateSpec) { var ( _node = &Subject{config: sc.config} - _spec = &sqlgraph.CreateSpec{ - Table: subject.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(subject.Table, sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt)) ) if value, ok := sc.mutation.Name(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subject.FieldName, - }) + _spec.SetField(subject.FieldName, field.TypeString, value) _node.Name = value } if nodes := sc.mutation.SubjectDigestsIDs(); len(nodes) > 0 { @@ -188,10 +141,7 @@ func (sc *SubjectCreate) createSpec() (*Subject, *sqlgraph.CreateSpec) { Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -207,10 +157,7 @@ func (sc *SubjectCreate) createSpec() (*Subject, *sqlgraph.CreateSpec) { Columns: []string{subject.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -225,11 +172,15 @@ func (sc *SubjectCreate) createSpec() (*Subject, *sqlgraph.CreateSpec) { // SubjectCreateBulk is the builder for creating many Subject entities in bulk. type SubjectCreateBulk struct { config + err error builders []*SubjectCreate } // Save creates the Subject entities in the database. func (scb *SubjectCreateBulk) Save(ctx context.Context) ([]*Subject, error) { + if scb.err != nil { + return nil, scb.err + } specs := make([]*sqlgraph.CreateSpec, len(scb.builders)) nodes := make([]*Subject, len(scb.builders)) mutators := make([]Mutator, len(scb.builders)) @@ -245,8 +196,8 @@ func (scb *SubjectCreateBulk) Save(ctx context.Context) ([]*Subject, error) { return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, scb.builders[i+1].mutation) } else { diff --git a/ent/subject_delete.go b/ent/subject_delete.go index a3874f5f..438c351a 100644 --- a/ent/subject_delete.go +++ b/ent/subject_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (sd *SubjectDelete) Where(ps ...predicate.Subject) *SubjectDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (sd *SubjectDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(sd.hooks) == 0 { - affected, err = sd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - sd.mutation = mutation - affected, err = sd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(sd.hooks) - 1; i >= 0; i-- { - if sd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, sd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, sd.sqlExec, sd.mutation, sd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (sd *SubjectDelete) ExecX(ctx context.Context) int { } func (sd *SubjectDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subject.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(subject.Table, sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt)) if ps := sd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (sd *SubjectDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + sd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type SubjectDeleteOne struct { sd *SubjectDelete } +// Where appends a list predicates to the SubjectDelete builder. +func (sdo *SubjectDeleteOne) Where(ps ...predicate.Subject) *SubjectDeleteOne { + sdo.sd.mutation.Where(ps...) + return sdo +} + // Exec executes the deletion query. func (sdo *SubjectDeleteOne) Exec(ctx context.Context) error { n, err := sdo.sd.Exec(ctx) @@ -111,5 +82,7 @@ func (sdo *SubjectDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (sdo *SubjectDeleteOne) ExecX(ctx context.Context) { - sdo.sd.ExecX(ctx) + if err := sdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/subject_query.go b/ent/subject_query.go index 840da6a3..919d165e 100644 --- a/ent/subject_query.go +++ b/ent/subject_query.go @@ -20,17 +20,16 @@ import ( // SubjectQuery is the builder for querying Subject entities. type SubjectQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string - predicates []predicate.Subject - withSubjectDigests *SubjectDigestQuery - withStatement *StatementQuery - withFKs bool - modifiers []func(*sql.Selector) - loadTotal []func(context.Context, []*Subject) error + ctx *QueryContext + order []subject.OrderOption + inters []Interceptor + predicates []predicate.Subject + withSubjectDigests *SubjectDigestQuery + withStatement *StatementQuery + withFKs bool + modifiers []func(*sql.Selector) + loadTotal []func(context.Context, []*Subject) error + withNamedSubjectDigests map[string]*SubjectDigestQuery // intermediate query (i.e. traversal path). sql *sql.Selector path func(context.Context) (*sql.Selector, error) @@ -42,34 +41,34 @@ func (sq *SubjectQuery) Where(ps ...predicate.Subject) *SubjectQuery { return sq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (sq *SubjectQuery) Limit(limit int) *SubjectQuery { - sq.limit = &limit + sq.ctx.Limit = &limit return sq } -// Offset adds an offset step to the query. +// Offset to start from. func (sq *SubjectQuery) Offset(offset int) *SubjectQuery { - sq.offset = &offset + sq.ctx.Offset = &offset return sq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sq *SubjectQuery) Unique(unique bool) *SubjectQuery { - sq.unique = &unique + sq.ctx.Unique = &unique return sq } -// Order adds an order step to the query. -func (sq *SubjectQuery) Order(o ...OrderFunc) *SubjectQuery { +// Order specifies how the records should be ordered. +func (sq *SubjectQuery) Order(o ...subject.OrderOption) *SubjectQuery { sq.order = append(sq.order, o...) return sq } // QuerySubjectDigests chains the current query on the "subject_digests" edge. func (sq *SubjectQuery) QuerySubjectDigests() *SubjectDigestQuery { - query := &SubjectDigestQuery{config: sq.config} + query := (&SubjectDigestClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -91,7 +90,7 @@ func (sq *SubjectQuery) QuerySubjectDigests() *SubjectDigestQuery { // QueryStatement chains the current query on the "statement" edge. func (sq *SubjectQuery) QueryStatement() *StatementQuery { - query := &StatementQuery{config: sq.config} + query := (&StatementClient{config: sq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sq.prepareQuery(ctx); err != nil { return nil, err @@ -114,7 +113,7 @@ func (sq *SubjectQuery) QueryStatement() *StatementQuery { // First returns the first Subject entity from the query. // Returns a *NotFoundError when no Subject was found. func (sq *SubjectQuery) First(ctx context.Context) (*Subject, error) { - nodes, err := sq.Limit(1).All(ctx) + nodes, err := sq.Limit(1).All(setContextOp(ctx, sq.ctx, "First")) if err != nil { return nil, err } @@ -137,7 +136,7 @@ func (sq *SubjectQuery) FirstX(ctx context.Context) *Subject { // Returns a *NotFoundError when no Subject ID was found. func (sq *SubjectQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(1).IDs(ctx); err != nil { + if ids, err = sq.Limit(1).IDs(setContextOp(ctx, sq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -160,7 +159,7 @@ func (sq *SubjectQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Subject entity is found. // Returns a *NotFoundError when no Subject entities are found. func (sq *SubjectQuery) Only(ctx context.Context) (*Subject, error) { - nodes, err := sq.Limit(2).All(ctx) + nodes, err := sq.Limit(2).All(setContextOp(ctx, sq.ctx, "Only")) if err != nil { return nil, err } @@ -188,7 +187,7 @@ func (sq *SubjectQuery) OnlyX(ctx context.Context) *Subject { // Returns a *NotFoundError when no entities are found. func (sq *SubjectQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sq.Limit(2).IDs(ctx); err != nil { + if ids, err = sq.Limit(2).IDs(setContextOp(ctx, sq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -213,10 +212,12 @@ func (sq *SubjectQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Subjects. func (sq *SubjectQuery) All(ctx context.Context) ([]*Subject, error) { + ctx = setContextOp(ctx, sq.ctx, "All") if err := sq.prepareQuery(ctx); err != nil { return nil, err } - return sq.sqlAll(ctx) + qr := querierAll[[]*Subject, *SubjectQuery]() + return withInterceptors[[]*Subject](ctx, sq, qr, sq.inters) } // AllX is like All, but panics if an error occurs. @@ -229,9 +230,12 @@ func (sq *SubjectQuery) AllX(ctx context.Context) []*Subject { } // IDs executes the query and returns a list of Subject IDs. -func (sq *SubjectQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := sq.Select(subject.FieldID).Scan(ctx, &ids); err != nil { +func (sq *SubjectQuery) IDs(ctx context.Context) (ids []int, err error) { + if sq.ctx.Unique == nil && sq.path != nil { + sq.Unique(true) + } + ctx = setContextOp(ctx, sq.ctx, "IDs") + if err = sq.Select(subject.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -248,10 +252,11 @@ func (sq *SubjectQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (sq *SubjectQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sq.ctx, "Count") if err := sq.prepareQuery(ctx); err != nil { return 0, err } - return sq.sqlCount(ctx) + return withInterceptors[int](ctx, sq, querierCount[*SubjectQuery](), sq.inters) } // CountX is like Count, but panics if an error occurs. @@ -265,10 +270,15 @@ func (sq *SubjectQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sq *SubjectQuery) Exist(ctx context.Context) (bool, error) { - if err := sq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, sq.ctx, "Exist") + switch _, err := sq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return sq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -288,23 +298,22 @@ func (sq *SubjectQuery) Clone() *SubjectQuery { } return &SubjectQuery{ config: sq.config, - limit: sq.limit, - offset: sq.offset, - order: append([]OrderFunc{}, sq.order...), + ctx: sq.ctx.Clone(), + order: append([]subject.OrderOption{}, sq.order...), + inters: append([]Interceptor{}, sq.inters...), predicates: append([]predicate.Subject{}, sq.predicates...), withSubjectDigests: sq.withSubjectDigests.Clone(), withStatement: sq.withStatement.Clone(), // clone intermediate query. - sql: sq.sql.Clone(), - path: sq.path, - unique: sq.unique, + sql: sq.sql.Clone(), + path: sq.path, } } // WithSubjectDigests tells the query-builder to eager-load the nodes that are connected to // the "subject_digests" edge. The optional arguments are used to configure the query builder of the edge. func (sq *SubjectQuery) WithSubjectDigests(opts ...func(*SubjectDigestQuery)) *SubjectQuery { - query := &SubjectDigestQuery{config: sq.config} + query := (&SubjectDigestClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -315,7 +324,7 @@ func (sq *SubjectQuery) WithSubjectDigests(opts ...func(*SubjectDigestQuery)) *S // WithStatement tells the query-builder to eager-load the nodes that are connected to // the "statement" edge. The optional arguments are used to configure the query builder of the edge. func (sq *SubjectQuery) WithStatement(opts ...func(*StatementQuery)) *SubjectQuery { - query := &StatementQuery{config: sq.config} + query := (&StatementClient{config: sq.config}).Query() for _, opt := range opts { opt(query) } @@ -338,16 +347,11 @@ func (sq *SubjectQuery) WithStatement(opts ...func(*StatementQuery)) *SubjectQue // Aggregate(ent.Count()). // Scan(ctx, &v) func (sq *SubjectQuery) GroupBy(field string, fields ...string) *SubjectGroupBy { - grbuild := &SubjectGroupBy{config: sq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := sq.prepareQuery(ctx); err != nil { - return nil, err - } - return sq.sqlQuery(ctx), nil - } + sq.ctx.Fields = append([]string{field}, fields...) + grbuild := &SubjectGroupBy{build: sq} + grbuild.flds = &sq.ctx.Fields grbuild.label = subject.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -364,15 +368,30 @@ func (sq *SubjectQuery) GroupBy(field string, fields ...string) *SubjectGroupBy // Select(subject.FieldName). // Scan(ctx, &v) func (sq *SubjectQuery) Select(fields ...string) *SubjectSelect { - sq.fields = append(sq.fields, fields...) - selbuild := &SubjectSelect{SubjectQuery: sq} - selbuild.label = subject.Label - selbuild.flds, selbuild.scan = &sq.fields, selbuild.Scan - return selbuild + sq.ctx.Fields = append(sq.ctx.Fields, fields...) + sbuild := &SubjectSelect{SubjectQuery: sq} + sbuild.label = subject.Label + sbuild.flds, sbuild.scan = &sq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SubjectSelect configured with the given aggregations. +func (sq *SubjectQuery) Aggregate(fns ...AggregateFunc) *SubjectSelect { + return sq.Select().Aggregate(fns...) } func (sq *SubjectQuery) prepareQuery(ctx context.Context) error { - for _, f := range sq.fields { + for _, inter := range sq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sq); err != nil { + return err + } + } + } + for _, f := range sq.ctx.Fields { if !subject.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -437,6 +456,13 @@ func (sq *SubjectQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*Subj return nil, err } } + for name, query := range sq.withNamedSubjectDigests { + if err := sq.loadSubjectDigests(ctx, query, nodes, + func(n *Subject) { n.appendNamedSubjectDigests(name) }, + func(n *Subject, e *SubjectDigest) { n.appendNamedSubjectDigests(name, e) }); err != nil { + return nil, err + } + } for i := range sq.loadTotal { if err := sq.loadTotal[i](ctx, nodes); err != nil { return nil, err @@ -457,7 +483,7 @@ func (sq *SubjectQuery) loadSubjectDigests(ctx context.Context, query *SubjectDi } query.withFKs = true query.Where(predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.InValues(subject.SubjectDigestsColumn, fks...)) + s.Where(sql.InValues(s.C(subject.SubjectDigestsColumn), fks...)) })) neighbors, err := query.All(ctx) if err != nil { @@ -470,7 +496,7 @@ func (sq *SubjectQuery) loadSubjectDigests(ctx context.Context, query *SubjectDi } node, ok := nodeids[*fk] if !ok { - return fmt.Errorf(`unexpected foreign-key "subject_subject_digests" returned %v for node %v`, *fk, n.ID) + return fmt.Errorf(`unexpected referenced foreign-key "subject_subject_digests" returned %v for node %v`, *fk, n.ID) } assign(node, n) } @@ -489,6 +515,9 @@ func (sq *SubjectQuery) loadStatement(ctx context.Context, query *StatementQuery } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(statement.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -511,41 +540,22 @@ func (sq *SubjectQuery) sqlCount(ctx context.Context) (int, error) { if len(sq.modifiers) > 0 { _spec.Modifiers = sq.modifiers } - _spec.Node.Columns = sq.fields - if len(sq.fields) > 0 { - _spec.Unique = sq.unique != nil && *sq.unique + _spec.Node.Columns = sq.ctx.Fields + if len(sq.ctx.Fields) > 0 { + _spec.Unique = sq.ctx.Unique != nil && *sq.ctx.Unique } return sqlgraph.CountNodes(ctx, sq.driver, _spec) } -func (sq *SubjectQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := sq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (sq *SubjectQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: subject.Table, - Columns: subject.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, - }, - From: sq.sql, - Unique: true, - } - if unique := sq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(subject.Table, subject.Columns, sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt)) + _spec.From = sq.sql + if unique := sq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if sq.path != nil { + _spec.Unique = true } - if fields := sq.fields; len(fields) > 0 { + if fields := sq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, subject.FieldID) for i := range fields { @@ -561,10 +571,10 @@ func (sq *SubjectQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sq.order; len(ps) > 0 { @@ -580,7 +590,7 @@ func (sq *SubjectQuery) querySpec() *sqlgraph.QuerySpec { func (sq *SubjectQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sq.driver.Dialect()) t1 := builder.Table(subject.Table) - columns := sq.fields + columns := sq.ctx.Fields if len(columns) == 0 { columns = subject.Columns } @@ -589,7 +599,7 @@ func (sq *SubjectQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sq.sql selector.Select(selector.Columns(columns...)...) } - if sq.unique != nil && *sq.unique { + if sq.ctx.Unique != nil && *sq.ctx.Unique { selector.Distinct() } for _, p := range sq.predicates { @@ -598,26 +608,35 @@ func (sq *SubjectQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sq.order { p(selector) } - if offset := sq.offset; offset != nil { + if offset := sq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sq.limit; limit != nil { + if limit := sq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector } +// WithNamedSubjectDigests tells the query-builder to eager-load the nodes that are connected to the "subject_digests" +// edge with the given name. The optional arguments are used to configure the query builder of the edge. +func (sq *SubjectQuery) WithNamedSubjectDigests(name string, opts ...func(*SubjectDigestQuery)) *SubjectQuery { + query := (&SubjectDigestClient{config: sq.config}).Query() + for _, opt := range opts { + opt(query) + } + if sq.withNamedSubjectDigests == nil { + sq.withNamedSubjectDigests = make(map[string]*SubjectDigestQuery) + } + sq.withNamedSubjectDigests[name] = query + return sq +} + // SubjectGroupBy is the group-by builder for Subject entities. type SubjectGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *SubjectQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -626,74 +645,77 @@ func (sgb *SubjectGroupBy) Aggregate(fns ...AggregateFunc) *SubjectGroupBy { return sgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (sgb *SubjectGroupBy) Scan(ctx context.Context, v any) error { - query, err := sgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, sgb.build.ctx, "GroupBy") + if err := sgb.build.prepareQuery(ctx); err != nil { return err } - sgb.sql = query - return sgb.sqlScan(ctx, v) + return scanWithInterceptors[*SubjectQuery, *SubjectGroupBy](ctx, sgb.build, sgb, sgb.build.inters, v) } -func (sgb *SubjectGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range sgb.fields { - if !subject.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (sgb *SubjectGroupBy) sqlScan(ctx context.Context, root *SubjectQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sgb.fns)) + for _, fn := range sgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sgb.flds)+len(sgb.fns)) + for _, f := range *sgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := sgb.sqlQuery() + selector.GroupBy(selector.Columns(*sgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := sgb.driver.Query(ctx, query, args, rows); err != nil { + if err := sgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (sgb *SubjectGroupBy) sqlQuery() *sql.Selector { - selector := sgb.sql.Select() - aggregation := make([]string, 0, len(sgb.fns)) - for _, fn := range sgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(sgb.fields)+len(sgb.fns)) - for _, f := range sgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(sgb.fields...)...) -} - // SubjectSelect is the builder for selecting fields of Subject entities. type SubjectSelect struct { *SubjectQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ss *SubjectSelect) Aggregate(fns ...AggregateFunc) *SubjectSelect { + ss.fns = append(ss.fns, fns...) + return ss } // Scan applies the selector query and scans the result into the given value. func (ss *SubjectSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ss.ctx, "Select") if err := ss.prepareQuery(ctx); err != nil { return err } - ss.sql = ss.SubjectQuery.sqlQuery(ctx) - return ss.sqlScan(ctx, v) + return scanWithInterceptors[*SubjectQuery, *SubjectSelect](ctx, ss.SubjectQuery, ss, ss.inters, v) } -func (ss *SubjectSelect) sqlScan(ctx context.Context, v any) error { +func (ss *SubjectSelect) sqlScan(ctx context.Context, root *SubjectQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ss.fns)) + for _, fn := range ss.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ss.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ss.sql.Query() + query, args := selector.Query() if err := ss.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/subject_update.go b/ent/subject_update.go index 8951f9c2..50daa4eb 100644 --- a/ent/subject_update.go +++ b/ent/subject_update.go @@ -103,40 +103,7 @@ func (su *SubjectUpdate) ClearStatement() *SubjectUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (su *SubjectUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(su.hooks) == 0 { - if err = su.check(); err != nil { - return 0, err - } - affected, err = su.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = su.check(); err != nil { - return 0, err - } - su.mutation = mutation - affected, err = su.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(su.hooks) - 1; i >= 0; i-- { - if su.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = su.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, su.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, su.sqlSave, su.mutation, su.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -172,16 +139,10 @@ func (su *SubjectUpdate) check() error { } func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subject.Table, - Columns: subject.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, - }, + if err := su.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(subject.Table, subject.Columns, sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt)) if ps := su.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -190,11 +151,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := su.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subject.FieldName, - }) + _spec.SetField(subject.FieldName, field.TypeString, value) } if su.mutation.SubjectDigestsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -204,10 +161,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -220,10 +174,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -239,10 +190,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -258,10 +206,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{subject.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -274,10 +219,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{subject.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -293,6 +235,7 @@ func (su *SubjectUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + su.mutation.done = true return n, nil } @@ -376,6 +319,12 @@ func (suo *SubjectUpdateOne) ClearStatement() *SubjectUpdateOne { return suo } +// Where appends a list predicates to the SubjectUpdate builder. +func (suo *SubjectUpdateOne) Where(ps ...predicate.Subject) *SubjectUpdateOne { + suo.mutation.Where(ps...) + return suo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (suo *SubjectUpdateOne) Select(field string, fields ...string) *SubjectUpdateOne { @@ -385,46 +334,7 @@ func (suo *SubjectUpdateOne) Select(field string, fields ...string) *SubjectUpda // Save executes the query and returns the updated Subject entity. func (suo *SubjectUpdateOne) Save(ctx context.Context) (*Subject, error) { - var ( - err error - node *Subject - ) - if len(suo.hooks) == 0 { - if err = suo.check(); err != nil { - return nil, err - } - node, err = suo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = suo.check(); err != nil { - return nil, err - } - suo.mutation = mutation - node, err = suo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(suo.hooks) - 1; i >= 0; i-- { - if suo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = suo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, suo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Subject) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SubjectMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, suo.sqlSave, suo.mutation, suo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -460,16 +370,10 @@ func (suo *SubjectUpdateOne) check() error { } func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subject.Table, - Columns: subject.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, - }, + if err := suo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(subject.Table, subject.Columns, sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt)) id, ok := suo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Subject.id" for update`)} @@ -495,11 +399,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e } } if value, ok := suo.mutation.Name(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subject.FieldName, - }) + _spec.SetField(subject.FieldName, field.TypeString, value) } if suo.mutation.SubjectDigestsCleared() { edge := &sqlgraph.EdgeSpec{ @@ -509,10 +409,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -525,10 +422,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -544,10 +438,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e Columns: []string{subject.SubjectDigestsColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -563,10 +454,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e Columns: []string{subject.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -579,10 +467,7 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e Columns: []string{subject.StatementColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: statement.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(statement.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -601,5 +486,6 @@ func (suo *SubjectUpdateOne) sqlSave(ctx context.Context) (_node *Subject, err e } return nil, err } + suo.mutation.done = true return _node, nil } diff --git a/ent/subjectdigest.go b/ent/subjectdigest.go index d8323e1f..60e38bc1 100644 --- a/ent/subjectdigest.go +++ b/ent/subjectdigest.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/subject" "github.com/testifysec/archivista/ent/subjectdigest" @@ -24,6 +25,7 @@ type SubjectDigest struct { // The values are being populated by the SubjectDigestQuery when eager-loading is set. Edges SubjectDigestEdges `json:"edges"` subject_subject_digests *int + selectValues sql.SelectValues } // SubjectDigestEdges holds the relations/edges for other nodes in the graph. @@ -34,7 +36,7 @@ type SubjectDigestEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [1]bool // totalCount holds the count of the edges above. - totalCount [1]*int + totalCount [1]map[string]int } // SubjectOrErr returns the Subject value or an error if the edge @@ -62,7 +64,7 @@ func (*SubjectDigest) scanValues(columns []string) ([]any, error) { case subjectdigest.ForeignKeys[0]: // subject_subject_digests values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type SubjectDigest", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -101,21 +103,29 @@ func (sd *SubjectDigest) assignValues(columns []string, values []any) error { sd.subject_subject_digests = new(int) *sd.subject_subject_digests = int(value.Int64) } + default: + sd.selectValues.Set(columns[i], values[i]) } } return nil } +// GetValue returns the ent.Value that was dynamically selected and assigned to the SubjectDigest. +// This includes values selected through modifiers, order, etc. +func (sd *SubjectDigest) GetValue(name string) (ent.Value, error) { + return sd.selectValues.Get(name) +} + // QuerySubject queries the "subject" edge of the SubjectDigest entity. func (sd *SubjectDigest) QuerySubject() *SubjectQuery { - return (&SubjectDigestClient{config: sd.config}).QuerySubject(sd) + return NewSubjectDigestClient(sd.config).QuerySubject(sd) } // Update returns a builder for updating this SubjectDigest. // Note that you need to call SubjectDigest.Unwrap() before calling this method if this SubjectDigest // was returned from a transaction, and the transaction was committed or rolled back. func (sd *SubjectDigest) Update() *SubjectDigestUpdateOne { - return (&SubjectDigestClient{config: sd.config}).UpdateOne(sd) + return NewSubjectDigestClient(sd.config).UpdateOne(sd) } // Unwrap unwraps the SubjectDigest entity that was returned from a transaction after it was closed, @@ -145,9 +155,3 @@ func (sd *SubjectDigest) String() string { // SubjectDigests is a parsable slice of SubjectDigest. type SubjectDigests []*SubjectDigest - -func (sd SubjectDigests) config(cfg config) { - for _i := range sd { - sd[_i].config = cfg - } -} diff --git a/ent/subjectdigest/subjectdigest.go b/ent/subjectdigest/subjectdigest.go index 8b546937..2d540aba 100644 --- a/ent/subjectdigest/subjectdigest.go +++ b/ent/subjectdigest/subjectdigest.go @@ -2,6 +2,11 @@ package subjectdigest +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the subjectdigest type in the database. Label = "subject_digest" @@ -58,3 +63,35 @@ var ( // ValueValidator is a validator for the "value" field. It is called by the builders before save. ValueValidator func(string) error ) + +// OrderOption defines the ordering options for the SubjectDigest queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByAlgorithm orders the results by the algorithm field. +func ByAlgorithm(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldAlgorithm, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} + +// BySubjectField orders the results by subject field. +func BySubjectField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSubjectStep(), sql.OrderByField(field, opts...)) + } +} +func newSubjectStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SubjectInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SubjectTable, SubjectColumn), + ) +} diff --git a/ent/subjectdigest/where.go b/ent/subjectdigest/where.go index a914f2e6..921d2bf1 100644 --- a/ent/subjectdigest/where.go +++ b/ent/subjectdigest/where.go @@ -10,285 +10,187 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.SubjectDigest(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.SubjectDigest(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.SubjectDigest(sql.FieldLTE(FieldID, id)) } // Algorithm applies equality check predicate on the "algorithm" field. It's identical to AlgorithmEQ. func Algorithm(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldAlgorithm, v)) } // Value applies equality check predicate on the "value" field. It's identical to ValueEQ. func Value(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldValue, v)) } // AlgorithmEQ applies the EQ predicate on the "algorithm" field. func AlgorithmEQ(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldAlgorithm, v)) } // AlgorithmNEQ applies the NEQ predicate on the "algorithm" field. func AlgorithmNEQ(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldNEQ(FieldAlgorithm, v)) } // AlgorithmIn applies the In predicate on the "algorithm" field. func AlgorithmIn(vs ...string) predicate.SubjectDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldAlgorithm), v...)) - }) + return predicate.SubjectDigest(sql.FieldIn(FieldAlgorithm, vs...)) } // AlgorithmNotIn applies the NotIn predicate on the "algorithm" field. func AlgorithmNotIn(vs ...string) predicate.SubjectDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldAlgorithm), v...)) - }) + return predicate.SubjectDigest(sql.FieldNotIn(FieldAlgorithm, vs...)) } // AlgorithmGT applies the GT predicate on the "algorithm" field. func AlgorithmGT(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldGT(FieldAlgorithm, v)) } // AlgorithmGTE applies the GTE predicate on the "algorithm" field. func AlgorithmGTE(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldGTE(FieldAlgorithm, v)) } // AlgorithmLT applies the LT predicate on the "algorithm" field. func AlgorithmLT(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldLT(FieldAlgorithm, v)) } // AlgorithmLTE applies the LTE predicate on the "algorithm" field. func AlgorithmLTE(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldLTE(FieldAlgorithm, v)) } // AlgorithmContains applies the Contains predicate on the "algorithm" field. func AlgorithmContains(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldContains(FieldAlgorithm, v)) } // AlgorithmHasPrefix applies the HasPrefix predicate on the "algorithm" field. func AlgorithmHasPrefix(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldHasPrefix(FieldAlgorithm, v)) } // AlgorithmHasSuffix applies the HasSuffix predicate on the "algorithm" field. func AlgorithmHasSuffix(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldHasSuffix(FieldAlgorithm, v)) } // AlgorithmEqualFold applies the EqualFold predicate on the "algorithm" field. func AlgorithmEqualFold(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldEqualFold(FieldAlgorithm, v)) } // AlgorithmContainsFold applies the ContainsFold predicate on the "algorithm" field. func AlgorithmContainsFold(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldAlgorithm), v)) - }) + return predicate.SubjectDigest(sql.FieldContainsFold(FieldAlgorithm, v)) } // ValueEQ applies the EQ predicate on the "value" field. func ValueEQ(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldEQ(FieldValue, v)) } // ValueNEQ applies the NEQ predicate on the "value" field. func ValueNEQ(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldNEQ(FieldValue, v)) } // ValueIn applies the In predicate on the "value" field. func ValueIn(vs ...string) predicate.SubjectDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldValue), v...)) - }) + return predicate.SubjectDigest(sql.FieldIn(FieldValue, vs...)) } // ValueNotIn applies the NotIn predicate on the "value" field. func ValueNotIn(vs ...string) predicate.SubjectDigest { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldValue), v...)) - }) + return predicate.SubjectDigest(sql.FieldNotIn(FieldValue, vs...)) } // ValueGT applies the GT predicate on the "value" field. func ValueGT(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldGT(FieldValue, v)) } // ValueGTE applies the GTE predicate on the "value" field. func ValueGTE(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldGTE(FieldValue, v)) } // ValueLT applies the LT predicate on the "value" field. func ValueLT(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldLT(FieldValue, v)) } // ValueLTE applies the LTE predicate on the "value" field. func ValueLTE(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldLTE(FieldValue, v)) } // ValueContains applies the Contains predicate on the "value" field. func ValueContains(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldContains(FieldValue, v)) } // ValueHasPrefix applies the HasPrefix predicate on the "value" field. func ValueHasPrefix(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldHasPrefix(FieldValue, v)) } // ValueHasSuffix applies the HasSuffix predicate on the "value" field. func ValueHasSuffix(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldHasSuffix(FieldValue, v)) } // ValueEqualFold applies the EqualFold predicate on the "value" field. func ValueEqualFold(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldEqualFold(FieldValue, v)) } // ValueContainsFold applies the ContainsFold predicate on the "value" field. func ValueContainsFold(v string) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldValue), v)) - }) + return predicate.SubjectDigest(sql.FieldContainsFold(FieldValue, v)) } // HasSubject applies the HasEdge predicate on the "subject" edge. @@ -296,7 +198,6 @@ func HasSubject() predicate.SubjectDigest { return predicate.SubjectDigest(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, SubjectTable, SubjectColumn), ) sqlgraph.HasNeighbors(s, step) @@ -306,11 +207,7 @@ func HasSubject() predicate.SubjectDigest { // HasSubjectWith applies the HasEdge predicate on the "subject" edge with a given conditions (other predicates). func HasSubjectWith(preds ...predicate.Subject) predicate.SubjectDigest { return predicate.SubjectDigest(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(SubjectInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, SubjectTable, SubjectColumn), - ) + step := newSubjectStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -321,32 +218,15 @@ func HasSubjectWith(preds ...predicate.Subject) predicate.SubjectDigest { // And groups predicates with the AND operator between them. func And(predicates ...predicate.SubjectDigest) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.SubjectDigest(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.SubjectDigest) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.SubjectDigest(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.SubjectDigest) predicate.SubjectDigest { - return predicate.SubjectDigest(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.SubjectDigest(sql.NotPredicates(p)) } diff --git a/ent/subjectdigest_create.go b/ent/subjectdigest_create.go index 279e841a..a1a9867d 100644 --- a/ent/subjectdigest_create.go +++ b/ent/subjectdigest_create.go @@ -58,49 +58,7 @@ func (sdc *SubjectDigestCreate) Mutation() *SubjectDigestMutation { // Save creates the SubjectDigest in the database. func (sdc *SubjectDigestCreate) Save(ctx context.Context) (*SubjectDigest, error) { - var ( - err error - node *SubjectDigest - ) - if len(sdc.hooks) == 0 { - if err = sdc.check(); err != nil { - return nil, err - } - node, err = sdc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sdc.check(); err != nil { - return nil, err - } - sdc.mutation = mutation - if node, err = sdc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(sdc.hooks) - 1; i >= 0; i-- { - if sdc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sdc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, sdc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*SubjectDigest) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SubjectDigestMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, sdc.sqlSave, sdc.mutation, sdc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -147,6 +105,9 @@ func (sdc *SubjectDigestCreate) check() error { } func (sdc *SubjectDigestCreate) sqlSave(ctx context.Context) (*SubjectDigest, error) { + if err := sdc.check(); err != nil { + return nil, err + } _node, _spec := sdc.createSpec() if err := sqlgraph.CreateNode(ctx, sdc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -156,34 +117,22 @@ func (sdc *SubjectDigestCreate) sqlSave(ctx context.Context) (*SubjectDigest, er } id := _spec.ID.Value.(int64) _node.ID = int(id) + sdc.mutation.id = &_node.ID + sdc.mutation.done = true return _node, nil } func (sdc *SubjectDigestCreate) createSpec() (*SubjectDigest, *sqlgraph.CreateSpec) { var ( _node = &SubjectDigest{config: sdc.config} - _spec = &sqlgraph.CreateSpec{ - Table: subjectdigest.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(subjectdigest.Table, sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt)) ) if value, ok := sdc.mutation.Algorithm(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldAlgorithm, - }) + _spec.SetField(subjectdigest.FieldAlgorithm, field.TypeString, value) _node.Algorithm = value } if value, ok := sdc.mutation.Value(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldValue, - }) + _spec.SetField(subjectdigest.FieldValue, field.TypeString, value) _node.Value = value } if nodes := sdc.mutation.SubjectIDs(); len(nodes) > 0 { @@ -194,10 +143,7 @@ func (sdc *SubjectDigestCreate) createSpec() (*SubjectDigest, *sqlgraph.CreateSp Columns: []string{subjectdigest.SubjectColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -212,11 +158,15 @@ func (sdc *SubjectDigestCreate) createSpec() (*SubjectDigest, *sqlgraph.CreateSp // SubjectDigestCreateBulk is the builder for creating many SubjectDigest entities in bulk. type SubjectDigestCreateBulk struct { config + err error builders []*SubjectDigestCreate } // Save creates the SubjectDigest entities in the database. func (sdcb *SubjectDigestCreateBulk) Save(ctx context.Context) ([]*SubjectDigest, error) { + if sdcb.err != nil { + return nil, sdcb.err + } specs := make([]*sqlgraph.CreateSpec, len(sdcb.builders)) nodes := make([]*SubjectDigest, len(sdcb.builders)) mutators := make([]Mutator, len(sdcb.builders)) @@ -232,8 +182,8 @@ func (sdcb *SubjectDigestCreateBulk) Save(ctx context.Context) ([]*SubjectDigest return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, sdcb.builders[i+1].mutation) } else { diff --git a/ent/subjectdigest_delete.go b/ent/subjectdigest_delete.go index 091226e2..e5dccfe6 100644 --- a/ent/subjectdigest_delete.go +++ b/ent/subjectdigest_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (sdd *SubjectDigestDelete) Where(ps ...predicate.SubjectDigest) *SubjectDig // Exec executes the deletion query and returns how many vertices were deleted. func (sdd *SubjectDigestDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(sdd.hooks) == 0 { - affected, err = sdd.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - sdd.mutation = mutation - affected, err = sdd.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(sdd.hooks) - 1; i >= 0; i-- { - if sdd.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sdd.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, sdd.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, sdd.sqlExec, sdd.mutation, sdd.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (sdd *SubjectDigestDelete) ExecX(ctx context.Context) int { } func (sdd *SubjectDigestDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subjectdigest.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(subjectdigest.Table, sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt)) if ps := sdd.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (sdd *SubjectDigestDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + sdd.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type SubjectDigestDeleteOne struct { sdd *SubjectDigestDelete } +// Where appends a list predicates to the SubjectDigestDelete builder. +func (sddo *SubjectDigestDeleteOne) Where(ps ...predicate.SubjectDigest) *SubjectDigestDeleteOne { + sddo.sdd.mutation.Where(ps...) + return sddo +} + // Exec executes the deletion query. func (sddo *SubjectDigestDeleteOne) Exec(ctx context.Context) error { n, err := sddo.sdd.Exec(ctx) @@ -111,5 +82,7 @@ func (sddo *SubjectDigestDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (sddo *SubjectDigestDeleteOne) ExecX(ctx context.Context) { - sddo.sdd.ExecX(ctx) + if err := sddo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/subjectdigest_query.go b/ent/subjectdigest_query.go index c6aabe96..62dac7b5 100644 --- a/ent/subjectdigest_query.go +++ b/ent/subjectdigest_query.go @@ -18,11 +18,9 @@ import ( // SubjectDigestQuery is the builder for querying SubjectDigest entities. type SubjectDigestQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []subjectdigest.OrderOption + inters []Interceptor predicates []predicate.SubjectDigest withSubject *SubjectQuery withFKs bool @@ -39,34 +37,34 @@ func (sdq *SubjectDigestQuery) Where(ps ...predicate.SubjectDigest) *SubjectDige return sdq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (sdq *SubjectDigestQuery) Limit(limit int) *SubjectDigestQuery { - sdq.limit = &limit + sdq.ctx.Limit = &limit return sdq } -// Offset adds an offset step to the query. +// Offset to start from. func (sdq *SubjectDigestQuery) Offset(offset int) *SubjectDigestQuery { - sdq.offset = &offset + sdq.ctx.Offset = &offset return sdq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (sdq *SubjectDigestQuery) Unique(unique bool) *SubjectDigestQuery { - sdq.unique = &unique + sdq.ctx.Unique = &unique return sdq } -// Order adds an order step to the query. -func (sdq *SubjectDigestQuery) Order(o ...OrderFunc) *SubjectDigestQuery { +// Order specifies how the records should be ordered. +func (sdq *SubjectDigestQuery) Order(o ...subjectdigest.OrderOption) *SubjectDigestQuery { sdq.order = append(sdq.order, o...) return sdq } // QuerySubject chains the current query on the "subject" edge. func (sdq *SubjectDigestQuery) QuerySubject() *SubjectQuery { - query := &SubjectQuery{config: sdq.config} + query := (&SubjectClient{config: sdq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := sdq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (sdq *SubjectDigestQuery) QuerySubject() *SubjectQuery { // First returns the first SubjectDigest entity from the query. // Returns a *NotFoundError when no SubjectDigest was found. func (sdq *SubjectDigestQuery) First(ctx context.Context) (*SubjectDigest, error) { - nodes, err := sdq.Limit(1).All(ctx) + nodes, err := sdq.Limit(1).All(setContextOp(ctx, sdq.ctx, "First")) if err != nil { return nil, err } @@ -112,7 +110,7 @@ func (sdq *SubjectDigestQuery) FirstX(ctx context.Context) *SubjectDigest { // Returns a *NotFoundError when no SubjectDigest ID was found. func (sdq *SubjectDigestQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sdq.Limit(1).IDs(ctx); err != nil { + if ids, err = sdq.Limit(1).IDs(setContextOp(ctx, sdq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -135,7 +133,7 @@ func (sdq *SubjectDigestQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one SubjectDigest entity is found. // Returns a *NotFoundError when no SubjectDigest entities are found. func (sdq *SubjectDigestQuery) Only(ctx context.Context) (*SubjectDigest, error) { - nodes, err := sdq.Limit(2).All(ctx) + nodes, err := sdq.Limit(2).All(setContextOp(ctx, sdq.ctx, "Only")) if err != nil { return nil, err } @@ -163,7 +161,7 @@ func (sdq *SubjectDigestQuery) OnlyX(ctx context.Context) *SubjectDigest { // Returns a *NotFoundError when no entities are found. func (sdq *SubjectDigestQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = sdq.Limit(2).IDs(ctx); err != nil { + if ids, err = sdq.Limit(2).IDs(setContextOp(ctx, sdq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -188,10 +186,12 @@ func (sdq *SubjectDigestQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of SubjectDigests. func (sdq *SubjectDigestQuery) All(ctx context.Context) ([]*SubjectDigest, error) { + ctx = setContextOp(ctx, sdq.ctx, "All") if err := sdq.prepareQuery(ctx); err != nil { return nil, err } - return sdq.sqlAll(ctx) + qr := querierAll[[]*SubjectDigest, *SubjectDigestQuery]() + return withInterceptors[[]*SubjectDigest](ctx, sdq, qr, sdq.inters) } // AllX is like All, but panics if an error occurs. @@ -204,9 +204,12 @@ func (sdq *SubjectDigestQuery) AllX(ctx context.Context) []*SubjectDigest { } // IDs executes the query and returns a list of SubjectDigest IDs. -func (sdq *SubjectDigestQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := sdq.Select(subjectdigest.FieldID).Scan(ctx, &ids); err != nil { +func (sdq *SubjectDigestQuery) IDs(ctx context.Context) (ids []int, err error) { + if sdq.ctx.Unique == nil && sdq.path != nil { + sdq.Unique(true) + } + ctx = setContextOp(ctx, sdq.ctx, "IDs") + if err = sdq.Select(subjectdigest.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -223,10 +226,11 @@ func (sdq *SubjectDigestQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (sdq *SubjectDigestQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, sdq.ctx, "Count") if err := sdq.prepareQuery(ctx); err != nil { return 0, err } - return sdq.sqlCount(ctx) + return withInterceptors[int](ctx, sdq, querierCount[*SubjectDigestQuery](), sdq.inters) } // CountX is like Count, but panics if an error occurs. @@ -240,10 +244,15 @@ func (sdq *SubjectDigestQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (sdq *SubjectDigestQuery) Exist(ctx context.Context) (bool, error) { - if err := sdq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, sdq.ctx, "Exist") + switch _, err := sdq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return sdq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -263,22 +272,21 @@ func (sdq *SubjectDigestQuery) Clone() *SubjectDigestQuery { } return &SubjectDigestQuery{ config: sdq.config, - limit: sdq.limit, - offset: sdq.offset, - order: append([]OrderFunc{}, sdq.order...), + ctx: sdq.ctx.Clone(), + order: append([]subjectdigest.OrderOption{}, sdq.order...), + inters: append([]Interceptor{}, sdq.inters...), predicates: append([]predicate.SubjectDigest{}, sdq.predicates...), withSubject: sdq.withSubject.Clone(), // clone intermediate query. - sql: sdq.sql.Clone(), - path: sdq.path, - unique: sdq.unique, + sql: sdq.sql.Clone(), + path: sdq.path, } } // WithSubject tells the query-builder to eager-load the nodes that are connected to // the "subject" edge. The optional arguments are used to configure the query builder of the edge. func (sdq *SubjectDigestQuery) WithSubject(opts ...func(*SubjectQuery)) *SubjectDigestQuery { - query := &SubjectQuery{config: sdq.config} + query := (&SubjectClient{config: sdq.config}).Query() for _, opt := range opts { opt(query) } @@ -301,16 +309,11 @@ func (sdq *SubjectDigestQuery) WithSubject(opts ...func(*SubjectQuery)) *Subject // Aggregate(ent.Count()). // Scan(ctx, &v) func (sdq *SubjectDigestQuery) GroupBy(field string, fields ...string) *SubjectDigestGroupBy { - grbuild := &SubjectDigestGroupBy{config: sdq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := sdq.prepareQuery(ctx); err != nil { - return nil, err - } - return sdq.sqlQuery(ctx), nil - } + sdq.ctx.Fields = append([]string{field}, fields...) + grbuild := &SubjectDigestGroupBy{build: sdq} + grbuild.flds = &sdq.ctx.Fields grbuild.label = subjectdigest.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -327,15 +330,30 @@ func (sdq *SubjectDigestQuery) GroupBy(field string, fields ...string) *SubjectD // Select(subjectdigest.FieldAlgorithm). // Scan(ctx, &v) func (sdq *SubjectDigestQuery) Select(fields ...string) *SubjectDigestSelect { - sdq.fields = append(sdq.fields, fields...) - selbuild := &SubjectDigestSelect{SubjectDigestQuery: sdq} - selbuild.label = subjectdigest.Label - selbuild.flds, selbuild.scan = &sdq.fields, selbuild.Scan - return selbuild + sdq.ctx.Fields = append(sdq.ctx.Fields, fields...) + sbuild := &SubjectDigestSelect{SubjectDigestQuery: sdq} + sbuild.label = subjectdigest.Label + sbuild.flds, sbuild.scan = &sdq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SubjectDigestSelect configured with the given aggregations. +func (sdq *SubjectDigestQuery) Aggregate(fns ...AggregateFunc) *SubjectDigestSelect { + return sdq.Select().Aggregate(fns...) } func (sdq *SubjectDigestQuery) prepareQuery(ctx context.Context) error { - for _, f := range sdq.fields { + for _, inter := range sdq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, sdq); err != nil { + return err + } + } + } + for _, f := range sdq.ctx.Fields { if !subjectdigest.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -413,6 +431,9 @@ func (sdq *SubjectDigestQuery) loadSubject(ctx context.Context, query *SubjectQu } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(subject.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -435,41 +456,22 @@ func (sdq *SubjectDigestQuery) sqlCount(ctx context.Context) (int, error) { if len(sdq.modifiers) > 0 { _spec.Modifiers = sdq.modifiers } - _spec.Node.Columns = sdq.fields - if len(sdq.fields) > 0 { - _spec.Unique = sdq.unique != nil && *sdq.unique + _spec.Node.Columns = sdq.ctx.Fields + if len(sdq.ctx.Fields) > 0 { + _spec.Unique = sdq.ctx.Unique != nil && *sdq.ctx.Unique } return sqlgraph.CountNodes(ctx, sdq.driver, _spec) } -func (sdq *SubjectDigestQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := sdq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (sdq *SubjectDigestQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: subjectdigest.Table, - Columns: subjectdigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, - }, - From: sdq.sql, - Unique: true, - } - if unique := sdq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(subjectdigest.Table, subjectdigest.Columns, sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt)) + _spec.From = sdq.sql + if unique := sdq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if sdq.path != nil { + _spec.Unique = true } - if fields := sdq.fields; len(fields) > 0 { + if fields := sdq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, subjectdigest.FieldID) for i := range fields { @@ -485,10 +487,10 @@ func (sdq *SubjectDigestQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := sdq.limit; limit != nil { + if limit := sdq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := sdq.offset; offset != nil { + if offset := sdq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := sdq.order; len(ps) > 0 { @@ -504,7 +506,7 @@ func (sdq *SubjectDigestQuery) querySpec() *sqlgraph.QuerySpec { func (sdq *SubjectDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(sdq.driver.Dialect()) t1 := builder.Table(subjectdigest.Table) - columns := sdq.fields + columns := sdq.ctx.Fields if len(columns) == 0 { columns = subjectdigest.Columns } @@ -513,7 +515,7 @@ func (sdq *SubjectDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = sdq.sql selector.Select(selector.Columns(columns...)...) } - if sdq.unique != nil && *sdq.unique { + if sdq.ctx.Unique != nil && *sdq.ctx.Unique { selector.Distinct() } for _, p := range sdq.predicates { @@ -522,12 +524,12 @@ func (sdq *SubjectDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range sdq.order { p(selector) } - if offset := sdq.offset; offset != nil { + if offset := sdq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := sdq.limit; limit != nil { + if limit := sdq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -535,13 +537,8 @@ func (sdq *SubjectDigestQuery) sqlQuery(ctx context.Context) *sql.Selector { // SubjectDigestGroupBy is the group-by builder for SubjectDigest entities. type SubjectDigestGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *SubjectDigestQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -550,74 +547,77 @@ func (sdgb *SubjectDigestGroupBy) Aggregate(fns ...AggregateFunc) *SubjectDigest return sdgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (sdgb *SubjectDigestGroupBy) Scan(ctx context.Context, v any) error { - query, err := sdgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, sdgb.build.ctx, "GroupBy") + if err := sdgb.build.prepareQuery(ctx); err != nil { return err } - sdgb.sql = query - return sdgb.sqlScan(ctx, v) + return scanWithInterceptors[*SubjectDigestQuery, *SubjectDigestGroupBy](ctx, sdgb.build, sdgb, sdgb.build.inters, v) } -func (sdgb *SubjectDigestGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range sdgb.fields { - if !subjectdigest.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (sdgb *SubjectDigestGroupBy) sqlScan(ctx context.Context, root *SubjectDigestQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(sdgb.fns)) + for _, fn := range sdgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*sdgb.flds)+len(sdgb.fns)) + for _, f := range *sdgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := sdgb.sqlQuery() + selector.GroupBy(selector.Columns(*sdgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := sdgb.driver.Query(ctx, query, args, rows); err != nil { + if err := sdgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (sdgb *SubjectDigestGroupBy) sqlQuery() *sql.Selector { - selector := sdgb.sql.Select() - aggregation := make([]string, 0, len(sdgb.fns)) - for _, fn := range sdgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(sdgb.fields)+len(sdgb.fns)) - for _, f := range sdgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(sdgb.fields...)...) -} - // SubjectDigestSelect is the builder for selecting fields of SubjectDigest entities. type SubjectDigestSelect struct { *SubjectDigestQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (sds *SubjectDigestSelect) Aggregate(fns ...AggregateFunc) *SubjectDigestSelect { + sds.fns = append(sds.fns, fns...) + return sds } // Scan applies the selector query and scans the result into the given value. func (sds *SubjectDigestSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, sds.ctx, "Select") if err := sds.prepareQuery(ctx); err != nil { return err } - sds.sql = sds.SubjectDigestQuery.sqlQuery(ctx) - return sds.sqlScan(ctx, v) + return scanWithInterceptors[*SubjectDigestQuery, *SubjectDigestSelect](ctx, sds.SubjectDigestQuery, sds, sds.inters, v) } -func (sds *SubjectDigestSelect) sqlScan(ctx context.Context, v any) error { +func (sds *SubjectDigestSelect) sqlScan(ctx context.Context, root *SubjectDigestQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(sds.fns)) + for _, fn := range sds.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*sds.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := sds.sql.Query() + query, args := selector.Query() if err := sds.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/subjectdigest_update.go b/ent/subjectdigest_update.go index aee6a577..079fc1c5 100644 --- a/ent/subjectdigest_update.go +++ b/ent/subjectdigest_update.go @@ -72,40 +72,7 @@ func (sdu *SubjectDigestUpdate) ClearSubject() *SubjectDigestUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (sdu *SubjectDigestUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(sdu.hooks) == 0 { - if err = sdu.check(); err != nil { - return 0, err - } - affected, err = sdu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sdu.check(); err != nil { - return 0, err - } - sdu.mutation = mutation - affected, err = sdu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(sdu.hooks) - 1; i >= 0; i-- { - if sdu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sdu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, sdu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, sdu.sqlSave, sdu.mutation, sdu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -146,16 +113,10 @@ func (sdu *SubjectDigestUpdate) check() error { } func (sdu *SubjectDigestUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subjectdigest.Table, - Columns: subjectdigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, - }, + if err := sdu.check(); err != nil { + return n, err } + _spec := sqlgraph.NewUpdateSpec(subjectdigest.Table, subjectdigest.Columns, sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt)) if ps := sdu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -164,18 +125,10 @@ func (sdu *SubjectDigestUpdate) sqlSave(ctx context.Context) (n int, err error) } } if value, ok := sdu.mutation.Algorithm(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldAlgorithm, - }) + _spec.SetField(subjectdigest.FieldAlgorithm, field.TypeString, value) } if value, ok := sdu.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldValue, - }) + _spec.SetField(subjectdigest.FieldValue, field.TypeString, value) } if sdu.mutation.SubjectCleared() { edge := &sqlgraph.EdgeSpec{ @@ -185,10 +138,7 @@ func (sdu *SubjectDigestUpdate) sqlSave(ctx context.Context) (n int, err error) Columns: []string{subjectdigest.SubjectColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -201,10 +151,7 @@ func (sdu *SubjectDigestUpdate) sqlSave(ctx context.Context) (n int, err error) Columns: []string{subjectdigest.SubjectColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -220,6 +167,7 @@ func (sdu *SubjectDigestUpdate) sqlSave(ctx context.Context) (n int, err error) } return 0, err } + sdu.mutation.done = true return n, nil } @@ -273,6 +221,12 @@ func (sduo *SubjectDigestUpdateOne) ClearSubject() *SubjectDigestUpdateOne { return sduo } +// Where appends a list predicates to the SubjectDigestUpdate builder. +func (sduo *SubjectDigestUpdateOne) Where(ps ...predicate.SubjectDigest) *SubjectDigestUpdateOne { + sduo.mutation.Where(ps...) + return sduo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (sduo *SubjectDigestUpdateOne) Select(field string, fields ...string) *SubjectDigestUpdateOne { @@ -282,46 +236,7 @@ func (sduo *SubjectDigestUpdateOne) Select(field string, fields ...string) *Subj // Save executes the query and returns the updated SubjectDigest entity. func (sduo *SubjectDigestUpdateOne) Save(ctx context.Context) (*SubjectDigest, error) { - var ( - err error - node *SubjectDigest - ) - if len(sduo.hooks) == 0 { - if err = sduo.check(); err != nil { - return nil, err - } - node, err = sduo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*SubjectDigestMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = sduo.check(); err != nil { - return nil, err - } - sduo.mutation = mutation - node, err = sduo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(sduo.hooks) - 1; i >= 0; i-- { - if sduo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = sduo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, sduo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*SubjectDigest) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from SubjectDigestMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, sduo.sqlSave, sduo.mutation, sduo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -362,16 +277,10 @@ func (sduo *SubjectDigestUpdateOne) check() error { } func (sduo *SubjectDigestUpdateOne) sqlSave(ctx context.Context) (_node *SubjectDigest, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: subjectdigest.Table, - Columns: subjectdigest.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subjectdigest.FieldID, - }, - }, + if err := sduo.check(); err != nil { + return _node, err } + _spec := sqlgraph.NewUpdateSpec(subjectdigest.Table, subjectdigest.Columns, sqlgraph.NewFieldSpec(subjectdigest.FieldID, field.TypeInt)) id, ok := sduo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SubjectDigest.id" for update`)} @@ -397,18 +306,10 @@ func (sduo *SubjectDigestUpdateOne) sqlSave(ctx context.Context) (_node *Subject } } if value, ok := sduo.mutation.Algorithm(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldAlgorithm, - }) + _spec.SetField(subjectdigest.FieldAlgorithm, field.TypeString, value) } if value, ok := sduo.mutation.Value(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: subjectdigest.FieldValue, - }) + _spec.SetField(subjectdigest.FieldValue, field.TypeString, value) } if sduo.mutation.SubjectCleared() { edge := &sqlgraph.EdgeSpec{ @@ -418,10 +319,7 @@ func (sduo *SubjectDigestUpdateOne) sqlSave(ctx context.Context) (_node *Subject Columns: []string{subjectdigest.SubjectColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -434,10 +332,7 @@ func (sduo *SubjectDigestUpdateOne) sqlSave(ctx context.Context) (_node *Subject Columns: []string{subjectdigest.SubjectColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: subject.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(subject.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -456,5 +351,6 @@ func (sduo *SubjectDigestUpdateOne) sqlSave(ctx context.Context) (_node *Subject } return nil, err } + sduo.mutation.done = true return _node, nil } diff --git a/ent/timestamp.go b/ent/timestamp.go index 28ee6ca1..fe4dedff 100644 --- a/ent/timestamp.go +++ b/ent/timestamp.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "entgo.io/ent" "entgo.io/ent/dialect/sql" "github.com/testifysec/archivista/ent/signature" "github.com/testifysec/archivista/ent/timestamp" @@ -25,6 +26,7 @@ type Timestamp struct { // The values are being populated by the TimestampQuery when eager-loading is set. Edges TimestampEdges `json:"edges"` signature_timestamps *int + selectValues sql.SelectValues } // TimestampEdges holds the relations/edges for other nodes in the graph. @@ -35,7 +37,7 @@ type TimestampEdges struct { // type was loaded (or requested) in eager-loading or not. loadedTypes [1]bool // totalCount holds the count of the edges above. - totalCount [1]*int + totalCount [1]map[string]int } // SignatureOrErr returns the Signature value or an error if the edge @@ -65,7 +67,7 @@ func (*Timestamp) scanValues(columns []string) ([]any, error) { case timestamp.ForeignKeys[0]: // signature_timestamps values[i] = new(sql.NullInt64) default: - return nil, fmt.Errorf("unexpected column %q for type Timestamp", columns[i]) + values[i] = new(sql.UnknownType) } } return values, nil @@ -104,21 +106,29 @@ func (t *Timestamp) assignValues(columns []string, values []any) error { t.signature_timestamps = new(int) *t.signature_timestamps = int(value.Int64) } + default: + t.selectValues.Set(columns[i], values[i]) } } return nil } +// Value returns the ent.Value that was dynamically selected and assigned to the Timestamp. +// This includes values selected through modifiers, order, etc. +func (t *Timestamp) Value(name string) (ent.Value, error) { + return t.selectValues.Get(name) +} + // QuerySignature queries the "signature" edge of the Timestamp entity. func (t *Timestamp) QuerySignature() *SignatureQuery { - return (&TimestampClient{config: t.config}).QuerySignature(t) + return NewTimestampClient(t.config).QuerySignature(t) } // Update returns a builder for updating this Timestamp. // Note that you need to call Timestamp.Unwrap() before calling this method if this Timestamp // was returned from a transaction, and the transaction was committed or rolled back. func (t *Timestamp) Update() *TimestampUpdateOne { - return (&TimestampClient{config: t.config}).UpdateOne(t) + return NewTimestampClient(t.config).UpdateOne(t) } // Unwrap unwraps the Timestamp entity that was returned from a transaction after it was closed, @@ -148,9 +158,3 @@ func (t *Timestamp) String() string { // Timestamps is a parsable slice of Timestamp. type Timestamps []*Timestamp - -func (t Timestamps) config(cfg config) { - for _i := range t { - t[_i].config = cfg - } -} diff --git a/ent/timestamp/timestamp.go b/ent/timestamp/timestamp.go index f8ea1464..f73b7cb3 100644 --- a/ent/timestamp/timestamp.go +++ b/ent/timestamp/timestamp.go @@ -2,6 +2,11 @@ package timestamp +import ( + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" +) + const ( // Label holds the string label denoting the timestamp type in the database. Label = "timestamp" @@ -51,3 +56,35 @@ func ValidColumn(column string) bool { } return false } + +// OrderOption defines the ordering options for the Timestamp queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByType orders the results by the type field. +func ByType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldType, opts...).ToFunc() +} + +// ByTimestamp orders the results by the timestamp field. +func ByTimestamp(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTimestamp, opts...).ToFunc() +} + +// BySignatureField orders the results by signature field. +func BySignatureField(field string, opts ...sql.OrderTermOption) OrderOption { + return func(s *sql.Selector) { + sqlgraph.OrderByNeighborTerms(s, newSignatureStep(), sql.OrderByField(field, opts...)) + } +} +func newSignatureStep() *sqlgraph.Step { + return sqlgraph.NewStep( + sqlgraph.From(Table, FieldID), + sqlgraph.To(SignatureInverseTable, FieldID), + sqlgraph.Edge(sqlgraph.M2O, true, SignatureTable, SignatureColumn), + ) +} diff --git a/ent/timestamp/where.go b/ent/timestamp/where.go index 326cf27a..fb54acea 100644 --- a/ent/timestamp/where.go +++ b/ent/timestamp/where.go @@ -12,250 +12,162 @@ import ( // ID filters vertices based on their ID field. func ID(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldID, id)) } // IDEQ applies the EQ predicate on the ID field. func IDEQ(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldID, id)) } // IDNEQ applies the NEQ predicate on the ID field. func IDNEQ(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldNEQ(FieldID, id)) } // IDIn applies the In predicate on the ID field. func IDIn(ids ...int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.In(s.C(FieldID), v...)) - }) + return predicate.Timestamp(sql.FieldIn(FieldID, ids...)) } // IDNotIn applies the NotIn predicate on the ID field. func IDNotIn(ids ...int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - v := make([]any, len(ids)) - for i := range v { - v[i] = ids[i] - } - s.Where(sql.NotIn(s.C(FieldID), v...)) - }) + return predicate.Timestamp(sql.FieldNotIn(FieldID, ids...)) } // IDGT applies the GT predicate on the ID field. func IDGT(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldGT(FieldID, id)) } // IDGTE applies the GTE predicate on the ID field. func IDGTE(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldGTE(FieldID, id)) } // IDLT applies the LT predicate on the ID field. func IDLT(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldLT(FieldID, id)) } // IDLTE applies the LTE predicate on the ID field. func IDLTE(id int) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldID), id)) - }) + return predicate.Timestamp(sql.FieldLTE(FieldID, id)) } // Type applies equality check predicate on the "type" field. It's identical to TypeEQ. func Type(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldType, v)) } // Timestamp applies equality check predicate on the "timestamp" field. It's identical to TimestampEQ. func Timestamp(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldTimestamp, v)) } // TypeEQ applies the EQ predicate on the "type" field. func TypeEQ(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldType, v)) } // TypeNEQ applies the NEQ predicate on the "type" field. func TypeNEQ(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldNEQ(FieldType, v)) } // TypeIn applies the In predicate on the "type" field. func TypeIn(vs ...string) predicate.Timestamp { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldType), v...)) - }) + return predicate.Timestamp(sql.FieldIn(FieldType, vs...)) } // TypeNotIn applies the NotIn predicate on the "type" field. func TypeNotIn(vs ...string) predicate.Timestamp { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldType), v...)) - }) + return predicate.Timestamp(sql.FieldNotIn(FieldType, vs...)) } // TypeGT applies the GT predicate on the "type" field. func TypeGT(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldGT(FieldType, v)) } // TypeGTE applies the GTE predicate on the "type" field. func TypeGTE(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldGTE(FieldType, v)) } // TypeLT applies the LT predicate on the "type" field. func TypeLT(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldLT(FieldType, v)) } // TypeLTE applies the LTE predicate on the "type" field. func TypeLTE(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldLTE(FieldType, v)) } // TypeContains applies the Contains predicate on the "type" field. func TypeContains(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.Contains(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldContains(FieldType, v)) } // TypeHasPrefix applies the HasPrefix predicate on the "type" field. func TypeHasPrefix(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.HasPrefix(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldHasPrefix(FieldType, v)) } // TypeHasSuffix applies the HasSuffix predicate on the "type" field. func TypeHasSuffix(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.HasSuffix(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldHasSuffix(FieldType, v)) } // TypeEqualFold applies the EqualFold predicate on the "type" field. func TypeEqualFold(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EqualFold(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldEqualFold(FieldType, v)) } // TypeContainsFold applies the ContainsFold predicate on the "type" field. func TypeContainsFold(v string) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.ContainsFold(s.C(FieldType), v)) - }) + return predicate.Timestamp(sql.FieldContainsFold(FieldType, v)) } // TimestampEQ applies the EQ predicate on the "timestamp" field. func TimestampEQ(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.EQ(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldEQ(FieldTimestamp, v)) } // TimestampNEQ applies the NEQ predicate on the "timestamp" field. func TimestampNEQ(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.NEQ(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldNEQ(FieldTimestamp, v)) } // TimestampIn applies the In predicate on the "timestamp" field. func TimestampIn(vs ...time.Time) predicate.Timestamp { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.In(s.C(FieldTimestamp), v...)) - }) + return predicate.Timestamp(sql.FieldIn(FieldTimestamp, vs...)) } // TimestampNotIn applies the NotIn predicate on the "timestamp" field. func TimestampNotIn(vs ...time.Time) predicate.Timestamp { - v := make([]any, len(vs)) - for i := range v { - v[i] = vs[i] - } - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.NotIn(s.C(FieldTimestamp), v...)) - }) + return predicate.Timestamp(sql.FieldNotIn(FieldTimestamp, vs...)) } // TimestampGT applies the GT predicate on the "timestamp" field. func TimestampGT(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GT(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldGT(FieldTimestamp, v)) } // TimestampGTE applies the GTE predicate on the "timestamp" field. func TimestampGTE(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.GTE(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldGTE(FieldTimestamp, v)) } // TimestampLT applies the LT predicate on the "timestamp" field. func TimestampLT(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LT(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldLT(FieldTimestamp, v)) } // TimestampLTE applies the LTE predicate on the "timestamp" field. func TimestampLTE(v time.Time) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s.Where(sql.LTE(s.C(FieldTimestamp), v)) - }) + return predicate.Timestamp(sql.FieldLTE(FieldTimestamp, v)) } // HasSignature applies the HasEdge predicate on the "signature" edge. @@ -263,7 +175,6 @@ func HasSignature() predicate.Timestamp { return predicate.Timestamp(func(s *sql.Selector) { step := sqlgraph.NewStep( sqlgraph.From(Table, FieldID), - sqlgraph.To(SignatureTable, FieldID), sqlgraph.Edge(sqlgraph.M2O, true, SignatureTable, SignatureColumn), ) sqlgraph.HasNeighbors(s, step) @@ -273,11 +184,7 @@ func HasSignature() predicate.Timestamp { // HasSignatureWith applies the HasEdge predicate on the "signature" edge with a given conditions (other predicates). func HasSignatureWith(preds ...predicate.Signature) predicate.Timestamp { return predicate.Timestamp(func(s *sql.Selector) { - step := sqlgraph.NewStep( - sqlgraph.From(Table, FieldID), - sqlgraph.To(SignatureInverseTable, FieldID), - sqlgraph.Edge(sqlgraph.M2O, true, SignatureTable, SignatureColumn), - ) + step := newSignatureStep() sqlgraph.HasNeighborsWith(s, step, func(s *sql.Selector) { for _, p := range preds { p(s) @@ -288,32 +195,15 @@ func HasSignatureWith(preds ...predicate.Signature) predicate.Timestamp { // And groups predicates with the AND operator between them. func And(predicates ...predicate.Timestamp) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for _, p := range predicates { - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Timestamp(sql.AndPredicates(predicates...)) } // Or groups predicates with the OR operator between them. func Or(predicates ...predicate.Timestamp) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - s1 := s.Clone().SetP(nil) - for i, p := range predicates { - if i > 0 { - s1.Or() - } - p(s1) - } - s.Where(s1.P()) - }) + return predicate.Timestamp(sql.OrPredicates(predicates...)) } // Not applies the not operator on the given predicate. func Not(p predicate.Timestamp) predicate.Timestamp { - return predicate.Timestamp(func(s *sql.Selector) { - p(s.Not()) - }) + return predicate.Timestamp(sql.NotPredicates(p)) } diff --git a/ent/timestamp_create.go b/ent/timestamp_create.go index 0caf5464..91aa479e 100644 --- a/ent/timestamp_create.go +++ b/ent/timestamp_create.go @@ -59,49 +59,7 @@ func (tc *TimestampCreate) Mutation() *TimestampMutation { // Save creates the Timestamp in the database. func (tc *TimestampCreate) Save(ctx context.Context) (*Timestamp, error) { - var ( - err error - node *Timestamp - ) - if len(tc.hooks) == 0 { - if err = tc.check(); err != nil { - return nil, err - } - node, err = tc.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*TimestampMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - if err = tc.check(); err != nil { - return nil, err - } - tc.mutation = mutation - if node, err = tc.sqlSave(ctx); err != nil { - return nil, err - } - mutation.id = &node.ID - mutation.done = true - return node, err - }) - for i := len(tc.hooks) - 1; i >= 0; i-- { - if tc.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = tc.hooks[i](mut) - } - v, err := mut.Mutate(ctx, tc.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Timestamp) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from TimestampMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, tc.sqlSave, tc.mutation, tc.hooks) } // SaveX calls Save and panics if Save returns an error. @@ -138,6 +96,9 @@ func (tc *TimestampCreate) check() error { } func (tc *TimestampCreate) sqlSave(ctx context.Context) (*Timestamp, error) { + if err := tc.check(); err != nil { + return nil, err + } _node, _spec := tc.createSpec() if err := sqlgraph.CreateNode(ctx, tc.driver, _spec); err != nil { if sqlgraph.IsConstraintError(err) { @@ -147,34 +108,22 @@ func (tc *TimestampCreate) sqlSave(ctx context.Context) (*Timestamp, error) { } id := _spec.ID.Value.(int64) _node.ID = int(id) + tc.mutation.id = &_node.ID + tc.mutation.done = true return _node, nil } func (tc *TimestampCreate) createSpec() (*Timestamp, *sqlgraph.CreateSpec) { var ( _node = &Timestamp{config: tc.config} - _spec = &sqlgraph.CreateSpec{ - Table: timestamp.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, - } + _spec = sqlgraph.NewCreateSpec(timestamp.Table, sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt)) ) if value, ok := tc.mutation.GetType(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: timestamp.FieldType, - }) + _spec.SetField(timestamp.FieldType, field.TypeString, value) _node.Type = value } if value, ok := tc.mutation.Timestamp(); ok { - _spec.Fields = append(_spec.Fields, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: timestamp.FieldTimestamp, - }) + _spec.SetField(timestamp.FieldTimestamp, field.TypeTime, value) _node.Timestamp = value } if nodes := tc.mutation.SignatureIDs(); len(nodes) > 0 { @@ -185,10 +134,7 @@ func (tc *TimestampCreate) createSpec() (*Timestamp, *sqlgraph.CreateSpec) { Columns: []string{timestamp.SignatureColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -203,11 +149,15 @@ func (tc *TimestampCreate) createSpec() (*Timestamp, *sqlgraph.CreateSpec) { // TimestampCreateBulk is the builder for creating many Timestamp entities in bulk. type TimestampCreateBulk struct { config + err error builders []*TimestampCreate } // Save creates the Timestamp entities in the database. func (tcb *TimestampCreateBulk) Save(ctx context.Context) ([]*Timestamp, error) { + if tcb.err != nil { + return nil, tcb.err + } specs := make([]*sqlgraph.CreateSpec, len(tcb.builders)) nodes := make([]*Timestamp, len(tcb.builders)) mutators := make([]Mutator, len(tcb.builders)) @@ -223,8 +173,8 @@ func (tcb *TimestampCreateBulk) Save(ctx context.Context) ([]*Timestamp, error) return nil, err } builder.mutation = mutation - nodes[i], specs[i] = builder.createSpec() var err error + nodes[i], specs[i] = builder.createSpec() if i < len(mutators)-1 { _, err = mutators[i+1].Mutate(root, tcb.builders[i+1].mutation) } else { diff --git a/ent/timestamp_delete.go b/ent/timestamp_delete.go index 0b89a520..e1a941a8 100644 --- a/ent/timestamp_delete.go +++ b/ent/timestamp_delete.go @@ -4,7 +4,6 @@ package ent import ( "context" - "fmt" "entgo.io/ent/dialect/sql" "entgo.io/ent/dialect/sql/sqlgraph" @@ -28,34 +27,7 @@ func (td *TimestampDelete) Where(ps ...predicate.Timestamp) *TimestampDelete { // Exec executes the deletion query and returns how many vertices were deleted. func (td *TimestampDelete) Exec(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(td.hooks) == 0 { - affected, err = td.sqlExec(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*TimestampMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - td.mutation = mutation - affected, err = td.sqlExec(ctx) - mutation.done = true - return affected, err - }) - for i := len(td.hooks) - 1; i >= 0; i-- { - if td.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = td.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, td.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, td.sqlExec, td.mutation, td.hooks) } // ExecX is like Exec, but panics if an error occurs. @@ -68,15 +40,7 @@ func (td *TimestampDelete) ExecX(ctx context.Context) int { } func (td *TimestampDelete) sqlExec(ctx context.Context) (int, error) { - _spec := &sqlgraph.DeleteSpec{ - Node: &sqlgraph.NodeSpec{ - Table: timestamp.Table, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, - }, - } + _spec := sqlgraph.NewDeleteSpec(timestamp.Table, sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt)) if ps := td.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -88,6 +52,7 @@ func (td *TimestampDelete) sqlExec(ctx context.Context) (int, error) { if err != nil && sqlgraph.IsConstraintError(err) { err = &ConstraintError{msg: err.Error(), wrap: err} } + td.mutation.done = true return affected, err } @@ -96,6 +61,12 @@ type TimestampDeleteOne struct { td *TimestampDelete } +// Where appends a list predicates to the TimestampDelete builder. +func (tdo *TimestampDeleteOne) Where(ps ...predicate.Timestamp) *TimestampDeleteOne { + tdo.td.mutation.Where(ps...) + return tdo +} + // Exec executes the deletion query. func (tdo *TimestampDeleteOne) Exec(ctx context.Context) error { n, err := tdo.td.Exec(ctx) @@ -111,5 +82,7 @@ func (tdo *TimestampDeleteOne) Exec(ctx context.Context) error { // ExecX is like Exec, but panics if an error occurs. func (tdo *TimestampDeleteOne) ExecX(ctx context.Context) { - tdo.td.ExecX(ctx) + if err := tdo.Exec(ctx); err != nil { + panic(err) + } } diff --git a/ent/timestamp_query.go b/ent/timestamp_query.go index a3d0047b..44d336b1 100644 --- a/ent/timestamp_query.go +++ b/ent/timestamp_query.go @@ -18,11 +18,9 @@ import ( // TimestampQuery is the builder for querying Timestamp entities. type TimestampQuery struct { config - limit *int - offset *int - unique *bool - order []OrderFunc - fields []string + ctx *QueryContext + order []timestamp.OrderOption + inters []Interceptor predicates []predicate.Timestamp withSignature *SignatureQuery withFKs bool @@ -39,34 +37,34 @@ func (tq *TimestampQuery) Where(ps ...predicate.Timestamp) *TimestampQuery { return tq } -// Limit adds a limit step to the query. +// Limit the number of records to be returned by this query. func (tq *TimestampQuery) Limit(limit int) *TimestampQuery { - tq.limit = &limit + tq.ctx.Limit = &limit return tq } -// Offset adds an offset step to the query. +// Offset to start from. func (tq *TimestampQuery) Offset(offset int) *TimestampQuery { - tq.offset = &offset + tq.ctx.Offset = &offset return tq } // Unique configures the query builder to filter duplicate records on query. // By default, unique is set to true, and can be disabled using this method. func (tq *TimestampQuery) Unique(unique bool) *TimestampQuery { - tq.unique = &unique + tq.ctx.Unique = &unique return tq } -// Order adds an order step to the query. -func (tq *TimestampQuery) Order(o ...OrderFunc) *TimestampQuery { +// Order specifies how the records should be ordered. +func (tq *TimestampQuery) Order(o ...timestamp.OrderOption) *TimestampQuery { tq.order = append(tq.order, o...) return tq } // QuerySignature chains the current query on the "signature" edge. func (tq *TimestampQuery) QuerySignature() *SignatureQuery { - query := &SignatureQuery{config: tq.config} + query := (&SignatureClient{config: tq.config}).Query() query.path = func(ctx context.Context) (fromU *sql.Selector, err error) { if err := tq.prepareQuery(ctx); err != nil { return nil, err @@ -89,7 +87,7 @@ func (tq *TimestampQuery) QuerySignature() *SignatureQuery { // First returns the first Timestamp entity from the query. // Returns a *NotFoundError when no Timestamp was found. func (tq *TimestampQuery) First(ctx context.Context) (*Timestamp, error) { - nodes, err := tq.Limit(1).All(ctx) + nodes, err := tq.Limit(1).All(setContextOp(ctx, tq.ctx, "First")) if err != nil { return nil, err } @@ -112,7 +110,7 @@ func (tq *TimestampQuery) FirstX(ctx context.Context) *Timestamp { // Returns a *NotFoundError when no Timestamp ID was found. func (tq *TimestampQuery) FirstID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(1).IDs(ctx); err != nil { + if ids, err = tq.Limit(1).IDs(setContextOp(ctx, tq.ctx, "FirstID")); err != nil { return } if len(ids) == 0 { @@ -135,7 +133,7 @@ func (tq *TimestampQuery) FirstIDX(ctx context.Context) int { // Returns a *NotSingularError when more than one Timestamp entity is found. // Returns a *NotFoundError when no Timestamp entities are found. func (tq *TimestampQuery) Only(ctx context.Context) (*Timestamp, error) { - nodes, err := tq.Limit(2).All(ctx) + nodes, err := tq.Limit(2).All(setContextOp(ctx, tq.ctx, "Only")) if err != nil { return nil, err } @@ -163,7 +161,7 @@ func (tq *TimestampQuery) OnlyX(ctx context.Context) *Timestamp { // Returns a *NotFoundError when no entities are found. func (tq *TimestampQuery) OnlyID(ctx context.Context) (id int, err error) { var ids []int - if ids, err = tq.Limit(2).IDs(ctx); err != nil { + if ids, err = tq.Limit(2).IDs(setContextOp(ctx, tq.ctx, "OnlyID")); err != nil { return } switch len(ids) { @@ -188,10 +186,12 @@ func (tq *TimestampQuery) OnlyIDX(ctx context.Context) int { // All executes the query and returns a list of Timestamps. func (tq *TimestampQuery) All(ctx context.Context) ([]*Timestamp, error) { + ctx = setContextOp(ctx, tq.ctx, "All") if err := tq.prepareQuery(ctx); err != nil { return nil, err } - return tq.sqlAll(ctx) + qr := querierAll[[]*Timestamp, *TimestampQuery]() + return withInterceptors[[]*Timestamp](ctx, tq, qr, tq.inters) } // AllX is like All, but panics if an error occurs. @@ -204,9 +204,12 @@ func (tq *TimestampQuery) AllX(ctx context.Context) []*Timestamp { } // IDs executes the query and returns a list of Timestamp IDs. -func (tq *TimestampQuery) IDs(ctx context.Context) ([]int, error) { - var ids []int - if err := tq.Select(timestamp.FieldID).Scan(ctx, &ids); err != nil { +func (tq *TimestampQuery) IDs(ctx context.Context) (ids []int, err error) { + if tq.ctx.Unique == nil && tq.path != nil { + tq.Unique(true) + } + ctx = setContextOp(ctx, tq.ctx, "IDs") + if err = tq.Select(timestamp.FieldID).Scan(ctx, &ids); err != nil { return nil, err } return ids, nil @@ -223,10 +226,11 @@ func (tq *TimestampQuery) IDsX(ctx context.Context) []int { // Count returns the count of the given query. func (tq *TimestampQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, tq.ctx, "Count") if err := tq.prepareQuery(ctx); err != nil { return 0, err } - return tq.sqlCount(ctx) + return withInterceptors[int](ctx, tq, querierCount[*TimestampQuery](), tq.inters) } // CountX is like Count, but panics if an error occurs. @@ -240,10 +244,15 @@ func (tq *TimestampQuery) CountX(ctx context.Context) int { // Exist returns true if the query has elements in the graph. func (tq *TimestampQuery) Exist(ctx context.Context) (bool, error) { - if err := tq.prepareQuery(ctx); err != nil { - return false, err + ctx = setContextOp(ctx, tq.ctx, "Exist") + switch _, err := tq.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil } - return tq.sqlExist(ctx) } // ExistX is like Exist, but panics if an error occurs. @@ -263,22 +272,21 @@ func (tq *TimestampQuery) Clone() *TimestampQuery { } return &TimestampQuery{ config: tq.config, - limit: tq.limit, - offset: tq.offset, - order: append([]OrderFunc{}, tq.order...), + ctx: tq.ctx.Clone(), + order: append([]timestamp.OrderOption{}, tq.order...), + inters: append([]Interceptor{}, tq.inters...), predicates: append([]predicate.Timestamp{}, tq.predicates...), withSignature: tq.withSignature.Clone(), // clone intermediate query. - sql: tq.sql.Clone(), - path: tq.path, - unique: tq.unique, + sql: tq.sql.Clone(), + path: tq.path, } } // WithSignature tells the query-builder to eager-load the nodes that are connected to // the "signature" edge. The optional arguments are used to configure the query builder of the edge. func (tq *TimestampQuery) WithSignature(opts ...func(*SignatureQuery)) *TimestampQuery { - query := &SignatureQuery{config: tq.config} + query := (&SignatureClient{config: tq.config}).Query() for _, opt := range opts { opt(query) } @@ -301,16 +309,11 @@ func (tq *TimestampQuery) WithSignature(opts ...func(*SignatureQuery)) *Timestam // Aggregate(ent.Count()). // Scan(ctx, &v) func (tq *TimestampQuery) GroupBy(field string, fields ...string) *TimestampGroupBy { - grbuild := &TimestampGroupBy{config: tq.config} - grbuild.fields = append([]string{field}, fields...) - grbuild.path = func(ctx context.Context) (prev *sql.Selector, err error) { - if err := tq.prepareQuery(ctx); err != nil { - return nil, err - } - return tq.sqlQuery(ctx), nil - } + tq.ctx.Fields = append([]string{field}, fields...) + grbuild := &TimestampGroupBy{build: tq} + grbuild.flds = &tq.ctx.Fields grbuild.label = timestamp.Label - grbuild.flds, grbuild.scan = &grbuild.fields, grbuild.Scan + grbuild.scan = grbuild.Scan return grbuild } @@ -327,15 +330,30 @@ func (tq *TimestampQuery) GroupBy(field string, fields ...string) *TimestampGrou // Select(timestamp.FieldType). // Scan(ctx, &v) func (tq *TimestampQuery) Select(fields ...string) *TimestampSelect { - tq.fields = append(tq.fields, fields...) - selbuild := &TimestampSelect{TimestampQuery: tq} - selbuild.label = timestamp.Label - selbuild.flds, selbuild.scan = &tq.fields, selbuild.Scan - return selbuild + tq.ctx.Fields = append(tq.ctx.Fields, fields...) + sbuild := &TimestampSelect{TimestampQuery: tq} + sbuild.label = timestamp.Label + sbuild.flds, sbuild.scan = &tq.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a TimestampSelect configured with the given aggregations. +func (tq *TimestampQuery) Aggregate(fns ...AggregateFunc) *TimestampSelect { + return tq.Select().Aggregate(fns...) } func (tq *TimestampQuery) prepareQuery(ctx context.Context) error { - for _, f := range tq.fields { + for _, inter := range tq.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, tq); err != nil { + return err + } + } + } + for _, f := range tq.ctx.Fields { if !timestamp.ValidColumn(f) { return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} } @@ -413,6 +431,9 @@ func (tq *TimestampQuery) loadSignature(ctx context.Context, query *SignatureQue } nodeids[fk] = append(nodeids[fk], nodes[i]) } + if len(ids) == 0 { + return nil + } query.Where(signature.IDIn(ids...)) neighbors, err := query.All(ctx) if err != nil { @@ -435,41 +456,22 @@ func (tq *TimestampQuery) sqlCount(ctx context.Context) (int, error) { if len(tq.modifiers) > 0 { _spec.Modifiers = tq.modifiers } - _spec.Node.Columns = tq.fields - if len(tq.fields) > 0 { - _spec.Unique = tq.unique != nil && *tq.unique + _spec.Node.Columns = tq.ctx.Fields + if len(tq.ctx.Fields) > 0 { + _spec.Unique = tq.ctx.Unique != nil && *tq.ctx.Unique } return sqlgraph.CountNodes(ctx, tq.driver, _spec) } -func (tq *TimestampQuery) sqlExist(ctx context.Context) (bool, error) { - switch _, err := tq.FirstID(ctx); { - case IsNotFound(err): - return false, nil - case err != nil: - return false, fmt.Errorf("ent: check existence: %w", err) - default: - return true, nil - } -} - func (tq *TimestampQuery) querySpec() *sqlgraph.QuerySpec { - _spec := &sqlgraph.QuerySpec{ - Node: &sqlgraph.NodeSpec{ - Table: timestamp.Table, - Columns: timestamp.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, - }, - From: tq.sql, - Unique: true, - } - if unique := tq.unique; unique != nil { + _spec := sqlgraph.NewQuerySpec(timestamp.Table, timestamp.Columns, sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt)) + _spec.From = tq.sql + if unique := tq.ctx.Unique; unique != nil { _spec.Unique = *unique + } else if tq.path != nil { + _spec.Unique = true } - if fields := tq.fields; len(fields) > 0 { + if fields := tq.ctx.Fields; len(fields) > 0 { _spec.Node.Columns = make([]string, 0, len(fields)) _spec.Node.Columns = append(_spec.Node.Columns, timestamp.FieldID) for i := range fields { @@ -485,10 +487,10 @@ func (tq *TimestampQuery) querySpec() *sqlgraph.QuerySpec { } } } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { _spec.Limit = *limit } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { _spec.Offset = *offset } if ps := tq.order; len(ps) > 0 { @@ -504,7 +506,7 @@ func (tq *TimestampQuery) querySpec() *sqlgraph.QuerySpec { func (tq *TimestampQuery) sqlQuery(ctx context.Context) *sql.Selector { builder := sql.Dialect(tq.driver.Dialect()) t1 := builder.Table(timestamp.Table) - columns := tq.fields + columns := tq.ctx.Fields if len(columns) == 0 { columns = timestamp.Columns } @@ -513,7 +515,7 @@ func (tq *TimestampQuery) sqlQuery(ctx context.Context) *sql.Selector { selector = tq.sql selector.Select(selector.Columns(columns...)...) } - if tq.unique != nil && *tq.unique { + if tq.ctx.Unique != nil && *tq.ctx.Unique { selector.Distinct() } for _, p := range tq.predicates { @@ -522,12 +524,12 @@ func (tq *TimestampQuery) sqlQuery(ctx context.Context) *sql.Selector { for _, p := range tq.order { p(selector) } - if offset := tq.offset; offset != nil { + if offset := tq.ctx.Offset; offset != nil { // limit is mandatory for offset clause. We start // with default value, and override it below if needed. selector.Offset(*offset).Limit(math.MaxInt32) } - if limit := tq.limit; limit != nil { + if limit := tq.ctx.Limit; limit != nil { selector.Limit(*limit) } return selector @@ -535,13 +537,8 @@ func (tq *TimestampQuery) sqlQuery(ctx context.Context) *sql.Selector { // TimestampGroupBy is the group-by builder for Timestamp entities. type TimestampGroupBy struct { - config selector - fields []string - fns []AggregateFunc - // intermediate query (i.e. traversal path). - sql *sql.Selector - path func(context.Context) (*sql.Selector, error) + build *TimestampQuery } // Aggregate adds the given aggregation functions to the group-by query. @@ -550,74 +547,77 @@ func (tgb *TimestampGroupBy) Aggregate(fns ...AggregateFunc) *TimestampGroupBy { return tgb } -// Scan applies the group-by query and scans the result into the given value. +// Scan applies the selector query and scans the result into the given value. func (tgb *TimestampGroupBy) Scan(ctx context.Context, v any) error { - query, err := tgb.path(ctx) - if err != nil { + ctx = setContextOp(ctx, tgb.build.ctx, "GroupBy") + if err := tgb.build.prepareQuery(ctx); err != nil { return err } - tgb.sql = query - return tgb.sqlScan(ctx, v) + return scanWithInterceptors[*TimestampQuery, *TimestampGroupBy](ctx, tgb.build, tgb, tgb.build.inters, v) } -func (tgb *TimestampGroupBy) sqlScan(ctx context.Context, v any) error { - for _, f := range tgb.fields { - if !timestamp.ValidColumn(f) { - return &ValidationError{Name: f, err: fmt.Errorf("invalid field %q for group-by", f)} +func (tgb *TimestampGroupBy) sqlScan(ctx context.Context, root *TimestampQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(tgb.fns)) + for _, fn := range tgb.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*tgb.flds)+len(tgb.fns)) + for _, f := range *tgb.flds { + columns = append(columns, selector.C(f)) } + columns = append(columns, aggregation...) + selector.Select(columns...) } - selector := tgb.sqlQuery() + selector.GroupBy(selector.Columns(*tgb.flds...)...) if err := selector.Err(); err != nil { return err } rows := &sql.Rows{} query, args := selector.Query() - if err := tgb.driver.Query(ctx, query, args, rows); err != nil { + if err := tgb.build.driver.Query(ctx, query, args, rows); err != nil { return err } defer rows.Close() return sql.ScanSlice(rows, v) } -func (tgb *TimestampGroupBy) sqlQuery() *sql.Selector { - selector := tgb.sql.Select() - aggregation := make([]string, 0, len(tgb.fns)) - for _, fn := range tgb.fns { - aggregation = append(aggregation, fn(selector)) - } - // If no columns were selected in a custom aggregation function, the default - // selection is the fields used for "group-by", and the aggregation functions. - if len(selector.SelectedColumns()) == 0 { - columns := make([]string, 0, len(tgb.fields)+len(tgb.fns)) - for _, f := range tgb.fields { - columns = append(columns, selector.C(f)) - } - columns = append(columns, aggregation...) - selector.Select(columns...) - } - return selector.GroupBy(selector.Columns(tgb.fields...)...) -} - // TimestampSelect is the builder for selecting fields of Timestamp entities. type TimestampSelect struct { *TimestampQuery selector - // intermediate query (i.e. traversal path). - sql *sql.Selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (ts *TimestampSelect) Aggregate(fns ...AggregateFunc) *TimestampSelect { + ts.fns = append(ts.fns, fns...) + return ts } // Scan applies the selector query and scans the result into the given value. func (ts *TimestampSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, ts.ctx, "Select") if err := ts.prepareQuery(ctx); err != nil { return err } - ts.sql = ts.TimestampQuery.sqlQuery(ctx) - return ts.sqlScan(ctx, v) + return scanWithInterceptors[*TimestampQuery, *TimestampSelect](ctx, ts.TimestampQuery, ts, ts.inters, v) } -func (ts *TimestampSelect) sqlScan(ctx context.Context, v any) error { +func (ts *TimestampSelect) sqlScan(ctx context.Context, root *TimestampQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(ts.fns)) + for _, fn := range ts.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*ts.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } rows := &sql.Rows{} - query, args := ts.sql.Query() + query, args := selector.Query() if err := ts.driver.Query(ctx, query, args, rows); err != nil { return err } diff --git a/ent/timestamp_update.go b/ent/timestamp_update.go index ab08a235..efcf8231 100644 --- a/ent/timestamp_update.go +++ b/ent/timestamp_update.go @@ -73,34 +73,7 @@ func (tu *TimestampUpdate) ClearSignature() *TimestampUpdate { // Save executes the query and returns the number of nodes affected by the update operation. func (tu *TimestampUpdate) Save(ctx context.Context) (int, error) { - var ( - err error - affected int - ) - if len(tu.hooks) == 0 { - affected, err = tu.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*TimestampMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - tu.mutation = mutation - affected, err = tu.sqlSave(ctx) - mutation.done = true - return affected, err - }) - for i := len(tu.hooks) - 1; i >= 0; i-- { - if tu.hooks[i] == nil { - return 0, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = tu.hooks[i](mut) - } - if _, err := mut.Mutate(ctx, tu.mutation); err != nil { - return 0, err - } - } - return affected, err + return withHooks(ctx, tu.sqlSave, tu.mutation, tu.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -126,16 +99,7 @@ func (tu *TimestampUpdate) ExecX(ctx context.Context) { } func (tu *TimestampUpdate) sqlSave(ctx context.Context) (n int, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: timestamp.Table, - Columns: timestamp.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(timestamp.Table, timestamp.Columns, sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt)) if ps := tu.mutation.predicates; len(ps) > 0 { _spec.Predicate = func(selector *sql.Selector) { for i := range ps { @@ -144,18 +108,10 @@ func (tu *TimestampUpdate) sqlSave(ctx context.Context) (n int, err error) { } } if value, ok := tu.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: timestamp.FieldType, - }) + _spec.SetField(timestamp.FieldType, field.TypeString, value) } if value, ok := tu.mutation.Timestamp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: timestamp.FieldTimestamp, - }) + _spec.SetField(timestamp.FieldTimestamp, field.TypeTime, value) } if tu.mutation.SignatureCleared() { edge := &sqlgraph.EdgeSpec{ @@ -165,10 +121,7 @@ func (tu *TimestampUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{timestamp.SignatureColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -181,10 +134,7 @@ func (tu *TimestampUpdate) sqlSave(ctx context.Context) (n int, err error) { Columns: []string{timestamp.SignatureColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -200,6 +150,7 @@ func (tu *TimestampUpdate) sqlSave(ctx context.Context) (n int, err error) { } return 0, err } + tu.mutation.done = true return n, nil } @@ -253,6 +204,12 @@ func (tuo *TimestampUpdateOne) ClearSignature() *TimestampUpdateOne { return tuo } +// Where appends a list predicates to the TimestampUpdate builder. +func (tuo *TimestampUpdateOne) Where(ps ...predicate.Timestamp) *TimestampUpdateOne { + tuo.mutation.Where(ps...) + return tuo +} + // Select allows selecting one or more fields (columns) of the returned entity. // The default is selecting all fields defined in the entity schema. func (tuo *TimestampUpdateOne) Select(field string, fields ...string) *TimestampUpdateOne { @@ -262,40 +219,7 @@ func (tuo *TimestampUpdateOne) Select(field string, fields ...string) *Timestamp // Save executes the query and returns the updated Timestamp entity. func (tuo *TimestampUpdateOne) Save(ctx context.Context) (*Timestamp, error) { - var ( - err error - node *Timestamp - ) - if len(tuo.hooks) == 0 { - node, err = tuo.sqlSave(ctx) - } else { - var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { - mutation, ok := m.(*TimestampMutation) - if !ok { - return nil, fmt.Errorf("unexpected mutation type %T", m) - } - tuo.mutation = mutation - node, err = tuo.sqlSave(ctx) - mutation.done = true - return node, err - }) - for i := len(tuo.hooks) - 1; i >= 0; i-- { - if tuo.hooks[i] == nil { - return nil, fmt.Errorf("ent: uninitialized hook (forgotten import ent/runtime?)") - } - mut = tuo.hooks[i](mut) - } - v, err := mut.Mutate(ctx, tuo.mutation) - if err != nil { - return nil, err - } - nv, ok := v.(*Timestamp) - if !ok { - return nil, fmt.Errorf("unexpected node type %T returned from TimestampMutation", v) - } - node = nv - } - return node, err + return withHooks(ctx, tuo.sqlSave, tuo.mutation, tuo.hooks) } // SaveX is like Save, but panics if an error occurs. @@ -321,16 +245,7 @@ func (tuo *TimestampUpdateOne) ExecX(ctx context.Context) { } func (tuo *TimestampUpdateOne) sqlSave(ctx context.Context) (_node *Timestamp, err error) { - _spec := &sqlgraph.UpdateSpec{ - Node: &sqlgraph.NodeSpec{ - Table: timestamp.Table, - Columns: timestamp.Columns, - ID: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: timestamp.FieldID, - }, - }, - } + _spec := sqlgraph.NewUpdateSpec(timestamp.Table, timestamp.Columns, sqlgraph.NewFieldSpec(timestamp.FieldID, field.TypeInt)) id, ok := tuo.mutation.ID() if !ok { return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "Timestamp.id" for update`)} @@ -356,18 +271,10 @@ func (tuo *TimestampUpdateOne) sqlSave(ctx context.Context) (_node *Timestamp, e } } if value, ok := tuo.mutation.GetType(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeString, - Value: value, - Column: timestamp.FieldType, - }) + _spec.SetField(timestamp.FieldType, field.TypeString, value) } if value, ok := tuo.mutation.Timestamp(); ok { - _spec.Fields.Set = append(_spec.Fields.Set, &sqlgraph.FieldSpec{ - Type: field.TypeTime, - Value: value, - Column: timestamp.FieldTimestamp, - }) + _spec.SetField(timestamp.FieldTimestamp, field.TypeTime, value) } if tuo.mutation.SignatureCleared() { edge := &sqlgraph.EdgeSpec{ @@ -377,10 +284,7 @@ func (tuo *TimestampUpdateOne) sqlSave(ctx context.Context) (_node *Timestamp, e Columns: []string{timestamp.SignatureColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } _spec.Edges.Clear = append(_spec.Edges.Clear, edge) @@ -393,10 +297,7 @@ func (tuo *TimestampUpdateOne) sqlSave(ctx context.Context) (_node *Timestamp, e Columns: []string{timestamp.SignatureColumn}, Bidi: false, Target: &sqlgraph.EdgeTarget{ - IDSpec: &sqlgraph.FieldSpec{ - Type: field.TypeInt, - Column: signature.FieldID, - }, + IDSpec: sqlgraph.NewFieldSpec(signature.FieldID, field.TypeInt), }, } for _, k := range nodes { @@ -415,5 +316,6 @@ func (tuo *TimestampUpdateOne) sqlSave(ctx context.Context) (_node *Timestamp, e } return nil, err } + tuo.mutation.done = true return _node, nil } diff --git a/ent/tx.go b/ent/tx.go index dfc7ab29..e7ea4609 100644 --- a/ent/tx.go +++ b/ent/tx.go @@ -34,12 +34,6 @@ type Tx struct { // lazily loaded. client *Client clientOnce sync.Once - - // completion callbacks. - mu sync.Mutex - onCommit []CommitHook - onRollback []RollbackHook - // ctx lives for the life of the transaction. It is // the same context used by the underlying connection. ctx context.Context @@ -84,9 +78,9 @@ func (tx *Tx) Commit() error { var fn Committer = CommitFunc(func(context.Context, *Tx) error { return txDriver.tx.Commit() }) - tx.mu.Lock() - hooks := append([]CommitHook(nil), tx.onCommit...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]CommitHook(nil), txDriver.onCommit...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -95,9 +89,10 @@ func (tx *Tx) Commit() error { // OnCommit adds a hook to call on commit. func (tx *Tx) OnCommit(f CommitHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onCommit = append(tx.onCommit, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onCommit = append(txDriver.onCommit, f) + txDriver.mu.Unlock() } type ( @@ -139,9 +134,9 @@ func (tx *Tx) Rollback() error { var fn Rollbacker = RollbackFunc(func(context.Context, *Tx) error { return txDriver.tx.Rollback() }) - tx.mu.Lock() - hooks := append([]RollbackHook(nil), tx.onRollback...) - tx.mu.Unlock() + txDriver.mu.Lock() + hooks := append([]RollbackHook(nil), txDriver.onRollback...) + txDriver.mu.Unlock() for i := len(hooks) - 1; i >= 0; i-- { fn = hooks[i](fn) } @@ -150,9 +145,10 @@ func (tx *Tx) Rollback() error { // OnRollback adds a hook to call on rollback. func (tx *Tx) OnRollback(f RollbackHook) { - tx.mu.Lock() - defer tx.mu.Unlock() - tx.onRollback = append(tx.onRollback, f) + txDriver := tx.config.driver.(*txDriver) + txDriver.mu.Lock() + txDriver.onRollback = append(txDriver.onRollback, f) + txDriver.mu.Unlock() } // Client returns a Client that binds to current transaction. @@ -192,6 +188,10 @@ type txDriver struct { drv dialect.Driver // tx is the underlying transaction. tx dialect.Tx + // completion hooks. + mu sync.Mutex + onCommit []CommitHook + onRollback []RollbackHook } // newTx creates a new transactional driver. diff --git a/generated.go b/generated.go index 88e35714..3a63f8df 100644 --- a/generated.go +++ b/generated.go @@ -12,6 +12,7 @@ import ( "sync/atomic" "time" + "entgo.io/contrib/entgql" "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/introspection" "github.com/testifysec/archivista/ent" @@ -92,10 +93,10 @@ type ComplexityRoot struct { } Query struct { - Dsses func(childComplexity int, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.DsseWhereInput) int + Dsses func(childComplexity int, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.DsseWhereInput) int Node func(childComplexity int, id int) int Nodes func(childComplexity int, ids []int) int - Subjects func(childComplexity int, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.SubjectWhereInput) int + Subjects func(childComplexity int, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.SubjectWhereInput) int } Signature struct { @@ -111,7 +112,7 @@ type ComplexityRoot struct { Dsse func(childComplexity int) int ID func(childComplexity int) int Predicate func(childComplexity int) int - Subjects func(childComplexity int, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.SubjectWhereInput) int + Subjects func(childComplexity int, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.SubjectWhereInput) int } Subject struct { @@ -150,8 +151,8 @@ type ComplexityRoot struct { type QueryResolver interface { Node(ctx context.Context, id int) (ent.Noder, error) Nodes(ctx context.Context, ids []int) ([]ent.Noder, error) - Dsses(ctx context.Context, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.DsseWhereInput) (*ent.DsseConnection, error) - Subjects(ctx context.Context, after *ent.Cursor, first *int, before *ent.Cursor, last *int, where *ent.SubjectWhereInput) (*ent.SubjectConnection, error) + Dsses(ctx context.Context, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.DsseWhereInput) (*ent.DsseConnection, error) + Subjects(ctx context.Context, after *entgql.Cursor[int], first *int, before *entgql.Cursor[int], last *int, where *ent.SubjectWhereInput) (*ent.SubjectConnection, error) } type executableSchema struct { @@ -361,7 +362,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Dsses(childComplexity, args["after"].(*ent.Cursor), args["first"].(*int), args["before"].(*ent.Cursor), args["last"].(*int), args["where"].(*ent.DsseWhereInput)), true + return e.complexity.Query.Dsses(childComplexity, args["after"].(*entgql.Cursor[int]), args["first"].(*int), args["before"].(*entgql.Cursor[int]), args["last"].(*int), args["where"].(*ent.DsseWhereInput)), true case "Query.node": if e.complexity.Query.Node == nil { @@ -397,7 +398,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Query.Subjects(childComplexity, args["after"].(*ent.Cursor), args["first"].(*int), args["before"].(*ent.Cursor), args["last"].(*int), args["where"].(*ent.SubjectWhereInput)), true + return e.complexity.Query.Subjects(childComplexity, args["after"].(*entgql.Cursor[int]), args["first"].(*int), args["before"].(*entgql.Cursor[int]), args["last"].(*int), args["where"].(*ent.SubjectWhereInput)), true case "Signature.dsse": if e.complexity.Signature.Dsse == nil { @@ -472,7 +473,7 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return 0, false } - return e.complexity.Statement.Subjects(childComplexity, args["after"].(*ent.Cursor), args["first"].(*int), args["before"].(*ent.Cursor), args["last"].(*int), args["where"].(*ent.SubjectWhereInput)), true + return e.complexity.Statement.Subjects(childComplexity, args["after"].(*entgql.Cursor[int]), args["first"].(*int), args["before"].(*entgql.Cursor[int]), args["last"].(*int), args["where"].(*ent.SubjectWhereInput)), true case "Subject.id": if e.complexity.Subject.ID == nil { @@ -1268,10 +1269,10 @@ func (ec *executionContext) field_Query___type_args(ctx context.Context, rawArgs func (ec *executionContext) field_Query_dsses_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 *ent.Cursor + var arg0 *entgql.Cursor[int] if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg0, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg0, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -1286,10 +1287,10 @@ func (ec *executionContext) field_Query_dsses_args(ctx context.Context, rawArgs } } args["first"] = arg1 - var arg2 *ent.Cursor + var arg2 *entgql.Cursor[int] if tmp, ok := rawArgs["before"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("before")) - arg2, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg2, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -1349,10 +1350,10 @@ func (ec *executionContext) field_Query_nodes_args(ctx context.Context, rawArgs func (ec *executionContext) field_Query_subjects_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 *ent.Cursor + var arg0 *entgql.Cursor[int] if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg0, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg0, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -1367,10 +1368,10 @@ func (ec *executionContext) field_Query_subjects_args(ctx context.Context, rawAr } } args["first"] = arg1 - var arg2 *ent.Cursor + var arg2 *entgql.Cursor[int] if tmp, ok := rawArgs["before"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("before")) - arg2, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg2, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -1400,10 +1401,10 @@ func (ec *executionContext) field_Query_subjects_args(ctx context.Context, rawAr func (ec *executionContext) field_Statement_subjects_args(ctx context.Context, rawArgs map[string]interface{}) (map[string]interface{}, error) { var err error args := map[string]interface{}{} - var arg0 *ent.Cursor + var arg0 *entgql.Cursor[int] if tmp, ok := rawArgs["after"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("after")) - arg0, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg0, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -1418,10 +1419,10 @@ func (ec *executionContext) field_Statement_subjects_args(ctx context.Context, r } } args["first"] = arg1 - var arg2 *ent.Cursor + var arg2 *entgql.Cursor[int] if tmp, ok := rawArgs["before"]; ok { ctx := graphql.WithPathContext(ctx, graphql.NewPathWithField("before")) - arg2, err = ec.unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, tmp) + arg2, err = ec.unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, tmp) if err != nil { return nil, err } @@ -2183,9 +2184,9 @@ func (ec *executionContext) _DsseConnection_pageInfo(ctx context.Context, field } return graphql.Null } - res := resTmp.(ent.PageInfo) + res := resTmp.(entgql.PageInfo[int]) fc.Result = res - return ec.marshalNPageInfo2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐPageInfo(ctx, field.Selections, res) + return ec.marshalNPageInfo2entgoᚗioᚋcontribᚋentgqlᚐPageInfo(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_DsseConnection_pageInfo(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -2336,9 +2337,9 @@ func (ec *executionContext) _DsseEdge_cursor(ctx context.Context, field graphql. } return graphql.Null } - res := resTmp.(ent.Cursor) + res := resTmp.(entgql.Cursor[int]) fc.Result = res - return ec.marshalNCursor2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, field.Selections, res) + return ec.marshalNCursor2entgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_DsseEdge_cursor(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -2354,7 +2355,7 @@ func (ec *executionContext) fieldContext_DsseEdge_cursor(ctx context.Context, fi return fc, nil } -func (ec *executionContext) _PageInfo_hasNextPage(ctx context.Context, field graphql.CollectedField, obj *ent.PageInfo) (ret graphql.Marshaler) { +func (ec *executionContext) _PageInfo_hasNextPage(ctx context.Context, field graphql.CollectedField, obj *entgql.PageInfo[int]) (ret graphql.Marshaler) { fc, err := ec.fieldContext_PageInfo_hasNextPage(ctx, field) if err != nil { return graphql.Null @@ -2398,7 +2399,7 @@ func (ec *executionContext) fieldContext_PageInfo_hasNextPage(ctx context.Contex return fc, nil } -func (ec *executionContext) _PageInfo_hasPreviousPage(ctx context.Context, field graphql.CollectedField, obj *ent.PageInfo) (ret graphql.Marshaler) { +func (ec *executionContext) _PageInfo_hasPreviousPage(ctx context.Context, field graphql.CollectedField, obj *entgql.PageInfo[int]) (ret graphql.Marshaler) { fc, err := ec.fieldContext_PageInfo_hasPreviousPage(ctx, field) if err != nil { return graphql.Null @@ -2442,7 +2443,7 @@ func (ec *executionContext) fieldContext_PageInfo_hasPreviousPage(ctx context.Co return fc, nil } -func (ec *executionContext) _PageInfo_startCursor(ctx context.Context, field graphql.CollectedField, obj *ent.PageInfo) (ret graphql.Marshaler) { +func (ec *executionContext) _PageInfo_startCursor(ctx context.Context, field graphql.CollectedField, obj *entgql.PageInfo[int]) (ret graphql.Marshaler) { fc, err := ec.fieldContext_PageInfo_startCursor(ctx, field) if err != nil { return graphql.Null @@ -2465,9 +2466,9 @@ func (ec *executionContext) _PageInfo_startCursor(ctx context.Context, field gra if resTmp == nil { return graphql.Null } - res := resTmp.(*ent.Cursor) + res := resTmp.(*entgql.Cursor[int]) fc.Result = res - return ec.marshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, field.Selections, res) + return ec.marshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_PageInfo_startCursor(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -2483,7 +2484,7 @@ func (ec *executionContext) fieldContext_PageInfo_startCursor(ctx context.Contex return fc, nil } -func (ec *executionContext) _PageInfo_endCursor(ctx context.Context, field graphql.CollectedField, obj *ent.PageInfo) (ret graphql.Marshaler) { +func (ec *executionContext) _PageInfo_endCursor(ctx context.Context, field graphql.CollectedField, obj *entgql.PageInfo[int]) (ret graphql.Marshaler) { fc, err := ec.fieldContext_PageInfo_endCursor(ctx, field) if err != nil { return graphql.Null @@ -2506,9 +2507,9 @@ func (ec *executionContext) _PageInfo_endCursor(ctx context.Context, field graph if resTmp == nil { return graphql.Null } - res := resTmp.(*ent.Cursor) + res := resTmp.(*entgql.Cursor[int]) fc.Result = res - return ec.marshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, field.Selections, res) + return ec.marshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_PageInfo_endCursor(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -2832,7 +2833,7 @@ func (ec *executionContext) _Query_dsses(ctx context.Context, field graphql.Coll }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Dsses(rctx, fc.Args["after"].(*ent.Cursor), fc.Args["first"].(*int), fc.Args["before"].(*ent.Cursor), fc.Args["last"].(*int), fc.Args["where"].(*ent.DsseWhereInput)) + return ec.resolvers.Query().Dsses(rctx, fc.Args["after"].(*entgql.Cursor[int]), fc.Args["first"].(*int), fc.Args["before"].(*entgql.Cursor[int]), fc.Args["last"].(*int), fc.Args["where"].(*ent.DsseWhereInput)) }) if err != nil { ec.Error(ctx, err) @@ -2895,7 +2896,7 @@ func (ec *executionContext) _Query_subjects(ctx context.Context, field graphql.C }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return ec.resolvers.Query().Subjects(rctx, fc.Args["after"].(*ent.Cursor), fc.Args["first"].(*int), fc.Args["before"].(*ent.Cursor), fc.Args["last"].(*int), fc.Args["where"].(*ent.SubjectWhereInput)) + return ec.resolvers.Query().Subjects(rctx, fc.Args["after"].(*entgql.Cursor[int]), fc.Args["first"].(*int), fc.Args["before"].(*entgql.Cursor[int]), fc.Args["last"].(*int), fc.Args["where"].(*ent.SubjectWhereInput)) }) if err != nil { ec.Error(ctx, err) @@ -3413,7 +3414,7 @@ func (ec *executionContext) _Statement_subjects(ctx context.Context, field graph }() resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { ctx = rctx // use context from middleware stack in children - return obj.Subjects(ctx, fc.Args["after"].(*ent.Cursor), fc.Args["first"].(*int), fc.Args["before"].(*ent.Cursor), fc.Args["last"].(*int), fc.Args["where"].(*ent.SubjectWhereInput)) + return obj.Subjects(ctx, fc.Args["after"].(*entgql.Cursor[int]), fc.Args["first"].(*int), fc.Args["before"].(*entgql.Cursor[int]), fc.Args["last"].(*int), fc.Args["where"].(*ent.SubjectWhereInput)) }) if err != nil { ec.Error(ctx, err) @@ -3833,9 +3834,9 @@ func (ec *executionContext) _SubjectConnection_pageInfo(ctx context.Context, fie } return graphql.Null } - res := resTmp.(ent.PageInfo) + res := resTmp.(entgql.PageInfo[int]) fc.Result = res - return ec.marshalNPageInfo2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐPageInfo(ctx, field.Selections, res) + return ec.marshalNPageInfo2entgoᚗioᚋcontribᚋentgqlᚐPageInfo(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_SubjectConnection_pageInfo(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -4165,9 +4166,9 @@ func (ec *executionContext) _SubjectEdge_cursor(ctx context.Context, field graph } return graphql.Null } - res := resTmp.(ent.Cursor) + res := resTmp.(entgql.Cursor[int]) fc.Result = res - return ec.marshalNCursor2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx, field.Selections, res) + return ec.marshalNCursor2entgoᚗioᚋcontribᚋentgqlᚐCursor(ctx, field.Selections, res) } func (ec *executionContext) fieldContext_SubjectEdge_cursor(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) { @@ -9095,7 +9096,7 @@ func (ec *executionContext) _DsseEdge(ctx context.Context, sel ast.SelectionSet, var pageInfoImplementors = []string{"PageInfo"} -func (ec *executionContext) _PageInfo(ctx context.Context, sel ast.SelectionSet, obj *ent.PageInfo) graphql.Marshaler { +func (ec *executionContext) _PageInfo(ctx context.Context, sel ast.SelectionSet, obj *entgql.PageInfo[int]) graphql.Marshaler { fields := graphql.CollectFields(ec.OperationContext, sel, pageInfoImplementors) out := graphql.NewFieldSet(fields) var invalids uint32 @@ -10112,13 +10113,13 @@ func (ec *executionContext) marshalNBoolean2bool(ctx context.Context, sel ast.Se return res } -func (ec *executionContext) unmarshalNCursor2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx context.Context, v interface{}) (ent.Cursor, error) { - var res ent.Cursor +func (ec *executionContext) unmarshalNCursor2entgoᚗioᚋcontribᚋentgqlᚐCursor(ctx context.Context, v interface{}) (entgql.Cursor[int], error) { + var res entgql.Cursor[int] err := res.UnmarshalGQL(v) return res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) marshalNCursor2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx context.Context, sel ast.SelectionSet, v ent.Cursor) graphql.Marshaler { +func (ec *executionContext) marshalNCursor2entgoᚗioᚋcontribᚋentgqlᚐCursor(ctx context.Context, sel ast.SelectionSet, v entgql.Cursor[int]) graphql.Marshaler { return v } @@ -10251,7 +10252,7 @@ func (ec *executionContext) marshalNNode2ᚕgithubᚗcomᚋtestifysecᚋarchivis return ret } -func (ec *executionContext) marshalNPageInfo2githubᚗcomᚋtestifysecᚋarchivistaᚋentᚐPageInfo(ctx context.Context, sel ast.SelectionSet, v ent.PageInfo) graphql.Marshaler { +func (ec *executionContext) marshalNPageInfo2entgoᚗioᚋcontribᚋentgqlᚐPageInfo(ctx context.Context, sel ast.SelectionSet, v entgql.PageInfo[int]) graphql.Marshaler { return ec._PageInfo(ctx, sel, &v) } @@ -10768,16 +10769,16 @@ func (ec *executionContext) marshalOBoolean2ᚖbool(ctx context.Context, sel ast return res } -func (ec *executionContext) unmarshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx context.Context, v interface{}) (*ent.Cursor, error) { +func (ec *executionContext) unmarshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx context.Context, v interface{}) (*entgql.Cursor[int], error) { if v == nil { return nil, nil } - var res = new(ent.Cursor) + var res = new(entgql.Cursor[int]) err := res.UnmarshalGQL(v) return res, graphql.ErrorOnPath(ctx, err) } -func (ec *executionContext) marshalOCursor2ᚖgithubᚗcomᚋtestifysecᚋarchivistaᚋentᚐCursor(ctx context.Context, sel ast.SelectionSet, v *ent.Cursor) graphql.Marshaler { +func (ec *executionContext) marshalOCursor2ᚖentgoᚗioᚋcontribᚋentgqlᚐCursor(ctx context.Context, sel ast.SelectionSet, v *entgql.Cursor[int]) graphql.Marshaler { if v == nil { return graphql.Null } diff --git a/go.mod b/go.mod index 03a354e8..4519bd83 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.19 require ( ariga.io/sqlcomment v0.0.0-20211020114721-6bb67a62a61a - entgo.io/contrib v0.3.0 - entgo.io/ent v0.11.3 + entgo.io/contrib v0.4.5 + entgo.io/ent v0.12.4 github.com/99designs/gqlgen v0.17.5-0.20220428154617-9250f9ac1f90 github.com/antonfisher/nested-logrus-formatter v1.3.1 github.com/digitorus/timestamp v0.0.0-20230220124323-d542479a2425 @@ -18,17 +18,16 @@ require ( github.com/lib/pq v1.10.7 github.com/minio/minio-go v6.0.14+incompatible github.com/sirupsen/logrus v1.9.0 - github.com/spf13/cobra v1.5.0 + github.com/spf13/cobra v1.7.0 github.com/stretchr/testify v1.8.3 github.com/testifysec/archivista-api v0.0.0-20230303165309-a31a92afd132 github.com/testifysec/go-witness v0.1.16 github.com/vektah/gqlparser/v2 v2.4.7 - github.com/vmihailenco/msgpack/v5 v5.0.0-beta.9 golang.org/x/sync v0.1.0 ) require ( - ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a // indirect + ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 // indirect github.com/agext/levenshtein v1.2.1 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect github.com/apparentlymart/go-textseg/v13 v13.0.0 // indirect @@ -47,25 +46,27 @@ require ( github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/hcl/v2 v2.13.0 // indirect - github.com/inconshreveable/mousetrap v1.0.0 // indirect + github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect github.com/mitchellh/go-wordwrap v0.0.0-20150314170334-ad45545899c7 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/vmihailenco/msgpack/v5 v5.0.0-beta.9 // indirect github.com/vmihailenco/tagparser v0.1.2 // indirect github.com/zclconf/go-cty v1.12.1 // indirect - go.opencensus.io v0.23.0 // indirect + go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/otel v1.16.0 // indirect go.opentelemetry.io/otel/metric v1.16.0 // indirect go.opentelemetry.io/otel/trace v1.16.0 // indirect golang.org/x/crypto v0.10.0 // indirect - golang.org/x/mod v0.8.0 // indirect + golang.org/x/exp v0.0.0-20221230185412-738e83a70c30 // indirect + golang.org/x/mod v0.10.0 // indirect golang.org/x/net v0.11.0 // indirect golang.org/x/sys v0.9.0 // indirect golang.org/x/text v0.10.0 // indirect - golang.org/x/tools v0.6.0 // indirect + golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index d19794c1..b3e22eb6 100644 --- a/go.sum +++ b/go.sum @@ -1,12 +1,12 @@ -ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a h1:6/nt4DODfgxzHTTg3tYy7YkVzruGQGZ/kRvXpA45KUo= -ariga.io/atlas v0.7.2-0.20220927111110-867ee0cca56a/go.mod h1:ft47uSh5hWGDCmQC9DsztZg6Xk+KagM5Ts/mZYKb9JE= +ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935 h1:JnYs/y8RJ3+MiIUp+3RgyyeO48VHLAZimqiaZYnMKk8= +ariga.io/atlas v0.14.1-0.20230918065911-83ad451a4935/go.mod h1:isZrlzJ5cpoCoKFoY9knZug7Lq4pP1cm8g3XciLZ0Pw= ariga.io/sqlcomment v0.0.0-20211020114721-6bb67a62a61a h1:rF33ixWgZAVi2wg4oNljORY954kdaZxjhD2IStWr1JA= ariga.io/sqlcomment v0.0.0-20211020114721-6bb67a62a61a/go.mod h1:K6ubZj0yR5JIrCkzfKmZjz44O/dsCClLW8Yb5pttjzo= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -entgo.io/contrib v0.3.0 h1:dcWQAQkPVmOhYgZz20e18sgh3hkW8TV+0Y5QrFjFAsA= -entgo.io/contrib v0.3.0/go.mod h1:mc3iYCB4aKIL3rcnrXN5C/+N4pKoOVnVUPlZvTbcdNk= -entgo.io/ent v0.11.3 h1:F5FBGAWiDCGder7YT+lqMnyzXl6d0xU3xMBM/SO3CMc= -entgo.io/ent v0.11.3/go.mod h1:mvDhvynOzAsOe7anH7ynPPtMjA/eeXP96kAfweevyxc= +entgo.io/contrib v0.4.5 h1:BFaOHwFLE8WZjVJadP0XHCIaxgcC1BAtUvAyw7M/GHk= +entgo.io/contrib v0.4.5/go.mod h1:wpZyq2DJgthugFvDBlaqMXj9mV4/9ebyGEn7xlTVQqE= +entgo.io/ent v0.12.4 h1:LddPnAyxls/O7DTXZvUGDj0NZIdGSu317+aoNLJWbD8= +entgo.io/ent v0.12.4/go.mod h1:Y3JVAjtlIk8xVZYSn3t3mf8xlZIn5SAOXZQxD6kKI+Q= github.com/99designs/gqlgen v0.17.5-0.20220428154617-9250f9ac1f90 h1:nGGP+sUJ6D3guzjVBgoH1PrZxoU4lUdfR/Q8THYrAJI= github.com/99designs/gqlgen v0.17.5-0.20220428154617-9250f9ac1f90/go.mod h1:SNpLVzaF37rRLSAXtu8FKVp5I4zycneMmFX6NT4XGSU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -103,8 +103,8 @@ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+l github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl/v2 v2.13.0 h1:0Apadu1w6M11dyGFxWnmhhcMjkbAiKCv7G1r/2QgCNc= github.com/hashicorp/hcl/v2 v2.13.0/go.mod h1:e4z5nxYlWNPdDSNYX+ph14EvWYMFm3eP0zIUqPc2jr0= -github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/kevinmbeaulieu/eq-go v1.0.0/go.mod h1:G3S8ajA56gKBZm4UB9AOyoOS37JO3roToPzKNM8dtdM= @@ -122,7 +122,7 @@ github.com/logrusorgru/aurora/v3 v3.0.0/go.mod h1:vsR12bk5grlLvLXAYrBsb5Oc/N+LxA github.com/matryer/moq v0.2.7/go.mod h1:kITsx543GOENm48TUAQyJ9+SAvFSr7iGQXPoth/VUBk= github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-sqlite3 v1.14.15 h1:vfoHhTN1af61xCRSWzFIWzx2YskyMTwHLrExkBOjvxI= +github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= github.com/minio/minio-go v6.0.14+incompatible h1:fnV+GD28LeqdN6vT2XdGKW8Qe/IfjJDswNVuni6km9o= github.com/minio/minio-go v6.0.14+incompatible/go.mod h1:7guKYtitv8dktvNUGrhzmNlA5wrAABTQXCoesZdFQO8= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= @@ -141,16 +141,20 @@ github.com/sergi/go-diff v1.1.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNX github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= -github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/testifysec/archivista-api v0.0.0-20230303165309-a31a92afd132 h1:w/8ACAtNVo201VSxtdStNvahmIYvgFzjfpeCdu0BkSo= @@ -168,8 +172,8 @@ github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgq github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/zclconf/go-cty v1.12.1 h1:PcupnljUm9EIvbgSHQnHhUr3fO6oFmkOrvs2BAFNXXY= github.com/zclconf/go-cty v1.12.1/go.mod h1:s9IfD1LK5ccNMSWCVFCE2rJfHiZgi7JijgeWIMfhLvA= -go.opencensus.io v0.23.0 h1:gqCw0LfLxScz8irSi8exQc7fyQ0fKQU/qnC/X8+V/1M= -go.opencensus.io v0.23.0/go.mod h1:XItmlyltB5F7CS4xOC1DcqMoFqwtC6OG2xF7mCv7P7E= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/otel v1.16.0 h1:Z7GVAX/UkAXPKsy94IU+i6thsQS4nb7LviLpnaNeW8s= go.opentelemetry.io/otel v1.16.0/go.mod h1:vl0h9NUa1D5s1nv3A5vZOYWn8av4K8Ml6JDeHrT/bx4= go.opentelemetry.io/otel/metric v1.16.0 h1:RbrpwVG1Hfv85LgnZ7+txXioPDoh6EdbZHo26Q3hqOo= @@ -183,13 +187,15 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20221230185412-738e83a70c30 h1:m9O6OTJ627iFnN2JIWfdqlZCzneRO6EEBsHXI25P8ws= +golang.org/x/exp v0.0.0-20221230185412-738e83a70c30/go.mod h1:CxIveKay+FTh1D0yPZemJVgC/95VzuuOLq5Qi4xnoYc= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= -golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= +golang.org/x/mod v0.10.0 h1:lFO9qtOdlre5W1jxS3r/4szv2/6iXxScdzjoBMXNhYk= +golang.org/x/mod v0.10.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -236,8 +242,8 @@ golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBn golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= -golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= +golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901 h1:0wxTF6pSjIIhNt7mo9GvjDfzyCOiWhmICgtO/Ah948s= +golang.org/x/tools v0.8.1-0.20230428195545-5283a0178901/go.mod h1:JxBZ99ISMI5ViVkT1tr6tdNmXeTrcpVSD3vZ1RsRdN4= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=