diff --git a/aws/scanner.go b/aws/scanner.go index 2e9eccf..15c1343 100644 --- a/aws/scanner.go +++ b/aws/scanner.go @@ -14,7 +14,7 @@ import ( "github.com/cyralinc/dmap/scan" ) -// AWSScanner is an implementation of the EnvironmentScanner interface for the AWS cloud +// AWSScanner is an implementation of the Scanner interface for the AWS cloud // provider. It supports scanning data repositories from multiple AWS regions, // including RDS clusters and instances, Redshift clusters and DynamoDB tables. type AWSScanner struct { @@ -23,8 +23,8 @@ type AWSScanner struct { awsClientConstructor awsClientConstructor } -// AWSScanner implements scan.EnvironmentScanner -var _ scan.EnvironmentScanner = (*AWSScanner)(nil) +// AWSScanner implements scan.Scanner +var _ scan.Scanner = (*AWSScanner)(nil) // NewAWSScanner creates a new instance of AWSScanner based on the ScannerConfig. // If AssumeRoleConfig is specified, the AWSScanner will assume this IAM Role @@ -57,7 +57,7 @@ func NewAWSScanner( // Scan performs a scan across all the AWS regions configured and return a scan // results, containing a list of data repositories that includes: RDS clusters // and instances, Redshift clusters and DynamoDB tables. -func (s *AWSScanner) Scan(ctx context.Context) (*scan.EnvironmentScanResults, error) { +func (s *AWSScanner) Scan(ctx context.Context) (*scan.ScanResults, error) { responseChan := make(chan scanResponse) var wg sync.WaitGroup wg.Add(len(s.scannerConfig.Regions)) @@ -100,18 +100,18 @@ func (s *AWSScanner) Scan(ctx context.Context) (*scan.EnvironmentScanResults, er select { case <-ctx.Done(): scanErrors = append(scanErrors, ctx.Err()) - return &scan.EnvironmentScanResults{ + return &scan.ScanResults{ Repositories: repositories, - }, &scan.EnvironmentScanError{Errs: scanErrors} + }, &scan.ScanError{Errs: scanErrors} case response, ok := <-responseChan: if !ok { // Channel closed, all scans finished. var scanErr error if len(scanErrors) > 0 { - scanErr = &scan.EnvironmentScanError{Errs: scanErrors} + scanErr = &scan.ScanError{Errs: scanErrors} } - return &scan.EnvironmentScanResults{ + return &scan.ScanResults{ Repositories: repositories, }, scanErr diff --git a/aws/scanner_test.go b/aws/scanner_test.go index 07d4aa3..37c70c8 100644 --- a/aws/scanner_test.go +++ b/aws/scanner_test.go @@ -181,7 +181,7 @@ func (s *AWSScannerTestSuite) TestScan() { ctx := context.Background() results, err := awsScanner.Scan(ctx) - expectedResults := &scan.EnvironmentScanResults{ + expectedResults := &scan.ScanResults{ Repositories: map[string]scan.Repository{ *s.dummyRDSClusters[0].DBClusterArn: { Id: *s.dummyRDSClusters[0].DBClusterArn, @@ -359,7 +359,7 @@ func (s *AWSScannerTestSuite) TestScan_WithErrors() { ctx := context.Background() results, err := awsScanner.Scan(ctx) - expectedResults := &scan.EnvironmentScanResults{ + expectedResults := &scan.ScanResults{ Repositories: nil, } diff --git a/classification/classification.go b/classification/classification.go index 53d8fe8..a93bb7c 100644 --- a/classification/classification.go +++ b/classification/classification.go @@ -8,7 +8,9 @@ package classification import ( "context" - "maps" + "encoding/json" + + "golang.org/x/exp/maps" ) // Classifier is an interface that represents a data classifier. A classifier @@ -22,35 +24,14 @@ type Classifier interface { Classify(ctx context.Context, input map[string]any) (Result, error) } -// ClassifiedTable represents a database table that has been classified. The -// classifications are stored in the Classifications field, which is a map of -// attribute names (i.e. columns) to the set of labels that attributes were -// classified as. -type ClassifiedTable struct { - Repo string `json:"repo"` - Database string `json:"database"` - Schema string `json:"schema"` - Table string `json:"table"` - Classifications Result `json:"classifications"` -} - // Result represents the classifications for a set of data attributes. The key // is the attribute (i.e. column) name and the value is the set of labels // that attribute was classified as. type Result map[string]LabelSet -// Merge combines the given other Result into this Result (the receiver). If -// an attribute from other is already present in this Result, the existing -// labels for that attribute are merged with the labels from other, otherwise -// labels from other for the attribute are simply added to this Result. -func (c Result) Merge(other Result) { - if c == nil { - return - } - for attr, labelSet := range other { - if _, ok := c[attr]; !ok { - c[attr] = make(LabelSet) - } - maps.Copy(c[attr], labelSet) - } +// LabelSet is a set of unique labels. +type LabelSet map[string]struct{} + +func (l LabelSet) MarshalJSON() ([]byte, error) { + return json.Marshal(maps.Keys(l)) } diff --git a/classification/classification_test.go b/classification/classification_test.go deleted file mode 100644 index d6dde27..0000000 --- a/classification/classification_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package classification - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestMerge_WhenCalledOnNilReceiver_ShouldNotPanic(t *testing.T) { - var result Result - require.NotPanics( - t, func() { - result.Merge(Result{"age": {"AGE": {Name: "AGE"}}}) - }, - ) -} - -func TestMerge_WhenCalledWithNonExistingAttributes_ShouldAddThem(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE"}}, - } - other := Result{ - "social_sec_num": {"SSN": {Name: "SSN"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE"}}, - "social_sec_num": {"SSN": {Name: "SSN"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} - -func TestMerge_WhenCalledWithExistingAttributes_ShouldMergeLabelSets(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE"}}, - } - other := Result{ - "age": {"CVV": {Name: "CVV"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE"}, "CVV": {Name: "CVV"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} - -func TestMerge_WhenCalledWithExistingAttributes_ShouldOverwrite(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE", Description: "Foo"}}, - } - other := Result{ - "age": {"AGE": {Name: "AGE", Description: "Bar"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE", Description: "Bar"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} diff --git a/classification/label.go b/classification/label.go index 88702e6..7aa38b5 100644 --- a/classification/label.go +++ b/classification/label.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "gopkg.in/yaml.v3" ) @@ -20,18 +20,23 @@ var ( // Label represents a data classification label. type Label struct { - Name string `yaml:"name" json:"name"` - Description string `yaml:"description" json:"description"` - Tags []string `yaml:"tags" json:"tags"` - ClassificationRule *rego.Rego `yaml:"-" json:"-"` + // Name is the name of the label. + Name string `yaml:"name" json:"name"` + // Description is a brief description of the label. + Description string `yaml:"description" json:"description"` + // Tags are a list of arbitrary tags associated with the label. + Tags []string `yaml:"tags" json:"tags"` + // ClassificationRule is the compiled Rego classification rule used to + // classify data. + ClassificationRule *ast.Module `yaml:"-" json:"-"` } // NewLabel creates a new Label with the given name, description, classification // rule, and tags. The classification rule is expected to be the raw Rego code // that will be used to classify data. If the classification rule is invalid, an // error is returned. -func NewLabel(name, description, classificationRule string, tags []string) (Label, error) { - rule, err := ruleRego(classificationRule) +func NewLabel(name, description, classificationRule string, tags ...string) (Label, error) { + rule, err := parseRego(classificationRule) if err != nil { return Label{}, fmt.Errorf("error preparing classification rule for label %s: %w", name, err) } @@ -43,25 +48,12 @@ func NewLabel(name, description, classificationRule string, tags []string) (Labe }, nil } -// LabelSet is a set of unique labels. The key is the label name and the value -// is the label itself. -type LabelSet map[string]Label - -// ToSlice returns the labels in the set as a slice. -func (s LabelSet) ToSlice() []Label { - var labels []Label - for _, label := range s { - labels = append(labels, label) - } - return labels -} - // GetEmbeddedLabels returns the predefined embedded labels and their // classification rules. The labels are read from the embedded labels.yaml file // and the classification rules are read from the embedded Rego files. -func GetEmbeddedLabels() (LabelSet, error) { +func GetEmbeddedLabels() ([]Label, error) { labels := struct { - Labels LabelSet `yaml:"labels"` + Labels map[string]Label `yaml:"labels"` }{} if err := yaml.Unmarshal([]byte(labelsYaml), &labels); err != nil { return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err) @@ -72,7 +64,7 @@ func GetEmbeddedLabels() (LabelSet, error) { if err != nil { return nil, fmt.Errorf("error reading rego file %s: %w", fname, err) } - rule, err := ruleRego(string(b)) + rule, err := parseRego(string(b)) if err != nil { return nil, fmt.Errorf("error preparing classification rule for label %s: %w", lbl.Name, err) } @@ -80,16 +72,14 @@ func GetEmbeddedLabels() (LabelSet, error) { lbl.ClassificationRule = rule labels.Labels[name] = lbl } - return labels.Labels, nil + return maps.Values(labels.Labels), nil } -func ruleRego(code string) (*rego.Rego, error) { +func parseRego(code string) (*ast.Module, error) { log.Tracef("classifier module code: '%s'", code) - moduleName := "classifier" - compiledRego, err := ast.CompileModules(map[string]string{moduleName: code}) + module, err := ast.ParseModule("classifier", code) if err != nil { - return nil, fmt.Errorf("error compiling rego code: %w", err) + return nil, fmt.Errorf("error parsing rego code: %w", err) } - regoQuery := compiledRego.Modules[moduleName].Package.Path.String() + ".output" - return rego.New(rego.Query(regoQuery), rego.Compiler(compiledRego)), nil + return module, nil } diff --git a/classification/label_classifier.go b/classification/label_classifier.go index 37fcc29..c7a23b9 100644 --- a/classification/label_classifier.go +++ b/classification/label_classifier.go @@ -11,37 +11,41 @@ import ( // LabelClassifier is a Classifier implementation that uses a set of labels and // their classification rules to classify data. type LabelClassifier struct { - labels LabelSet + queries map[string]*rego.Rego } // LabelClassifier implements Classifier var _ Classifier = (*LabelClassifier)(nil) -// NewLabelClassifier creates a new LabelClassifier with the provided labels and -// classification rules. +// NewLabelClassifier creates a new LabelClassifier with the provided labels. +// func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) { if len(labels) == 0 { return nil, fmt.Errorf("labels cannot be empty") } - l := make(LabelSet, len(labels)) + queries := make(map[string]*rego.Rego, len(labels)) for _, lbl := range labels { - l[lbl.Name] = lbl + queries[lbl.Name] = rego.New( + // We only care about the 'output' variable. + rego.Query(lbl.ClassificationRule.Package.Path.String() + ".output"), + rego.ParsedModule(lbl.ClassificationRule), + ) } - return &LabelClassifier{labels: l}, nil + return &LabelClassifier{queries: queries}, nil } // Classify performs the classification of the provided attributes using the -// classifier's labels and classification rules. It returns a Result, which is -// a map of attribute names to the set of labels that the attribute was -// classified as. +// classifier's labels and their corresponding classification rules. It returns +// a Result, which is a map of attribute names to the set of labels that the +// attribute was classified as. func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (Result, error) { - result := make(Result, len(c.labels)) - for _, lbl := range c.labels { - output, err := evalQuery(ctx, lbl.ClassificationRule, input) + result := make(Result, len(c.queries)) + for lbl, query := range c.queries { + output, err := evalQuery(ctx, query, input) if err != nil { - return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl.Name, err) + return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl, err) } - log.Debugf("classification results for label %s: %v", lbl.Name, output) + log.Debugf("classification results for label %s: %v", lbl, output) for attrName, classified := range output { if classified { attrLabels, ok := result[attrName] @@ -49,21 +53,21 @@ func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (R attrLabels = make(LabelSet) result[attrName] = attrLabels } - attrLabels[lbl.Name] = lbl + // Add the label to the set of labels for the attribute. + attrLabels[lbl] = struct{}{} } } } return result, nil } -// evalQuery evaluates the provided prepared Rego query with the given -// attributes as input, and returns the classification results. The output is a +// evalQuery evaluates the provided Rego query with the given attributes as input, and returns the classification results. The output is a // map of attribute names to boolean values, where the boolean indicates whether // the attribute is classified as belonging to the label. -func evalQuery(ctx context.Context, rule *rego.Rego, input map[string]any) (map[string]bool, error) { - q, err := rule.PrepareForEval(ctx) +func evalQuery(ctx context.Context, query *rego.Rego, input map[string]any) (map[string]bool, error) { + q, err := query.PrepareForEval(ctx) if err != nil { - return nil, fmt.Errorf("error preparing rule for evaluation: %w", err) + return nil, fmt.Errorf("error preparing query for evaluation: %w", err) } // Evaluate the prepared Rego query. This performs the actual classification // logic. diff --git a/classification/label_classifier_test.go b/classification/label_classifier_test.go index abea713..50fd1e0 100644 --- a/classification/label_classifier_test.go +++ b/classification/label_classifier_test.go @@ -9,7 +9,9 @@ import ( ) func TestNewLabelClassifier_Success(t *testing.T) { - classifier, err := NewLabelClassifier(Label{Name: "foo"}) + lbl, err := NewLabel("foo", "test label", "package foo\noutput = true") + require.NoError(t, err) + classifier, err := NewLabelClassifier(lbl) require.NoError(t, err) require.NotNil(t, classifier) } @@ -40,7 +42,7 @@ func TestLabelClassifier_Classify(t *testing.T) { input: map[string]any{"age": "42"}, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -53,7 +55,7 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -63,7 +65,7 @@ func TestLabelClassifier_Classify(t *testing.T) { input: map[string]any{"age": "42"}, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -76,10 +78,10 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, "ccn": { - "CCN": Label{Name: "CCN"}, + "CCN": {}, }, }, }, @@ -92,11 +94,11 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, - "CVV": Label{Name: "CVV"}, + "AGE": {}, + "CVV": {}, }, "cvv": { - "CVV": Label{Name: "CVV"}, + "CVV": {}, }, }, }, @@ -127,10 +129,9 @@ func requireResultEqual(t *testing.T, want, got Result) { func requireLabelSetEqual(t *testing.T, want, got LabelSet) { require.Len(t, got, len(want)) - for k, v := range want { - gotLbl, ok := got[k] + for k := range want { + _, ok := got[k] require.Truef(t, ok, "missing label %s", k) - require.Equal(t, v.Name, gotLbl.Name) } } @@ -149,7 +150,7 @@ func newTestLabel(t *testing.T, lblName string) Label { fin, err := regoFs.ReadFile(fname) require.NoError(t, err) classifierCode := string(fin) - lbl, err := NewLabel(lblName, "test label", classifierCode, nil) + lbl, err := NewLabel(lblName, "test label", classifierCode) require.NoError(t, err) return lbl } diff --git a/classification/label_test.go b/classification/label_test.go index 8b718f2..35a0b69 100644 --- a/classification/label_test.go +++ b/classification/label_test.go @@ -1,8 +1,11 @@ package classification import ( + "context" + "fmt" "testing" + "github.com/open-policy-agent/opa/rego" "github.com/stretchr/testify/require" ) @@ -11,3 +14,52 @@ func TestGetEmbeddedLabels(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, got) } + +func TestRego(t *testing.T) { + + module := ` +package example.authz + +import rego.v1 + +default allow := false + +allow if { + input.method == "GET" + input.path == ["salary", input.subject.user] +} + +allow if is_admin + +is_admin if "admin" in input.subject.groups +` + + mod, err := parseRego(module) + require.NoError(t, err) + require.NotNil(t, mod) + path := mod.Package.Path.String() + fmt.Println(path) + + ctx := context.TODO() + + query, err := rego.New( + rego.Query("data.example.authz.allow"), + rego.Module("example.rego", module), + ).PrepareForEval(ctx) + require.NoError(t, err) + require.NotNil(t, query) + + input := map[string]interface{}{ + "method": "GET", + "path": []interface{}{"salary", "bob"}, + "subject": map[string]interface{}{ + "user": "bob", + "groups": []interface{}{"sales", "marketing"}, + }, + } + + results, err := query.Eval(ctx, rego.EvalInput(input)) + require.NoError(t, err) + require.NotEmpty(t, results) + require.True(t, results.Allowed()) +} diff --git a/classification/publisher.go b/classification/publisher.go deleted file mode 100644 index aed92f4..0000000 --- a/classification/publisher.go +++ /dev/null @@ -1,15 +0,0 @@ -package classification - -import ( - "context" -) - -// Publisher publishes classification and discovery results to some destination, -// which is left up to the implementer. -// TODO: add doc about labels -ccampo 2024-04-02 -type Publisher interface { - // PublishClassifications publishes a slice of ClassifiedTable to some - // destination. Any error(s) during publication should be returned. - // TODO: add labels -ccampo 2024-04-02 - PublishClassifications(ctx context.Context, repoId string, results []ClassifiedTable) error -} diff --git a/classification/stdout.go b/classification/stdout.go deleted file mode 100644 index 883dc7e..0000000 --- a/classification/stdout.go +++ /dev/null @@ -1,33 +0,0 @@ -package classification - -import ( - "context" - "encoding/json" - "fmt" -) - -// StdOutPublisher "publishes" classification results to stdout in JSON format. -type StdOutPublisher struct{} - -// StdOutPublisher implements Publisher -var _ Publisher = (*StdOutPublisher)(nil) - -// TODO: godoc -ccampo 2024-04-02 -func NewStdOutPublisher() *StdOutPublisher { - return &StdOutPublisher{} -} - -// TODO: godoc -ccampo 2024-04-02 -func (c *StdOutPublisher) PublishClassifications(_ context.Context, _ string, tables []ClassifiedTable) error { - results := struct { - Results []ClassifiedTable `json:"results"` - }{ - Results: tables, - } - b, err := json.MarshalIndent(results, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal results: %w", err) - } - fmt.Println(string(b)) - return nil -} diff --git a/cmd/main.go b/cmd/main.go index 632b9c3..fee5fd2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,9 +8,12 @@ import ( ) type Globals struct { - LogLevel logLevelFlag `help:"Set the logging level (trace|debug|info|warn|error|fatal)" enum:"trace,debug,info,warn,error,fatal" default:"info"` - LogFormat logFormatFlag `help:"Set the logging format (text|json)" enum:"text,json" default:"text"` - Version kong.VersionFlag `name:"version" help:"Print version information and quit"` + LogLevel logLevelFlag `help:"Set the logging level (trace|debug|info|warn|error|fatal)" enum:"trace,debug,info,warn,error,fatal" default:"info"` + LogFormat logFormatFlag `help:"Set the logging format (text|json)" enum:"text,json" default:"text"` + Version kong.VersionFlag `name:"version" help:"Print version information and quit"` + ApiBaseUrl string `help:"Base URL of the Dmap API." default:"https://api.dmap.cyral.io"` + ClientID string `help:"API client ID to access the Dmap API."` + ClientSecret string `help:"API client secret to access the Dmap API."` //#nosec G101 -- false positive } type logLevelFlag string diff --git a/cmd/repo_scan.go b/cmd/repo_scan.go index b5830d9..ce4f327 100644 --- a/cmd/repo_scan.go +++ b/cmd/repo_scan.go @@ -2,21 +2,86 @@ package main import ( "context" + "encoding/json" "fmt" + "reflect" + "strings" - "github.com/cyralinc/dmap/scan" + "github.com/alecthomas/kong" + "github.com/gobwas/glob" + + "github.com/cyralinc/dmap/sql" ) type RepoScanCmd struct { - scan.RepoScannerConfig + Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` + ExternalID string `help:"External ID of the repository." required:""` + Host string `help:"Hostname of the repository." required:""` + Port uint16 `help:"Port of the repository." required:""` + User string `help:"Username to connect to the sql." required:""` + Password string `help:"Password to connect to the sql." required:""` + Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` + Advanced map[string]any `help:"Advanced configuration for the sql."` + IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s)." default:"*"` + ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s)." default:""` + MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` + SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` + Offset uint `help:"Offset to start sampling from." default:"0"` +} + +// GlobFlag is a kong.MapperValue implementation that represents a glob pattern. +type GlobFlag []glob.Glob + +// Decode parses the glob patterns and compiles them into glob.Glob objects. It +// is an implementation of kong.MapperValue's Decode method. +func (g GlobFlag) Decode(ctx *kong.DecodeContext) error { + var patterns string + if err := ctx.Scan.PopValueInto("string", &patterns); err != nil { + return err + } + var parsedPatterns []glob.Glob + for _, pattern := range strings.Split(patterns, ",") { + parsedPattern, err := glob.Compile(pattern) + if err != nil { + return fmt.Errorf("cannot compile %s pattern: %w", pattern, err) + } + parsedPatterns = append(parsedPatterns, parsedPattern) + } + ctx.Value.Target.Set(reflect.ValueOf(GlobFlag(parsedPatterns))) + return nil } func (cmd *RepoScanCmd) Run(_ *Globals) error { ctx := context.Background() - scanner, err := scan.NewRepoScanner(ctx, cmd.RepoScannerConfig) + cfg := sql.ScannerConfig{ + RepoType: cmd.Type, + RepoConfig: sql.RepoConfig{ + Host: cmd.Host, + Port: cmd.Port, + User: cmd.User, + Password: cmd.Password, + Database: cmd.Database, + MaxOpenConns: cmd.MaxOpenConns, + Advanced: cmd.Advanced, + }, + IncludePaths: cmd.IncludePaths, + ExcludePaths: cmd.ExcludePaths, + SampleSize: cmd.SampleSize, + Offset: cmd.Offset, + } + scanner, err := sql.NewScanner(cfg) if err != nil { return fmt.Errorf("error creating new scanner: %w", err) } - defer scanner.Cleanup() - return scanner.Scan(ctx) + results, err := scanner.Scan(ctx) + if err != nil { + return fmt.Errorf("error scanning repository: %w", err) + } + jsonResults, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("error marshalling results: %w", err) + } + fmt.Println(string(jsonResults)) + // TODO: publish results to API -ccampo 2024-04-03 + return nil } diff --git a/discovery/config.go b/discovery/config.go deleted file mode 100644 index 925c31b..0000000 --- a/discovery/config.go +++ /dev/null @@ -1,177 +0,0 @@ -package discovery - -import ( - "fmt" - "reflect" - "strings" - - "github.com/alecthomas/kong" - "github.com/gobwas/glob" -) - -const configConnOpts = "connection-string-args" - -// RepoConfig is the necessary configuration to connect to a data sql. -type RepoConfig struct { - Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` - Host string `help:"Hostname of the sql." required:""` - Port uint16 `help:"Port of the sql." required:""` - User string `help:"Username to connect to the sql." required:""` - Password string `help:"Password to connect to the sql." required:""` - Advanced map[string]any `help:"Advanced configuration for the sql."` - Database string `help:"Name of the database to connect to. If not specified, the default database is used."` - MaxOpenConns uint `help:"Maximum number of open connections to the sql." default:"10"` - SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` - IncludePaths GlobFlag `help:"List of glob patterns to include when querying the database(s), as a comma separated list." default:"*"` - ExcludePaths GlobFlag `help:"List of glob patterns to exclude when querying the database(s), as a comma separated list." default:""` -} - -// GlobFlag is a kong.MapperValue implementation that represents a glob pattern. -type GlobFlag []glob.Glob - -// Decode parses the glob patterns and compiles them into glob.Glob objects. It -// is an implementation of kong.MapperValue's Decode method. -func (g GlobFlag) Decode(ctx *kong.DecodeContext) error { - var patterns string - if err := ctx.Scan.PopValueInto("string", &patterns); err != nil { - return err - } - var parsedPatterns []glob.Glob - for _, pattern := range strings.Split(patterns, ",") { - parsedPattern, err := glob.Compile(pattern) - if err != nil { - return fmt.Errorf("cannot compile %s pattern: %w", pattern, err) - } - parsedPatterns = append(parsedPatterns, parsedPattern) - } - ctx.Value.Target.Set(reflect.ValueOf(GlobFlag(parsedPatterns))) - return nil -} - -// BuildConnOptsStr parses the repo config to produce a string in the format -// "?option=value&option2=value2". Example: -// -// BuildConnOptsStr(RepoConfig{ -// Advanced: map[string]any{ -// "connection-string-args": []any{"sslmode=disable"}, -// }, -// }) -// -// returns ("?sslmode=disable", nil). -func BuildConnOptsStr(cfg RepoConfig) (string, error) { - connOptsMap, err := mapFromConnOpts(cfg) - if err != nil { - return "", fmt.Errorf("connection options: %w", err) - } - - connOptsStr := "" - for key, val := range connOptsMap { - // Don't add if the value is empty, since that would make the - // string malformed. - if val != "" { - if connOptsStr == "" { - connOptsStr += fmt.Sprintf("%s=%s", key, val) - } else { - // Need & for subsequent options - connOptsStr += fmt.Sprintf("&%s=%s", key, val) - } - } - } - // Only add ? if connection string is not empty - if connOptsStr != "" { - connOptsStr = "?" + connOptsStr - } - - return connOptsStr, nil -} - -// FetchAdvancedConfigString fetches a map in the repo advanced configuration, -// for a given repo and set of parameters. Example: -// -// repo-advanced: -// -// snowflake: -// account: exampleAccount -// role: exampleRole -// warehouse: exampleWarehouse -// -// Calling FetchAdvancedMapConfig(, "snowflake", -// []string{"account", "role", "warehouse"}) returns the map -// -// {"account": "exampleAccount", "role": "exampleRole", "warehouse": -// "exampleWarehouse"} -// -// The suffix 'String' means that the values of the map are strings. This gives -// room to have FetchAdvancedConfigList or FetchAdvancedConfigMap, for example, -// without name conflicts. -func FetchAdvancedConfigString( - cfg RepoConfig, - repo string, - parameters []string, -) (map[string]string, error) { - advancedCfg, err := getAdvancedConfig(cfg, repo) - if err != nil { - return nil, err - } - repoSpecificMap := make(map[string]string) - for _, key := range parameters { - var valInterface any - var val string - var ok bool - if valInterface, ok = advancedCfg[key]; !ok { - return nil, fmt.Errorf("unable to find '%s' in %s advanced config", key, repo) - } - if val, ok = valInterface.(string); !ok { - return nil, fmt.Errorf("'%s' in %s config must be a string", key, repo) - } - repoSpecificMap[key] = val - } - return repoSpecificMap, nil -} - -// mapFromConnOpts builds a map from the list of connection options given in the -// Each option has the format 'option=value'. Err only if the config is -// malformed, to inform user. -func mapFromConnOpts(cfg RepoConfig) (map[string]string, error) { - m := make(map[string]string) - connOptsInterface, ok := cfg.Advanced[configConnOpts] - if !ok { - return nil, nil - } - connOpts, ok := connOptsInterface.([]any) - if !ok { - return nil, fmt.Errorf("'%s' is not a list", configConnOpts) - } - for _, optInterface := range connOpts { - opt, ok := optInterface.(string) - if !ok { - return nil, fmt.Errorf("'%v' is not a string", optInterface) - } - splitOpt := strings.Split(opt, "=") - if len(splitOpt) != 2 { - return nil, fmt.Errorf( - "malformed '%s'. "+ - "Please follow the format 'option=value'", configConnOpts, - ) - } - key := splitOpt[0] - val := splitOpt[1] - m[key] = val - } - return m, nil -} - -// getAdvancedConfig gets the Advanced field in a repo config and converts it to -// a map[string]any. In every step, it checks for error and generates -// nice messages. -func getAdvancedConfig(cfg RepoConfig, repo string) (map[string]any, error) { - advancedCfgInterface, ok := cfg.Advanced[repo] - if !ok { - return nil, fmt.Errorf("unable to find '%s' in advanced config", repo) - } - advancedCfg, ok := advancedCfgInterface.(map[string]any) - if !ok { - return nil, fmt.Errorf("'%s' in advanced config is not a map", repo) - } - return advancedCfg, nil -} diff --git a/discovery/doc.go b/discovery/doc.go deleted file mode 100644 index 999e7e2..0000000 --- a/discovery/doc.go +++ /dev/null @@ -1,16 +0,0 @@ -// Package discovery provides mechanisms to perform database introspection and -// data discovery on various data repositories. It provides a RepoScanner type that -// can be used to scan a data repository for sensitive data, classify the data, -// and publish the results to external sources. -// -// Additionally, the SQLRepository interface provides an API for performing -// database introspection and data discovery on SQL databases. It encapsulates -// the concept of a Dmap data SQL repository. All out-of-the-box SQLRepository -// implementations are included in their own files named after the repository -// type, e.g. mysql.go, postgres.go, etc. -// -// Registry provides an API for registering and constructing SQLRepository -// implementations within an application. There is a global DefaultRegistry -// which has all-out-of-the-box SQLRepository implementations registered to it -// by default. -package discovery diff --git a/discovery/gen.go b/discovery/gen.go deleted file mode 100644 index b295816..0000000 --- a/discovery/gen.go +++ /dev/null @@ -1,5 +0,0 @@ -package discovery - -// Mock generation - see https://vektra.github.io/mockery/ - -//go:generate mockery --testonly --inpackage --with-expecter --name=Repository --filename=mock_repository_test.go diff --git a/discovery/postgres.go b/discovery/postgres.go deleted file mode 100644 index df9603a..0000000 --- a/discovery/postgres.go +++ /dev/null @@ -1,127 +0,0 @@ -package discovery - -import ( - "context" - "fmt" - - // Postgresql DB driver - _ "github.com/lib/pq" -) - -const ( - RepoTypePostgres = "postgres" - - PostgresDatabaseQuery = ` -SELECT - datname -FROM - pg_database -WHERE - datistemplate = false - AND datallowconn = true - AND datname <> 'rdsadmin' -` -) - -// PostgresRepository is a SQLRepository implementation for Postgres databases. -type PostgresRepository struct { - // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository -} - -// PostgresRepository implements SQLRepository -var _ SQLRepository = (*PostgresRepository)(nil) - -// NewPostgresRepository creates a new PostgresRepository. -func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) - if err != nil { - return nil, fmt.Errorf("error parsing postgres config: %w", err) - } - database := cfg.Database - // Connect to the default database, if unspecified. - if database == "" { - database = "postgres" - } - connStr := fmt.Sprintf( - "postgresql://%s:%s@%s:%d/%s%s", - cfg.User, - cfg.Password, - cfg.Host, - cfg.Port, - database, - pgCfg.ConnOptsStr, - ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) - if err != nil { - return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) - } - return &PostgresRepository{genericSqlRepo: sqlRepo}, nil -} - -// ListDatabases returns a list of the names of all databases on the server by -// using a Postgres-specific database query. It delegates the actual work to -// GenericRepository.ListDatabasesWithQuery - see that method for more details. -func (r *PostgresRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) -} - -// Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more -// details. -func (r *PostgresRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) -} - -// SampleTable delegates sampling to GenericRepository, using a -// Postgres-specific table sample query. See SQLRepository.SampleTable and -// GenericRepository.SampleTableWithQuery for more details. -func (r *PostgresRepository) SampleTable( - ctx context.Context, - meta *TableMetadata, - params SampleParameters, -) (Sample, error) { - // Postgres uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") - // Postgres uses $x for placeholders - query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) -} - -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and -// GenericRepository.Ping for more details. -func (r *PostgresRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) -} - -// Close delegates the close to GenericRepository. See SQLRepository.Close and -// GenericRepository.Close for more details. -func (r *PostgresRepository) Close() error { - return r.genericSqlRepo.Close() -} - -// PostgresConfig contains Postgres-specific configuration parameters. -type PostgresConfig struct { - // ConnOptsStr is a string containing Postgres-specific connection options. - ConnOptsStr string -} - -// ParsePostgresConfig parses the Postgres-specific configuration parameters -// from the given The Postgres connection options are built from the -// config and stored in the ConnOptsStr field of the returned Postgres -func ParsePostgresConfig(cfg RepoConfig) (*PostgresConfig, error) { - connOptsStr, err := BuildConnOptsStr(cfg) - if err != nil { - return nil, fmt.Errorf("error building connection options string: %w", err) - } - return &PostgresConfig{ConnOptsStr: connOptsStr}, nil -} diff --git a/discovery/postgres_test.go b/discovery/postgres_test.go deleted file mode 100644 index 1794a17..0000000 --- a/discovery/postgres_test.go +++ /dev/null @@ -1,29 +0,0 @@ -package discovery - -import ( - "context" - "database/sql" - "testing" - - "github.com/DATA-DOG/go-sqlmock" - "github.com/stretchr/testify/require" -) - -func TestPostgresRepository_ListDatabases(t *testing.T) { - ctx, db, mock, r := initPostgresRepoTest(t) - defer func() { _ = db.Close() }() - dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(PostgresDatabaseQuery).WillReturnRows(dbRows) - dbs, err := r.ListDatabases(ctx) - require.NoError(t, err) - require.ElementsMatch(t, []string{"db1", "db2"}, dbs) -} - -func initPostgresRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *PostgresRepository) { - ctx := context.Background() - db, mock, err := sqlmock.New() - require.NoError(t, err) - return ctx, db, mock, &PostgresRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypePostgres, "dbName", db), - } -} diff --git a/discovery/registry_test.go b/discovery/registry_test.go deleted file mode 100644 index d459e36..0000000 --- a/discovery/registry_test.go +++ /dev/null @@ -1,113 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// TODO: refactor tests to use registry instance, not default registry -ccampo 2024-04-02 - -func TestRegistry_Register_Successful(t *testing.T) { - repoType := "repoType" - constructor := func(context.Context, RepoConfig) (SQLRepository, error) { - return nil, nil - } - reg := NewRegistry() - err := reg.Register(repoType, constructor) - require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) -} - -func TestRegistry_MustRegister_NilConstructor(t *testing.T) { - reg := NewRegistry() - assert.Panics(t, func() { reg.MustRegister("repoType", nil) }) -} - -func TestRegistry_MustRegister_TwoCalls_Panics(t *testing.T) { - repoType := "repoType" - constructor := func(context.Context, RepoConfig) (SQLRepository, error) { - return nil, nil - } - reg := NewRegistry() - reg.MustRegister(repoType, constructor) - assert.Contains(t, reg.constructors, repoType) - assert.Panics(t, func() { reg.MustRegister(repoType, constructor) }) -} - -func TestRegistry_NewRepository_IsSuccessful(t *testing.T) { - repoType := "repoType" - called := false - expectedRepo := dummyRepo{} - constructor := func(context.Context, RepoConfig) (SQLRepository, error) { - called = true - return expectedRepo, nil - } - reg := NewRegistry() - err := reg.Register(repoType, constructor) - require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) - - cfg := RepoConfig{Type: repoType} - repo, err := reg.NewRepository(context.Background(), cfg) - assert.NoError(t, err) - assert.Equal(t, expectedRepo, repo) - assert.True(t, called, "Constructor was not called") -} - -func TestRegistry_NewRepository_ConstructorError(t *testing.T) { - repoType := "repoType" - called := false - expectedErr := errors.New("dummy error") - constructor := func(context.Context, RepoConfig) (SQLRepository, error) { - called = true - return nil, expectedErr - } - reg := NewRegistry() - err := reg.Register(repoType, constructor) - require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) - - cfg := RepoConfig{Type: repoType} - repo, err := reg.NewRepository(context.Background(), cfg) - assert.ErrorIs(t, err, expectedErr) - assert.Nil(t, repo) - assert.True(t, called, "Constructor was not called") -} - -func TestRegistry_NewRepository_UnsupportedRepoType(t *testing.T) { - repoType := "repoType" - cfg := RepoConfig{Type: repoType} - reg := NewRegistry() - repo, err := reg.NewRepository(context.Background(), cfg) - assert.Error(t, err) - assert.Nil(t, repo) -} - -type dummyRepo struct{} - -func (d dummyRepo) SampleTable(context.Context, *TableMetadata, SampleParameters) ( - Sample, - error, -) { - panic("not implemented") -} - -func (d dummyRepo) ListDatabases(context.Context) ([]string, error) { - panic("not implemented") -} - -func (d dummyRepo) Introspect(context.Context) (*Metadata, error) { - panic("not implemented") -} - -func (d dummyRepo) Ping(context.Context) error { - panic("not implemented") -} - -func (d dummyRepo) Close() error { - panic("not implemented") -} diff --git a/discovery/sample_all_databases_test.go b/discovery/sample_all_databases_test.go deleted file mode 100644 index 0539d05..0000000 --- a/discovery/sample_all_databases_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const mockRepoType = "mockRepo" - -func setup(t *testing.T) *MockRepository { - repo := NewMockRepository(t) - MustRegister( - mockRepoType, - func(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - return repo, nil - }, - ) - return repo -} - -func cleanup() { - delete(DefaultRegistry.constructors, mockRepoType) -} - -func TestSampleAllDatabases_Error(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - listDbErr := errors.New("error listing databases") - repo.On("ListDatabases", ctx).Return(nil, listDbErr) - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.Nil(t, samples) - require.ErrorIs(t, err, listDbErr) -} - -func TestSampleAllDatabases_Successful_TwoDatabases(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sample := Sample{ - Metadata: SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - }, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(sample, nil) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - // Two databases should be sampled, and our mock will return the sample for - // each sample call. This really just asserts that we've sampled the correct - // number of times. - require.ElementsMatch(t, samples, []Sample{sample, sample}) -} - -func TestSampleAllDatabases_IntrospectError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - introspectErr := errors.New("introspect error") - repo.On("Introspect", mock.Anything).Return(nil, introspectErr) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.Empty(t, samples) - require.NoError(t, err) -} - -func TestSampleAllDatabases_SampleError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sampleErr := errors.New("sample error") - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(Sample{}, sampleErr) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - require.Empty(t, samples) -} - -func TestSampleAllDatabases_TwoDatabases_OneSampleError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sample := Sample{ - Metadata: SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - }, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(sample, nil).Once() - sampleErr := errors.New("sample error") - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(Sample{}, sampleErr).Once() - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - // Because of a single sample error, we expect only one database was - // sampled. - require.ElementsMatch(t, samples, []Sample{sample}) -} diff --git a/discovery/sample_repository_test.go b/discovery/sample_repository_test.go deleted file mode 100644 index e1323e3..0000000 --- a/discovery/sample_repository_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSampleRepository(t *testing.T) { - ctx := context.Background() - repo := fakeRepo{} - params := SampleParameters{SampleSize: 2} - samples, err := SampleRepository(ctx, repo, params) - require.NoError(t, err) - - // Order is not important and is actually non-deterministic due to concurrency - assert.ElementsMatch(t, samples, []Sample{table1Sample, table2Sample}) -} - -func TestSampleRepository_PartialError(t *testing.T) { - ctx := context.Background() - repo := fakeRepo{includeForbiddenTables: true} - params := SampleParameters{SampleSize: 2} - samples, err := SampleRepository(ctx, repo, params) - require.ErrorContains(t, err, "forbidden table") - - // Order is not important and is actually non-deterministic due to concurrency - assert.ElementsMatch(t, samples, []Sample{table1Sample, table2Sample}) -} - -var repoMeta = Metadata{ - Name: "name", - RepoType: "repoType", - Database: "database", - Schemas: map[string]*SchemaMetadata{ - "schema1": { - Name: "", - Tables: map[string]*TableMetadata{ - "table1": { - Schema: "schema1", - Name: "table1", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "table1", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "table1", - Name: "name2", - DataType: "decimal", - }, - }, - }, - }, - }, - "schema2": { - Name: "", - Tables: map[string]*TableMetadata{ - "table2": { - Schema: "schema2", - Name: "table2", - Attributes: []*AttributeMetadata{ - { - Schema: "schema2", - Table: "table2", - Name: "name3", - DataType: "int", - }, - { - Schema: "schema2", - Table: "table2", - Name: "name4", - DataType: "timestamp", - }, - }, - }, - }, - }, - }, -} - -var schema1ForbiddenTable = TableMetadata{ - Schema: "schema1", - Name: "forbidden", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "forbidden", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "forbidden", - Name: "name2", - DataType: "decimal", - }, - }, -} - -var table1Sample = Sample{ - Metadata: SampleMetadata{ - Repo: "name", - Database: "database", - Schema: "schema1", - Table: "table1", - }, - Results: []SampleResult{ - { - "name1": "foo", - "name2": "bar", - }, - { - "name1": "baz", - "name2": "qux", - }, - }, -} - -var table2Sample = Sample{ - Metadata: SampleMetadata{ - Repo: "name", - Database: "database", - Schema: "schema2", - Table: "table2", - }, - Results: []SampleResult{ - { - "name3": "foo1", - "name4": "bar1", - }, - { - "name3": "baz1", - "name4": "qux1", - }, - }, -} - -type fakeRepo struct { - includeForbiddenTables bool -} - -func (f fakeRepo) Introspect(context.Context) (*Metadata, error) { - if f.includeForbiddenTables { - repoMeta.Schemas["schema1"].Tables["forbidden"] = &schema1ForbiddenTable - } - return &repoMeta, nil -} - -func (f fakeRepo) SampleTable(_ context.Context, meta *TableMetadata, _ SampleParameters) ( - Sample, - error, -) { - if meta.Name == "table1" { - return table1Sample, nil - } else if meta.Name == "table2" { - return table2Sample, nil - } else if meta.Name == "forbidden" { - return Sample{}, errors.New("forbidden table") - } else { - return Sample{}, errors.New("unrecognized table") - } -} - -func (f fakeRepo) ListDatabases(context.Context) ([]string, error) { - panic("not implemented") -} - -func (f fakeRepo) Ping(context.Context) error { - panic("not implemented") -} - -func (f fakeRepo) Close() error { - panic("not implemented") -} diff --git a/discovery/sampling.go b/discovery/sampling.go deleted file mode 100644 index 24b3a14..0000000 --- a/discovery/sampling.go +++ /dev/null @@ -1,216 +0,0 @@ -package discovery - -import ( - "context" - "fmt" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - "golang.org/x/sync/semaphore" -) - -// SampleParameters contains all parameters necessary to sample a table. -type SampleParameters struct { - SampleSize uint - Offset uint -} - -// Sample represents a sample of a database table. The Metadata field contains -// metadata about the sample itself. The actual results of the sample, which -// are represented by a set of database rows, are contained in the Results -// field. -type Sample struct { - Metadata SampleMetadata - Results []SampleResult -} - -// SampleMetadata contains the metadata associated with a given sample, such as -// repo name, database name, table, schema, and query (if applicable). This can -// be used for diagnostic and informational purposes when analyzing a -// particular sample. -type SampleMetadata struct { - Repo string - Database string - Schema string - Table string -} - -// SampleResult stores the results from a single database sample. It is -// equivalent to a database row, where the map key is the column name and the -// map value is the column value. -type SampleResult map[string]any - -// sampleAndErr is a "pair" type intended to be passed to a channel (see -// SampleRepository) -type sampleAndErr struct { - sample Sample - err error -} - -// samplesAndErr is a "pair" type intended to be passed to a channel (see -// SampleAllDatabases) -type samplesAndErr struct { - samples []Sample - err error -} - -// GetAttributeNamesAndValues splits a SampleResult map into two slices and -// returns them. The first slice contains all the keys of SampleResult, -// representing the table's attribute names, and the second slice is the map's -// corresponding values. -func (result SampleResult) GetAttributeNamesAndValues() ([]string, []string) { - names := make([]string, 0, len(result)) - vals := make([]string, 0, len(result)) - for name, val := range result { - names = append(names, name) - var v string - if b, ok := val.([]byte); ok { - v = string(b) - } else { - v = fmt.Sprint(val) - } - vals = append(vals, v) - } - return names, vals -} - -// SampleRepository is a helper function which will sample every table in a -// given repository and return them as a collection of Sample. First the -// repository is introspected by calling sql.Introspect to return the -// repository metadata (Metadata). Then, for each schema and table in the -// metadata, it calls sql.SampleTable in a new goroutine. Once all the -// sampling goroutines are finished, their results are collected and returned -// as a slice of Sample. -func SampleRepository(ctx context.Context, repo SQLRepository, params SampleParameters) ( - []Sample, - error, -) { - meta, err := repo.Introspect(ctx) - if err != nil { - return nil, fmt.Errorf("error introspecting repository: %w", err) - } - - // Fan out sample executions - out := make(chan sampleAndErr) - numTables := 0 - for _, schemaMeta := range meta.Schemas { - for _, tableMeta := range schemaMeta.Tables { - numTables++ - go func(meta *TableMetadata, params SampleParameters) { - sample, err := repo.SampleTable(ctx, meta, params) - out <- sampleAndErr{sample: sample, err: err} - }(tableMeta, params) - } - } - - var samples []Sample - var errs error - for i := 0; i < numTables; i++ { - res := <-out - if res.err != nil { - errs = multierror.Append(errs, res.err) - } else { - samples = append(samples, res.sample) - } - } - close(out) - - if errs != nil { - return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) - } - - return samples, nil -} - -// SampleAllDatabases uses the given repository to list all the databases on the -// server, and samples each one in parallel by calling SampleRepository for each -// database. The repository is intended to be configured to connect to the -// default database on the server, or at least some database which can be used -// to enumerate the full set of databases on the server. An error will be -// returned if the set of databases cannot be listed. If there is an error -// connecting to or sampling a database, the error will be logged and no samples -// will be returned for that database. Therefore, the returned slice of samples -// contains samples for only the databases which could be discovered and -// successfully sampled, and could potentially be empty if no databases were -// sampled. -func SampleAllDatabases( - ctx context.Context, - repo SQLRepository, - repoCfg RepoConfig, - sampleParams SampleParameters, -) ( - []Sample, - error, -) { - // We assume that this repository will be connected to the default database - // (or at least some database that can discover all the other databases), - // and we use that to discover all other databases. - dbs, err := repo.ListDatabases(ctx) - if err != nil { - return nil, fmt.Errorf("error listing databases: %w", err) - } - - // Sample each database on a separate goroutine, and send the samples - // to the 'out' channel. Each slice of samples will be aggregated below - // on the main goroutine and returned. - var wg sync.WaitGroup - out := make(chan samplesAndErr) - wg.Add(len(dbs)) - // Ensures that we avoid opening more than the specified number of - // connections. - var sema *semaphore.Weighted - if repoCfg.MaxOpenConns > 0 { - sema = semaphore.NewWeighted(int64(repoCfg.MaxOpenConns)) - } - for _, db := range dbs { - go func(db string, cfg RepoConfig) { - defer wg.Done() - if sema != nil { - _ = sema.Acquire(ctx, 1) - defer sema.Release(1) - } - cfg.Database = db - // Create a repository instance for this database. It will be used - // to connect and sample the database. - // TODO: this is ugly - there's gotta be a better way to do this -ccampo 2024-04-02 - repo, err := NewRepository(ctx, cfg) - if err != nil { - log.WithError(err).Errorf("error creating repository instance for database %s", db) - return - } - // Close this repository and free up unused resources since we don't - // need it any longer. - defer func() { _ = repo.Close() }() - s, err := SampleRepository(ctx, repo, sampleParams) - if err != nil && len(s) == 0 { - log.WithError(err).Errorf("error gathering repository data samples for database %s", db) - return - } - // Send the samples for this database to the 'out' channel. The - // samples for each database will be aggregated into a single slice - // on the main goroutine and returned. - out <- samplesAndErr{samples: s, err: err} - }(db, repoCfg) - } - - // Start a goroutine to close the 'out' channel once all the goroutines - // we launched above are done. This will allow the aggregation range loop - // below to terminate properly. Note that this must start after the wg.Add - // call. See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). - go func() { - wg.Wait() - close(out) - }() - - // Aggregate and return the results. - var ret []Sample - var errs error - for res := range out { - ret = append(ret, res.samples...) - if res.err != nil { - errs = multierror.Append(errs, res.err) - } - } - return ret, errs -} diff --git a/discovery/sampling_test.go b/discovery/sampling_test.go deleted file mode 100644 index f964a4f..0000000 --- a/discovery/sampling_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package discovery - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSampleResult_GetAttributeNamesAndValues(t *testing.T) { - tests := []struct { - name string - result SampleResult - wantNames, wantVals []string - }{ - { - name: "string", - result: SampleResult{ - "foo": "fooVal", - }, - wantNames: []string{"foo"}, - wantVals: []string{"fooVal"}, - }, - { - name: "int", - result: SampleResult{ - "foo": 123, - }, - wantNames: []string{"foo"}, - wantVals: []string{"123"}, - }, - { - name: "float", - result: SampleResult{ - "foo": 12.3, - }, - wantNames: []string{"foo"}, - wantVals: []string{"12.3"}, - }, - { - name: "bytes", - result: SampleResult{ - "foo": []byte("fooVal"), - }, - wantNames: []string{"foo"}, - wantVals: []string{"fooVal"}, - }, - { - name: "bool", - result: SampleResult{ - "foo": false, - }, - wantNames: []string{"foo"}, - wantVals: []string{"false"}, - }, - { - name: "varied types", - result: SampleResult{ - "foo": "fooVal", - "bar": 123, - "baz": 12.3, - "qux": []byte("quxVal"), - "quxx": true, - }, - wantNames: []string{"foo", "bar", "baz", "qux", "quxx"}, - wantVals: []string{"fooVal", "123", "12.3", "quxVal", "true"}, - }, - } - for _, tt := range tests { - t.Run( - tt.name, func(t *testing.T) { - gotNames, gotVals := tt.result.GetAttributeNamesAndValues() - require.ElementsMatch(t, tt.wantNames, gotNames) - require.ElementsMatch(t, tt.wantVals, gotVals) - }, - ) - } -} diff --git a/discovery/sqlrepository.go b/discovery/sqlrepository.go deleted file mode 100644 index 8f0eeb8..0000000 --- a/discovery/sqlrepository.go +++ /dev/null @@ -1,32 +0,0 @@ -package discovery - -import ( - "context" -) - -// SQLRepository represents a Dmap data SQL repository, and provides functionality -// to introspect its corresponding schema. -type SQLRepository interface { - // ListDatabases returns a list of the names of all databases on the server. - ListDatabases(ctx context.Context) ([]string, error) - // Introspect will read and analyze the basic properties of the repository - // and return as a Metadata instance. This includes all the repository's - // databases, schemas, tables, columns, and attributes. - Introspect(ctx context.Context) (*Metadata, error) - // SampleTable samples the table referenced by the TableMetadata meta - // parameter and returns the sample as a slice of Sample. The parameters for - // the sample, such as sample size, are passed via the params parameter (see - // SampleParameters for more details). The returned sample result set - // contains one Sample for each table row sampled. The length of the results - // will be less than or equal to the sample size. If there are fewer results - // than the specified sample size, it is because the table in question had a - // row count less than the sample size. Prefer small sample sizes to limit - // impact on the database. - SampleTable(ctx context.Context, meta *TableMetadata, params SampleParameters) (Sample, error) - // Ping is meant to be used as a general purpose connectivity test. It - // should be invoked e.g. in the dry-run mode. - Ping(ctx context.Context) error - // Close is meant to be used as a general purpose cleanup. It should be - // invoked when the SQLRepository is no longer used. - Close() error -} diff --git a/scan/doc.go b/scan/doc.go index f0557aa..38346fc 100644 --- a/scan/doc.go +++ b/scan/doc.go @@ -1,4 +1,3 @@ -// Package scan provides a EnvironmentScanner interface which can be used for scanning -// cloud environments and performing data repository discovery. -// TODO: fix doc -ccampo 2024-04-02 +// Package scan provides and API to scan cloud environments for data +// repositories and an API to scan those repositories for sensitive data. package scan diff --git a/scan/env_scanner.go b/scan/env_scanner.go deleted file mode 100644 index fdb90eb..0000000 --- a/scan/env_scanner.go +++ /dev/null @@ -1,66 +0,0 @@ -package scan - -import ( - "context" - "errors" - "time" -) - -// EnvironmentScanner provides an API to scan cloud environments. It should be -// implemented for a specific cloud provider (e.g. AWS, GCP, etc.). It defines -// the Scan method responsible for scanning the existing data repositories of -// the corresponding cloud provider environment. -type EnvironmentScanner interface { - Scan(ctx context.Context) (*EnvironmentScanResults, error) -} - -// RepoType defines the AWS data repository types supported (e.g. RDS, Redshift, -// DynamoDB, etc). -type RepoType string - -const ( - RepoTypeRDS RepoType = "TYPE_RDS" - RepoTypeRedshift RepoType = "TYPE_REDSHIFT" - RepoTypeDynamoDB RepoType = "TYPE_DYNAMODB" - RepoTypeS3 RepoType = "TYPE_S3" - RepoTypeDocumentDB RepoType = "TYPE_DOCUMENTDB" -) - -// Repository represents a scanned data repository. -type Repository struct { - Id string - Name string - Type RepoType - CreatedAt time.Time - Tags []string - Properties any -} - -// EnvironmentScanResults represents the results of a repository scan, including -// all the data repositories that were scanned. The map key is the repository ID -// and the value is the repository itself. -type EnvironmentScanResults struct { - Repositories map[string]Repository -} - -// EnvironmentScanError is an error type that represents a collection of errors -// that occurred during the scanning process. -type EnvironmentScanError struct { - Errs []error -} - -// Error returns a string representation of the error. -func (e *EnvironmentScanError) Error() string { - if e == nil { - return "" - } - return errors.Join(e.Errs...).Error() -} - -// Unwrap returns the list of errors that occurred during the scanning process. -func (e *EnvironmentScanError) Unwrap() []error { - if e == nil { - return nil - } - return e.Errs -} diff --git a/scan/repo_scanner.go b/scan/repo_scanner.go deleted file mode 100644 index 87bcf83..0000000 --- a/scan/repo_scanner.go +++ /dev/null @@ -1,198 +0,0 @@ -package scan - -import ( - "context" - "fmt" - - log "github.com/sirupsen/logrus" - - "github.com/cyralinc/dmap/classification" - "github.com/cyralinc/dmap/discovery" -) - -// RepoScannerConfig is the configuration for the RepoScanner. -type RepoScannerConfig struct { - Dmap DmapConfig `embed:""` - Repo discovery.RepoConfig `embed:""` -} - -// DmapConfig is the necessary configuration to connect to the Dmap API. -type DmapConfig struct { - ApiBaseUrl string `help:"Base URL of the Dmap API." default:"https://api.dmap.cyral.io"` - ClientID string `help:"API client ID to access the Dmap API."` - ClientSecret string `help:"API client secret to access the Dmap API."` //#nosec G101 -- false positive -} - -// RepoScanner is a data discovery scanner that scans a data repository for -// sensitive data. It also classifies the data and publishes the results to -// the configured external sources. It currently only supports SQL-based -// repositories. -type RepoScanner struct { - config RepoScannerConfig - repository discovery.SQLRepository - classifier classification.Classifier - publisher classification.Publisher -} - -// RepoScannerOption is a functional option type for the RepoScanner type. -type RepoScannerOption func(*RepoScanner) - -// WithSQLRepository is a functional option that sets the SQLRepository for the -// RepoScanner. -func WithSQLRepository(r discovery.SQLRepository) RepoScannerOption { - return func(s *RepoScanner) { s.repository = r } -} - -// WithClassifier is a functional option that sets the classifier for the -// RepoScanner. -func WithClassifier(c classification.Classifier) RepoScannerOption { - return func(s *RepoScanner) { s.classifier = c } -} - -// WithPublisher is a functional option that sets the publisher for the RepoScanner. -func WithPublisher(p classification.Publisher) RepoScannerOption { - return func(s *RepoScanner) { s.publisher = p } -} - -// NewRepoScanner creates a new RepoScanner instance with the provided configuration. -func NewRepoScanner(ctx context.Context, cfg RepoScannerConfig, opts ...RepoScannerOption) (*RepoScanner, error) { - s := &RepoScanner{config: cfg} - // Apply options. - for _, opt := range opts { - opt(s) - } - if s.publisher == nil { - // Default to stdout publisher. - s.publisher = classification.NewStdOutPublisher() - } - if s.classifier == nil { - // Create a new label classifier with the embedded labels. - lbls, err := classification.GetEmbeddedLabels() - if err != nil { - return nil, fmt.Errorf("error getting embedded labels: %w", err) - } - c, err := classification.NewLabelClassifier(lbls.ToSlice()...) - if err != nil { - return nil, fmt.Errorf("error creating new label classifier: %w", err) - } - s.classifier = c - } - if s.repository == nil { - // Get a repository instance from the default registry. - repo, err := discovery.NewRepository(ctx, s.config.Repo) - if err != nil { - return nil, fmt.Errorf("error connecting to database: %w", err) - } - s.repository = repo - } - return s, nil -} - -// Scan performs the data repository scan. It introspects and samples the -// repository, classifies the sampled data, and publishes the results to the -// configured classification publisher. -func (s *RepoScanner) Scan(ctx context.Context) error { - sampleParams := discovery.SampleParameters{SampleSize: s.config.Repo.SampleSize} - var samples []discovery.Sample - // The name of the database to connect to has been left unspecified by - // the user, so we try to connect and sample all databases instead. Note - // that Oracle doesn't really have the concept of "databases", and thus - // the RepoScanner always scans the entire database, so only the single - // (default) repository instance is required in that case. - if s.config.Repo.Database == "" && s.config.Repo.Type != discovery.RepoTypeOracle { - var err error - samples, err = discovery.SampleAllDatabases( - ctx, - s.repository, - s.config.Repo, - sampleParams, - ) - if err != nil { - err = fmt.Errorf("error sampling databases: %w", err) - // If we didn't get any samples, just return the error. - if len(samples) == 0 { - return err - } - // There were error(s) during sampling, but we still got some - // samples. Just warn and continue. - log.WithError(err).Warn("error sampling databases") - } - } else { - // User specified a database (or this is an Oracle DB), therefore - // we already have a repository instance for it. Just use it to - // sample that database only. - var err error - samples, err = discovery.SampleRepository(ctx, s.repository, sampleParams) - if err != nil { - err = fmt.Errorf("error gathering repository data samples: %w", err) - // If we didn't get any samples, just return the error. - if len(samples) == 0 { - return err - } - // There were error(s) during sampling, but we still got some - // samples. Just warn and continue. - log.WithError(err).Warn("error gathering repository data samples") - } - } - - // Classify sampled data - classifications, err := classifySamples(ctx, samples, s.classifier) - if err != nil { - return fmt.Errorf("error classifying samples: %w", err) - } - - // Publish classifications if necessary - if len(classifications) == 0 { - log.Info("No discovered classifications") - } else if err := s.publisher.PublishClassifications(ctx, s.config.Repo.Host, classifications); err != nil { - return fmt.Errorf("error publishing classifications: %w", err) - } - - // Done! - return nil -} - -// Cleanup performs cleanup operations for the RepoScanner. -func (s *RepoScanner) Cleanup() { - // Nil checks are prevent panics if deps are not yet initialized. - if s.repository != nil { - _ = s.repository.Close() - } -} - -// classifySamples uses the provided classifiers to classify the sample data -// passed via the "samples" parameter. It is mostly a helper function which -// loops through each repository.Sample, retrieves the attribute names and -// values of that sample, passes them to Classifier.Classify, and then -// aggregates the results. Please see the documentation for Classifier and its -// Classify method for more details. The returned slice represents all the -// unique classification results for a given sample set. -func classifySamples( - ctx context.Context, - samples []discovery.Sample, - classifier classification.Classifier, -) ([]classification.ClassifiedTable, error) { - tables := make([]classification.ClassifiedTable, 0, len(samples)) - for _, sample := range samples { - // Classify each sampled row and combine the results. - result := make(classification.Result) - for _, sampleResult := range sample.Results { - res, err := classifier.Classify(ctx, sampleResult) - if err != nil { - return nil, fmt.Errorf("error classifying sample: %w", err) - } - result.Merge(res) - } - if len(result) > 0 { - table := classification.ClassifiedTable{ - Repo: sample.Metadata.Repo, - Database: sample.Metadata.Database, - Schema: sample.Metadata.Schema, - Table: sample.Metadata.Table, - Classifications: result, - } - tables = append(tables, table) - } - } - return tables, nil -} diff --git a/scan/repo_scanner_test.go b/scan/repo_scanner_test.go deleted file mode 100644 index 95c4e50..0000000 --- a/scan/repo_scanner_test.go +++ /dev/null @@ -1,247 +0,0 @@ -package scan - -import ( - "context" - "testing" - - "github.com/cyralinc/dmap/classification" - "github.com/cyralinc/dmap/discovery" - "github.com/cyralinc/dmap/scan/mocks" - - "github.com/stretchr/testify/require" -) - -func Test_classifySamples_SingleTable(t *testing.T) { - ctx := context.Background() - meta := discovery.SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - } - - sample := discovery.Sample{ - Metadata: meta, - Results: []discovery.SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - } - - classifier := mocks.NewClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - - expected := []classification.ClassifiedTable{ - { - Repo: meta.Repo, - Database: meta.Database, - Schema: meta.Schema, - Table: meta.Table, - Classifications: classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - }, - } - actual, err := classifySamples(ctx, []discovery.Sample{sample}, classifier) - require.NoError(t, err) - require.Len(t, actual, len(expected)) - for i := range actual { - requireClassifiedTableEqual(t, expected[i], actual[i]) - } -} - -func Test_classifySamples_MultipleTables(t *testing.T) { - ctx := context.Background() - meta1 := discovery.SampleMetadata{ - Repo: "repo1", - Database: "db1", - Schema: "schema1", - Table: "table1", - } - meta2 := discovery.SampleMetadata{ - Repo: "repo2", - Database: "db2", - Schema: "schema2", - Table: "table2", - } - - samples := []discovery.Sample{ - { - Metadata: meta1, - Results: []discovery.SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - }, - { - Metadata: meta2, - Results: []discovery.SampleResult{ - { - "fullname": "John Doe", - "dob": "2000-01-01", - "random": "foobarbaz", - }, - }, - }, - } - - classifier := mocks.NewClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( - classification.Result{ - "fullname": { - "FULL_NAME": {Name: "FULL_NAME"}, - }, - "dob": { - "DOB": {Name: "DOB"}, - }, - }, - nil, - ) - - expected := []classification.ClassifiedTable{ - { - Repo: meta1.Repo, - Database: meta1.Database, - Schema: meta1.Schema, - Table: meta1.Table, - Classifications: classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - }, - { - Repo: meta2.Repo, - Database: meta2.Database, - Schema: meta2.Schema, - Table: meta2.Table, - Classifications: classification.Result{ - "fullname": { - "FULL_NAME": {Name: "FULL_NAME"}, - }, - "dob": { - "DOB": {Name: "DOB"}, - }, - }, - }, - } - actual, err := classifySamples(ctx, samples, classifier) - require.NoError(t, err) - require.Len(t, actual, len(expected)) - for i := range actual { - requireClassifiedTableEqual(t, expected[i], actual[i]) - } -} - -func requireClassifiedTableEqual(t *testing.T, expected, actual classification.ClassifiedTable) { - require.Equal(t, expected.Repo, actual.Repo) - require.Equal(t, expected.Database, actual.Database) - require.Equal(t, expected.Schema, actual.Schema) - require.Equal(t, expected.Table, actual.Table) - requireResultEqual(t, expected.Classifications, actual.Classifications) -} - -func requireResultEqual(t *testing.T, want, got classification.Result) { - require.Len(t, got, len(want)) - for k, v := range want { - gotSet, ok := got[k] - require.Truef(t, ok, "missing attribute %s", k) - requireLabelSetEqual(t, v, gotSet) - } -} - -func requireLabelSetEqual(t *testing.T, want, got classification.LabelSet) { - require.Len(t, got, len(want)) - for k, v := range want { - gotLbl, ok := got[k] - require.Truef(t, ok, "missing label %s", k) - require.Equal(t, v.Name, gotLbl.Name) - } -} diff --git a/scan/scanner.go b/scan/scanner.go new file mode 100644 index 0000000..8a39f4f --- /dev/null +++ b/scan/scanner.go @@ -0,0 +1,89 @@ +package scan + +import ( + "context" + "errors" + "time" + + "github.com/cyralinc/dmap/classification" +) + +// Scanner provides an API to scan cloud environments. It should be +// implemented for a specific cloud provider (e.g. AWS, GCP, etc.). It defines +// the Scan method responsible for scanning the existing data repositories of +// the corresponding cloud provider environment. +type Scanner interface { + Scan(ctx context.Context) (*ScanResults, error) +} + +// RepoScanner is a scanner that scans a data repository for sensitive data. +type RepoScanner interface { + Scan(ctx context.Context) (*RepoScanResults, error) +} + +// RepoScanResults is the result of a repository scan. +type RepoScanResults struct { + Labels []classification.Label `json:"labels"` + Classifications []Classification `json:"classifications"` +} + +// TODO: godoc -ccampo 2024-04-03 +type Classification struct { + // AttributePath is the full path of the data repository attribute + // (e.g. the column). Each element corresponds to a component, in increasing + // order of granularity (e.g. [database, schema, table, column]). + AttributePath []string `json:"attributePath"` + // Labels is the set of labels that the attribute was classified as. + Labels classification.LabelSet `json:"labels"` +} + +// RepoType defines the AWS data repository types supported (e.g. RDS, Redshift, +// DynamoDB, etc). +type RepoType string + +const ( + RepoTypeRDS RepoType = "TYPE_RDS" + RepoTypeRedshift RepoType = "TYPE_REDSHIFT" + RepoTypeDynamoDB RepoType = "TYPE_DYNAMODB" + RepoTypeS3 RepoType = "TYPE_S3" + RepoTypeDocumentDB RepoType = "TYPE_DOCUMENTDB" +) + +// Repository represents a scanned data repository. +type Repository struct { + Id string + Name string + Type RepoType + CreatedAt time.Time + Tags []string + Properties any +} + +// ScanResults represents the results of a repository scan, including all the +// data repositories that were scanned. The map key is the repository ID and the +// value is the repository itself. +type ScanResults struct { + Repositories map[string]Repository +} + +// ScanError is an error type that represents a collection of errors that +// occurred during the scanning process. +type ScanError struct { + Errs []error +} + +// Error returns a string representation of the error. +func (e *ScanError) Error() string { + if e == nil { + return "" + } + return errors.Join(e.Errs...).Error() +} + +// Unwrap returns the list of errors that occurred during the scanning process. +func (e *ScanError) Unwrap() []error { + if e == nil { + return nil + } + return e.Errs +} diff --git a/sql/classify.go b/sql/classify.go new file mode 100644 index 0000000..30dd914 --- /dev/null +++ b/sql/classify.go @@ -0,0 +1,61 @@ +package sql + +import ( + "context" + "fmt" + "maps" + "strings" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/scan" +) + +// classifySamples uses the provided classifiers to classify the sample data +// passed via the "samples" parameter. It is mostly a helper function which +// loops through each repository.Sample, retrieves the attribute names and +// values of that sample, passes them to Classifier.Classify, and then +// aggregates the results. Please see the documentation for Classifier and its +// Classify method for more details. The returned slice represents all the +// unique classification results for a given sample set. +func classifySamples( + ctx context.Context, + samples []Sample, + classifier classification.Classifier, +) ([]scan.Classification, error) { + uniqueResults := make(map[string]scan.Classification) + for _, sample := range samples { + // Classify each sampled row and combine the results. + for _, sampleResult := range sample.Results { + res, err := classifier.Classify(ctx, sampleResult) + if err != nil { + return nil, fmt.Errorf("error classifying sample: %w", err) + } + for attr, labels := range res { + attrPath := append(sample.TablePath, attr) + key := pathKey(attrPath) + result, ok := uniqueResults[key] + if !ok { + uniqueResults[key] = scan.Classification{ + AttributePath: attrPath, + Labels: labels, + } + } else { + // Merge the labels from the new result into the existing result. + maps.Copy(result.Labels, labels) + } + } + } + } + // Convert the map of unique results to a slice. + results := make([]scan.Classification, 0, len(uniqueResults)) + for _, result := range uniqueResults { + results = append(results, result) + } + return results, nil +} + +func pathKey(path []string) string { + // U+2063 is an invisible separator. It is used here to ensure that the + // pathKey is unique and does not conflict with any of the path elements. + return strings.Join(path, "\u2063") +} diff --git a/sql/classify_test.go b/sql/classify_test.go new file mode 100644 index 0000000..ee63af6 --- /dev/null +++ b/sql/classify_test.go @@ -0,0 +1,159 @@ +package sql + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/scan" + "github.com/cyralinc/dmap/sql/mocks" +) + +func Test_classifySamples_SingleSample(t *testing.T) { + ctx := context.Background() + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + } + classifier := mocks.NewClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + + expected := []scan.Classification{ + { + AttributePath: append(sample.TablePath, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(sample.TablePath, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(sample.TablePath, "credit_card_num"), + Labels: lblSet("CCN"), + }, + } + actual, err := classifySamples(ctx, []Sample{sample}, classifier) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func Test_classifySamples_MultipleSamples(t *testing.T) { + ctx := context.Background() + samples := []Sample{ + { + TablePath: []string{"db1", "schema1", "table1"}, + Results: []SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + }, + { + TablePath: []string{"db2", "schema2", "table2"}, + Results: []SampleResult{ + { + "fullname": "John Doe", + "dob": "2000-01-01", + "random": "foobarbaz", + }, + }, + }, + } + + classifier := mocks.NewClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( + classification.Result{ + "fullname": lblSet("FULL_NAME"), + "dob": lblSet("DOB"), + }, + nil, + ) + + expected := []scan.Classification{ + { + AttributePath: append(samples[0].TablePath, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(samples[0].TablePath, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(samples[0].TablePath, "credit_card_num"), + Labels: lblSet("CCN"), + }, + { + AttributePath: append(samples[1].TablePath, "fullname"), + Labels: lblSet("FULL_NAME"), + }, + { + AttributePath: append(samples[1].TablePath, "dob"), + Labels: lblSet("DOB"), + }, + } + actual, err := classifySamples(ctx, samples, classifier) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func lblSet(labels ...string) classification.LabelSet { + set := make(classification.LabelSet) + for _, label := range labels { + set[label] = struct { + }{} + } + return set +} diff --git a/sql/config.go b/sql/config.go new file mode 100644 index 0000000..a8785b9 --- /dev/null +++ b/sql/config.go @@ -0,0 +1,84 @@ +package sql + +import ( + "fmt" +) + +const configConnOpts = "connection-string-args" + +// RepoConfig is the necessary configuration to connect to a data sql. +type RepoConfig struct { + // Host is the hostname of the database. + Host string + // Port is the port of the database. + Port uint16 + // User is the username to connect to the database. + User string + // Password is the password to connect to the database. + Password string + // Database is the name of the database to connect to. + Database string + // MaxOpenConns is the maximum number of open connections to the database. + MaxOpenConns uint + // Advanced is a map of advanced configuration options. + Advanced map[string]any +} + +// FetchAdvancedConfigString fetches a map in the repo advanced configuration, +// for a given repo and set of parameters. Example: +// +// repo-advanced: +// +// snowflake: +// account: exampleAccount +// role: exampleRole +// warehouse: exampleWarehouse +// +// Calling FetchAdvancedMapConfig(, "snowflake", +// []string{"account", "role", "warehouse"}) returns the map +// +// {"account": "exampleAccount", "role": "exampleRole", "warehouse": +// "exampleWarehouse"} +// +// The suffix 'String' means that the values of the map are strings. This gives +// room to have FetchAdvancedConfigList or FetchAdvancedConfigMap, for example, +// without name conflicts. +func FetchAdvancedConfigString( + cfg RepoConfig, + repo string, + parameters []string, +) (map[string]string, error) { + advancedCfg, err := getAdvancedConfig(cfg, repo) + if err != nil { + return nil, err + } + repoSpecificMap := make(map[string]string) + for _, key := range parameters { + var valInterface any + var val string + var ok bool + if valInterface, ok = advancedCfg[key]; !ok { + return nil, fmt.Errorf("unable to find '%s' in %s advanced config", key, repo) + } + if val, ok = valInterface.(string); !ok { + return nil, fmt.Errorf("'%s' in %s config must be a string", key, repo) + } + repoSpecificMap[key] = val + } + return repoSpecificMap, nil +} + +// getAdvancedConfig gets the Advanced field in a repo config and converts it to +// a map[string]any. In every step, it checks for error and generates +// nice messages. +func getAdvancedConfig(cfg RepoConfig, repo string) (map[string]any, error) { + advancedCfgInterface, ok := cfg.Advanced[repo] + if !ok { + return nil, fmt.Errorf("unable to find '%s' in advanced config", repo) + } + advancedCfg, ok := advancedCfgInterface.(map[string]any) + if !ok { + return nil, fmt.Errorf("'%s' in advanced config is not a map", repo) + } + return advancedCfg, nil +} diff --git a/sql/config_test.go b/sql/config_test.go new file mode 100644 index 0000000..e92a5c8 --- /dev/null +++ b/sql/config_test.go @@ -0,0 +1,77 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAdvancedConfigSucc(t *testing.T) { + sampleCfg := RepoConfig{ + Advanced: map[string]any{ + "snowflake": map[string]any{ + "account": "exampleAccount", + "role": "exampleRole", + "warehouse": "exampleWarehouse", + }, + }, + } + repoSpecificMap, err := FetchAdvancedConfigString( + sampleCfg, + "snowflake", []string{"account", "role", "warehouse"}, + ) + require.NoError(t, err) + require.EqualValues( + t, repoSpecificMap, map[string]string{ + "account": "exampleAccount", + "role": "exampleRole", + "warehouse": "exampleWarehouse", + }, + ) +} + +func TestAdvancedConfigMissing(t *testing.T) { + // Without the snowflake config at all + sampleCfg := RepoConfig{ + Advanced: map[string]any{}, + } + _, err := FetchAdvancedConfigString( + sampleCfg, + "snowflake", []string{"account", "role", "warehouse"}, + ) + require.Error(t, err) + + sampleCfg = RepoConfig{ + Advanced: map[string]any{ + "snowflake": map[string]any{ + // Missing account + + "role": "exampleRole", + "warehouse": "exampleWarehouse", + }, + }, + } + _, err = FetchAdvancedConfigString( + sampleCfg, + "snowflake", []string{"account", "role", "warehouse"}, + ) + require.Error(t, err) +} + +func TestAdvancedConfigMalformed(t *testing.T) { + sampleCfg := RepoConfig{ + Advanced: map[string]any{ + "snowflake": map[string]any{ + // Let's give a _list_ of things + "account": []string{"account1", "account2"}, + "role": []string{"role1", "role2"}, + "warehouse": []string{"warehouse1", "warehouse2"}, + }, + }, + } + _, err := FetchAdvancedConfigString( + sampleCfg, + "snowflake", []string{"account", "role", "warehouse"}, + ) + require.Error(t, err) +} diff --git a/discovery/denodo.go b/sql/denodo.go similarity index 65% rename from discovery/denodo.go rename to sql/denodo.go index 2075db5..9e7a730 100644 --- a/discovery/denodo.go +++ b/sql/denodo.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -24,19 +24,19 @@ const ( "CATALOG_VDP_METADATA_VIEWS()" ) -// DenodoRepository is a sql.SQLRepository implementation for Denodo. +// DenodoRepository is a Repository implementation for Denodo. type DenodoRepository struct { - // The majority of the sql.SQLRepository functionality is delegated to + // The majority of the Repository functionality is delegated to // a generic SQL repository instance. - genericSqlRepo *GenericRepository + generic *GenericRepository } -// DenodoRepository implements sql.SQLRepository -var _ SQLRepository = (*DenodoRepository)(nil) +// DenodoRepository implements sql.Repository +var _ Repository = (*DenodoRepository)(nil) // NewDenodoRepository is the constructor for sql. func NewDenodoRepository(cfg RepoConfig) (*DenodoRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) + pgCfg, err := parsePostgresConfig(cfg) if err != nil { return nil, fmt.Errorf("unable to parse postgres config: %w", err) } @@ -52,19 +52,11 @@ func NewDenodoRepository(cfg RepoConfig) (*DenodoRepository, error) { cfg.Database, pgCfg.ConnOptsStr, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &DenodoRepository{genericSqlRepo: sqlRepo}, nil + return &DenodoRepository{generic: generic}, nil } // ListDatabases is left unimplemented for Denodo, because Denodo doesn't have @@ -74,41 +66,40 @@ func (r *DenodoRepository) ListDatabases(_ context.Context) ([]string, error) { } // Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *DenodoRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.IntrospectWithQuery(ctx, DenodoIntrospectQuery) +func (r *DenodoRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.IntrospectWithQuery(ctx, DenodoIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using a Denodo-specific -// table sample query. See SQLRepository.SampleTable and +// table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. func (r *DenodoRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Denodo uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // The postgres driver is currently unable to properly send the // parameters of a prepared statement to Denodo. Therefore, instead of // building a prepared statement, we populate the query string before // sending it to the driver. query := fmt.Sprintf( "SELECT %s FROM %s.%s OFFSET %d ROWS LIMIT %d", - attrStr, meta.Schema, meta.Name, params.Offset, params.SampleSize, + attrStr, params.Metadata.Schema, params.Metadata.Name, params.Offset, params.SampleSize, ) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query) + return r.generic.SampleTableWithQuery(ctx, query, params) } -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and +// Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. func (r *DenodoRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. func (r *DenodoRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/sql/doc.go b/sql/doc.go new file mode 100644 index 0000000..e1f8447 --- /dev/null +++ b/sql/doc.go @@ -0,0 +1,14 @@ +// Package sql provides mechanisms to perform database introspection and +// data discovery on various SQL data repositories. +// +// Additionally, the Repository interface provides an API for performing +// database introspection and data discovery on SQL databases. It encapsulates +// the concept of a Dmap data SQL repository. All out-of-the-box Repository +// implementations are included in their own files named after the repository +// type, e.g. mysql.go, postgres.go, etc. +// +// Registry provides an API for registering and constructing Repository +// implementations within an application. There is a global DefaultRegistry +// which has all-out-of-the-box Repository implementations registered to it +// by default. +package sql diff --git a/scan/gen.go b/sql/gen.go similarity index 63% rename from scan/gen.go rename to sql/gen.go index a6b6284..c544e77 100644 --- a/scan/gen.go +++ b/sql/gen.go @@ -1,5 +1,6 @@ -package scan +package sql // Mock generation - see https://vektra.github.io/mockery/ //go:generate mockery --with-expecter --srcpkg=github.com/cyralinc/dmap/classification --name=Classifier --filename=mock_classifier.go +//go:generate mockery --with-expecter --inpackage --name=Repository --filename=mock_repository_test.go diff --git a/discovery/generic.go b/sql/generic.go similarity index 68% rename from discovery/generic.go rename to sql/generic.go index 92db79c..1c870fd 100644 --- a/discovery/generic.go +++ b/sql/generic.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -34,27 +34,24 @@ const ( ) // GenericRepository implements generic SQL functionalities that work for a -// subset of ANSI SQL compatible databases. Many SQLRepository implementations may +// subset of ANSI SQL compatible databases. Many Repository implementations may // partially or fully delegate to this implementation. In that respect, it acts // somewhat as a base implementation which can be used by SQL-compatible // repositories. Note that while GenericRepository is an implementation of -// the SQLRepository interface, GenericRepository is meant to be used as a building -// block for other SQLRepository implementations, rather than as a standalone -// implementation. Specifically, the SQLRepository.ListDatabases method is left +// the Repository interface, GenericRepository is meant to be used as a building +// block for other Repository implementations, rather than as a standalone +// implementation. Specifically, the Repository.ListDatabases method is left // un-implemented, since there is no standard way to list databases across // different SQL database platforms. It does however provide the // ListDatabasesWithQuery method, which dependent implementations can use to // provide a custom query to list databases. type GenericRepository struct { - repoName string - repoType string - database string - db *sql.DB - includePaths []glob.Glob - excludePaths []glob.Glob + repoType string + database string + db *sql.DB } -var _ SQLRepository = (*GenericRepository)(nil) +var _ Repository = (*GenericRepository)(nil) // NewGenericRepository is a constructor for the GenericRepository type. It // opens a database handle for a given repoType and returns a pointer to a new @@ -64,15 +61,7 @@ var _ SQLRepository = (*GenericRepository)(nil) // connections to the database. The repoIncludePaths and repoExcludePaths // parameters are used to filter the tables and columns that are introspected by // the repository. -func NewGenericRepository( - repoName, - repoType, - database, - connStr string, - maxOpenConns uint, - repoIncludePaths, - repoExcludePaths []glob.Glob, -) ( +func NewGenericRepository(repoType, database, connStr string, maxOpenConns uint) ( *GenericRepository, error, ) { @@ -81,20 +70,16 @@ func NewGenericRepository( return nil, fmt.Errorf("error retrieving DB handle for repo type %s: %w", repoType, err) } return &GenericRepository{ - repoName: repoName, - repoType: repoType, - database: database, - db: db, - includePaths: repoIncludePaths, - excludePaths: repoExcludePaths, + repoType: repoType, + database: database, + db: db, }, nil } // NewGenericRepositoryFromDB instantiate a new GenericRepository based on a // given sql.DB handle. -func NewGenericRepositoryFromDB(repoName, repoType, database string, db *sql.DB) *GenericRepository { +func NewGenericRepositoryFromDB(repoType, database string, db *sql.DB) *GenericRepository { return &GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -139,8 +124,11 @@ func (r *GenericRepository) ListDatabasesWithQuery( } // Introspect calls IntrospectWithQuery with a default query string -func (r *GenericRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.IntrospectWithQuery(ctx, GenericIntrospectQuery) +func (r *GenericRepository) Introspect( + ctx context.Context, + params IntrospectParameters, +) (*Metadata, error) { + return r.IntrospectWithQuery(ctx, GenericIntrospectQuery, params) } // IntrospectWithQuery executes a query against the information_schema table in @@ -156,64 +144,55 @@ func (r *GenericRepository) Introspect(ctx context.Context) (*Metadata, error) { func (r *GenericRepository) IntrospectWithQuery( ctx context.Context, query string, + params IntrospectParameters, ) (*Metadata, error) { log.Tracef("Query: %s", query) rows, err := r.db.QueryContext(ctx, query) if err != nil { - return nil, err + return nil, fmt.Errorf("error performing introspect query: %w", err) } defer func() { _ = rows.Close() }() - - repoMeta, err := newMetadataFromQueryResult( - r.repoType, r.repoName, - r.database, r.includePaths, r.excludePaths, rows, - ) - if err != nil { - return nil, err - } - return repoMeta, nil + return newMetadataFromQueryResult(r.database, params.IncludePaths, params.ExcludePaths, rows) } // SampleTable samples the table referenced by the TableMetadata meta parameter // by issuing a standard, ANSI-compatible SELECT query to the database. All // attributes of the table are selected, and are quoted using double quotes. See -// SQLRepository.SampleTable for more details. +// Repository.SampleTable for more details. func (r *GenericRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // ANSI SQL uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + attrStr := params.Metadata.QuotedAttributeNamesString("\"") + query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.SampleTableWithQuery(ctx, query, params) } // SampleTableWithQuery calls SampleTable with a custom SQL query. Any // placeholder parameters in the query should be passed via params. func (r *GenericRepository) SampleTableWithQuery( ctx context.Context, - meta *TableMetadata, query string, - params ...any, + params SampleParameters, ) (Sample, error) { log.Tracef("Query: %s", query) - rows, err := r.db.QueryContext(ctx, query, params...) + rows, err := r.db.QueryContext(ctx, query, params.SampleSize, params.Offset) if err != nil { return Sample{}, - fmt.Errorf("error sampling %s.%s.%s: %w", r.database, meta.Schema, meta.Name, err) + fmt.Errorf( + "error sampling database %s, schema %s, table %s: %w", + r.database, + params.Metadata.Schema, + params.Metadata.Name, + err, + ) } defer func() { _ = rows.Close() }() - sample := Sample{ - Metadata: SampleMetadata{ - Repo: r.repoName, - Database: r.database, - Schema: meta.Schema, - Table: meta.Name, - }, + TablePath: []string{r.database, params.Metadata.Schema, params.Metadata.Name}, } - + // Iterate the row set and append each row to the sample results. for rows.Next() { data, err := getCurrentRowAsMap(rows) if err != nil { @@ -221,13 +200,13 @@ func (r *GenericRepository) SampleTableWithQuery( } sample.Results = append(sample.Results, data) } - - // Something broke while iterating the row set - err = rows.Err() - if err != nil { + if err := rows.Err(); err != nil { + // Something broke while iterating the row set. return Sample{}, fmt.Errorf("error iterating sample data row set: %w", err) } - + if len(sample.Results) == 0 { + return Sample{}, nil + } return sample, nil } @@ -265,61 +244,6 @@ func newDbHandle(repoType, connStr string, maxOpenConns uint) (*sql.DB, error) { return db, nil } -// newMetadataFromQueryResult builds the repository metadata from the results -// of a query to the INFORMATION_SCHEMA columns view. -func newMetadataFromQueryResult( - repoType, repoName, db string, - includePaths, excludePaths []glob.Glob, rows *sql.Rows, -) ( - *Metadata, - error, -) { - repo := NewMetadata(repoType, repoName, db) - - for rows.Next() { - var attr AttributeMetadata - err := rows.Scan(&attr.Schema, &attr.Table, &attr.Name, &attr.DataType) - if err != nil { - return nil, err - } - - // skip tables that match excludePaths or does not match includePaths - log.Tracef("checking if %s.%s.%s matches excludePaths %s\n", db, attr.Schema, attr.Table, excludePaths) - if matchPathPatterns(db, attr.Schema, attr.Table, excludePaths) { - continue - } - log.Tracef("checking if %s.%s.%s matches includePaths: %s\n", db, attr.Schema, attr.Table, includePaths) - if !matchPathPatterns(db, attr.Schema, attr.Table, includePaths) { - continue - } - - // SchemaMetadata exists - add a table if necessary - if schema, ok := repo.Schemas[attr.Schema]; ok { - // TableMetadata exists - just append the attribute - if table, ok := schema.Tables[attr.Table]; ok { - table.Attributes = append(table.Attributes, &attr) - } else { // First time seeing this table - table := NewTableMetadata(attr.Schema, attr.Table) - table.Attributes = append(table.Attributes, &attr) - schema.Tables[attr.Table] = table - } - } else { // SchemaMetadata doesn't exist - create it - table := NewTableMetadata(attr.Schema, attr.Table) - table.Attributes = append(table.Attributes, &attr) - schema := NewSchemaMetadata(attr.Schema) - schema.Tables[attr.Table] = table - repo.Schemas[attr.Schema] = schema - } - } - - // Something broke while iterating the row set - if err := rows.Err(); err != nil { - return nil, err - } - - return repo, nil -} - // getCurrentRowAsMap transforms the current row referenced by a sql.Rows row // set into a map where the key is the column name and the value is the column // value. It is effectively an alternative to the sql.Rows.Scan method, where it diff --git a/discovery/generic_test.go b/sql/generic_test.go similarity index 84% rename from discovery/generic_test.go rename to sql/generic_test.go index ac13df5..86bab72 100644 --- a/discovery/generic_test.go +++ b/sql/generic_test.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -7,7 +7,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/gobwas/glob" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,16 +15,13 @@ func Test_Introspect_IsSuccessful(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, - repoType: repoType, - database: database, - db: db, - includePaths: []glob.Glob{glob.MustCompile("exampleDb.*")}, + repoType: repoType, + database: database, + db: db, } cols := []string{ @@ -46,11 +42,12 @@ func Test_Introspect_IsSuccessful(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + params := IntrospectParameters{ + IncludePaths: []glob.Glob{glob.MustCompile("exampleDb.*")}, + } + meta, err := repo.Introspect(ctx, params) expectedMetadata := Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: map[string]*SchemaMetadata{ "schema1": { @@ -108,8 +105,8 @@ func Test_Introspect_IsSuccessful(t *testing.T) { }, } - assert.NoError(t, err) - assert.EqualValues(t, &expectedMetadata, meta) + require.NoError(t, err) + require.EqualValues(t, &expectedMetadata, meta) } func Test_Introspect_QueryError(t *testing.T) { @@ -117,12 +114,10 @@ func Test_Introspect_QueryError(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -135,10 +130,10 @@ func Test_Introspect_QueryError(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + meta, err := repo.Introspect(ctx, IntrospectParameters{}) - assert.Nil(t, meta) - assert.EqualError(t, err, expectedErr.Error()) + require.Nil(t, meta) + require.ErrorIs(t, err, expectedErr) } func Test_Introspect_RowError(t *testing.T) { @@ -146,12 +141,10 @@ func Test_Introspect_RowError(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -168,17 +161,17 @@ func Test_Introspect_RowError(t *testing.T) { rows := sqlmock.NewRows(cols). AddRow("schema1", "table1", "column1", "varchar"). - RowError(0, errors.New("dummy error")) + RowError(0, expectedErr) mock.ExpectQuery("SELECT (.+) FROM information_schema.columns WHERE (.+)"). WillReturnRows(rows) ctx := context.Background() - meta, err := repo.Introspect(ctx) + meta, err := repo.Introspect(ctx, IntrospectParameters{}) - assert.Nil(t, meta) - assert.EqualError(t, err, expectedErr.Error()) + require.Nil(t, meta) + require.ErrorIs(t, err, expectedErr) } func Test_Introspect_Filtered(t *testing.T) { @@ -186,22 +179,13 @@ func Test_Introspect_Filtered(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, - includePaths: []glob.Glob{ - glob.MustCompile("exampleDb.schema1.*"), - glob.MustCompile("exampleDb.*.table1"), - }, - excludePaths: []glob.Glob{ - glob.MustCompile("exampleDb.schema3.*"), - }, } cols := []string{ @@ -226,11 +210,18 @@ func Test_Introspect_Filtered(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + params := IntrospectParameters{ + IncludePaths: []glob.Glob{ + glob.MustCompile("exampleDb.schema1.*"), + glob.MustCompile("exampleDb.*.table1"), + }, + ExcludePaths: []glob.Glob{ + glob.MustCompile("exampleDb.schema3.*"), + }, + } + meta, err := repo.Introspect(ctx, params) expectedMetadata := Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: map[string]*SchemaMetadata{ "schema1": { @@ -294,6 +285,6 @@ func Test_Introspect_Filtered(t *testing.T) { }, } - assert.NoError(t, err) - assert.EqualValues(t, &expectedMetadata, meta) + require.NoError(t, err) + require.EqualValues(t, &expectedMetadata, meta) } diff --git a/discovery/metadata.go b/sql/metadata.go similarity index 58% rename from discovery/metadata.go rename to sql/metadata.go index ca71935..622e8ef 100644 --- a/discovery/metadata.go +++ b/sql/metadata.go @@ -1,7 +1,12 @@ -package discovery +package sql import ( + "database/sql" + "fmt" "strings" + + "github.com/gobwas/glob" + log "github.com/sirupsen/logrus" ) // Metadata represents the structure of a SQL database. The traditional @@ -10,7 +15,6 @@ import ( // those cases, the 'Database' field is expected to be an empty string. See: // https://stackoverflow.com/a/17943883 type Metadata struct { - Name string RepoType string Database string Schemas map[string]*SchemaMetadata @@ -60,15 +64,66 @@ type AttributeMetadata struct { // NewMetadata creates a new Metadata object with the given repository type, // repository name, and database name, with an empty map of schemas. -func NewMetadata(repoType, repoName, database string) *Metadata { +func NewMetadata(database string) *Metadata { return &Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: make(map[string]*SchemaMetadata), } } +// newMetadataFromQueryResult builds the repository metadata from the results +// of a query to the INFORMATION_SCHEMA columns view. +func newMetadataFromQueryResult( + db string, + includePaths, excludePaths []glob.Glob, + rows *sql.Rows, +) ( + *Metadata, + error, +) { + repo := NewMetadata(db) + for rows.Next() { + var attr AttributeMetadata + if err := rows.Scan(&attr.Schema, &attr.Table, &attr.Name, &attr.DataType); err != nil { + return nil, fmt.Errorf("error scanning metadata query result row: %w", err) + } + + // Skip tables that match excludePaths or does not match includePaths. + log.Tracef("checking if %s.%s.%s matches excludePaths %s\n", db, attr.Schema, attr.Table, excludePaths) + if matchPathPatterns(db, attr.Schema, attr.Table, excludePaths) { + continue + } + log.Tracef("checking if %s.%s.%s matches includePaths: %s\n", db, attr.Schema, attr.Table, includePaths) + if !matchPathPatterns(db, attr.Schema, attr.Table, includePaths) { + continue + } + + // SchemaMetadata exists - add a table if necessary. + if schema, ok := repo.Schemas[attr.Schema]; ok { + // TableMetadata exists - just append the attribute. + if table, ok := schema.Tables[attr.Table]; ok { + table.Attributes = append(table.Attributes, &attr) + } else { // First time seeing this table. + table := NewTableMetadata(attr.Schema, attr.Table) + table.Attributes = append(table.Attributes, &attr) + schema.Tables[attr.Table] = table + } + } else { // SchemaMetadata doesn't exist - create it. + table := NewTableMetadata(attr.Schema, attr.Table) + table.Attributes = append(table.Attributes, &attr) + schema := NewSchemaMetadata(attr.Schema) + schema.Tables[attr.Table] = table + repo.Schemas[attr.Schema] = schema + } + } + + // Something broke while iterating the row set. + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating metadata query rows: %w", err) + } + return repo, nil +} + // NewSchemaMetadata creates a new SchemaMetadata object with the given schema // name and an empty map of tables. func NewSchemaMetadata(schemaName string) *SchemaMetadata { diff --git a/discovery/metadata_test.go b/sql/metadata_test.go similarity index 88% rename from discovery/metadata_test.go rename to sql/metadata_test.go index c787e03..fbd5d12 100644 --- a/discovery/metadata_test.go +++ b/sql/metadata_test.go @@ -1,9 +1,9 @@ -package discovery +package sql import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAttributeNames(t *testing.T) { @@ -34,7 +34,7 @@ func TestAttributeNames(t *testing.T) { names := []string{"name1", "name2", "name3"} - assert.ElementsMatch(t, table.AttributeNames(), names) + require.ElementsMatch(t, table.AttributeNames(), names) } func TestQuotedAttributeNamesString(t *testing.T) { @@ -67,5 +67,5 @@ func TestQuotedAttributeNamesString(t *testing.T) { expected := "`name1`,`name2`,`name3`" namesStr := table.QuotedAttributeNamesString(quoteChar) - assert.Equal(t, expected, namesStr) + require.Equal(t, expected, namesStr) } diff --git a/discovery/mock_repository_test.go b/sql/mock_repository_test.go similarity index 67% rename from discovery/mock_repository_test.go rename to sql/mock_repository_test.go index 492bcb9..13ff093 100644 --- a/discovery/mock_repository_test.go +++ b/sql/mock_repository_test.go @@ -1,6 +1,6 @@ // Code generated by mockery v2.42.1. DO NOT EDIT. -package discovery +package sql import ( context "context" @@ -8,7 +8,7 @@ import ( mock "github.com/stretchr/testify/mock" ) -// MockRepository is an autogenerated mock type for the SQLRepository type +// MockRepository is an autogenerated mock type for the Repository type type MockRepository struct { mock.Mock } @@ -50,11 +50,9 @@ func (_e *MockRepository_Expecter) Close() *MockRepository_Close_Call { } func (_c *MockRepository_Close_Call) Run(run func()) *MockRepository_Close_Call { - _c.Call.Run( - func(args mock.Arguments) { - run() - }, - ) + _c.Call.Run(func(args mock.Arguments) { + run() + }) return _c } @@ -68,9 +66,9 @@ func (_c *MockRepository_Close_Call) RunAndReturn(run func() error) *MockReposit return _c } -// Introspect provides a mock function with given fields: ctx -func (_m *MockRepository) Introspect(ctx context.Context) (*Metadata, error) { - ret := _m.Called(ctx) +// Introspect provides a mock function with given fields: ctx, params +func (_m *MockRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + ret := _m.Called(ctx, params) if len(ret) == 0 { panic("no return value specified for Introspect") @@ -78,19 +76,19 @@ func (_m *MockRepository) Introspect(ctx context.Context) (*Metadata, error) { var r0 *Metadata var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*Metadata, error)); ok { - return rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, IntrospectParameters) (*Metadata, error)); ok { + return rf(ctx, params) } - if rf, ok := ret.Get(0).(func(context.Context) *Metadata); ok { - r0 = rf(ctx) + if rf, ok := ret.Get(0).(func(context.Context, IntrospectParameters) *Metadata); ok { + r0 = rf(ctx, params) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*Metadata) } } - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) + if rf, ok := ret.Get(1).(func(context.Context, IntrospectParameters) error); ok { + r1 = rf(ctx, params) } else { r1 = ret.Error(1) } @@ -105,16 +103,15 @@ type MockRepository_Introspect_Call struct { // Introspect is a helper method to define mock.On call // - ctx context.Context -func (_e *MockRepository_Expecter) Introspect(ctx interface{}) *MockRepository_Introspect_Call { - return &MockRepository_Introspect_Call{Call: _e.mock.On("Introspect", ctx)} +// - params IntrospectParameters +func (_e *MockRepository_Expecter) Introspect(ctx interface{}, params interface{}) *MockRepository_Introspect_Call { + return &MockRepository_Introspect_Call{Call: _e.mock.On("Introspect", ctx, params)} } -func (_c *MockRepository_Introspect_Call) Run(run func(ctx context.Context)) *MockRepository_Introspect_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) +func (_c *MockRepository_Introspect_Call) Run(run func(ctx context.Context, params IntrospectParameters)) *MockRepository_Introspect_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(IntrospectParameters)) + }) return _c } @@ -123,12 +120,7 @@ func (_c *MockRepository_Introspect_Call) Return(_a0 *Metadata, _a1 error) *Mock return _c } -func (_c *MockRepository_Introspect_Call) RunAndReturn( - run func(context.Context) ( - *Metadata, - error, - ), -) *MockRepository_Introspect_Call { +func (_c *MockRepository_Introspect_Call) RunAndReturn(run func(context.Context, IntrospectParameters) (*Metadata, error)) *MockRepository_Introspect_Call { _c.Call.Return(run) return _c } @@ -175,11 +167,9 @@ func (_e *MockRepository_Expecter) ListDatabases(ctx interface{}) *MockRepositor } func (_c *MockRepository_ListDatabases_Call) Run(run func(ctx context.Context)) *MockRepository_ListDatabases_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) return _c } @@ -188,12 +178,7 @@ func (_c *MockRepository_ListDatabases_Call) Return(_a0 []string, _a1 error) *Mo return _c } -func (_c *MockRepository_ListDatabases_Call) RunAndReturn( - run func(context.Context) ( - []string, - error, - ), -) *MockRepository_ListDatabases_Call { +func (_c *MockRepository_ListDatabases_Call) RunAndReturn(run func(context.Context) ([]string, error)) *MockRepository_ListDatabases_Call { _c.Call.Return(run) return _c } @@ -228,11 +213,9 @@ func (_e *MockRepository_Expecter) Ping(ctx interface{}) *MockRepository_Ping_Ca } func (_c *MockRepository_Ping_Call) Run(run func(ctx context.Context)) *MockRepository_Ping_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) return _c } @@ -246,12 +229,9 @@ func (_c *MockRepository_Ping_Call) RunAndReturn(run func(context.Context) error return _c } -// SampleTable provides a mock function with given fields: ctx, meta, params -func (_m *MockRepository) SampleTable(ctx context.Context, meta *TableMetadata, params SampleParameters) ( - Sample, - error, -) { - ret := _m.Called(ctx, meta, params) +// SampleTable provides a mock function with given fields: ctx, params +func (_m *MockRepository) SampleTable(ctx context.Context, params SampleParameters) (Sample, error) { + ret := _m.Called(ctx, params) if len(ret) == 0 { panic("no return value specified for SampleTable") @@ -259,17 +239,17 @@ func (_m *MockRepository) SampleTable(ctx context.Context, meta *TableMetadata, var r0 Sample var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *TableMetadata, SampleParameters) (Sample, error)); ok { - return rf(ctx, meta, params) + if rf, ok := ret.Get(0).(func(context.Context, SampleParameters) (Sample, error)); ok { + return rf(ctx, params) } - if rf, ok := ret.Get(0).(func(context.Context, *TableMetadata, SampleParameters) Sample); ok { - r0 = rf(ctx, meta, params) + if rf, ok := ret.Get(0).(func(context.Context, SampleParameters) Sample); ok { + r0 = rf(ctx, params) } else { r0 = ret.Get(0).(Sample) } - if rf, ok := ret.Get(1).(func(context.Context, *TableMetadata, SampleParameters) error); ok { - r1 = rf(ctx, meta, params) + if rf, ok := ret.Get(1).(func(context.Context, SampleParameters) error); ok { + r1 = rf(ctx, params) } else { r1 = ret.Error(1) } @@ -284,28 +264,15 @@ type MockRepository_SampleTable_Call struct { // SampleTable is a helper method to define mock.On call // - ctx context.Context -// - meta *TableMetadata // - params SampleParameters -func (_e *MockRepository_Expecter) SampleTable( - ctx interface{}, - meta interface{}, - params interface{}, -) *MockRepository_SampleTable_Call { - return &MockRepository_SampleTable_Call{Call: _e.mock.On("SampleTable", ctx, meta, params)} -} - -func (_c *MockRepository_SampleTable_Call) Run( - run func( - ctx context.Context, - meta *TableMetadata, - params SampleParameters, - ), -) *MockRepository_SampleTable_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*TableMetadata), args[2].(SampleParameters)) - }, - ) +func (_e *MockRepository_Expecter) SampleTable(ctx interface{}, params interface{}) *MockRepository_SampleTable_Call { + return &MockRepository_SampleTable_Call{Call: _e.mock.On("SampleTable", ctx, params)} +} + +func (_c *MockRepository_SampleTable_Call) Run(run func(ctx context.Context, params SampleParameters)) *MockRepository_SampleTable_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(SampleParameters)) + }) return _c } @@ -314,25 +281,17 @@ func (_c *MockRepository_SampleTable_Call) Return(_a0 Sample, _a1 error) *MockRe return _c } -func (_c *MockRepository_SampleTable_Call) RunAndReturn( - run func( - context.Context, - *TableMetadata, - SampleParameters, - ) (Sample, error), -) *MockRepository_SampleTable_Call { +func (_c *MockRepository_SampleTable_Call) RunAndReturn(run func(context.Context, SampleParameters) (Sample, error)) *MockRepository_SampleTable_Call { _c.Call.Return(run) return _c } -// NewMockRepository creates a new instance of Mocksql. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewMockRepository creates a new instance of MockRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewMockRepository( - t interface { - mock.TestingT - Cleanup(func()) - }, -) *MockRepository { +func NewMockRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *MockRepository { mock := &MockRepository{} mock.Mock.Test(t) diff --git a/scan/mocks/mock_classifier.go b/sql/mocks/mock_classifier.go similarity index 100% rename from scan/mocks/mock_classifier.go rename to sql/mocks/mock_classifier.go diff --git a/discovery/mysql.go b/sql/mysql.go similarity index 56% rename from discovery/mysql.go rename to sql/mysql.go index 7a7afc8..0973c8b 100644 --- a/discovery/mysql.go +++ b/sql/mysql.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -22,15 +22,15 @@ WHERE ` ) -// MySqlRepository is a SQLRepository implementation for MySQL databases. +// MySqlRepository is a Repository implementation for MySQL databases. type MySqlRepository struct { - // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // The majority of the Repository functionality is delegated to + // a generic SQL repository instance. + generic *GenericRepository } -// MySqlRepository implements SQLRepository -var _ SQLRepository = (*MySqlRepository)(nil) +// MySqlRepository implements Repository +var _ Repository = (*MySqlRepository)(nil) // NewMySqlRepository creates a new MySQL sql. func NewMySqlRepository(cfg RepoConfig) (*MySqlRepository, error) { @@ -44,59 +44,50 @@ func NewMySqlRepository(cfg RepoConfig) (*MySqlRepository, error) { // https://github.com/go-sql-driver/mysql#dsn-data-source-name cfg.Database, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeMysql, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeMysql, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &MySqlRepository{genericSqlRepo: sqlRepo}, nil + return &MySqlRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a MySQL-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *MySqlRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, MySqlDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, MySqlDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *MySqlRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *MySqlRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a MySQL-specific -// table sample query. See SQLRepository.SampleTable and +// table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. func (r *MySqlRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // MySQL uses backticks to quote identifiers. - attrStr := meta.QuotedAttributeNamesString("`") + attrStr := params.Metadata.QuotedAttributeNamesString("`") // The generic select/limit/offset query and ? placeholders work fine with // MySQL. - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and +// Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. func (r *MySqlRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. func (r *MySqlRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/mysql_test.go b/sql/mysql_test.go similarity index 87% rename from discovery/mysql_test.go rename to sql/mysql_test.go index 9361a5f..346d05b 100644 --- a/discovery/mysql_test.go +++ b/sql/mysql_test.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -24,6 +24,6 @@ func initMySqlRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &MySqlRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeMysql, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeMysql, "dbName", db), } } diff --git a/discovery/oracle.go b/sql/oracle.go similarity index 73% rename from discovery/oracle.go rename to sql/oracle.go index a3dfcb1..e4c0394 100644 --- a/discovery/oracle.go +++ b/sql/oracle.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -36,15 +36,15 @@ ON configServiceName = "service-name" ) -// OracleRepository is a SQLRepository implementation for Oracle databases. +// OracleRepository is a Repository implementation for Oracle databases. type OracleRepository struct { // The majority of the OracleRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } -// OracleRepository implements SQLRepository -var _ SQLRepository = (*OracleRepository)(nil) +// OracleRepository implements Repository +var _ Repository = (*OracleRepository)(nil) // NewOracleRepository creates a new Oracle repository. func NewOracleRepository(cfg RepoConfig) (*OracleRepository, error) { @@ -60,19 +60,11 @@ func NewOracleRepository(cfg RepoConfig) (*OracleRepository, error) { cfg.Port, oracleCfg.ServiceName, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeOracle, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeOracle, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &OracleRepository{genericSqlRepo: sqlRepo}, nil + return &OracleRepository{generic: generic}, nil } // ListDatabases is left unimplemented for Oracle, because Oracle doesn't have @@ -83,28 +75,27 @@ func (r *OracleRepository) ListDatabases(_ context.Context) ([]string, error) { } // Introspect delegates introspection to GenericRepository, using an -// Oracle-specific introspection query. See SQLRepository.Introspect and +// Oracle-specific introspection query. See Repository.Introspect and // GenericRepository.IntrospectWithQuery for more details. -func (r *OracleRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.IntrospectWithQuery(ctx, OracleIntrospectQuery) +func (r *OracleRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.IntrospectWithQuery(ctx, OracleIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using an Oracle-specific -// table sample query. See SQLRepository.SampleTable and +// table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. func (r *OracleRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Oracle uses double-quotes to quote identifiers. - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // Oracle uses :x for placeholders. query := fmt.Sprintf( "SELECT %s FROM %s.%s OFFSET :1 ROWS FETCH NEXT :2 ROWS ONLY", - attrStr, meta.Schema, meta.Name, + attrStr, params.Metadata.Schema, params.Metadata.Name, ) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.Offset, params.SampleSize) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping verifies the connection to Oracle database used by this Oracle @@ -113,13 +104,13 @@ func (r *OracleRepository) SampleTable( // Oracle being Oracle does not like this. Instead, we defer to the native // Ping method implemented by the Oracle DB driver. func (r *OracleRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.GetDb().PingContext(ctx) + return r.generic.GetDb().PingContext(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. func (r *OracleRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } // OracleConfig is a struct to hold Oracle-specific configuration. diff --git a/sql/postgres.go b/sql/postgres.go new file mode 100644 index 0000000..e5a81ee --- /dev/null +++ b/sql/postgres.go @@ -0,0 +1,191 @@ +package sql + +import ( + "context" + "fmt" + "strings" + + // Postgresql DB driver + _ "github.com/lib/pq" +) + +const ( + RepoTypePostgres = "postgres" + + PostgresDatabaseQuery = ` +SELECT + datname +FROM + pg_database +WHERE + datistemplate = false + AND datallowconn = true + AND datname <> 'rdsadmin' +` +) + +// PostgresRepository is a Repository implementation for Postgres databases. +type PostgresRepository struct { + // The majority of the Repository functionality is delegated to + // a generic SQL repository instance. + generic *GenericRepository +} + +// PostgresRepository implements Repository +var _ Repository = (*PostgresRepository)(nil) + +// NewPostgresRepository creates a new PostgresRepository. +func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { + pgCfg, err := parsePostgresConfig(cfg) + if err != nil { + return nil, fmt.Errorf("error parsing postgres config: %w", err) + } + database := cfg.Database + // Connect to the default database, if unspecified. + if database == "" { + database = "postgres" + } + connStr := fmt.Sprintf( + "postgresql://%s:%s@%s:%d/%s%s", + cfg.User, + cfg.Password, + cfg.Host, + cfg.Port, + database, + pgCfg.ConnOptsStr, + ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) + if err != nil { + return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) + } + return &PostgresRepository{generic: generic}, nil +} + +// ListDatabases returns a list of the names of all databases on the server by +// using a Postgres-specific database query. It delegates the actual work to +// GenericRepository.ListDatabasesWithQuery - see that method for more details. +func (r *PostgresRepository) ListDatabases(ctx context.Context) ([]string, error) { + return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) +} + +// Introspect delegates introspection to GenericRepository. See +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more +// details. +func (r *PostgresRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) +} + +// SampleTable delegates sampling to GenericRepository, using a +// Postgres-specific table sample query. See Repository.SampleTable and +// GenericRepository.SampleTableWithQuery for more details. +func (r *PostgresRepository) SampleTable( + ctx context.Context, + params SampleParameters, +) (Sample, error) { + // Postgres uses double-quotes to quote identifiers + attrStr := params.Metadata.QuotedAttributeNamesString("\"") + // Postgres uses $x for placeholders + query := fmt.Sprintf( + "SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", + attrStr, + params.Metadata.Schema, + params.Metadata.Name, + ) + return r.generic.SampleTableWithQuery(ctx, query, params) +} + +// Ping delegates the ping to GenericRepository. See Repository.Ping and +// GenericRepository.Ping for more details. +func (r *PostgresRepository) Ping(ctx context.Context) error { + return r.generic.Ping(ctx) +} + +// Close delegates the close to GenericRepository. See Repository.Close and +// GenericRepository.Close for more details. +func (r *PostgresRepository) Close() error { + return r.generic.Close() +} + +// PostgresConfig contains Postgres-specific configuration parameters. +type PostgresConfig struct { + // ConnOptsStr is a string containing Postgres-specific connection options. + ConnOptsStr string +} + +// parsePostgresConfig parses the Postgres-specific configuration parameters +// from the given The Postgres connection options are built from the +// config and stored in the ConnOptsStr field of the returned Postgres +func parsePostgresConfig(cfg RepoConfig) (*PostgresConfig, error) { + connOptsStr, err := buildConnOptsStr(cfg) + if err != nil { + return nil, fmt.Errorf("error building connection options string: %w", err) + } + return &PostgresConfig{ConnOptsStr: connOptsStr}, nil +} + +// buildConnOptsStr parses the repo config to produce a string in the format +// "?option=value&option2=value2". Example: +// +// buildConnOptsStr(RepoConfig{ +// Advanced: map[string]any{ +// "connection-string-args": []any{"sslmode=disable"}, +// }, +// }) +// +// returns ("?sslmode=disable", nil). +func buildConnOptsStr(cfg RepoConfig) (string, error) { + connOptsMap, err := mapFromConnOpts(cfg) + if err != nil { + return "", fmt.Errorf("connection options: %w", err) + } + connOptsStr := "" + for key, val := range connOptsMap { + // Don't add if the value is empty, since that would make the + // string malformed. + if val != "" { + if connOptsStr == "" { + connOptsStr += fmt.Sprintf("%s=%s", key, val) + } else { + // Need & for subsequent options + connOptsStr += fmt.Sprintf("&%s=%s", key, val) + } + } + } + // Only add ? if connection string is not empty + if connOptsStr != "" { + connOptsStr = "?" + connOptsStr + } + return connOptsStr, nil +} + +// mapFromConnOpts builds a map from the list of connection options given. Each +// option has the format 'option=value'. An error is returned if the config is +// malformed. +func mapFromConnOpts(cfg RepoConfig) (map[string]string, error) { + m := make(map[string]string) + connOptsInterface, ok := cfg.Advanced[configConnOpts] + if !ok { + return nil, nil + } + connOpts, ok := connOptsInterface.([]any) + if !ok { + return nil, fmt.Errorf("'%s' is not a list", configConnOpts) + } + for _, optInterface := range connOpts { + opt, ok := optInterface.(string) + if !ok { + return nil, fmt.Errorf("'%v' is not a string", optInterface) + } + splitOpt := strings.Split(opt, "=") + if len(splitOpt) != 2 { + return nil, fmt.Errorf( + "malformed '%s'. "+ + "Please follow the format 'option=value'", configConnOpts, + ) + } + key := splitOpt[0] + val := splitOpt[1] + m[key] = val + } + return m, nil +} diff --git a/discovery/config_test.go b/sql/postgres_test.go similarity index 50% rename from discovery/config_test.go rename to sql/postgres_test.go index c54ea99..84a1974 100644 --- a/discovery/config_test.go +++ b/sql/postgres_test.go @@ -1,23 +1,27 @@ -package discovery +package sql import ( + "context" + "database/sql" "testing" + "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/require" ) -// Returns a correct repo config -func getSampleRepoConfig() RepoConfig { - return RepoConfig{ - Advanced: map[string]any{ - configConnOpts: []any{"sslmode=disable"}, - }, - } +func TestPostgresRepository_ListDatabases(t *testing.T) { + ctx, db, mock, r := initPostgresRepoTest(t) + defer func() { _ = db.Close() }() + dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") + mock.ExpectQuery(PostgresDatabaseQuery).WillReturnRows(dbRows) + dbs, err := r.ListDatabases(ctx) + require.NoError(t, err) + require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } func TestBuildConnOptionsSucc(t *testing.T) { sampleRepoCfg := getSampleRepoConfig() - connOptsStr, err := BuildConnOptsStr(sampleRepoCfg) + connOptsStr, err := buildConnOptsStr(sampleRepoCfg) require.NoError(t, err) require.Equal(t, connOptsStr, "?sslmode=disable") } @@ -31,7 +35,7 @@ func TestBuildConnOptionsFail(t *testing.T) { }, }, } - connOptsStr, err := BuildConnOptsStr(invalidRepoCfg) + connOptsStr, err := buildConnOptsStr(invalidRepoCfg) require.Error(t, err) require.Empty(t, connOptsStr) } @@ -79,72 +83,20 @@ func TestMapConnOptionsMalformedColon(t *testing.T) { require.Error(t, err) } -func TestAdvancedConfigSucc(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - "snowflake": map[string]any{ - "account": "exampleAccount", - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - }, - } - repoSpecificMap, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) +func initPostgresRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *PostgresRepository) { + ctx := context.Background() + db, mock, err := sqlmock.New() require.NoError(t, err) - require.EqualValues( - t, repoSpecificMap, map[string]string{ - "account": "exampleAccount", - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - ) -} - -func TestAdvancedConfigMissing(t *testing.T) { - // Without the snowflake config at all - sampleCfg := RepoConfig{ - Advanced: map[string]any{}, - } - _, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) - - sampleCfg = RepoConfig{ - Advanced: map[string]any{ - "snowflake": map[string]any{ - // Missing account - - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - }, + return ctx, db, mock, &PostgresRepository{ + generic: NewGenericRepositoryFromDB(RepoTypePostgres, "dbName", db), } - _, err = FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) } -func TestAdvancedConfigMalformed(t *testing.T) { - sampleCfg := RepoConfig{ +// Returns a correct repo config +func getSampleRepoConfig() RepoConfig { + return RepoConfig{ Advanced: map[string]any{ - "snowflake": map[string]any{ - // Let's give a _list_ of things - "account": []string{"account1", "account2"}, - "role": []string{"role1", "role2"}, - "warehouse": []string{"warehouse1", "warehouse2"}, - }, + configConnOpts: []any{"sslmode=disable"}, }, } - _, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) } diff --git a/discovery/redshift.go b/sql/redshift.go similarity index 59% rename from discovery/redshift.go rename to sql/redshift.go index 57027d8..ae847e8 100644 --- a/discovery/redshift.go +++ b/sql/redshift.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -12,19 +12,19 @@ const ( RepoTypeRedshift = "redshift" ) -// RedshiftRepository is a SQLRepository implementation for Redshift databases. +// RedshiftRepository is a Repository implementation for Redshift databases. type RedshiftRepository struct { // The majority of the RedshiftRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } -// RedshiftRepository implements SQLRepository -var _ SQLRepository = (*RedshiftRepository)(nil) +// RedshiftRepository implements Repository +var _ Repository = (*RedshiftRepository)(nil) // NewRedshiftRepository creates a new RedshiftRepository. func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) + pgCfg, err := parsePostgresConfig(cfg) if err != nil { return nil, fmt.Errorf("unable to parse postgres config: %w", err) } @@ -42,19 +42,11 @@ func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { database, pgCfg.ConnOptsStr, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &RedshiftRepository{genericSqlRepo: sqlRepo}, nil + return &RedshiftRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by @@ -62,39 +54,38 @@ func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *RedshiftRepository) ListDatabases(ctx context.Context) ([]string, error) { // Redshift and Postgres use the same query to list the server databases. - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *RedshiftRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *RedshiftRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a -// Redshift-specific table sample query. See SQLRepository.SampleTable and +// Redshift-specific table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. func (r *RedshiftRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Redshift uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // Redshift uses $x for placeholders - query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and +// Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. func (r *RedshiftRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. func (r *RedshiftRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/redshift_test.go b/sql/redshift_test.go similarity index 87% rename from discovery/redshift_test.go rename to sql/redshift_test.go index 07df14c..3b45d75 100644 --- a/discovery/redshift_test.go +++ b/sql/redshift_test.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -24,6 +24,6 @@ func initRedshiftRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmo db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &RedshiftRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeRedshift, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeRedshift, "dbName", db), } } diff --git a/discovery/registry.go b/sql/registry.go similarity index 72% rename from discovery/registry.go rename to sql/registry.go index da18fc2..5ea7136 100644 --- a/discovery/registry.go +++ b/sql/registry.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -11,7 +11,7 @@ var ( // package of which a number of convenience functions in this package act // on. All currently out-of-the-box repository types are registered to this // registry by this package's init function. Users who want to use custom - // SQLRepository implementations, or just avoid global state altogether, should + // Repository implementations, or just avoid global state altogether, should // use their own instance of Registry, instead of using DefaultRegistry and // the corresponding convenience functions. DefaultRegistry = NewRegistry() @@ -26,7 +26,7 @@ type Registry struct { // RepoConstructor represents the function signature that all repository // implementations should use for their constructor functions. -type RepoConstructor func(ctx context.Context, cfg RepoConfig) (SQLRepository, error) +type RepoConstructor func(ctx context.Context, cfg RepoConfig) (Repository, error) // NewRegistry creates a new Registry instance. func NewRegistry() *Registry { @@ -57,15 +57,21 @@ func (r *Registry) MustRegister(repoType string, constructor RepoConstructor) { } } -// NewRepository is a factory method to return a concrete SQLRepository +// Unregister removes a repository type from the registry. If the repository +// type is not registered, this method is a no-op. +func (r *Registry) Unregister(repoType string) { + delete(r.constructors, repoType) +} + +// NewRepository is a factory method to return a concrete Repository // implementation based on the specified type, e.g. MySQL, Postgres, SQL Server, // etc., which must be registered with the registry. If the repository type is // not registered, an error is returned. A new instance of the repository is // returned each time this method is called. -func (r *Registry) NewRepository(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - constructor, ok := r.constructors[cfg.Type] +func (r *Registry) NewRepository(ctx context.Context, repoType string, cfg RepoConfig) (Repository, error) { + constructor, ok := r.constructors[repoType] if !ok { - return nil, errors.New("unsupported repo type " + cfg.Type) + return nil, errors.New("unsupported repo type " + repoType) } repo, err := constructor(ctx, cfg) if err != nil { @@ -86,10 +92,16 @@ func MustRegister(repoType string, constructor RepoConstructor) { DefaultRegistry.MustRegister(repoType, constructor) } +// Unregister is a convenience function that delegates to DefaultRegistry. See +// Registry.Unregister for more details. +func Unregister(repoType string) { + DefaultRegistry.Unregister(repoType) +} + // NewRepository is a convenience function that delegates to DefaultRegistry. // See Registry.NewRepository for more details. -func NewRepository(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - return DefaultRegistry.NewRepository(ctx, cfg) +func NewRepository(ctx context.Context, repoType string, cfg RepoConfig) (Repository, error) { + return DefaultRegistry.NewRepository(ctx, repoType, cfg) } // init registers all out-of-the-box repository types and their respective @@ -97,43 +109,43 @@ func NewRepository(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { func init() { MustRegister( RepoTypeDenodo, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewDenodoRepository(cfg) }, ) MustRegister( RepoTypeMysql, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewMySqlRepository(cfg) }, ) MustRegister( RepoTypeOracle, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewOracleRepository(cfg) }, ) MustRegister( RepoTypePostgres, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewPostgresRepository(cfg) }, ) MustRegister( RepoTypeRedshift, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewRedshiftRepository(cfg) }, ) MustRegister( RepoTypeSnowflake, - func(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { + func(ctx context.Context, cfg RepoConfig) (Repository, error) { return NewSnowflakeRepository(cfg) }, ) MustRegister( RepoTypeSqlServer, - func(_ context.Context, cfg RepoConfig) (SQLRepository, error) { + func(_ context.Context, cfg RepoConfig) (Repository, error) { return NewSqlServerRepository(cfg) }, ) diff --git a/sql/registry_test.go b/sql/registry_test.go new file mode 100644 index 0000000..9a55d76 --- /dev/null +++ b/sql/registry_test.go @@ -0,0 +1,82 @@ +package sql + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRegistry_Register_Successful(t *testing.T) { + repoType := "repoType" + constructor := func(context.Context, RepoConfig) (Repository, error) { + return nil, nil + } + reg := NewRegistry() + err := reg.Register(repoType, constructor) + require.NoError(t, err) + require.Contains(t, reg.constructors, repoType) +} + +func TestRegistry_MustRegister_NilConstructor(t *testing.T) { + reg := NewRegistry() + require.Panics(t, func() { reg.MustRegister("repoType", nil) }) +} + +func TestRegistry_MustRegister_TwoCalls_Panics(t *testing.T) { + repoType := "repoType" + constructor := func(context.Context, RepoConfig) (Repository, error) { + return nil, nil + } + reg := NewRegistry() + reg.MustRegister(repoType, constructor) + require.Contains(t, reg.constructors, repoType) + require.Panics(t, func() { reg.MustRegister(repoType, constructor) }) +} + +func TestRegistry_NewRepository_IsSuccessful(t *testing.T) { + repoType := "repoType" + called := false + expectedRepo := (Repository)(nil) + constructor := func(context.Context, RepoConfig) (Repository, error) { + called = true + return expectedRepo, nil + } + reg := NewRegistry() + err := reg.Register(repoType, constructor) + require.NoError(t, err) + require.Contains(t, reg.constructors, repoType) + + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.NoError(t, err) + require.Equal(t, expectedRepo, repo) + require.True(t, called, "Constructor was not called") +} + +func TestRegistry_NewRepository_ConstructorError(t *testing.T) { + repoType := "repoType" + called := false + expectedErr := errors.New("dummy error") + constructor := func(context.Context, RepoConfig) (Repository, error) { + called = true + return nil, expectedErr + } + reg := NewRegistry() + err := reg.Register(repoType, constructor) + require.NoError(t, err) + require.Contains(t, reg.constructors, repoType) + + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.ErrorIs(t, err, expectedErr) + require.Nil(t, repo) + require.True(t, called, "Constructor was not called") +} + +func TestRegistry_NewRepository_UnsupportedRepoType(t *testing.T) { + repoType := "repoType" + reg := NewRegistry() + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.Error(t, err) + require.Nil(t, repo) +} diff --git a/sql/repository.go b/sql/repository.go new file mode 100644 index 0000000..a751c9a --- /dev/null +++ b/sql/repository.go @@ -0,0 +1,76 @@ +package sql + +import ( + "context" + + "github.com/gobwas/glob" +) + +// Repository represents a Dmap data SQL repository, and provides functionality +// to introspect its corresponding schema. +type Repository interface { + // ListDatabases returns a list of the names of all databases on the server. + ListDatabases(ctx context.Context) ([]string, error) + // Introspect will read and analyze the basic properties of the repository + // and return as a Metadata instance. This includes all the repository's + // databases, schemas, tables, columns, and attributes. + Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) + // SampleTable samples the table referenced by the TableMetadata meta + // parameter and returns the sample as a slice of Sample. The parameters for + // the sample, such as sample size, are passed via the params parameter (see + // SampleParameters for more details). The returned sample result set + // contains one Sample for each table row sampled. The length of the results + // will be less than or equal to the sample size. If there are fewer results + // than the specified sample size, it is because the table in question had a + // row count less than the sample size. Prefer small sample sizes to limit + // impact on the database. + SampleTable(ctx context.Context, params SampleParameters) (Sample, error) + // Ping is meant to be used as a general purpose connectivity test. It + // should be invoked e.g. in the dry-run mode. + Ping(ctx context.Context) error + // Close is meant to be used as a general purpose cleanup. It should be + // invoked when the Repository is no longer used. + Close() error +} + +// IntrospectParameters is a struct that holds the parameters for the Introspect +// method of the Repository interface. +type IntrospectParameters struct { + // IncludePaths is a list of glob patterns that will be used to filter + // the tables that will be introspected. If a table name matches any of + // the patterns in this list, it will be included in the repository + // metadata. + IncludePaths []glob.Glob + // ExcludePaths is a list of glob patterns that will be used to filter + // the tables that will be introspected. If a table name matches any of + // the patterns in this list, it will be excluded from the repository + // metadata. + ExcludePaths []glob.Glob +} + +// Sample represents a sample of data from a database table. +type Sample struct { + // TablePath is the full path of the data repository table that was sampled. + // Each element corresponds to a component, in increasing order of + // granularity (e.g. [database, schema, table]). + TablePath []string + // Results is the set of sample results. Each SampleResult is equivalent to + // a database row, where the map key is the column name and the map value is + // the column value. + Results []SampleResult +} + +// SampleParameters contains all parameters necessary to sample a table. +type SampleParameters struct { + // Metadata is the metadata for the table to be sampled. + Metadata *TableMetadata + // SampleSize is the number of rows to sample from the table. + SampleSize uint + // Offset is the number of rows to skip before starting the sample. + Offset uint +} + +// SampleResult stores the results from a single database sample. It is +// equivalent to a database row, where the map key is the column name and the +// map value is the column value. +type SampleResult map[string]any diff --git a/sql/sample.go b/sql/sample.go new file mode 100644 index 0000000..5497f7c --- /dev/null +++ b/sql/sample.go @@ -0,0 +1,190 @@ +package sql + +import ( + "context" + "fmt" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" +) + +// sampleAndErr is a "pair" type intended to be passed to a channel (see +// sampleDb) +type sampleAndErr struct { + sample Sample + err error +} + +// samplesAndErr is a "pair" type intended to be passed to a channel (see +// sampleAllDbs) +type samplesAndErr struct { + samples []Sample + err error +} + +// sampleAllDbs uses the given Repository to list all the +// databases on the server, and samples each one in parallel by calling +// sampleDb for each database. The repository is intended to be +// configured to connect to the default database on the server, or at least some +// database which can be used to enumerate the full set of databases on the +// server. An error will be returned if the set of databases cannot be listed. +// If there is an error connecting to or sampling a database, the error will be +// logged and no samples will be returned for that database. Therefore, the +// returned slice of samples contains samples for only the databases which could +// be discovered and successfully sampled, and could potentially be empty if no +// databases were sampled. +func sampleAllDbs( + ctx context.Context, + ctor RepoConstructor, + cfg RepoConfig, + introspectParams IntrospectParameters, + sampleSize, offset uint, +) ( + []Sample, + error, +) { + // Create a repository instance that will be used to list all the databases + // on the server. + repo, err := ctor(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("error creating repository instance: %w", err) + } + defer func() { _ = repo.Close() }() + + // We assume that this repository will be connected to the default database + // (or at least some database that can discover all the other databases), + // and we use that to discover all other databases. + dbs, err := repo.ListDatabases(ctx) + if err != nil { + return nil, fmt.Errorf("error listing databases: %w", err) + } + + // Sample each database on a separate goroutine, and send the samples to + // the 'out' channel. Each slice of samples will be aggregated below on the + // main goroutine and returned. + var wg sync.WaitGroup + out := make(chan samplesAndErr) + wg.Add(len(dbs)) + // Ensures that we avoid opening more than the specified number of + // connections. + var sema *semaphore.Weighted + if cfg.MaxOpenConns > 0 { + sema = semaphore.NewWeighted(int64(cfg.MaxOpenConns)) + } + for _, db := range dbs { + go func(db string, cfg RepoConfig) { + defer wg.Done() + if sema != nil { + _ = sema.Acquire(ctx, 1) + defer sema.Release(1) + } + // Create a repository instance for this specific database. It will + // be used to connect to and sample the database. + cfg.Database = db + repo, err := ctor(ctx, cfg) + if err != nil { + log.WithError(err).Errorf("error creating repository instance for database %s", db) + return + } + defer func() { _ = repo.Close() }() + // Sample the database. + s, err := sampleDb(ctx, repo, introspectParams, sampleSize, offset) + if err != nil && len(s) == 0 { + log.WithError(err).Errorf("error gathering repository data samples for database %s", db) + return + } + // Send the samples for this database to the 'out' channel. The + // samples for each database will be aggregated into a single slice + // on the main goroutine and returned. + out <- samplesAndErr{samples: s, err: err} + }(db, cfg) + } + + // Start a goroutine to close the 'out' channel once all the goroutines + // we launched above are done. This will allow the aggregation range loop + // below to terminate properly. Note that this must start after the wg.Add + // call. See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). + go func() { + wg.Wait() + close(out) + }() + + // Aggregate and return the results. + var ret []Sample + var errs error + for res := range out { + ret = append(ret, res.samples...) + if res.err != nil { + errs = multierror.Append(errs, res.err) + } + } + return ret, errs +} + +// sampleDb is a helper function which will sample every table in a +// given repository and return them as a collection of Sample. First the +// repository is introspected by calling Introspect to return the +// repository metadata (Metadata). Then, for each schema and table in the +// metadata, it calls SampleTable in a new goroutine. Once all the +// sampling goroutines are finished, their results are collected and returned +// as a slice of Sample. +func sampleDb( + ctx context.Context, + repo Repository, + introspectParams IntrospectParameters, + sampleSize, offset uint, +) ( + []Sample, + error, +) { + // Introspect the repository to get the metadata. + meta, err := repo.Introspect(ctx, introspectParams) + if err != nil { + return nil, fmt.Errorf("error introspecting repository: %w", err) + } + + // Fan out sample executions. + out := make(chan sampleAndErr) + numTables := 0 + for _, schemaMeta := range meta.Schemas { + for _, tableMeta := range schemaMeta.Tables { + numTables++ + go func(meta *TableMetadata) { + params := SampleParameters{ + Metadata: meta, + SampleSize: sampleSize, + Offset: offset, + } + sample, err := repo.SampleTable(ctx, params) + select { + case <-ctx.Done(): + return + case out <- sampleAndErr{sample: sample, err: err}: + } + }(tableMeta) + } + } + + // Aggregate and return the results. + var samples []Sample + var errs error + for i := 0; i < numTables; i++ { + select { + case <-ctx.Done(): + return samples, ctx.Err() + case res:= <-out: + if res.err != nil { + errs = multierror.Append(errs, res.err) + } else { + samples = append(samples, res.sample) + } + } + } + close(out) + if errs != nil { + return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) + } + return samples, nil +} diff --git a/sql/sample_test.go b/sql/sample_test.go new file mode 100644 index 0000000..811cfa5 --- /dev/null +++ b/sql/sample_test.go @@ -0,0 +1,205 @@ +package sql + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +var ( + table1Sample = Sample{ + TablePath: []string{"database", "schema1", "table1"}, + Results: []SampleResult{ + { + "name1": "foo", + "name2": "bar", + }, + { + "name1": "baz", + "name2": "qux", + }, + }, + } + + table2Sample = Sample{ + TablePath: []string{"database", "schema2", "table2"}, + Results: []SampleResult{ + { + "name3": "foo1", + "name4": "bar1", + }, + { + "name3": "baz1", + "name4": "qux1", + }, + }, + } +) + +func Test_sampleDb_Success(t *testing.T) { + ctx := context.Background() + repo := NewMockRepository(t) + meta := Metadata{ + Database: "database", + Schemas: map[string]*SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + samples, err := sampleDb(ctx, repo, IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Order is not important and is actually non-deterministic due to concurrency + expected := []Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} + +func Test_sampleDb_PartialError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + repo := NewMockRepository(t) + meta := Metadata{ + Database: "database", + Schemas: map[string]*SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + "forbidden": { + Schema: "schema1", + Name: "forbidden", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "forbidden", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "forbidden", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["forbidden"], + } + sampleParamsForbidden := SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + errForbidden := errors.New("forbidden table") + repo.EXPECT().SampleTable(ctx, sampleParamsForbidden).Return(Sample{}, errForbidden) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + + samples, err := sampleDb(ctx, repo, IntrospectParameters{}, 0, 0) + require.ErrorIs(t, err, errForbidden) + // Order is not important and is actually non-deterministic due to concurrency + expected := []Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} diff --git a/sql/samplealldatabases_test.go b/sql/samplealldatabases_test.go new file mode 100644 index 0000000..831c7df --- /dev/null +++ b/sql/samplealldatabases_test.go @@ -0,0 +1,182 @@ +package sql + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func Test_sampleAllDbs_Error(t *testing.T) { + ctx := context.Background() + listDbErr := errors.New("error listing databases") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(nil, listDbErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + } + cfg := RepoConfig{} + samples, err := sampleAllDbs(ctx, ctor, cfg, IntrospectParameters{}, 0, 0) + require.Nil(t, samples) + require.ErrorIs(t, err, listDbErr) +} + +func Test_sampleAllDbs_Successful_TwoDatabases(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "attr": "foo", + }, + }, + } + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Two databases should be sampled, and our mock will return the sample for + // each sample call. This really just asserts that we've sampled the correct + // number of times. + require.ElementsMatch(t, samples, []Sample{sample, sample}) +} + +func Test_sampleAllDbs_IntrospectError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + introspectErr := errors.New("introspect error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(nil, introspectErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) + require.Empty(t, samples) + require.NoError(t, err) +} + +func Test_sampleAllDbs_SampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sampleErr := errors.New("sample error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) + require.NoError(t, err) + require.Empty(t, samples) +} + +func Test_sampleAllDbs_TwoDatabases_OneSampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "attr": "foo", + }, + }, + } + sampleErr := errors.New("sample error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil).Once() + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr).Once() + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Because of a single sample error, we expect only one database was + // sampled. + require.ElementsMatch(t, samples, []Sample{sample}) +} diff --git a/sql/scanner.go b/sql/scanner.go new file mode 100644 index 0000000..72e7710 --- /dev/null +++ b/sql/scanner.go @@ -0,0 +1,121 @@ +package sql + +import ( + "context" + "fmt" + + "github.com/gobwas/glob" + log "github.com/sirupsen/logrus" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/scan" +) + +// ScannerConfig is the configuration for the Scanner. +type ScannerConfig struct { + RepoType string + RepoConfig RepoConfig + Registry *Registry + IncludePaths, ExcludePaths []glob.Glob + SampleSize uint + Offset uint +} + +// Scanner is a data discovery scanner that scans a data repository for +// sensitive data. It also classifies the data and publishes the results to +// the configured external sources. It currently only supports SQL-based +// repositories. +type Scanner struct { + Config ScannerConfig + labels []classification.Label + classifier classification.Classifier +} + +// RepoScanner implements the scan.RepoScanner interface. +var _ scan.RepoScanner = (*Scanner)(nil) + +// NewScanner creates a new Scanner instance with the provided configuration. +func NewScanner(cfg ScannerConfig) (*Scanner, error) { + if cfg.RepoType == "" { + return nil, fmt.Errorf("repository type not specified") + } + if cfg.Registry == nil { + cfg.Registry = DefaultRegistry + } + // Create a new label classifier with the embedded labels. + lbls, err := classification.GetEmbeddedLabels() + if err != nil { + return nil, fmt.Errorf("error getting embedded labels: %w", err) + } + c, err := classification.NewLabelClassifier(lbls...) + if err != nil { + return nil, fmt.Errorf("error creating new label classifier: %w", err) + } + return &Scanner{Config: cfg, labels: lbls, classifier: c}, nil +} + +// Scan performs the data repository scan. It introspects and samples the +// repository, classifies the sampled data, and publishes the results to the +// configured classification publisher. +func (s *Scanner) Scan(ctx context.Context) (*scan.RepoScanResults, error) { + // Introspect and sample the data repository. + samples, err := s.sample(ctx) + if err != nil { + msg := "error sampling repository" + // If we didn't get any samples, just return the error. + if len(samples) == 0 { + return nil, fmt.Errorf("%s: %w", msg, err) + } + // There were error(s) during sampling, but we still got some samples. + // Just warn and continue. + log.WithError(err).Warn(msg) + } + // Classify the sampled data. + classifications, err := classifySamples(ctx, samples, s.classifier) + if err != nil { + return nil, fmt.Errorf("error classifying samples: %w", err) + } + return &scan.RepoScanResults{ + Labels: s.labels, + Classifications: classifications, + }, nil +} + +func (s *Scanner) sample(ctx context.Context) ([]Sample, error) { + // This closure is used to create a new repository instance for each + // database that is sampled. When there are multiple databases to sample, + // it is passed to sampleAllDbs to create the necessary repository instances + // for each database. When there is only a single database to sample, it is + // used directly below to create the repository instance for that database, + // which is passed to sampleDb to sample the database. + newRepo := func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return s.Config.Registry.NewRepository(ctx, s.Config.RepoType, cfg) + } + introspectParams := IntrospectParameters{ + IncludePaths: s.Config.IncludePaths, + ExcludePaths: s.Config.ExcludePaths, + } + // Check if the user specified a single database, or told us to scan an + // Oracle DB. In that case, therefore we only need to sample that single + // database. Note that Oracle doesn't really have the concept of + // "databases", therefore a single repository instance will always scan the + // entire database. + if s.Config.RepoConfig.Database != "" || s.Config.RepoType == RepoTypeOracle { + repo, err := newRepo(ctx, s.Config.RepoConfig) + if err != nil { + return nil, fmt.Errorf("error creating repository: %w", err) + } + defer func() { _ = repo.Close() }() + return sampleDb(ctx, repo, introspectParams, s.Config.SampleSize, s.Config.Offset) + } + // The name of the database to connect to has been left unspecified by the + // user, so we try to connect and sample all databases instead. + return sampleAllDbs( + ctx, + newRepo, + s.Config.RepoConfig, + introspectParams, + s.Config.SampleSize, + s.Config.Offset, + ) +} diff --git a/sql/scanner_test.go b/sql/scanner_test.go new file mode 100644 index 0000000..e4b317b --- /dev/null +++ b/sql/scanner_test.go @@ -0,0 +1 @@ +package sql diff --git a/discovery/snowflake.go b/sql/snowflake.go similarity index 69% rename from discovery/snowflake.go rename to sql/snowflake.go index cf6d3a0..d788013 100644 --- a/discovery/snowflake.go +++ b/sql/snowflake.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -23,15 +23,15 @@ WHERE configWarehouse = "warehouse" ) -// SnowflakeRepository is a SQLRepository implementation for Snowflake databases. +// SnowflakeRepository is a Repository implementation for Snowflake databases. type SnowflakeRepository struct { - // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // The majority of the Repository functionality is delegated to + // a generic SQL repository instance. + generic *GenericRepository } -// SnowflakeRepository implements SQLRepository -var _ SQLRepository = (*SnowflakeRepository)(nil) +// SnowflakeRepository implements Repository +var _ Repository = (*SnowflakeRepository)(nil) // NewSnowflakeRepository creates a new SnowflakeRepository. func NewSnowflakeRepository(cfg RepoConfig) (*SnowflakeRepository, error) { @@ -53,55 +53,46 @@ func NewSnowflakeRepository(cfg RepoConfig) (*SnowflakeRepository, error) { snowflakeCfg.Role, snowflakeCfg.Warehouse, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeSnowflake, - database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeSnowflake, database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &SnowflakeRepository{genericSqlRepo: sqlRepo}, nil + return &SnowflakeRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a Snowflake-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *SnowflakeRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, SnowflakeDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, SnowflakeDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *SnowflakeRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *SnowflakeRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository. See -// SQLRepository.SampleTable and GenericRepository.SampleTable for more details. +// Repository.SampleTable and GenericRepository.SampleTable for more details. func (r *SnowflakeRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { - return r.genericSqlRepo.SampleTable(ctx, meta, params) + return r.generic.SampleTable(ctx, params) } -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and +// Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. func (r *SnowflakeRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. func (r *SnowflakeRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } // SnowflakeConfig holds Snowflake-specific configuration parameters. diff --git a/discovery/snowflake_test.go b/sql/snowflake_test.go similarity index 87% rename from discovery/snowflake_test.go rename to sql/snowflake_test.go index 6499a11..bfbc97d 100644 --- a/discovery/snowflake_test.go +++ b/sql/snowflake_test.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -24,6 +24,6 @@ func initSnowflakeRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlm db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &SnowflakeRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeSnowflake, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeSnowflake, "dbName", db), } } diff --git a/discovery/sqlserver.go b/sql/sqlserver.go similarity index 52% rename from discovery/sqlserver.go rename to sql/sqlserver.go index 15320e7..de8d501 100644 --- a/discovery/sqlserver.go +++ b/sql/sqlserver.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -21,19 +21,19 @@ const ( SqlServerDatabaseQuery = "SELECT name FROM sys.databases WHERE name != 'model' AND name != 'tempdb'" ) -// SqlServerRepository is a SQLRepository implementation for MS SQL Server +// SQLServerRepository is a Repository implementation for MS SQL Server // databases. -type SqlServerRepository struct { - // The majority of the SQLRepository functionality is delegated to a generic - // SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository +type SQLServerRepository struct { + // The majority of the Repository functionality is delegated to a generic + // SQL repository instance. + generic *GenericRepository } -// SqlServerRepository implements SQLRepository -var _ SQLRepository = (*SqlServerRepository)(nil) +// SQLServerRepository implements Repository +var _ Repository = (*SQLServerRepository)(nil) // NewSqlServerRepository creates a new MS SQL Server sql. -func NewSqlServerRepository(cfg RepoConfig) (*SqlServerRepository, error) { +func NewSqlServerRepository(cfg RepoConfig) (*SQLServerRepository, error) { connStr := fmt.Sprintf( "sqlserver://%s:%s@%s:%d", cfg.User, @@ -45,57 +45,48 @@ func NewSqlServerRepository(cfg RepoConfig) (*SqlServerRepository, error) { if cfg.Database != "" { connStr = fmt.Sprintf(connStr+"?database=%s", cfg.Database) } - genericSqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeSqlServer, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.IncludePaths, - ) + generic, err := NewGenericRepository(RepoTypeSqlServer, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &SqlServerRepository{genericSqlRepo: genericSqlRepo}, nil + return &SQLServerRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a SQL Server-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. -func (r *SqlServerRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, SqlServerDatabaseQuery) +func (r *SQLServerRepository) ListDatabases(ctx context.Context) ([]string, error) { + return r.generic.ListDatabasesWithQuery(ctx, SqlServerDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See -// SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more +// Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *SqlServerRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *SQLServerRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a -// SQL Server-specific table sample query. See SQLRepository.SampleTable and +// SQL Server-specific table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. -func (r *SqlServerRepository) SampleTable( +func (r *SQLServerRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Sqlserver uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(SqlServerSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize) + attrStr := params.Metadata.QuotedAttributeNamesString("\"") + query := fmt.Sprintf(SqlServerSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } -// Ping delegates the ping to GenericRepository. See SQLRepository.Ping and +// Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. -func (r *SqlServerRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) +func (r *SQLServerRepository) Ping(ctx context.Context) error { + return r.generic.Ping(ctx) } -// Close delegates the close to GenericRepository. See SQLRepository.Close and +// Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. -func (r *SqlServerRepository) Close() error { - return r.genericSqlRepo.Close() +func (r *SQLServerRepository) Close() error { + return r.generic.Close() } diff --git a/discovery/sqlserver_test.go b/sql/sqlserver_test.go similarity index 77% rename from discovery/sqlserver_test.go rename to sql/sqlserver_test.go index 16f2ae5..223dfec 100644 --- a/discovery/sqlserver_test.go +++ b/sql/sqlserver_test.go @@ -1,4 +1,4 @@ -package discovery +package sql import ( "context" @@ -19,11 +19,11 @@ func TestSqlServerRepository_ListDatabases(t *testing.T) { require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } -func initSqlServerRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *SqlServerRepository) { +func initSqlServerRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *SQLServerRepository) { ctx := context.Background() db, mock, err := sqlmock.New() require.NoError(t, err) - return ctx, db, mock, &SqlServerRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeSqlServer, "dbName", db), + return ctx, db, mock, &SQLServerRepository{ + generic: NewGenericRepositoryFromDB(RepoTypeSqlServer, "dbName", db), } }