From ce8478fc7be123ecb0992c25664c69e4b24164c9 Mon Sep 17 00:00:00 2001 From: Chris Campo Date: Tue, 27 Aug 2024 12:44:27 -0400 Subject: [PATCH 1/2] Add concurrency control parameters and bump deps --- cmd/dmap/repo_scan.go | 51 ++++++---- go.mod | 38 +++---- go.sum | 38 +++++++ sql/config.go | 8 ++ sql/scanner.go | 226 ++++++++++++++++++++++++++++-------------- sql/scanner_test.go | 24 ++--- 6 files changed, 255 insertions(+), 130 deletions(-) diff --git a/cmd/dmap/repo_scan.go b/cmd/dmap/repo_scan.go index f864f08..7e8561e 100644 --- a/cmd/dmap/repo_scan.go +++ b/cmd/dmap/repo_scan.go @@ -6,6 +6,7 @@ import ( "fmt" "reflect" "strings" + "time" "github.com/alecthomas/kong" "github.com/gobwas/glob" @@ -15,21 +16,24 @@ import ( ) type RepoScanCmd struct { - Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` - Host string `help:"Hostname of the repository." required:""` - Port uint16 `help:"Port of the repository." required:""` - User string `help:"Username to connect to the repository." required:""` - Password string `help:"Password to connect to the repository." required:""` - RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."` - Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` - Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."` - IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"` - ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."` - MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` - SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` - Offset uint `help:"Offset to start sampling each table from." default:"0"` - LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."` - Silent bool `help:"Do not print the results to stdout." short:"s"` + Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` + Host string `help:"Hostname of the repository." required:""` + Port uint16 `help:"Port of the repository." required:""` + User string `help:"Username to connect to the repository." required:""` + Password string `help:"Password to connect to the repository." required:""` + RepoID string `help:"The ID of the repository used by the Dmap service to identify the data repository. For RDS or Redshift, this is the ARN of the database. Optional, but required to publish the scan results Dmap service."` + Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` + Advanced map[string]any `help:"Advanced configuration for the repository, semicolon separated (e.g. key1=value1;key2=value2). Please see the documentation for details on how to provide this argument for specific repository types."` + IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)." default:"*"` + ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s), semicolon separated (e.g. foo*;bar*;*.baz)."` + MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` + MaxParallelDbs uint `help:"Maximum number of parallel databases scanned at once. If zero, there is no limit." default:"0"` + MaxConcurrency uint `help:"Maximum number of concurrent query goroutines. If zero, there is no limit." default:"0"` + QueryTimeout time.Duration `help:"Maximum time a query can run before being cancelled. If zero, there is no timeout." default:"0s"` + SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` + Offset uint `help:"Offset to start sampling each table from." default:"0"` + LabelYamlFile string `help:"Filename of the yaml file containing the custom set of data labels (e.g. /path/to/labels.yaml). If omitted, a set of predefined labels is used."` + Silent bool `help:"Do not print the results to stdout." short:"s"` } func (cmd *RepoScanCmd) Validate() error { @@ -69,13 +73,16 @@ func (cmd *RepoScanCmd) Run(globals *Globals) error { cfg := sql.ScannerConfig{ RepoType: cmd.Type, RepoConfig: sql.RepoConfig{ - Host: cmd.Host, - Port: cmd.Port, - User: cmd.User, - Password: cmd.Password, - Database: cmd.Database, - MaxOpenConns: cmd.MaxOpenConns, - Advanced: cmd.Advanced, + Host: cmd.Host, + Port: cmd.Port, + User: cmd.User, + Password: cmd.Password, + Database: cmd.Database, + MaxOpenConns: cmd.MaxOpenConns, + MaxParallelDbs: cmd.MaxParallelDbs, + MaxConcurrency: cmd.MaxConcurrency, + QueryTimeout: cmd.QueryTimeout, + Advanced: cmd.Advanced, }, IncludePaths: cmd.IncludePaths, ExcludePaths: cmd.ExcludePaths, diff --git a/go.mod b/go.mod index d3e7e02..12d3966 100644 --- a/go.mod +++ b/go.mod @@ -31,16 +31,16 @@ require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/99designs/go-keychain v0.0.0-20191008050251-8e49817e8af4 // indirect github.com/99designs/keyring v1.2.2 // indirect - github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0 // indirect - github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.1 // indirect - github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0 // indirect github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c // indirect github.com/OneOfOne/xxhash v1.2.8 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect github.com/apache/arrow/go/v15 v15.0.2 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.4 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 // indirect - github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.7 // indirect + github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.15 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect @@ -58,7 +58,7 @@ require ( github.com/danieljoos/wincred v1.2.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dvsekhvalnov/jose2go v1.7.0 // indirect - github.com/gabriel-vasile/mimetype v1.4.4 // indirect + github.com/gabriel-vasile/mimetype v1.4.5 // indirect github.com/go-ini/ini v1.67.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -79,7 +79,7 @@ require ( github.com/pierrec/lz4/v4 v4.1.21 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/client_golang v1.19.1 // indirect + github.com/prometheus/client_golang v1.20.2 // indirect github.com/prometheus/client_model v0.6.1 // indirect github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect @@ -90,19 +90,19 @@ require ( github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/yashtewari/glob-intersection v0.2.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect - go.opentelemetry.io/otel v1.28.0 // indirect - go.opentelemetry.io/otel/metric v1.28.0 // indirect - go.opentelemetry.io/otel/sdk v1.28.0 // indirect - go.opentelemetry.io/otel/trace v1.28.0 // indirect - golang.org/x/crypto v0.25.0 // indirect - golang.org/x/exp v0.0.0-20240707233637-46b078467d37 // indirect - golang.org/x/mod v0.19.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/term v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect - golang.org/x/tools v0.23.0 // indirect - golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect + go.opentelemetry.io/otel v1.29.0 // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/sdk v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect + golang.org/x/crypto v0.26.0 // indirect + golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 // indirect + golang.org/x/mod v0.20.0 // indirect + golang.org/x/net v0.28.0 // indirect + golang.org/x/sys v0.24.0 // indirect + golang.org/x/term v0.23.0 // indirect + golang.org/x/text v0.17.0 // indirect + golang.org/x/tools v0.24.0 // indirect + golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect sigs.k8s.io/yaml v1.4.0 // indirect diff --git a/go.sum b/go.sum index 842393b..cc6b2c4 100644 --- a/go.sum +++ b/go.sum @@ -7,16 +7,22 @@ github.com/99designs/keyring v1.2.2/go.mod h1:wes/FrByc8j7lFOAGLGSNEg8f/PaI3cgTB github.com/Azure/azure-sdk-for-go/sdk/azcore v0.19.0/go.mod h1:h6H6c8enJmmocHUbLiiGY6sx7f9i+X3m1CHdd5c6Rdw= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0 h1:1nGuui+4POelzDwI7RG56yfQJHCnKvwfMoU7VsEp+Zg= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.12.0/go.mod h1:99EvauvlcJ1U06amZiksfYz/3aFGyIhWGHVyiZXtBAI= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0 h1:nyQWyZvwGTvunIMxi1Y9uXkcyr+I7TeNrr/foo4Kpk8= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.14.0/go.mod h1:l38EPgmsp71HHLq9j7De57JcKOWPyhrsW1Awm1JS6K0= github.com/Azure/azure-sdk-for-go/sdk/azidentity v0.11.0/go.mod h1:HcM1YX14R7CJcghJGOYCgdezslRSVzqwLf/q+4Y2r/0= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0 h1:tfLQ34V6F7tVSwoTf/4lH5sE0o6eCJuNDTmH09nDpbc= github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.7.0/go.mod h1:9kIvujWAA58nmPmWB1m23fyWic1kYZMxD9CxaWn4Qpg= github.com/Azure/azure-sdk-for-go/sdk/internal v0.7.0/go.mod h1:yqy467j36fJxcRV2TzfVZ1pCb5vxm4BtZPUdYWe/Xo8= github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.1 h1:Xy/qV1DyOhhqsU/z0PyFMJfYCxnzna+vBEUtFW0ksQo= github.com/Azure/azure-sdk-for-go/sdk/internal v1.9.1/go.mod h1:oib6iWdC+sILvNUoJbbBn3xv7TXow7mEp/WRcsYvmow= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0 h1:ywEEhmNahHBihViHepv3xPBn1663uRv2t2q/ESv9seY= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.10.0/go.mod h1:iZDifYGJTIgIIkYRNWPENUnqx6bJ2xnSDFI2tjwZNuY= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.5.0 h1:AifHbc4mg0x9zW52WOpKbsHaDKuRhlI7TVl47thgQ70= github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/storage/armstorage v1.5.0/go.mod h1:T5RfihdXtBDxt1Ch2wobif3TvzTdumDy29kahv6AV9A= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2 h1:YUUxeiOWgdAQE3pXt2H7QXzZs0q8UBjgRbl56qo8GYM= github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.3.2/go.mod h1:dmXQgZuiSubAecswZE+Sm8jkvEa7kQgTPVRvwL/nd0E= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0 h1:Be6KInmFEKV81c0pOAEbRYehLMwmmGI1exuFj248AMk= +github.com/Azure/azure-sdk-for-go/sdk/storage/azblob v1.4.0/go.mod h1:WCPBHsOXfBVnivScjs2ypRfimjEW0qPVLGgJkZlrIOA= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 h1:XHOnouVk1mxXfQidrMEnLlPk9UMeRtyBTnEFtxkV0kU= github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83P5OoQ5p6ip3ScPr0BAq0BvuPAvacpEuSzyouqAI= github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= @@ -49,6 +55,8 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.7 h1:kNemAUX+bJFBSfPkGVZ8HFOKIadjLoI2Ua1ZKivhGSo= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.7/go.mod h1:71S2C1g/Zjn+ANmyoOqJ586OrPF9uC9iiHt9ZAT+MOw= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.15 h1:ijB7hr56MngOiELJe0C5aQRaBQ11LveNgWFyG02AUto= +github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.17.15/go.mod h1:0QEmQSSWMVfiAk93l1/ayR9DQ9+jwni7gHS2NARZXB0= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= @@ -120,6 +128,8 @@ github.com/foxcpp/go-mockdns v1.1.0 h1:jI0rD8M0wuYAxL7r/ynTrCQQq0BVqfB99Vgk7Dlme github.com/foxcpp/go-mockdns v1.1.0/go.mod h1:IhLeSFGed3mJIAXPH2aiRQB+kqz7oqu8ld2qVbOu7Wk= github.com/gabriel-vasile/mimetype v1.4.4 h1:QjV6pZ7/XZ7ryI2KuyeEDE8wnh7fHP9YnQy+R0LnH8I= github.com/gabriel-vasile/mimetype v1.4.4/go.mod h1:JwLei5XPtWdGiMFB5Pjle1oEeoSeEuJfJE+TtfvdB/s= +github.com/gabriel-vasile/mimetype v1.4.5 h1:J7wGKdGu33ocBOhGy0z653k/lFKLFDPJMG8Gql0kxn4= +github.com/gabriel-vasile/mimetype v1.4.5/go.mod h1:ibHel+/kbxn9x2407k1izTA1S81ku1z/DlgOW2QE0M4= github.com/go-ini/ini v1.67.0 h1:z6ZrTEZqSWOTyH2FlglNbNgARyHG8oLW9gMELqKr06A= github.com/go-ini/ini v1.67.0/go.mod h1:ByCAeIL28uOIIG0E3PJtZPDL8WnHpFKFOtgjp+3Ies8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -210,6 +220,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_golang v1.19.1 h1:wZWJDwK+NameRJuPGDhlnFgx8e8HN3XHQeLaYJFJBOE= github.com/prometheus/client_golang v1.19.1/go.mod h1:mP78NwGzrVks5S2H6ab8+ZZGJLZUq1hoULYBAYBw1Ho= +github.com/prometheus/client_golang v1.20.2 h1:5ctymQzZlyOON1666svgwn3s6IKWgfbjsejTMiXIyjg= +github.com/prometheus/client_golang v1.20.2/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= github.com/prometheus/common v0.55.0 h1:KEi6DK7lXW/m7Ig5i47x0vRzuBsHuvJdi5ee6Y3G1dc= @@ -251,16 +263,24 @@ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 h1:4K4tsIX go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0/go.mod h1:jjdQuTGVsXV4vSs+CJ2qYDeDPf9yIJV23qlIzBm73Vg= go.opentelemetry.io/otel v1.28.0 h1:/SqNcYk+idO0CxKEUOtKQClMK/MimZihKYMruSMViUo= go.opentelemetry.io/otel v1.28.0/go.mod h1:q68ijF8Fc8CnMHKyzqL6akLO46ePnjkgfIMIjUIX9z4= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0 h1:3Q/xZUyC1BBkualc9ROb4G8qkH90LXEIICcs5zv1OYY= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.28.0/go.mod h1:s75jGIWA9OfCMzF0xr+ZgfrB5FEbbV7UuYo32ahUiFI= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0 h1:R3X6ZXmNPRR8ul6i3WgFURCHzaXjHdm0karRG/+dj3s= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.28.0/go.mod h1:QWFXnDavXWwMx2EEcZsf3yxgEKAqsxQ+Syjp+seyInw= go.opentelemetry.io/otel/metric v1.28.0 h1:f0HGvSl1KRAU1DLgLGFjrwVyismPlnuU6JD6bOeuA5Q= go.opentelemetry.io/otel/metric v1.28.0/go.mod h1:Fb1eVBFZmLVTMb6PPohq3TO9IIhUisDsbJoL/+uQW4s= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= go.opentelemetry.io/otel/sdk v1.28.0 h1:b9d7hIry8yZsgtbmM0DKyPWMMUMlK9NEKuIG4aBqWyE= go.opentelemetry.io/otel/sdk v1.28.0/go.mod h1:oYj7ClPUA7Iw3m+r7GeEjz0qckQRJK2B8zjcZEfu7Pg= +go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= +go.opentelemetry.io/otel/sdk v1.29.0/go.mod h1:pM8Dx5WKnvxLCb+8lG1PRNIDxu9g9b9g59Qr7hfAAok= go.opentelemetry.io/otel/trace v1.28.0 h1:GhQ9cUuQGmNDd5BTCP2dAvv75RdMxEfTmYejp+lkx9g= go.opentelemetry.io/otel/trace v1.28.0/go.mod h1:jPyXzNPg6da9+38HEwElrQiHlVMTnVfM3/yv2OlIHaI= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -272,8 +292,12 @@ golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDf golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.25.0 h1:ypSNr+bnYL2YhwoMt2zPxHFmbAN1KZs/njMG3hxUp30= golang.org/x/crypto v0.25.0/go.mod h1:T+wALwcMOSE0kXgUAnPAHqTLW+XHgcELELW8VaDgm/M= +golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw= +golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54= golang.org/x/exp v0.0.0-20240707233637-46b078467d37 h1:uLDX+AfeFCct3a2C7uIWBKMJIR3CJMhcgfrUAqjRK6w= golang.org/x/exp v0.0.0-20240707233637-46b078467d37/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948 h1:kx6Ds3MlpiUHKj7syVnbp57++8WpuKPcR5yjLBjvLEA= +golang.org/x/exp v0.0.0-20240823005443-9b4947da3948/go.mod h1:akd2r19cwCdwSwWeIdzYQGa/EZZyqcOdwWiwj5L5eKQ= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= @@ -281,6 +305,8 @@ golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.19.0 h1:fEdghXQSo20giMthA7cd28ZC+jts4amQ3YMXiP5oMQ8= golang.org/x/mod v0.19.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= @@ -294,6 +320,8 @@ golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= +golang.org/x/net v0.28.0 h1:a9JDOJc5GMUJ0+UDqmLT86WiEy7iWyIhz8gz8E4e5hE= +golang.org/x/net v0.28.0/go.mod h1:yqtgsTWOOnlGLG9GFRrK3++bGOUEkNBoHZc8MEDWPNg= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -318,6 +346,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.24.0 h1:Twjiwq9dn6R1fQcyiK+wQyHWfaz/BJB+YIpzU/Cv3Xg= +golang.org/x/sys v0.24.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= @@ -328,6 +358,8 @@ golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.22.0 h1:BbsgPEJULsl2fV/AT3v15Mjva5yXKQDyKf+TbDz7QJk= golang.org/x/term v0.22.0/go.mod h1:F3qCibpT5AMpCRfhfT53vVJwhLtIVHhB9XDjfFvnMI4= +golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU= +golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -339,6 +371,8 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +golang.org/x/text v0.17.0 h1:XtiM5bkSOt+ewxlOE/aE/AKEHibwj/6gvWMl9Rsh0Qc= +golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= golang.org/x/time v0.6.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -349,9 +383,13 @@ golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58 golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.23.0 h1:SGsXPZ+2l4JsgaCKkx+FQ9YZ5XEtA1GZYuoDjenLjvg= golang.org/x/tools v0.23.0/go.mod h1:pnu6ufv6vQkll6szChhK3C3L/ruaIv5eBeztNG8wtsI= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= +golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 h1:LLhsEBxRTBLuKlQxFBYUOU8xyFgXv6cOTp2HASDlsDk= +golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= gonum.org/v1/gonum v0.12.0 h1:xKuo6hzt+gMav00meVPUlXwSdoEJP46BR+wdxQEFK2o= gonum.org/v1/gonum v0.12.0/go.mod h1:73TDxJfAAHeA8Mk9mf8NlIppyhQNo5GLTcYeqgo2lvY= google.golang.org/genproto/googleapis/api v0.0.0-20240701130421-f6361c86f094 h1:0+ozOGcrp+Y8Aq8TLNN2Aliibms5LEzsq99ZZmAGYm0= diff --git a/sql/config.go b/sql/config.go index 648fa4a..b4af1b2 100644 --- a/sql/config.go +++ b/sql/config.go @@ -2,6 +2,7 @@ package sql import ( "fmt" + "time" ) // RepoConfig is the necessary configuration to connect to a data sql. @@ -18,6 +19,13 @@ type RepoConfig struct { Database string // MaxOpenConns is the maximum number of open connections to the database. MaxOpenConns uint + // MaxParallelDbs is the maximum number of parallel databases scanned at + // once. + MaxParallelDbs uint + // MaxConcurrency is the maximum number of concurrent query goroutines. + MaxConcurrency uint + // QueryTimeout is the maximum time a query can run before being cancelled. + QueryTimeout time.Duration // Advanced is a map of advanced configuration options. Advanced map[string]any } diff --git a/sql/scanner.go b/sql/scanner.go index 4201a8a..e7eae2c 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -16,6 +16,18 @@ import ( "github.com/cyralinc/dmap/scan" ) +// Pair type intended to be passed to a channel (see sampleAllDbs). +type samplesAndErr struct { + samples []Sample + err error +} + +// Pair type intended to be passed to a channel (see sampleDb). +type sampleAndErr struct { + sample Sample + err error +} + // ScannerConfig is the configuration for the Scanner. type ScannerConfig struct { RepoType string @@ -149,49 +161,96 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { } defer func() { _ = repo.Close() }() // Introspect the repository to get the metadata. + introspectCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + introspectCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } introspectParams := IntrospectParameters{ IncludePaths: s.config.IncludePaths, ExcludePaths: s.config.ExcludePaths, } - meta, err := repo.Introspect(ctx, introspectParams) + meta, err := repo.Introspect(introspectCtx, introspectParams) if err != nil { return nil, fmt.Errorf("error introspecting repository: %w", err) } - // This is a "pair" type intended to be passed to the channel below. - type sampleAndErr struct { - sample Sample - err error - } - // Fan out sample executions. + // This goroutine launches additional goroutines, one for each table, which + // sample the respective tables and send the results to the out channel. A + // semaphore is optionally used to limit the number of tables that are + // sampled concurrently. We do this on a dedicated goroutine so we can + // immediately read from the out channel on this goroutine, and avoid + // possible deadlocks due to the semaphore. out := make(chan sampleAndErr) - numTables := 0 - for _, schemaMeta := range meta.Schemas { - for _, tableMeta := range schemaMeta.Tables { - numTables++ - go func(meta *TableMetadata) { - params := SampleParameters{ - Metadata: meta, - SampleSize: s.config.SampleSize, - Offset: s.config.Offset, - } - sample, err := repo.SampleTable(ctx, params) - select { - case <-ctx.Done(): - return - case out <- sampleAndErr{sample: sample, err: err}: + go func() { + // Before we return, wait for all the goroutines we launch below to + // complete, and then close the out channel once they're all done so the + // main goroutine can aggregate the results and return them. + var wg sync.WaitGroup + defer func() { wg.Wait(); close(out) }() + // Optionally use a semaphore to limit the number of tables sampled + // concurrently. + var sema *semaphore.Weighted + if s.config.RepoConfig.MaxConcurrency > 0 { + sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxConcurrency)) + } + for _, schemaMeta := range meta.Schemas { + for _, tableMeta := range schemaMeta.Tables { + if sema != nil { + // Acquire a semaphore slot before launching a goroutine to + // sample the table. This will block if the semaphore is + // full, and will unblock once a slot is available. An error + // means the context was cancelled. + if err := sema.Acquire(ctx, 1); err != nil { + log.WithError(err).Error("error acquiring semaphore") + return + } } - }(tableMeta) + wg.Add(1) + // Launch a goroutine to sample the table. + go func(ctx context.Context, meta *TableMetadata) { + sampleCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + sampleCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } + params := SampleParameters{ + Metadata: meta, + SampleSize: s.config.SampleSize, + Offset: s.config.Offset, + } + sample, err := repo.SampleTable(sampleCtx, params) + select { + case <-ctx.Done(): + case out <- sampleAndErr{sample: sample, err: err}: + } + if sema != nil { + // Release the slot once the goroutine is done. + sema.Release(1) + } + wg.Done() + }(ctx, tableMeta) + } } - } + }() // Aggregate and return the results. var samples []Sample var errs error - for i := 0; i < numTables; i++ { + for { select { case <-ctx.Done(): - return samples, ctx.Err() - case res := <-out: + errs = errors.Join(errs, ctx.Err()) + return samples, fmt.Errorf("error(s) sampling repository: %w", errs) + case res, ok := <-out: + if !ok { + // The out channel has been closed, so we're done. + if errs != nil { + return samples, fmt.Errorf("error(s) sampling repository: %w", errs) + } + return samples, nil + } if res.err != nil { errs = errors.Join(errs, res.err) } else { @@ -199,11 +258,6 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { } } } - close(out) - if errs != nil { - return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) - } - return samples, nil } // sampleAllDbs samples all the databases on the server. It samples each @@ -230,58 +284,76 @@ func (s *Scanner) sampleAllDbs(ctx context.Context) ([]Sample, error) { // We assume that this repository will be connected to the default database // (or at least some database that can discover all the other databases). // Use it to discover all the other databases on the server. - dbs, err := repo.ListDatabases(ctx) + listDbCtx := ctx + if s.config.RepoConfig.QueryTimeout > 0 { + var cancel context.CancelFunc + listDbCtx, cancel = context.WithTimeout(ctx, s.config.RepoConfig.QueryTimeout) + defer cancel() + } + dbs, err := repo.ListDatabases(listDbCtx) if err != nil { return nil, fmt.Errorf("error listing databases: %w", err) } - // Sample each database on a separate goroutine, and send the samples to - // the 'out' channel. Each slice of samples will be aggregated below on this - // goroutine and returned. - var wg sync.WaitGroup - // This is a "pair" type intended to be passed to the channel below. - type samplesAndErr struct { - samples []Sample - err error - } + // This goroutine launches additional goroutines, one for each database, + // which sample the respective databases and send the results to the out + // channel. A semaphore is optionally used to limit the number of databases + // sampled concurrently. We do this on a dedicated goroutine so we can + // immediately read from the out channel on this goroutine, and avoid + // possible deadlocks due to the semaphore. + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() out := make(chan samplesAndErr) - wg.Add(len(dbs)) - // Using a semaphore here ensures that we avoid opening more than the - // specified total number of connections, since we end up creating multiple - // database handles (one per database). - var sema *semaphore.Weighted - if s.config.RepoConfig.MaxOpenConns > 0 { - sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxOpenConns)) - } - for _, db := range dbs { - go func(db string, cfg RepoConfig) { - defer wg.Done() + go func() { + // Before we return, wait for all the goroutines we launch below to + // complete, and then close the out channel once they're all done so the + // main goroutine can aggregate the results and return them. + var wg sync.WaitGroup + defer func() { wg.Wait(); close(out) }() + // Optionally use a semaphore to limit the number of databases sampled + // concurrently. + var sema *semaphore.Weighted + if s.config.RepoConfig.MaxParallelDbs > 0 { + sema = semaphore.NewWeighted(int64(s.config.RepoConfig.MaxParallelDbs)) + } + for _, db := range dbs { if sema != nil { - _ = sema.Acquire(ctx, 1) - defer sema.Release(1) - } - // Sample this specific database. - samples, err := s.sampleDb(ctx, db) - if err != nil && len(samples) == 0 { - log.WithError(err).Errorf("error gathering repository data samples for database %s", db) - return - } - // Send the samples for this database to the 'out' channel. The - // samples for each database will be aggregated into a single slice - // on the main goroutine and returned. - select { - case <-ctx.Done(): - return - case out <- samplesAndErr{samples: samples, err: err}: + // Acquire a semaphore slot before launching a goroutine to + // sample the database. This will block if the semaphore is + // full, and will unblock once a slot is available. An error + // means the context was cancelled. + if err := sema.Acquire(ctx, 1); err != nil { + log.WithError(err).Error("error acquiring semaphore") + return + } } - }(db, s.config.RepoConfig) - } - - // Start a goroutine to close the 'out' channel once all the goroutines we - // launched above are done. This will allow the aggregation range loop below - // to terminate properly. Note that this must start after the wg.Add call. - // See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). - go func() { wg.Wait(); close(out) }() + // Launch a goroutine to sample the database. + wg.Add(1) + go func(db string, cfg RepoConfig) { + defer func() { + if sema != nil { + // Release the slot once the goroutine is done. + sema.Release(1) + } + wg.Done() + }() + // Sample this specific database. + samples, err := s.sampleDb(ctx, db) + if err != nil && len(samples) == 0 { + log.WithError(err).Errorf("error gathering repository data samples for database %s", db) + return + } + // Send the samples for this database to the 'out' channel. The + // samples for each database will be aggregated into a single + // slice on the main goroutine and returned. + select { + case <-ctx.Done(): + case out <- samplesAndErr{samples: samples, err: err}: + } + }(db, s.config.RepoConfig) + } + }() // Aggregate and return the results. var ret []Sample diff --git a/sql/scanner_test.go b/sql/scanner_test.go index 3962396..5d24cf4 100644 --- a/sql/scanner_test.go +++ b/sql/scanner_test.go @@ -324,9 +324,9 @@ func TestScanner_sampleAllDbs_Successful_TwoDatabases(t *testing.T) { }, } repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil) + repo.EXPECT().ListDatabases(mock.Anything).Return(dbs, nil) + repo.EXPECT().Introspect(mock.Anything, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(mock.Anything, mock.Anything).Return(sample, nil) repo.EXPECT().Close().Return(nil) repoType := "mock" reg := NewRegistry() @@ -356,8 +356,8 @@ func TestScanner_sampleAllDbs_IntrospectError(t *testing.T) { dbs := []string{"db1", "db2"} introspectErr := errors.New("introspect error") repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(nil, introspectErr) + repo.EXPECT().ListDatabases(mock.Anything).Return(dbs, nil) + repo.EXPECT().Introspect(mock.Anything, mock.Anything).Return(nil, introspectErr) repo.EXPECT().Close().Return(nil) repoType := "mock" reg := NewRegistry() @@ -407,9 +407,9 @@ func TestScanner_sampleAllDbs_SampleError(t *testing.T) { } sampleErr := errors.New("sample error") repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr) + repo.EXPECT().ListDatabases(mock.Anything).Return(dbs, nil) + repo.EXPECT().Introspect(mock.Anything, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(mock.Anything, mock.Anything).Return(Sample{}, sampleErr) repo.EXPECT().Close().Return(nil) repoType := "mock" reg := NewRegistry() @@ -467,10 +467,10 @@ func TestScanner_sampleAllDbs_TwoDatabases_OneSampleError(t *testing.T) { } sampleErr := errors.New("sample error") repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil).Once() - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr).Once() + repo.EXPECT().ListDatabases(mock.Anything).Return(dbs, nil) + repo.EXPECT().Introspect(mock.Anything, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(mock.Anything, mock.Anything).Return(sample, nil).Once() + repo.EXPECT().SampleTable(mock.Anything, mock.Anything).Return(Sample{}, sampleErr).Once() repo.EXPECT().Close().Return(nil) repoType := "mock" reg := NewRegistry() From bb28f860874e28c1f91a5b5b79c37d01bed867ab Mon Sep 17 00:00:00 2001 From: Chris Campo Date: Tue, 27 Aug 2024 16:51:22 -0400 Subject: [PATCH 2/2] Address PR comments --- sql/scanner.go | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/scanner.go b/sql/scanner.go index e7eae2c..3d757fe 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -209,6 +209,13 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { wg.Add(1) // Launch a goroutine to sample the table. go func(ctx context.Context, meta *TableMetadata) { + defer func() { + if sema != nil { + // Release the slot once the goroutine is done. + sema.Release(1) + } + wg.Done() + }() sampleCtx := ctx if s.config.RepoConfig.QueryTimeout > 0 { var cancel context.CancelFunc @@ -225,11 +232,6 @@ func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { case <-ctx.Done(): case out <- sampleAndErr{sample: sample, err: err}: } - if sema != nil { - // Release the slot once the goroutine is done. - sema.Release(1) - } - wg.Done() }(ctx, tableMeta) } }