Skip to content

Commit

Permalink
Add adapter tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eugenetriguba committed Mar 11, 2024
1 parent 5e9a90e commit f0ee4ba
Show file tree
Hide file tree
Showing 8 changed files with 368 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
run: |
env $(cat ".env.${{ matrix.db.name }}" | xargs) docker compose -f "docker-compose.${{ matrix.db.name }}.yml" up -d
- name: Run Tests
run: env $(cat ".env.${{ matrix.db.name }}" | xargs) go test -cover ./...
run: env $(cat ".env.${{ matrix.db.name }}" | xargs) go test -tags "${{ matrix.db.name }}" -cover ./...
12 changes: 6 additions & 6 deletions internal/bolttest/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@ import (

func NewTestDB(t *testing.T) storage.DB {
connectionConfig := NewTestConnectionConfig()
testdb, err := storage.NewDB(connectionConfig)
db, err := storage.NewDB(connectionConfig)
assert.Nil(t, err)
t.Cleanup(func() {
DropTable(t, testdb, "bolt_migrations")
assert.Nil(t, testdb.Close())
DropTable(t, db, "bolt_migrations")
assert.Nil(t, db.Close())
})
return testdb
return db
}

func NewTestConnectionConfig() configloader.ConnectionConfig {
Expand All @@ -32,7 +32,7 @@ func NewTestConnectionConfig() configloader.ConnectionConfig {
}
}

func DropTable(t *testing.T, testdb storage.DB, tableName string) {
_, err := testdb.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
func DropTable(t *testing.T, db storage.DB, tableName string) {
_, err := db.Exec(fmt.Sprintf("DROP TABLE IF EXISTS %s;", tableName))
assert.Nil(t, err)
}
8 changes: 8 additions & 0 deletions internal/storage/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,16 @@ package storage
import "github.com/eugenetriguba/bolt/internal/configloader"

type DBAdapter interface {
// ConvertGenericPlaceholders replaces any generic `?` placeholders
// in `query` for the database driver specific placeholders and returns
// the updated query.
ConvertGenericPlaceholders(query string, argsCount int) string
// TableExists checks if the tableName exists within the
// database currently connected to.
TableExists(executor sqlExecutor, tableName string) (bool, error)
// DatabaseName retrieves the currently selected database name.
DatabaseName(executor sqlExecutor) (string, error)
// CreateDSN creates a DSN to be used with sql.Open in the database
// driver specific format.
CreateDSN(cfg configloader.ConnectionConfig) string
}
35 changes: 19 additions & 16 deletions internal/storage/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import (
var postgresqlDriverName = "postgresql"
var mysqlDriverName = "mysql"

var supportedDrivers = map[string]string{
postgresqlDriverName: "pgx",
mysqlDriverName: "mysql",
var supportedDrivers = map[string]dbDriver{
postgresqlDriverName: {name: "pgx", adapter: PostgresqlAdapter{}},
mysqlDriverName: {name: "mysql", adapter: MySQLAdapter{}},
}

var driverAdapters = map[string]DBAdapter{
postgresqlDriverName: PostgresqlAdapter{},
mysqlDriverName: MySQLAdapter{},
type dbDriver struct {
name string
adapter DBAdapter
}

var (
Expand Down Expand Up @@ -56,17 +56,12 @@ type DB struct {
// the provided connection parameters.
// - ErrUnsupportedDriver: The provided driver is not supported.
func NewDB(cfg configloader.ConnectionConfig) (DB, error) {
driverName, exists := supportedDrivers[cfg.Driver]
driver, exists := supportedDrivers[cfg.Driver]
if !exists {
return DB{}, ErrUnsupportedDriver
}

adapter, exists := driverAdapters[cfg.Driver]
if !exists {
return DB{}, ErrUnsupportedDriver
}

db, err := sql.Open(driverName, adapter.CreateDSN(cfg))
db, err := sql.Open(driver.name, driver.adapter.CreateDSN(cfg))
if err != nil {
return DB{}, fmt.Errorf("%w: %v", ErrMalformedConnectionString, err)
}
Expand All @@ -79,7 +74,7 @@ func NewDB(cfg configloader.ConnectionConfig) (DB, error) {
return DB{}, fmt.Errorf("%w: %v", ErrUnableToConnect, err)
}

return DB{executor: db, sqlDB: db, adapter: adapter}, nil
return DB{executor: db, sqlDB: db, adapter: driver.adapter}, nil
}

// Close closes the database connection. Any further
Expand All @@ -89,27 +84,35 @@ func (db DB) Close() error {
return db.sqlDB.Close()
}

// Exec is a wrapper around the sql.DB Exec.
func (db DB) Exec(query string, args ...any) (sql.Result, error) {
newQuery := db.adapter.ConvertGenericPlaceholders(query, len(args))
return db.executor.Exec(newQuery, args...)
}

// Query is a wrapper around the sql.DB Query.
func (db DB) Query(query string, args ...any) (*sql.Rows, error) {
newQuery := db.adapter.ConvertGenericPlaceholders(query, len(args))
return db.executor.Query(newQuery, args...)
}

// QueryRow is a wrapper around the sql.DB QueryRow.
func (db DB) QueryRow(query string, args ...any) *sql.Row {
newQuery := db.adapter.ConvertGenericPlaceholders(query, len(args))
return db.executor.QueryRow(newQuery, args...)
}

// TableExists checks if the tableName exists within the
// database currently connected to.
func (db DB) TableExists(tableName string) (bool, error) {
return db.adapter.TableExists(db.executor, tableName)
}

type txFunc func(db DB) error

// Tx executes fn within a transaction block. If
// fn returns an error, the transaction will be rolled
// back. Otherwise, it will be committed.
func (db *DB) Tx(fn txFunc) error {
tx, err := db.sqlDB.Begin()
if err != nil {
Expand All @@ -127,12 +130,12 @@ func (db *DB) Tx(fn txFunc) error {

err = fn(txDB)
if err != nil {
return err
return fmt.Errorf("unable to execute transaction: %w", err)
}

err = tx.Commit()
if err != nil {
return err
return fmt.Errorf("unable to commit transaction: %w", err)
}

return nil
Expand Down
133 changes: 127 additions & 6 deletions internal/storage/db_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package storage_test

import (
"database/sql"
"errors"
"testing"

"github.com/eugenetriguba/bolt/internal/bolttest"
Expand All @@ -9,21 +11,140 @@ import (
)

func TestNewDB_Success(t *testing.T) {
db, err := storage.NewDB(bolttest.NewTestConnectionConfig())
assert.Nil(t, err)
t.Cleanup(func() {
assert.Nil(t, db.Close())
})
_, err = db.Exec("SELECT 1;")
assert.Nil(t, err)
}

func TestNewDB_UnsupportedDriver(t *testing.T) {
t.Setenv("BOLT_DB_CONN_DRIVER", "abc123")
_, err := storage.NewDB(bolttest.NewTestConnectionConfig())
assert.ErrorIs(t, err, storage.ErrUnsupportedDriver)
}

func TestNewDB_UnableToConnect(t *testing.T) {
t.Setenv("BOLT_DB_CONN_HOST", "")
t.Setenv("BOLT_DB_CONN_PORT", "")
_, err := storage.NewDB(bolttest.NewTestConnectionConfig())
assert.ErrorIs(t, err, storage.ErrUnableToConnect)
}

func TestClose_IsClosed(t *testing.T) {
db, err := storage.NewDB(bolttest.NewTestConnectionConfig())
assert.Nil(t, err)

err = db.Close()
assert.Nil(t, err)

_, err = db.Exec("SELECT 1;")
assert.ErrorContains(t, err, "sql: database is closed")
}

func TestTableExists_DoesExist(t *testing.T) {
db, err := storage.NewDB(bolttest.NewTestConnectionConfig())
assert.Nil(t, err)
t.Cleanup(func() {
bolttest.DropTable(t, db, "tmp")
assert.Nil(t, db.Close())
})
_, err = db.Exec("CREATE TABLE tmp(id int primary key);")
assert.Nil(t, err)

exists, err := db.TableExists("tmp")

assert.Nil(t, err)
assert.True(t, exists)
}

func TestTableExists_DoesNotExist(t *testing.T) {
cfg := bolttest.NewTestConnectionConfig()
db, err := storage.NewDB(cfg)
assert.Nil(t, err)
t.Cleanup(func() {
assert.Nil(t, db.Close())
})

exists, err := db.TableExists("tmp")

assert.Nil(t, err)
assert.False(t, exists)
}

func TestQueryPlaceholders(t *testing.T) {
cfg := bolttest.NewTestConnectionConfig()
db, err := storage.NewDB(cfg)
assert.Nil(t, err)
t.Cleanup(func() {
bolttest.DropTable(t, db, "tmp")
assert.Nil(t, db.Close())
})
_, err = db.Exec(`CREATE TABLE tmp(id int primary key);`)
assert.Nil(t, err)
_, err = db.Exec(`INSERT INTO tmp(id) VALUES(1);`)
assert.Nil(t, err)
_, err = db.Exec(`INSERT INTO tmp(id) VALUES(2);`)
assert.Nil(t, err)

var queryRowId int
err = db.QueryRow("SELECT id FROM tmp WHERE id = ?", 1).Scan(&queryRowId)
assert.Nil(t, err)
defer db.Close()
_, err = db.Exec("SELECT 1;")
assert.Equal(t, queryRowId, 1)

rows, err := db.Query("SELECT id FROM tmp WHERE id = ?", 1)
assert.Nil(t, err)
assert.True(t, rows.Next())
var queryId int
err = rows.Scan(&queryId)
assert.Nil(t, err)
assert.Equal(t, queryId, 1)
assert.False(t, rows.Next())
}

func TestNewDB_UnsupportedDriver(t *testing.T) {
t.Setenv("BOLT_DB_CONN_DRIVER", "abc123")
func TestTx_Commit(t *testing.T) {
cfg := bolttest.NewTestConnectionConfig()
db, err := storage.NewDB(cfg)
assert.Nil(t, err)
t.Cleanup(func() {
bolttest.DropTable(t, db, "tmp")
assert.Nil(t, db.Close())
})

_, err := storage.NewDB(cfg)
err = db.Tx(func(db storage.DB) error {
_, err = db.Exec(`CREATE TABLE tmp(id int primary key);`)
assert.Nil(t, err)
return nil
})
assert.Nil(t, err)

assert.ErrorIs(t, err, storage.ErrUnsupportedDriver)
exists, err := db.TableExists("tmp")
assert.Nil(t, err)
assert.True(t, exists)
}

func TestTx_Rollback(t *testing.T) {
cfg := bolttest.NewTestConnectionConfig()
db, err := storage.NewDB(cfg)
assert.Nil(t, err)
t.Cleanup(func() {
bolttest.DropTable(t, db, "tmp")
assert.Nil(t, db.Close())
})
_, err = db.Exec(`CREATE TABLE tmp(id INT PRIMARY KEY);`)
assert.Nil(t, err)
expectedErr := errors.New("error!")

err = db.Tx(func(db storage.DB) error {
_, err = db.Exec(`INSERT INTO tmp(id) VALUES(1)`)
assert.Nil(t, err)
return expectedErr
})
assert.ErrorIs(t, err, expectedErr)

var id int
err = db.QueryRow("SELECT id FROM tmp WHERE id = 1;").Scan(&id)
assert.ErrorIs(t, err, sql.ErrNoRows)
}
Loading

0 comments on commit f0ee4ba

Please sign in to comment.