Skip to content

Commit

Permalink
Add concurrency control parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Aug 27, 2024
1 parent f616f1c commit 50c033a
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 99 deletions.
51 changes: 29 additions & 22 deletions cmd/dmap/repo_scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"reflect"
"strings"
"time"

"github.com/alecthomas/kong"
"github.com/gobwas/glob"
Expand All @@ -15,21 +16,24 @@ import (
)

type RepoScanCmd struct {
Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""`
Host string `help:"Hostname of the repository." required:""`
Port uint16 `help:"Port of the repository." required:""`
User string `help:"Username to connect to the repository." required:""`
Password string `help:"Password to connect to the repository." required:""`
RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."`
Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."`
Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."`
IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"`
ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."`
MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"`
SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"`
Offset uint `help:"Offset to start sampling each table from." default:"0"`
LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."`
Silent bool `help:"Do not print the results to stdout." short:"s"`
Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""`
Host string `help:"Hostname of the repository." required:""`
Port uint16 `help:"Port of the repository." required:""`
User string `help:"Username to connect to the repository." required:""`
Password string `help:"Password to connect to the repository." required:""`
RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."`
Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."`
Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."`
IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"`
ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."`
MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"`
MaxParallelDbs uint `help:"Maximum number of parallel databases scanned at once. If zero, there is no limit." default:"0"`
MaxConcurrency uint `help:"Maximum number of concurrent query goroutines. If zero, there is no limit." default:"0"`
QueryTimeout time.Duration `help:"Maximum time a query can run before being cancelled. If zero, there is no timeout." default:"0s"`
SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"`
Offset uint `help:"Offset to start sampling each table from." default:"0"`
LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."`
Silent bool `help:"Do not print the results to stdout." short:"s"`
}

