diff --git a/cmd/dmap/repo_scan.go b/cmd/dmap/repo_scan.go index f864f08..7e8561e 100644 --- a/cmd/dmap/repo_scan.go +++ b/cmd/dmap/repo_scan.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "time" "github.com/alecthomas/kong" "github.com/gobwas/glob" @@ -15,21 +16,24 @@ import ( ) type RepoScanCmd struct { - Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` - Host string `help:"Hostname of the repository." required:""` - Port uint16 `help:"Port of the repository." required:""` - User string `help:"Username to connect to the repository." required:""` - Password string `help:"Password to connect to the repository." required:""` - RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."` - Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` - Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."` - IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"` - ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."` - MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` - SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` - Offset uint `help:"Offset to start sampling each table from." default:"0"` - LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."` - Silent bool `help:"Do not print the results to stdout." short:"s"` + Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` + Host string `help:"Hostname of the repository." required:""` + Port uint16 `help:"Port of the repository." required:""` + User string `help:"Username to connect to the repository." required:""` + Password string `help:"Password to connect to the repository." required:""` + RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."` + Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` + Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."` + IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"` + ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."` + MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` + MaxParallelDbs uint `help:"Maximum number of parallel databases scanned at once. If zero, there is no limit." default:"0"` + MaxConcurrency uint `help:"Maximum number of concurrent query goroutines. If zero, there is no limit." default:"0"` + QueryTimeout time.Duration `help:"Maximum time a query can run before being cancelled. If zero, there is no timeout." default:"0s"` + SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` + Offset uint `help:"Offset to start sampling each table from." default:"0"` + LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."` + Silent bool `help:"Do not print the results to stdout." short:"s"` } func (cmd *RepoScanCmd) Validate() error { @@ -69,13 +73,16 @@ func (cmd *RepoScanCmd) Run(globals *Globals) error { cfg := sql.ScannerConfig{ RepoType: cmd.Type, RepoConfig: sql.RepoConfig{ - Host: cmd.Host, - Port: cmd.Port, - User: cmd.User, - Password: cmd.Password, - Database: cmd.Database, - MaxOpenConns: cmd.MaxOpenConns, - Advanced: cmd.Advanced, + Host: cmd.Host, + Port: cmd.Port, + User: cmd.User, + Password: cmd.Password, + Database: cmd.Database, + MaxOpenConns: cmd.MaxOpenConns, + MaxParallelDbs: cmd.MaxParallelDbs, + MaxConcurrency: cmd.MaxConcurrency, + QueryTimeout: cmd.QueryTimeout, + Advanced: cmd.Advanced, }, IncludePaths: cmd.IncludePaths, ExcludePaths: cmd.ExcludePaths, diff --git a/sql/config.go b/sql/config.go index 648fa4a..b4af1b2 100644 --- a/sql/config.go +++ b/sql/config.go @@ -2,6 +2,7 @@ package sql import ( "fmt" + "time" ) // RepoConfig is the necessary configuration to connect to a data sql. @@ -18,6 +19,13 @@ type RepoConfig struct { Database string // MaxOpenConns is the maximum number of open connections to the database. MaxOpenConns uint + // MaxParallelDbs is the maximum number of parallel databases scanned at + // once. + MaxParallelDbs uint + // MaxConcurrency is the maximum number of concurrent query goroutines. + MaxConcurrency uint + // QueryTimeout is the maximum time a query can run before being cancelled. + QueryTimeout time.Duration // Advanced is a map of advanced configuration options. Advanced map[string]any } diff --git a/sql/scanner.go b/sql/scanner.go index 4201a8a..21d8105 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -16,6 +16,18 @@ import ( "github.com/cyralinc/dmap/scan" ) +// Pair type intended to be passed to a channel (see sampleAllDbs). +type samplesAndErr struct { + samples []Sample + err error +} + +// Pair type intended to be passed to a channel (see sampleDb). +type sampleAndErr struct { + sample Sample + err error +} + // ScannerConfig is the configuration for the Scanner. type ScannerConfig struct { RepoType string @@ -149,49 +161,91 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { } defer func() { _ = repo.Close() }() // Introspect the repository to get the metadata. + introspectCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + introspectCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } introspectParams := IntrospectParameters{ IncludePaths: s.config.IncludePaths, ExcludePaths: s.config.ExcludePaths, } - meta, err := repo.Introspect(ctx, introspectParams) + meta, err := repo.Introspect(introspectCtx, introspectParams) if err != nil { return nil, fmt.Errorf("error introspecting repository: %w", err) } - // This is a "pair" type intended to be passed to the channel below. - type sampleAndErr struct { - sample Sample - err error - } - // Fan out sample executions. + // This goroutine launches additional goroutines, one for each table, which + // sample the respective tables and send the results to the out channel. A + // semaphore is optionally used to limit the number of tables that are + // sampled concurrently. We do this on a dedicated goroutine so we can + // immediately read from the out channel on this goroutine, and avoid + // possible deadlocks due to the semaphore. out := make(chan sampleAndErr) - numTables := 0 - for _, schemaMeta := range meta.Schemas { - for _, tableMeta := range schemaMeta.Tables { - numTables++ - go func(meta *TableMetadata) { - params := SampleParameters{ - Metadata: meta, - SampleSize: s.config.SampleSize, - Offset: s.config.Offset, - } - sample, err := repo.SampleTable(ctx, params) - select { - case <-ctx.Done(): - return - case out <- sampleAndErr{sample: sample, err: err}: + go func() { + // Before we return, wait for all the goroutines we launch below to + // complete, and then close the out channel once they're all done so the + // main goroutine can aggregate the results and return them. + var wg sync.WaitGroup + defer func() { wg.Wait(); close(out) }() + // Optionally use a semaphore to limit the number of tables sampled + // concurrently. + var sema *semaphore.Weighted + if s.config.RepoConfig.MaxConcurrency > 0 { + sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxConcurrency)) + } + for _, schemaMeta := range meta.Schemas { + for _, tableMeta := range schemaMeta.Tables { + if sema != nil { + // Acquire a semaphore slot before launching a goroutine to + // sample the table. This will block if the semaphore is + // full, and will unblock once a slot is available. An error + // means the context was cancelled. + if err := sema.Acquire(ctx, 1); err != nil { + log.WithError(err).Error("error acquiring semaphore") + return + } } - }(tableMeta) + wg.Add(1) + // Launch a goroutine to sample the table. + go func(ctx context.Context, meta *TableMetadata) { + sampleCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + sampleCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } + params := SampleParameters{ + Metadata: meta, + SampleSize: s.config.SampleSize, + Offset: s.config.Offset, + } + sample, err := repo.SampleTable(sampleCtx, params) + select { + case <-ctx.Done(): + case out <- sampleAndErr{sample: sample, err: err}: + } + }(ctx, tableMeta) + } } - } + }() // Aggregate and return the results. var samples []Sample var errs error - for i := 0; i < numTables; i++ { + for { select { case <-ctx.Done(): - return samples, ctx.Err() - case res := <-out: + errs = errors.Join(errs, ctx.Err()) + return samples, fmt.Errorf("error(s) sampling repository: %w", errs) + case res, ok := <-out: + if !ok { + // The out channel has been closed, so we're done. + if errs != nil { + return samples, fmt.Errorf("error(s) sampling repository: %w", errs) + } + return samples, nil + } if res.err != nil { errs = errors.Join(errs, res.err) } else { @@ -199,11 +253,6 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { } } } - close(out) - if errs != nil { - return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) - } - return samples, nil } // sampleAllDbs samples all the databases on the server. It samples each @@ -230,58 +279,76 @@ func (s *Scanner) sampleAllDbs(ctx context.Context) ([]Sample, error) { // We assume that this repository will be connected to the default database // (or at least some database that can discover all the other databases). // Use it to discover all the other databases on the server. - dbs, err := repo.ListDatabases(ctx) + listDbCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + listDbCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } + dbs, err := repo.ListDatabases(listDbCtx) if err != nil { return nil, fmt.Errorf("error listing databases: %w", err) } - // Sample each database on a separate goroutine, and send the samples to - // the 'out' channel. Each slice of samples will be aggregated below on this - // goroutine and returned. - var wg sync.WaitGroup - // This is a "pair" type intended to be passed to the channel below. - type samplesAndErr struct { - samples []Sample - err error - } + // This goroutine launches additional goroutines, one for each database, + // which sample the respective databases and send the results to the out + // channel. A semaphore is optionally used to limit the number of databases + // sampled concurrently. We do this on a dedicated goroutine so we can + // immediately read from the out channel on this goroutine, and avoid + // possible deadlocks due to the semaphore. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() out := make(chan samplesAndErr) - wg.Add(len(dbs)) - // Using a semaphore here ensures that we avoid opening more than the - // specified total number of connections, since we end up creating multiple - // database handles (one per database). - var sema *semaphore.Weighted - if s.config.RepoConfig.MaxOpenConns > 0 { - sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxOpenConns)) - } - for _, db := range dbs { - go func(db string, cfg RepoConfig) { - defer wg.Done() + go func() { + // Before we return, wait for all the goroutines we launch below to + // complete, and then close the out channel once they're all done so the + // main goroutine can aggregate the results and return them. + var wg sync.WaitGroup + defer func() { wg.Wait(); close(out) }() + // Optionally use a semaphore to limit the number of databases sampled + // concurrently. + var sema *semaphore.Weighted + if s.config.RepoConfig.MaxParallelDbs > 0 { + sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxParallelDbs)) + } + for _, db := range dbs { if sema != nil { - _ = sema.Acquire(ctx, 1) - defer sema.Release(1) - } - // Sample this specific database. - samples, err := s.sampleDb(ctx, db) - if err != nil && len(samples) == 0 { - log.WithError(err).Errorf("error gathering repository data samples for database %s", db) - return - } - // Send the samples for this database to the 'out' channel. The - // samples for each database will be aggregated into a single slice - // on the main goroutine and returned. - select { - case <-ctx.Done(): - return - case out <- samplesAndErr{samples: samples, err: err}: + // Acquire a semaphore slot before launching a goroutine to + // sample the database. This will block if the semaphore is + // full, and will unblock once a slot is available. An error + // means the context was cancelled. + if err := sema.Acquire(ctx, 1); err != nil { + log.WithError(err).Error("error acquiring semaphore") + return + } } - }(db, s.config.RepoConfig) - } - - // Start a goroutine to close the 'out' channel once all the goroutines we - // launched above are done. This will allow the aggregation range loop below - // to terminate properly. Note that this must start after the wg.Add call. - // See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). - go func() { wg.Wait(); close(out) }() + // Launch a goroutine to sample the database. + wg.Add(1) + go func(db string, cfg RepoConfig) { + defer func() { + if sema != nil { + // Release the slot once the goroutine is done. + sema.Release(1) + } + wg.Done() + }() + // Sample this specific database. + samples, err := s.sampleDb(ctx, db) + if err != nil && len(samples) == 0 { + log.WithError(err).Errorf("error gathering repository data samples for database %s", db) + return + } + // Send the samples for this database to the 'out' channel. The + // samples for each database will be aggregated into a single + // slice on the main goroutine and returned. + select { + case <-ctx.Done(): + case out <- samplesAndErr{samples: samples, err: err}: + } + }(db, s.config.RepoConfig) + } + }() // Aggregate and return the results. var ret []Sample