Skip to content

Commit

Permalink
Merge pull request #652 from upper/with-context
Browse files Browse the repository at this point in the history
Refactor WithContext to make it work without cloning the session and add `ConnMaxIdleTime`
  • Loading branch information
xiam authored Jun 17, 2022
2 parents b5aff2b + d8a05e0 commit 2147806
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 141 deletions.
1 change: 1 addition & 0 deletions adapter/postgresql/database_pgx.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build !pq
// +build !pq

package postgresql
Expand Down
68 changes: 35 additions & 33 deletions internal/sqladapter/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,42 +67,43 @@ type condsFilter interface {

// collection is the implementation of Collection.
type collection struct {
name string
sess Session

name string
adapter CollectionAdapter
}

// NewCollection initializes a Collection by wrapping a CollectionAdapter.
func NewCollection(sess Session, name string, adapter CollectionAdapter) Collection {
type collectionWithSession struct {
*collection

session Session
}

func newCollection(name string, adapter CollectionAdapter) *collection {
if adapter == nil {
panic("upper: received nil adapter")
panic("upper: nil adapter")
}
c := &collection{
sess: sess,
return &collection{
name: name,
adapter: adapter,
}
return c
}

func (c *collection) SQL() db.SQL {
return c.sess.SQL()
func (c *collectionWithSession) SQL() db.SQL {
return c.session.SQL()
}

func (c *collection) Session() db.Session {
return c.sess
func (c *collectionWithSession) Session() db.Session {
return c.session
}

func (c *collection) Name() string {
func (c *collectionWithSession) Name() string {
return c.name
}

func (c *collection) Count() (uint64, error) {
func (c *collectionWithSession) Count() (uint64, error) {
return c.Find().Count()
}

func (c *collection) Insert(item interface{}) (db.InsertResult, error) {
func (c *collectionWithSession) Insert(item interface{}) (db.InsertResult, error) {
id, err := c.adapter.Insert(c, item)
if err != nil {
return nil, err
Expand All @@ -111,11 +112,11 @@ func (c *collection) Insert(item interface{}) (db.InsertResult, error) {
return db.NewInsertResult(id), nil
}

func (c *collection) PrimaryKeys() ([]string, error) {
return c.sess.PrimaryKeys(c.Name())
func (c *collectionWithSession) PrimaryKeys() ([]string, error) {
return c.session.PrimaryKeys(c.Name())
}

func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) {
func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) {
pk, err := c.PrimaryKeys()
if err != nil {
return nil, err
Expand All @@ -131,15 +132,16 @@ func (c *collection) filterConds(conds ...interface{}) ([]interface{}, error) {
return conds, nil
}

func (c *collection) Find(conds ...interface{}) db.Result {
func (c *collectionWithSession) Find(conds ...interface{}) db.Result {
filteredConds, err := c.filterConds(conds...)
if err != nil {
res := &Result{}
res.setErr(err)
return res
}

res := NewResult(
c.sess.SQL(),
c.session.SQL(),
c.Name(),
filteredConds,
)
Expand All @@ -149,14 +151,14 @@ func (c *collection) Find(conds ...interface{}) db.Result {
return res
}

func (c *collection) Exists() (bool, error) {
if err := c.sess.TableExists(c.Name()); err != nil {
func (c *collectionWithSession) Exists() (bool, error) {
if err := c.session.TableExists(c.Name()); err != nil {
return false, err
}
return true, nil
}

func (c *collection) InsertReturning(item interface{}) error {
func (c *collectionWithSession) InsertReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
Expand All @@ -175,12 +177,12 @@ func (c *collection) InsertReturning(item interface{}) error {
}

var tx Session
isTransaction := c.sess.IsTransaction()
isTransaction := c.session.IsTransaction()
if isTransaction {
tx = c.sess
tx = c.session
} else {
var err error
tx, err = c.sess.NewTransaction(c.sess.Context(), nil)
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -261,7 +263,7 @@ cancel:
return err
}

func (c *collection) UpdateReturning(item interface{}) error {
func (c *collectionWithSession) UpdateReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
Expand All @@ -280,14 +282,14 @@ func (c *collection) UpdateReturning(item interface{}) error {
}

var tx Session
isTransaction := c.sess.IsTransaction()
isTransaction := c.session.IsTransaction()

if isTransaction {
tx = c.sess
tx = c.session
} else {
// Not within a transaction, let's create one.
var err error
tx, err = c.sess.NewTransaction(c.sess.Context(), nil)
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -355,12 +357,12 @@ cancel:
return err
}

func (c *collection) Truncate() error {
func (c *collectionWithSession) Truncate() error {
stmt := exql.Statement{
Type: exql.Truncate,
Table: exql.TableWithName(c.Name()),
}
if _, err := c.sess.SQL().Exec(&stmt); err != nil {
if _, err := c.session.SQL().Exec(&stmt); err != nil {
return err
}
return nil
Expand Down
1 change: 1 addition & 0 deletions internal/sqladapter/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ func (r *Result) One(dst interface{}) error {
r.setErr(err)
return err
}

err = query.Iterator().One(dst)
r.setErr(err)
return err
Expand Down
Loading

0 comments on commit 2147806

Please sign in to comment.