func (cmd *RepoScanCmd) Validate() error {
Expand Down Expand Up @@ -69,13 +73,16 @@ func (cmd *RepoScanCmd) Run(globals *Globals) error {
cfg := sql.ScannerConfig{
RepoType: cmd.Type,
RepoConfig: sql.RepoConfig{
Host: cmd.Host,
Port: cmd.Port,
User: cmd.User,
Password: cmd.Password,
Database: cmd.Database,
MaxOpenConns: cmd.MaxOpenConns,
Advanced: cmd.Advanced,
Host: cmd.Host,
Port: cmd.Port,
User: cmd.User,
Password: cmd.Password,
Database: cmd.Database,
MaxOpenConns: cmd.MaxOpenConns,
MaxParallelDbs: cmd.MaxParallelDbs,
MaxConcurrency: cmd.MaxConcurrency,
QueryTimeout: cmd.QueryTimeout,
Advanced: cmd.Advanced,
},
IncludePaths: cmd.IncludePaths,
ExcludePaths: cmd.ExcludePaths,
Expand Down
8 changes: 8 additions & 0 deletions sql/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sql

import (
"fmt"
"time"
)

// RepoConfig is the necessary configuration to connect to a data sql.
Expand All @@ -18,6 +19,13 @@ type RepoConfig struct {
Database string
// MaxOpenConns is the maximum number of open connections to the database.
MaxOpenConns uint
// MaxParallelDbs is the maximum number of parallel databases scanned at
// once.
MaxParallelDbs uint
// MaxConcurrency is the maximum number of concurrent query goroutines.
MaxConcurrency uint
// QueryTimeout is the maximum time a query can run before being cancelled.
QueryTimeout time.Duration
// Advanced is a map of advanced configuration options.
Advanced map[string]any
}
Expand Down
221 changes: 144 additions & 77 deletions sql/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@ import (
"github.com/cyralinc/dmap/scan"
)

// Pair type intended to be passed to a channel (see sampleAllDbs).
type samplesAndErr struct {
samples []Sample
err error
}

// Pair type intended to be passed to a channel (see sampleDb).
type sampleAndErr struct {
sample Sample
err error
}

// ScannerConfig is the configuration for the Scanner.
type ScannerConfig struct {
RepoType string
Expand Down Expand Up @@ -149,61 +161,98 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) {
}
defer func() { _ = repo.Close() }()
// Introspect the repository to get the metadata.
introspectCtx := ctx
if s.config.RepoConfig.QueryTimeout > 0 {
var cancel context.CancelFunc
introspectCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout)
defer cancel()
}
introspectParams := IntrospectParameters{
IncludePaths: s.config.IncludePaths,
ExcludePaths: s.config.ExcludePaths,
}
meta, err := repo.Introspect(ctx, introspectParams)
meta, err := repo.Introspect(introspectCtx, introspectParams)
if err != nil {
return nil, fmt.Errorf("error introspecting repository: %w", err)
}
// This is a "pair" type intended to be passed to the channel below.
type sampleAndErr struct {
sample Sample
err error
}
// Fan out sample executions.
// This goroutine launches additional goroutines, one for each table, which
// sample the respective tables and send the results to the out channel. A
// semaphore is optionally used to limit the number of tables that are
// sampled concurrently. We do this on a dedicated goroutine so we can
// immediately read from the out channel on this goroutine, and avoid
// possible deadlocks due to the semaphore.
out := make(chan sampleAndErr)
numTables := 0
for _, schemaMeta := range meta.Schemas {
for _, tableMeta := range schemaMeta.Tables {
numTables++
go func(meta *TableMetadata) {
params := SampleParameters{
Metadata: meta,
SampleSize: s.config.SampleSize,
Offset: s.config.Offset,
}
sample, err := repo.SampleTable(ctx, params)
select {
case <-ctx.Done():
return
case out <- sampleAndErr{sample: sample, err: err}:
go func() {
// Before we return, wait for all the goroutines we launch below to
// complete, and then close the out channel once they're all done so the
// main goroutine can aggregate the results and return them.
var wg sync.WaitGroup
defer func() { wg.Wait(); close(out) }()
// Optionally use a semaphore to limit the number of tables sampled
// concurrently.
var sema *semaphore.Weighted
if s.config.RepoConfig.MaxConcurrency > 0 {
sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxConcurrency))
}
for _, schemaMeta := range meta.Schemas {
for _, tableMeta := range schemaMeta.Tables {
if sema != nil {
// Acquire a semaphore slot before launching a goroutine to
// sample the table. This will block if the semaphore is
// full, and will unblock once a slot is available. An error
// means the context was cancelled.
if err := sema.Acquire(ctx, 1); err != nil {
log.WithError(err).Error("error acquiring semaphore")
return
}
}
}(tableMeta)
wg.Add(1)
// Launch a goroutine to sample the table.
go func(ctx context.Context, meta *TableMetadata) {
sampleCtx := ctx
if s.config.RepoConfig.QueryTimeout > 0 {
var cancel context.CancelFunc
sampleCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout)
defer cancel()
}
params := SampleParameters{
Metadata: meta,
SampleSize: s.config.SampleSize,
Offset: s.config.Offset,
}
sample, err := repo.SampleTable(sampleCtx, params)
select {
case <-ctx.Done():
case out <- sampleAndErr{sample: sample, err: err}:
}
}(ctx, tableMeta)
}
}
}
}()

// Aggregate and return the results.
var samples []Sample
var errs error
for i := 0; i < numTables; i++ {
for {
select {
case <-ctx.Done():
return samples, ctx.Err()
case res := <-out:
errs = errors.Join(errs, ctx.Err())
return samples, fmt.Errorf("error(s) sampling repository: %w", errs)
case res, ok := <-out:
if !ok {
// The out channel has been closed, so we're done.
if errs != nil {
return samples, fmt.Errorf("error(s) sampling repository: %w", errs)
}
return samples, nil
}
if res.err != nil {
errs = errors.Join(errs, res.err)
} else {
samples = append(samples, res.sample)
}
}
}
close(out)
if errs != nil {
return samples, fmt.Errorf("error(s) while sampling repository: %w", errs)
}
return samples, nil
}

// sampleAllDbs samples all the databases on the server. It samples each
Expand All @@ -230,58 +279,76 @@ func (s *Scanner) sampleAllDbs(ctx context.Context) ([]Sample, error) {
// We assume that this repository will be connected to the default database
// (or at least some database that can discover all the other databases).
// Use it to discover all the other databases on the server.
dbs, err := repo.ListDatabases(ctx)
listDbCtx := ctx
if s.config.RepoConfig.QueryTimeout > 0 {
var cancel context.CancelFunc
listDbCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout)
defer cancel()
}
dbs, err := repo.ListDatabases(listDbCtx)
if err != nil {
return nil, fmt.Errorf("error listing databases: %w", err)
}

// Sample each database on a separate goroutine, and send the samples to
// the 'out' channel. Each slice of samples will be aggregated below on this
// goroutine and returned.
var wg sync.WaitGroup
// This is a "pair" type intended to be passed to the channel below.
type samplesAndErr struct {
samples []Sample
err error
}
// This goroutine launches additional goroutines, one for each database,
// which sample the respective databases and send the results to the out
// channel. A semaphore is optionally used to limit the number of databases
// sampled concurrently. We do this on a dedicated goroutine so we can
// immediately read from the out channel on this goroutine, and avoid
// possible deadlocks due to the semaphore.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
out := make(chan samplesAndErr)
wg.Add(len(dbs))
// Using a semaphore here ensures that we avoid opening more than the
// specified total number of connections, since we end up creating multiple
// database handles (one per database).
var sema *semaphore.Weighted
if s.config.RepoConfig.MaxOpenConns > 0 {
sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxOpenConns))
}
for _, db := range dbs {
go func(db string, cfg RepoConfig) {
defer wg.Done()
go func() {
// Before we return, wait for all the goroutines we launch below to
// complete, and then close the out channel once they're all done so the
// main goroutine can aggregate the results and return them.
var wg sync.WaitGroup
defer func() { wg.Wait(); close(out) }()
// Optionally use a semaphore to limit the number of databases sampled
// concurrently.
var sema *semaphore.Weighted
if s.config.RepoConfig.MaxParallelDbs > 0 {
sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxParallelDbs))
}
for _, db := range dbs {
if sema != nil {
_ = sema.Acquire(ctx, 1)
defer sema.Release(1)
}
// Sample this specific database.
samples, err := s.sampleDb(ctx, db)
if err != nil && len(samples) == 0 {
log.WithError(err).Errorf("error gathering repository data samples for database %s", db)
return
}
// Send the samples for this database to the 'out' channel. The
// samples for each database will be aggregated into a single slice
// on the main goroutine and returned.
select {
case <-ctx.Done():
return
case out <- samplesAndErr{samples: samples, err: err}:
// Acquire a semaphore slot before launching a goroutine to
// sample the database. This will block if the semaphore is
// full, and will unblock once a slot is available. An error
// means the context was cancelled.
if err := sema.Acquire(ctx, 1); err != nil {
log.WithError(err).Error("error acquiring semaphore")
return
}
}
}(db, s.config.RepoConfig)
}

// Start a goroutine to close the 'out' channel once all the goroutines we
// launched above are done. This will allow the aggregation range loop below
// to terminate properly. Note that this must start after the wg.Add call.
// See https://go.dev/blog/pipelines ("Fan-out, fan-in" section).
go func() { wg.Wait(); close(out) }()
// Launch a goroutine to sample the database.
wg.Add(1)
go func(db string, cfg RepoConfig) {
defer func() {
if sema != nil {
// Release the slot once the goroutine is done.
sema.Release(1)
}
wg.Done()
}()
// Sample this specific database.
samples, err := s.sampleDb(ctx, db)
if err != nil && len(samples) == 0 {
log.WithError(err).Errorf("error gathering repository data samples for database %s", db)
return
}
// Send the samples for this database to the 'out' channel. The
// samples for each database will be aggregated into a single
// slice on the main goroutine and returned.
select {
case <-ctx.Done():
case out <- samplesAndErr{samples: samples, err: err}:
}
}(db, s.config.RepoConfig)
}
}()

// Aggregate and return the results.
var ret []Sample
Expand Down

0 comments on commit 50c033a

Please sign in to comment.