diff --git a/classification/classification.go b/classification/classification.go index 53d8fe8..a93bb7c 100644 --- a/classification/classification.go +++ b/classification/classification.go @@ -8,7 +8,9 @@ package classification import ( "context" - "maps" + "encoding/json" + + "golang.org/x/exp/maps" ) // Classifier is an interface that represents a data classifier. A classifier @@ -22,35 +24,14 @@ type Classifier interface { Classify(ctx context.Context, input map[string]any) (Result, error) } -// ClassifiedTable represents a database table that has been classified. The -// classifications are stored in the Classifications field, which is a map of -// attribute names (i.e. columns) to the set of labels that attributes were -// classified as. -type ClassifiedTable struct { - Repo string `json:"repo"` - Database string `json:"database"` - Schema string `json:"schema"` - Table string `json:"table"` - Classifications Result `json:"classifications"` -} - // Result represents the classifications for a set of data attributes. The key // is the attribute (i.e. column) name and the value is the set of labels // that attribute was classified as. type Result map[string]LabelSet -// Merge combines the given other Result into this Result (the receiver). If -// an attribute from other is already present in this Result, the existing -// labels for that attribute are merged with the labels from other, otherwise -// labels from other for the attribute are simply added to this Result. -func (c Result) Merge(other Result) { - if c == nil { - return - } - for attr, labelSet := range other { - if _, ok := c[attr]; !ok { - c[attr] = make(LabelSet) - } - maps.Copy(c[attr], labelSet) - } +// LabelSet is a set of unique labels. +type LabelSet map[string]struct{} + +func (l LabelSet) MarshalJSON() ([]byte, error) { + return json.Marshal(maps.Keys(l)) } diff --git a/classification/classification_test.go b/classification/classification_test.go deleted file mode 100644 index d6dde27..0000000 --- a/classification/classification_test.go +++ /dev/null @@ -1,59 +0,0 @@ -package classification - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestMerge_WhenCalledOnNilReceiver_ShouldNotPanic(t *testing.T) { - var result Result - require.NotPanics( - t, func() { - result.Merge(Result{"age": {"AGE": {Name: "AGE"}}}) - }, - ) -} - -func TestMerge_WhenCalledWithNonExistingAttributes_ShouldAddThem(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE"}}, - } - other := Result{ - "social_sec_num": {"SSN": {Name: "SSN"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE"}}, - "social_sec_num": {"SSN": {Name: "SSN"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} - -func TestMerge_WhenCalledWithExistingAttributes_ShouldMergeLabelSets(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE"}}, - } - other := Result{ - "age": {"CVV": {Name: "CVV"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE"}, "CVV": {Name: "CVV"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} - -func TestMerge_WhenCalledWithExistingAttributes_ShouldOverwrite(t *testing.T) { - result := Result{ - "age": {"AGE": {Name: "AGE", Description: "Foo"}}, - } - other := Result{ - "age": {"AGE": {Name: "AGE", Description: "Bar"}}, - } - expected := Result{ - "age": {"AGE": {Name: "AGE", Description: "Bar"}}, - } - result.Merge(other) - require.Equal(t, expected, result) -} diff --git a/classification/label.go b/classification/label.go index 88702e6..7aa38b5 100644 --- a/classification/label.go +++ b/classification/label.go @@ -6,8 +6,8 @@ import ( "strings" "github.com/open-policy-agent/opa/ast" - "github.com/open-policy-agent/opa/rego" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "gopkg.in/yaml.v3" ) @@ -20,18 +20,23 @@ var ( // Label represents a data classification label. type Label struct { - Name string `yaml:"name" json:"name"` - Description string `yaml:"description" json:"description"` - Tags []string `yaml:"tags" json:"tags"` - ClassificationRule *rego.Rego `yaml:"-" json:"-"` + // Name is the name of the label. + Name string `yaml:"name" json:"name"` + // Description is a brief description of the label. + Description string `yaml:"description" json:"description"` + // Tags are a list of arbitrary tags associated with the label. + Tags []string `yaml:"tags" json:"tags"` + // ClassificationRule is the compiled Rego classification rule used to + // classify data. + ClassificationRule *ast.Module `yaml:"-" json:"-"` } // NewLabel creates a new Label with the given name, description, classification // rule, and tags. The classification rule is expected to be the raw Rego code // that will be used to classify data. If the classification rule is invalid, an // error is returned. -func NewLabel(name, description, classificationRule string, tags []string) (Label, error) { - rule, err := ruleRego(classificationRule) +func NewLabel(name, description, classificationRule string, tags ...string) (Label, error) { + rule, err := parseRego(classificationRule) if err != nil { return Label{}, fmt.Errorf("error preparing classification rule for label %s: %w", name, err) } @@ -43,25 +48,12 @@ func NewLabel(name, description, classificationRule string, tags []string) (Labe }, nil } -// LabelSet is a set of unique labels. The key is the label name and the value -// is the label itself. -type LabelSet map[string]Label - -// ToSlice returns the labels in the set as a slice. -func (s LabelSet) ToSlice() []Label { - var labels []Label - for _, label := range s { - labels = append(labels, label) - } - return labels -} - // GetEmbeddedLabels returns the predefined embedded labels and their // classification rules. The labels are read from the embedded labels.yaml file // and the classification rules are read from the embedded Rego files. -func GetEmbeddedLabels() (LabelSet, error) { +func GetEmbeddedLabels() ([]Label, error) { labels := struct { - Labels LabelSet `yaml:"labels"` + Labels map[string]Label `yaml:"labels"` }{} if err := yaml.Unmarshal([]byte(labelsYaml), &labels); err != nil { return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err) @@ -72,7 +64,7 @@ func GetEmbeddedLabels() (LabelSet, error) { if err != nil { return nil, fmt.Errorf("error reading rego file %s: %w", fname, err) } - rule, err := ruleRego(string(b)) + rule, err := parseRego(string(b)) if err != nil { return nil, fmt.Errorf("error preparing classification rule for label %s: %w", lbl.Name, err) } @@ -80,16 +72,14 @@ func GetEmbeddedLabels() (LabelSet, error) { lbl.ClassificationRule = rule labels.Labels[name] = lbl } - return labels.Labels, nil + return maps.Values(labels.Labels), nil } -func ruleRego(code string) (*rego.Rego, error) { +func parseRego(code string) (*ast.Module, error) { log.Tracef("classifier module code: '%s'", code) - moduleName := "classifier" - compiledRego, err := ast.CompileModules(map[string]string{moduleName: code}) + module, err := ast.ParseModule("classifier", code) if err != nil { - return nil, fmt.Errorf("error compiling rego code: %w", err) + return nil, fmt.Errorf("error parsing rego code: %w", err) } - regoQuery := compiledRego.Modules[moduleName].Package.Path.String() + ".output" - return rego.New(rego.Query(regoQuery), rego.Compiler(compiledRego)), nil + return module, nil } diff --git a/classification/label_classifier.go b/classification/label_classifier.go index 37fcc29..c7a23b9 100644 --- a/classification/label_classifier.go +++ b/classification/label_classifier.go @@ -11,37 +11,41 @@ import ( // LabelClassifier is a Classifier implementation that uses a set of labels and // their classification rules to classify data. type LabelClassifier struct { - labels LabelSet + queries map[string]*rego.Rego } // LabelClassifier implements Classifier var _ Classifier = (*LabelClassifier)(nil) -// NewLabelClassifier creates a new LabelClassifier with the provided labels and -// classification rules. +// NewLabelClassifier creates a new LabelClassifier with the provided labels. +// func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) { if len(labels) == 0 { return nil, fmt.Errorf("labels cannot be empty") } - l := make(LabelSet, len(labels)) + queries := make(map[string]*rego.Rego, len(labels)) for _, lbl := range labels { - l[lbl.Name] = lbl + queries[lbl.Name] = rego.New( + // We only care about the 'output' variable. + rego.Query(lbl.ClassificationRule.Package.Path.String() + ".output"), + rego.ParsedModule(lbl.ClassificationRule), + ) } - return &LabelClassifier{labels: l}, nil + return &LabelClassifier{queries: queries}, nil } // Classify performs the classification of the provided attributes using the -// classifier's labels and classification rules. It returns a Result, which is -// a map of attribute names to the set of labels that the attribute was -// classified as. +// classifier's labels and their corresponding classification rules. It returns +// a Result, which is a map of attribute names to the set of labels that the +// attribute was classified as. func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (Result, error) { - result := make(Result, len(c.labels)) - for _, lbl := range c.labels { - output, err := evalQuery(ctx, lbl.ClassificationRule, input) + result := make(Result, len(c.queries)) + for lbl, query := range c.queries { + output, err := evalQuery(ctx, query, input) if err != nil { - return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl.Name, err) + return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl, err) } - log.Debugf("classification results for label %s: %v", lbl.Name, output) + log.Debugf("classification results for label %s: %v", lbl, output) for attrName, classified := range output { if classified { attrLabels, ok := result[attrName] @@ -49,21 +53,21 @@ func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (R attrLabels = make(LabelSet) result[attrName] = attrLabels } - attrLabels[lbl.Name] = lbl + // Add the label to the set of labels for the attribute. + attrLabels[lbl] = struct{}{} } } } return result, nil } -// evalQuery evaluates the provided prepared Rego query with the given -// attributes as input, and returns the classification results. The output is a +// evalQuery evaluates the provided Rego query with the given attributes as input, and returns the classification results. The output is a // map of attribute names to boolean values, where the boolean indicates whether // the attribute is classified as belonging to the label. -func evalQuery(ctx context.Context, rule *rego.Rego, input map[string]any) (map[string]bool, error) { - q, err := rule.PrepareForEval(ctx) +func evalQuery(ctx context.Context, query *rego.Rego, input map[string]any) (map[string]bool, error) { + q, err := query.PrepareForEval(ctx) if err != nil { - return nil, fmt.Errorf("error preparing rule for evaluation: %w", err) + return nil, fmt.Errorf("error preparing query for evaluation: %w", err) } // Evaluate the prepared Rego query. This performs the actual classification // logic. diff --git a/classification/label_classifier_test.go b/classification/label_classifier_test.go index abea713..50fd1e0 100644 --- a/classification/label_classifier_test.go +++ b/classification/label_classifier_test.go @@ -9,7 +9,9 @@ import ( ) func TestNewLabelClassifier_Success(t *testing.T) { - classifier, err := NewLabelClassifier(Label{Name: "foo"}) + lbl, err := NewLabel("foo", "test label", "package foo\noutput = true") + require.NoError(t, err) + classifier, err := NewLabelClassifier(lbl) require.NoError(t, err) require.NotNil(t, classifier) } @@ -40,7 +42,7 @@ func TestLabelClassifier_Classify(t *testing.T) { input: map[string]any{"age": "42"}, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -53,7 +55,7 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -63,7 +65,7 @@ func TestLabelClassifier_Classify(t *testing.T) { input: map[string]any{"age": "42"}, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, }, }, @@ -76,10 +78,10 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, + "AGE": {}, }, "ccn": { - "CCN": Label{Name: "CCN"}, + "CCN": {}, }, }, }, @@ -92,11 +94,11 @@ func TestLabelClassifier_Classify(t *testing.T) { }, want: Result{ "age": { - "AGE": Label{Name: "AGE"}, - "CVV": Label{Name: "CVV"}, + "AGE": {}, + "CVV": {}, }, "cvv": { - "CVV": Label{Name: "CVV"}, + "CVV": {}, }, }, }, @@ -127,10 +129,9 @@ func requireResultEqual(t *testing.T, want, got Result) { func requireLabelSetEqual(t *testing.T, want, got LabelSet) { require.Len(t, got, len(want)) - for k, v := range want { - gotLbl, ok := got[k] + for k := range want { + _, ok := got[k] require.Truef(t, ok, "missing label %s", k) - require.Equal(t, v.Name, gotLbl.Name) } } @@ -149,7 +150,7 @@ func newTestLabel(t *testing.T, lblName string) Label { fin, err := regoFs.ReadFile(fname) require.NoError(t, err) classifierCode := string(fin) - lbl, err := NewLabel(lblName, "test label", classifierCode, nil) + lbl, err := NewLabel(lblName, "test label", classifierCode) require.NoError(t, err) return lbl } diff --git a/classification/label_test.go b/classification/label_test.go index 8b718f2..35a0b69 100644 --- a/classification/label_test.go +++ b/classification/label_test.go @@ -1,8 +1,11 @@ package classification import ( + "context" + "fmt" "testing" + "github.com/open-policy-agent/opa/rego" "github.com/stretchr/testify/require" ) @@ -11,3 +14,52 @@ func TestGetEmbeddedLabels(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, got) } + +func TestRego(t *testing.T) { + + module := ` +package example.authz + +import rego.v1 + +default allow := false + +allow if { + input.method == "GET" + input.path == ["salary", input.subject.user] +} + +allow if is_admin + +is_admin if "admin" in input.subject.groups +` + + mod, err := parseRego(module) + require.NoError(t, err) + require.NotNil(t, mod) + path := mod.Package.Path.String() + fmt.Println(path) + + ctx := context.TODO() + + query, err := rego.New( + rego.Query("data.example.authz.allow"), + rego.Module("example.rego", module), + ).PrepareForEval(ctx) + require.NoError(t, err) + require.NotNil(t, query) + + input := map[string]interface{}{ + "method": "GET", + "path": []interface{}{"salary", "bob"}, + "subject": map[string]interface{}{ + "user": "bob", + "groups": []interface{}{"sales", "marketing"}, + }, + } + + results, err := query.Eval(ctx, rego.EvalInput(input)) + require.NoError(t, err) + require.NotEmpty(t, results) + require.True(t, results.Allowed()) +} diff --git a/classification/publisher.go b/classification/publisher.go deleted file mode 100644 index aed92f4..0000000 --- a/classification/publisher.go +++ /dev/null @@ -1,15 +0,0 @@ -package classification - -import ( - "context" -) - -// Publisher publishes classification and discovery results to some destination, -// which is left up to the implementer. -// TODO: add doc about labels -ccampo 2024-04-02 -type Publisher interface { - // PublishClassifications publishes a slice of ClassifiedTable to some - // destination. Any error(s) during publication should be returned. - // TODO: add labels -ccampo 2024-04-02 - PublishClassifications(ctx context.Context, repoId string, results []ClassifiedTable) error -} diff --git a/classification/stdout.go b/classification/stdout.go deleted file mode 100644 index 883dc7e..0000000 --- a/classification/stdout.go +++ /dev/null @@ -1,33 +0,0 @@ -package classification - -import ( - "context" - "encoding/json" - "fmt" -) - -// StdOutPublisher "publishes" classification results to stdout in JSON format. -type StdOutPublisher struct{} - -// StdOutPublisher implements Publisher -var _ Publisher = (*StdOutPublisher)(nil) - -// TODO: godoc -ccampo 2024-04-02 -func NewStdOutPublisher() *StdOutPublisher { - return &StdOutPublisher{} -} - -// TODO: godoc -ccampo 2024-04-02 -func (c *StdOutPublisher) PublishClassifications(_ context.Context, _ string, tables []ClassifiedTable) error { - results := struct { - Results []ClassifiedTable `json:"results"` - }{ - Results: tables, - } - b, err := json.MarshalIndent(results, "", " ") - if err != nil { - return fmt.Errorf("failed to marshal results: %w", err) - } - fmt.Println(string(b)) - return nil -} diff --git a/cmd/main.go b/cmd/main.go index 632b9c3..fee5fd2 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,9 +8,12 @@ import ( ) type Globals struct { - LogLevel logLevelFlag `help:"Set the logging level (trace|debug|info|warn|error|fatal)" enum:"trace,debug,info,warn,error,fatal" default:"info"` - LogFormat logFormatFlag `help:"Set the logging format (text|json)" enum:"text,json" default:"text"` - Version kong.VersionFlag `name:"version" help:"Print version information and quit"` + LogLevel logLevelFlag `help:"Set the logging level (trace|debug|info|warn|error|fatal)" enum:"trace,debug,info,warn,error,fatal" default:"info"` + LogFormat logFormatFlag `help:"Set the logging format (text|json)" enum:"text,json" default:"text"` + Version kong.VersionFlag `name:"version" help:"Print version information and quit"` + ApiBaseUrl string `help:"Base URL of the Dmap API." default:"https://api.dmap.cyral.io"` + ClientID string `help:"API client ID to access the Dmap API."` + ClientSecret string `help:"API client secret to access the Dmap API."` //#nosec G101 -- false positive } type logLevelFlag string diff --git a/cmd/repo_scan.go b/cmd/repo_scan.go index b5830d9..d033e13 100644 --- a/cmd/repo_scan.go +++ b/cmd/repo_scan.go @@ -2,21 +2,87 @@ package main import ( "context" + "encoding/json" "fmt" + "reflect" + "strings" + "github.com/alecthomas/kong" + "github.com/gobwas/glob" + + "github.com/cyralinc/dmap/discovery" "github.com/cyralinc/dmap/scan" ) type RepoScanCmd struct { - scan.RepoScannerConfig + Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` + ExternalID string `help:"External ID of the repository." required:""` + Host string `help:"Hostname of the repository." required:""` + Port uint16 `help:"Port of the repository." required:""` + User string `help:"Username to connect to the sql." required:""` + Password string `help:"Password to connect to the sql." required:""` + Database string `help:"Name of the database to connect to. If not specified, the default database is used (if possible)."` + Advanced map[string]any `help:"Advanced configuration for the sql."` + IncludePaths GlobFlag `help:"List of glob patterns to include when introspecting the database(s)." default:"*"` + ExcludePaths GlobFlag `help:"List of glob patterns to exclude when introspecting the database(s)." default:""` + MaxOpenConns uint `help:"Maximum number of open connections to the database." default:"10"` + SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` + Offset uint `help:"Offset to start sampling from." default:"0"` +} + +// GlobFlag is a kong.MapperValue implementation that represents a glob pattern. +type GlobFlag []glob.Glob + +// Decode parses the glob patterns and compiles them into glob.Glob objects. It +// is an implementation of kong.MapperValue's Decode method. +func (g GlobFlag) Decode(ctx *kong.DecodeContext) error { + var patterns string + if err := ctx.Scan.PopValueInto("string", &patterns); err != nil { + return err + } + var parsedPatterns []glob.Glob + for _, pattern := range strings.Split(patterns, ",") { + parsedPattern, err := glob.Compile(pattern) + if err != nil { + return fmt.Errorf("cannot compile %s pattern: %w", pattern, err) + } + parsedPatterns = append(parsedPatterns, parsedPattern) + } + ctx.Value.Target.Set(reflect.ValueOf(GlobFlag(parsedPatterns))) + return nil } func (cmd *RepoScanCmd) Run(_ *Globals) error { ctx := context.Background() - scanner, err := scan.NewRepoScanner(ctx, cmd.RepoScannerConfig) + cfg := scan.RepoScannerConfig{ + RepoType: cmd.Type, + RepoConfig: discovery.RepoConfig{ + Host: cmd.Host, + Port: cmd.Port, + User: cmd.User, + Password: cmd.Password, + Database: cmd.Database, + MaxOpenConns: cmd.MaxOpenConns, + Advanced: cmd.Advanced, + }, + IncludePaths: cmd.IncludePaths, + ExcludePaths: cmd.ExcludePaths, + SampleSize: cmd.SampleSize, + Offset: cmd.Offset, + } + scanner, err := scan.NewRepoScanner(cfg) if err != nil { return fmt.Errorf("error creating new scanner: %w", err) } - defer scanner.Cleanup() - return scanner.Scan(ctx) + results, err := scanner.Scan(ctx) + if err != nil { + return fmt.Errorf("error scanning repository: %w", err) + } + jsonResults, err := json.MarshalIndent(results, "", " ") + if err != nil { + return fmt.Errorf("error marshalling results: %w", err) + } + fmt.Println(string(jsonResults)) + // TODO: publish results to API -ccampo 2024-04-03 + return nil } diff --git a/discovery/config.go b/discovery/config.go index 925c31b..2145b5a 100644 --- a/discovery/config.go +++ b/discovery/config.go @@ -2,87 +2,26 @@ package discovery import ( "fmt" - "reflect" - "strings" - - "github.com/alecthomas/kong" - "github.com/gobwas/glob" ) const configConnOpts = "connection-string-args" // RepoConfig is the necessary configuration to connect to a data sql. type RepoConfig struct { - Type string `help:"Type of repository to connect to (postgres|mysql|oracle|sqlserver|snowflake|redshift|denodo)." enum:"postgres,mysql,oracle,sqlserver,snowflake,redshift,denodo" required:""` - Host string `help:"Hostname of the sql." required:""` - Port uint16 `help:"Port of the sql." required:""` - User string `help:"Username to connect to the sql." required:""` - Password string `help:"Password to connect to the sql." required:""` - Advanced map[string]any `help:"Advanced configuration for the sql."` - Database string `help:"Name of the database to connect to. If not specified, the default database is used."` - MaxOpenConns uint `help:"Maximum number of open connections to the sql." default:"10"` - SampleSize uint `help:"Number of rows to sample from the repository (per table)." default:"5"` - IncludePaths GlobFlag `help:"List of glob patterns to include when querying the database(s), as a comma separated list." default:"*"` - ExcludePaths GlobFlag `help:"List of glob patterns to exclude when querying the database(s), as a comma separated list." default:""` -} - -// GlobFlag is a kong.MapperValue implementation that represents a glob pattern. -type GlobFlag []glob.Glob - -// Decode parses the glob patterns and compiles them into glob.Glob objects. It -// is an implementation of kong.MapperValue's Decode method. -func (g GlobFlag) Decode(ctx *kong.DecodeContext) error { - var patterns string - if err := ctx.Scan.PopValueInto("string", &patterns); err != nil { - return err - } - var parsedPatterns []glob.Glob - for _, pattern := range strings.Split(patterns, ",") { - parsedPattern, err := glob.Compile(pattern) - if err != nil { - return fmt.Errorf("cannot compile %s pattern: %w", pattern, err) - } - parsedPatterns = append(parsedPatterns, parsedPattern) - } - ctx.Value.Target.Set(reflect.ValueOf(GlobFlag(parsedPatterns))) - return nil -} - -// BuildConnOptsStr parses the repo config to produce a string in the format -// "?option=value&option2=value2". Example: -// -// BuildConnOptsStr(RepoConfig{ -// Advanced: map[string]any{ -// "connection-string-args": []any{"sslmode=disable"}, -// }, -// }) -// -// returns ("?sslmode=disable", nil). -func BuildConnOptsStr(cfg RepoConfig) (string, error) { - connOptsMap, err := mapFromConnOpts(cfg) - if err != nil { - return "", fmt.Errorf("connection options: %w", err) - } - - connOptsStr := "" - for key, val := range connOptsMap { - // Don't add if the value is empty, since that would make the - // string malformed. - if val != "" { - if connOptsStr == "" { - connOptsStr += fmt.Sprintf("%s=%s", key, val) - } else { - // Need & for subsequent options - connOptsStr += fmt.Sprintf("&%s=%s", key, val) - } - } - } - // Only add ? if connection string is not empty - if connOptsStr != "" { - connOptsStr = "?" + connOptsStr - } - - return connOptsStr, nil + // Host is the hostname of the database. + Host string + // Port is the port of the database. + Port uint16 + // User is the username to connect to the database. + User string + // Password is the password to connect to the database. + Password string + // Database is the name of the database to connect to. + Database string + // MaxOpenConns is the maximum number of open connections to the database. + MaxOpenConns uint + // Advanced is a map of advanced configuration options. + Advanced map[string]any } // FetchAdvancedConfigString fetches a map in the repo advanced configuration, @@ -129,38 +68,6 @@ func FetchAdvancedConfigString( return repoSpecificMap, nil } -// mapFromConnOpts builds a map from the list of connection options given in the -// Each option has the format 'option=value'. Err only if the config is -// malformed, to inform user. -func mapFromConnOpts(cfg RepoConfig) (map[string]string, error) { - m := make(map[string]string) - connOptsInterface, ok := cfg.Advanced[configConnOpts] - if !ok { - return nil, nil - } - connOpts, ok := connOptsInterface.([]any) - if !ok { - return nil, fmt.Errorf("'%s' is not a list", configConnOpts) - } - for _, optInterface := range connOpts { - opt, ok := optInterface.(string) - if !ok { - return nil, fmt.Errorf("'%v' is not a string", optInterface) - } - splitOpt := strings.Split(opt, "=") - if len(splitOpt) != 2 { - return nil, fmt.Errorf( - "malformed '%s'. "+ - "Please follow the format 'option=value'", configConnOpts, - ) - } - key := splitOpt[0] - val := splitOpt[1] - m[key] = val - } - return m, nil -} - // getAdvancedConfig gets the Advanced field in a repo config and converts it to // a map[string]any. In every step, it checks for error and generates // nice messages. diff --git a/discovery/config_test.go b/discovery/config_test.go index c54ea99..faac440 100644 --- a/discovery/config_test.go +++ b/discovery/config_test.go @@ -6,79 +6,6 @@ import ( "github.com/stretchr/testify/require" ) -// Returns a correct repo config -func getSampleRepoConfig() RepoConfig { - return RepoConfig{ - Advanced: map[string]any{ - configConnOpts: []any{"sslmode=disable"}, - }, - } -} - -func TestBuildConnOptionsSucc(t *testing.T) { - sampleRepoCfg := getSampleRepoConfig() - connOptsStr, err := BuildConnOptsStr(sampleRepoCfg) - require.NoError(t, err) - require.Equal(t, connOptsStr, "?sslmode=disable") -} - -func TestBuildConnOptionsFail(t *testing.T) { - invalidRepoCfg := RepoConfig{ - Advanced: map[string]any{ - // Invalid: map instead of string - configConnOpts: []any{ - map[string]string{"sslmode": "disable"}, - }, - }, - } - connOptsStr, err := BuildConnOptsStr(invalidRepoCfg) - require.Error(t, err) - require.Empty(t, connOptsStr) -} - -func TestMapConnOptionsSucc(t *testing.T) { - sampleRepoCfg := getSampleRepoConfig() - connOptsMap, err := mapFromConnOpts(sampleRepoCfg) - require.NoError(t, err) - require.EqualValues( - t, connOptsMap, map[string]string{ - "sslmode": "disable", - }, - ) -} - -// The mapping should only fail if the config is malformed, not if it is missing -func TestMapConnOptionsMissing(t *testing.T) { - sampleCfg := RepoConfig{} - optsMap, err := mapFromConnOpts(sampleCfg) - require.NoError(t, err) - require.Empty(t, optsMap) -} - -func TestMapConnOptionsMalformedMap(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - // Let's put a map instead of the required list - configConnOpts: map[string]any{ - "testKey": "testValue", - }, - }, - } - _, err := mapFromConnOpts(sampleCfg) - require.Error(t, err) -} - -func TestMapConnOptionsMalformedColon(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - // Let's use a colon instead of '=' to divide options - configConnOpts: []string{"sslmode:disable"}, - }, - } - _, err := mapFromConnOpts(sampleCfg) - require.Error(t, err) -} - func TestAdvancedConfigSucc(t *testing.T) { sampleCfg := RepoConfig{ Advanced: map[string]any{ diff --git a/discovery/denodo.go b/discovery/denodo.go index 2075db5..259cc8a 100644 --- a/discovery/denodo.go +++ b/discovery/denodo.go @@ -24,11 +24,11 @@ const ( "CATALOG_VDP_METADATA_VIEWS()" ) -// DenodoRepository is a sql.SQLRepository implementation for Denodo. +// DenodoRepository is a SQLRepository implementation for Denodo. type DenodoRepository struct { - // The majority of the sql.SQLRepository functionality is delegated to + // The majority of the SQLRepository functionality is delegated to // a generic SQL repository instance. - genericSqlRepo *GenericRepository + generic *GenericRepository } // DenodoRepository implements sql.SQLRepository @@ -36,7 +36,7 @@ var _ SQLRepository = (*DenodoRepository)(nil) // NewDenodoRepository is the constructor for sql. func NewDenodoRepository(cfg RepoConfig) (*DenodoRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) + pgCfg, err := parsePostgresConfig(cfg) if err != nil { return nil, fmt.Errorf("unable to parse postgres config: %w", err) } @@ -52,19 +52,11 @@ func NewDenodoRepository(cfg RepoConfig) (*DenodoRepository, error) { cfg.Database, pgCfg.ConnOptsStr, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &DenodoRepository{genericSqlRepo: sqlRepo}, nil + return &DenodoRepository{generic: generic}, nil } // ListDatabases is left unimplemented for Denodo, because Denodo doesn't have @@ -76,8 +68,8 @@ func (r *DenodoRepository) ListDatabases(_ context.Context) ([]string, error) { // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *DenodoRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.IntrospectWithQuery(ctx, DenodoIntrospectQuery) +func (r *DenodoRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.IntrospectWithQuery(ctx, DenodoIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using a Denodo-specific @@ -85,30 +77,29 @@ func (r *DenodoRepository) Introspect(ctx context.Context) (*Metadata, error) { // GenericRepository.SampleTableWithQuery for more details. func (r *DenodoRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Denodo uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // The postgres driver is currently unable to properly send the // parameters of a prepared statement to Denodo. Therefore, instead of // building a prepared statement, we populate the query string before // sending it to the driver. query := fmt.Sprintf( "SELECT %s FROM %s.%s OFFSET %d ROWS LIMIT %d", - attrStr, meta.Schema, meta.Name, params.Offset, params.SampleSize, + attrStr, params.Metadata.Schema, params.Metadata.Name, params.Offset, params.SampleSize, ) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *DenodoRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *DenodoRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/doc.go b/discovery/doc.go index 999e7e2..cdc59d4 100644 --- a/discovery/doc.go +++ b/discovery/doc.go @@ -1,7 +1,5 @@ // Package discovery provides mechanisms to perform database introspection and -// data discovery on various data repositories. It provides a RepoScanner type that -// can be used to scan a data repository for sensitive data, classify the data, -// and publish the results to external sources. +// data discovery on various SQL data repositories. // // Additionally, the SQLRepository interface provides an API for performing // database introspection and data discovery on SQL databases. It encapsulates diff --git a/discovery/gen.go b/discovery/gen.go deleted file mode 100644 index b295816..0000000 --- a/discovery/gen.go +++ /dev/null @@ -1,5 +0,0 @@ -package discovery - -// Mock generation - see https://vektra.github.io/mockery/ - -//go:generate mockery --testonly --inpackage --with-expecter --name=Repository --filename=mock_repository_test.go diff --git a/discovery/generic.go b/discovery/generic.go index 92db79c..6c80ee5 100644 --- a/discovery/generic.go +++ b/discovery/generic.go @@ -46,12 +46,9 @@ const ( // ListDatabasesWithQuery method, which dependent implementations can use to // provide a custom query to list databases. type GenericRepository struct { - repoName string - repoType string - database string - db *sql.DB - includePaths []glob.Glob - excludePaths []glob.Glob + repoType string + database string + db *sql.DB } var _ SQLRepository = (*GenericRepository)(nil) @@ -64,15 +61,7 @@ var _ SQLRepository = (*GenericRepository)(nil) // connections to the database. The repoIncludePaths and repoExcludePaths // parameters are used to filter the tables and columns that are introspected by // the repository. -func NewGenericRepository( - repoName, - repoType, - database, - connStr string, - maxOpenConns uint, - repoIncludePaths, - repoExcludePaths []glob.Glob, -) ( +func NewGenericRepository(repoType, database, connStr string, maxOpenConns uint) ( *GenericRepository, error, ) { @@ -81,20 +70,16 @@ func NewGenericRepository( return nil, fmt.Errorf("error retrieving DB handle for repo type %s: %w", repoType, err) } return &GenericRepository{ - repoName: repoName, - repoType: repoType, - database: database, - db: db, - includePaths: repoIncludePaths, - excludePaths: repoExcludePaths, + repoType: repoType, + database: database, + db: db, }, nil } // NewGenericRepositoryFromDB instantiate a new GenericRepository based on a // given sql.DB handle. -func NewGenericRepositoryFromDB(repoName, repoType, database string, db *sql.DB) *GenericRepository { +func NewGenericRepositoryFromDB(repoType, database string, db *sql.DB) *GenericRepository { return &GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -139,8 +124,11 @@ func (r *GenericRepository) ListDatabasesWithQuery( } // Introspect calls IntrospectWithQuery with a default query string -func (r *GenericRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.IntrospectWithQuery(ctx, GenericIntrospectQuery) +func (r *GenericRepository) Introspect( + ctx context.Context, + params IntrospectParameters, +) (*Metadata, error) { + return r.IntrospectWithQuery(ctx, GenericIntrospectQuery, params) } // IntrospectWithQuery executes a query against the information_schema table in @@ -156,22 +144,15 @@ func (r *GenericRepository) Introspect(ctx context.Context) (*Metadata, error) { func (r *GenericRepository) IntrospectWithQuery( ctx context.Context, query string, + params IntrospectParameters, ) (*Metadata, error) { log.Tracef("Query: %s", query) rows, err := r.db.QueryContext(ctx, query) if err != nil { - return nil, err + return nil, fmt.Errorf("error performing introspect query: %w", err) } defer func() { _ = rows.Close() }() - - repoMeta, err := newMetadataFromQueryResult( - r.repoType, r.repoName, - r.database, r.includePaths, r.excludePaths, rows, - ) - if err != nil { - return nil, err - } - return repoMeta, nil + return newMetadataFromQueryResult(r.database, params.IncludePaths, params.ExcludePaths, rows) } // SampleTable samples the table referenced by the TableMetadata meta parameter @@ -180,40 +161,38 @@ func (r *GenericRepository) IntrospectWithQuery( // SQLRepository.SampleTable for more details. func (r *GenericRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // ANSI SQL uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + attrStr := params.Metadata.QuotedAttributeNamesString("\"") + query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.SampleTableWithQuery(ctx, query, params) } // SampleTableWithQuery calls SampleTable with a custom SQL query. Any // placeholder parameters in the query should be passed via params. func (r *GenericRepository) SampleTableWithQuery( ctx context.Context, - meta *TableMetadata, query string, - params ...any, + params SampleParameters, ) (Sample, error) { log.Tracef("Query: %s", query) - rows, err := r.db.QueryContext(ctx, query, params...) + rows, err := r.db.QueryContext(ctx, query, params.SampleSize, params.Offset) if err != nil { return Sample{}, - fmt.Errorf("error sampling %s.%s.%s: %w", r.database, meta.Schema, meta.Name, err) + fmt.Errorf( + "error sampling database %s, schema %s, table %s: %w", + r.database, + params.Metadata.Schema, + params.Metadata.Name, + err, + ) } defer func() { _ = rows.Close() }() - sample := Sample{ - Metadata: SampleMetadata{ - Repo: r.repoName, - Database: r.database, - Schema: meta.Schema, - Table: meta.Name, - }, + TablePath: []string{r.database, params.Metadata.Schema, params.Metadata.Name}, } - + // Iterate the row set and append each row to the sample results. for rows.Next() { data, err := getCurrentRowAsMap(rows) if err != nil { @@ -221,13 +200,13 @@ func (r *GenericRepository) SampleTableWithQuery( } sample.Results = append(sample.Results, data) } - - // Something broke while iterating the row set - err = rows.Err() - if err != nil { + if err := rows.Err(); err != nil { + // Something broke while iterating the row set. return Sample{}, fmt.Errorf("error iterating sample data row set: %w", err) } - + if len(sample.Results) == 0 { + return Sample{}, nil + } return sample, nil } @@ -265,61 +244,6 @@ func newDbHandle(repoType, connStr string, maxOpenConns uint) (*sql.DB, error) { return db, nil } -// newMetadataFromQueryResult builds the repository metadata from the results -// of a query to the INFORMATION_SCHEMA columns view. -func newMetadataFromQueryResult( - repoType, repoName, db string, - includePaths, excludePaths []glob.Glob, rows *sql.Rows, -) ( - *Metadata, - error, -) { - repo := NewMetadata(repoType, repoName, db) - - for rows.Next() { - var attr AttributeMetadata - err := rows.Scan(&attr.Schema, &attr.Table, &attr.Name, &attr.DataType) - if err != nil { - return nil, err - } - - // skip tables that match excludePaths or does not match includePaths - log.Tracef("checking if %s.%s.%s matches excludePaths %s\n", db, attr.Schema, attr.Table, excludePaths) - if matchPathPatterns(db, attr.Schema, attr.Table, excludePaths) { - continue - } - log.Tracef("checking if %s.%s.%s matches includePaths: %s\n", db, attr.Schema, attr.Table, includePaths) - if !matchPathPatterns(db, attr.Schema, attr.Table, includePaths) { - continue - } - - // SchemaMetadata exists - add a table if necessary - if schema, ok := repo.Schemas[attr.Schema]; ok { - // TableMetadata exists - just append the attribute - if table, ok := schema.Tables[attr.Table]; ok { - table.Attributes = append(table.Attributes, &attr) - } else { // First time seeing this table - table := NewTableMetadata(attr.Schema, attr.Table) - table.Attributes = append(table.Attributes, &attr) - schema.Tables[attr.Table] = table - } - } else { // SchemaMetadata doesn't exist - create it - table := NewTableMetadata(attr.Schema, attr.Table) - table.Attributes = append(table.Attributes, &attr) - schema := NewSchemaMetadata(attr.Schema) - schema.Tables[attr.Table] = table - repo.Schemas[attr.Schema] = schema - } - } - - // Something broke while iterating the row set - if err := rows.Err(); err != nil { - return nil, err - } - - return repo, nil -} - // getCurrentRowAsMap transforms the current row referenced by a sql.Rows row // set into a map where the key is the column name and the value is the column // value. It is effectively an alternative to the sql.Rows.Scan method, where it diff --git a/discovery/generic_test.go b/discovery/generic_test.go index ac13df5..c490e86 100644 --- a/discovery/generic_test.go +++ b/discovery/generic_test.go @@ -7,7 +7,6 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/gobwas/glob" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -16,16 +15,13 @@ func Test_Introspect_IsSuccessful(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, - repoType: repoType, - database: database, - db: db, - includePaths: []glob.Glob{glob.MustCompile("exampleDb.*")}, + repoType: repoType, + database: database, + db: db, } cols := []string{ @@ -46,11 +42,12 @@ func Test_Introspect_IsSuccessful(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + params := IntrospectParameters{ + IncludePaths: []glob.Glob{glob.MustCompile("exampleDb.*")}, + } + meta, err := repo.Introspect(ctx, params) expectedMetadata := Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: map[string]*SchemaMetadata{ "schema1": { @@ -108,8 +105,8 @@ func Test_Introspect_IsSuccessful(t *testing.T) { }, } - assert.NoError(t, err) - assert.EqualValues(t, &expectedMetadata, meta) + require.NoError(t, err) + require.EqualValues(t, &expectedMetadata, meta) } func Test_Introspect_QueryError(t *testing.T) { @@ -117,12 +114,10 @@ func Test_Introspect_QueryError(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -135,10 +130,10 @@ func Test_Introspect_QueryError(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + meta, err := repo.Introspect(ctx, IntrospectParameters{}) - assert.Nil(t, meta) - assert.EqualError(t, err, expectedErr.Error()) + require.Nil(t, meta) + require.ErrorIs(t, err, expectedErr) } func Test_Introspect_RowError(t *testing.T) { @@ -146,12 +141,10 @@ func Test_Introspect_RowError(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, @@ -168,17 +161,17 @@ func Test_Introspect_RowError(t *testing.T) { rows := sqlmock.NewRows(cols). AddRow("schema1", "table1", "column1", "varchar"). - RowError(0, errors.New("dummy error")) + RowError(0, expectedErr) mock.ExpectQuery("SELECT (.+) FROM information_schema.columns WHERE (.+)"). WillReturnRows(rows) ctx := context.Background() - meta, err := repo.Introspect(ctx) + meta, err := repo.Introspect(ctx, IntrospectParameters{}) - assert.Nil(t, meta) - assert.EqualError(t, err, expectedErr.Error()) + require.Nil(t, meta) + require.ErrorIs(t, err, expectedErr) } func Test_Introspect_Filtered(t *testing.T) { @@ -186,22 +179,13 @@ func Test_Introspect_Filtered(t *testing.T) { require.NoError(t, err) defer func() { _ = db.Close() }() - repoName := "testRepo" repoType := "genericSql" database := "exampleDb" repo := GenericRepository{ - repoName: repoName, repoType: repoType, database: database, db: db, - includePaths: []glob.Glob{ - glob.MustCompile("exampleDb.schema1.*"), - glob.MustCompile("exampleDb.*.table1"), - }, - excludePaths: []glob.Glob{ - glob.MustCompile("exampleDb.schema3.*"), - }, } cols := []string{ @@ -226,11 +210,18 @@ func Test_Introspect_Filtered(t *testing.T) { ctx := context.Background() - meta, err := repo.Introspect(ctx) + params := IntrospectParameters{ + IncludePaths: []glob.Glob{ + glob.MustCompile("exampleDb.schema1.*"), + glob.MustCompile("exampleDb.*.table1"), + }, + ExcludePaths: []glob.Glob{ + glob.MustCompile("exampleDb.schema3.*"), + }, + } + meta, err := repo.Introspect(ctx, params) expectedMetadata := Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: map[string]*SchemaMetadata{ "schema1": { @@ -294,6 +285,6 @@ func Test_Introspect_Filtered(t *testing.T) { }, } - assert.NoError(t, err) - assert.EqualValues(t, &expectedMetadata, meta) + require.NoError(t, err) + require.EqualValues(t, &expectedMetadata, meta) } diff --git a/discovery/metadata.go b/discovery/metadata.go index ca71935..ec56cc4 100644 --- a/discovery/metadata.go +++ b/discovery/metadata.go @@ -1,7 +1,12 @@ package discovery import ( + "database/sql" + "fmt" "strings" + + "github.com/gobwas/glob" + log "github.com/sirupsen/logrus" ) // Metadata represents the structure of a SQL database. The traditional @@ -10,7 +15,6 @@ import ( // those cases, the 'Database' field is expected to be an empty string. See: // https://stackoverflow.com/a/17943883 type Metadata struct { - Name string RepoType string Database string Schemas map[string]*SchemaMetadata @@ -60,15 +64,66 @@ type AttributeMetadata struct { // NewMetadata creates a new Metadata object with the given repository type, // repository name, and database name, with an empty map of schemas. -func NewMetadata(repoType, repoName, database string) *Metadata { +func NewMetadata(database string) *Metadata { return &Metadata{ - Name: repoName, - RepoType: repoType, Database: database, Schemas: make(map[string]*SchemaMetadata), } } +// newMetadataFromQueryResult builds the repository metadata from the results +// of a query to the INFORMATION_SCHEMA columns view. +func newMetadataFromQueryResult( + db string, + includePaths, excludePaths []glob.Glob, + rows *sql.Rows, +) ( + *Metadata, + error, +) { + repo := NewMetadata(db) + for rows.Next() { + var attr AttributeMetadata + if err := rows.Scan(&attr.Schema, &attr.Table, &attr.Name, &attr.DataType); err != nil { + return nil, fmt.Errorf("error scanning metadata query result row: %w", err) + } + + // Skip tables that match excludePaths or does not match includePaths. + log.Tracef("checking if %s.%s.%s matches excludePaths %s\n", db, attr.Schema, attr.Table, excludePaths) + if matchPathPatterns(db, attr.Schema, attr.Table, excludePaths) { + continue + } + log.Tracef("checking if %s.%s.%s matches includePaths: %s\n", db, attr.Schema, attr.Table, includePaths) + if !matchPathPatterns(db, attr.Schema, attr.Table, includePaths) { + continue + } + + // SchemaMetadata exists - add a table if necessary. + if schema, ok := repo.Schemas[attr.Schema]; ok { + // TableMetadata exists - just append the attribute. + if table, ok := schema.Tables[attr.Table]; ok { + table.Attributes = append(table.Attributes, &attr) + } else { // First time seeing this table. + table := NewTableMetadata(attr.Schema, attr.Table) + table.Attributes = append(table.Attributes, &attr) + schema.Tables[attr.Table] = table + } + } else { // SchemaMetadata doesn't exist - create it. + table := NewTableMetadata(attr.Schema, attr.Table) + table.Attributes = append(table.Attributes, &attr) + schema := NewSchemaMetadata(attr.Schema) + schema.Tables[attr.Table] = table + repo.Schemas[attr.Schema] = schema + } + } + + // Something broke while iterating the row set. + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating metadata query rows: %w", err) + } + return repo, nil +} + // NewSchemaMetadata creates a new SchemaMetadata object with the given schema // name and an empty map of tables. func NewSchemaMetadata(schemaName string) *SchemaMetadata { diff --git a/discovery/metadata_test.go b/discovery/metadata_test.go index c787e03..4ac9ec7 100644 --- a/discovery/metadata_test.go +++ b/discovery/metadata_test.go @@ -3,7 +3,7 @@ package discovery import ( "testing" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestAttributeNames(t *testing.T) { @@ -34,7 +34,7 @@ func TestAttributeNames(t *testing.T) { names := []string{"name1", "name2", "name3"} - assert.ElementsMatch(t, table.AttributeNames(), names) + require.ElementsMatch(t, table.AttributeNames(), names) } func TestQuotedAttributeNamesString(t *testing.T) { @@ -67,5 +67,5 @@ func TestQuotedAttributeNamesString(t *testing.T) { expected := "`name1`,`name2`,`name3`" namesStr := table.QuotedAttributeNamesString(quoteChar) - assert.Equal(t, expected, namesStr) + require.Equal(t, expected, namesStr) } diff --git a/discovery/mock_repository_test.go b/discovery/mock_repository_test.go deleted file mode 100644 index 492bcb9..0000000 --- a/discovery/mock_repository_test.go +++ /dev/null @@ -1,342 +0,0 @@ -// Code generated by mockery v2.42.1. DO NOT EDIT. - -package discovery - -import ( - context "context" - - mock "github.com/stretchr/testify/mock" -) - -// MockRepository is an autogenerated mock type for the SQLRepository type -type MockRepository struct { - mock.Mock -} - -type MockRepository_Expecter struct { - mock *mock.Mock -} - -func (_m *MockRepository) EXPECT() *MockRepository_Expecter { - return &MockRepository_Expecter{mock: &_m.Mock} -} - -// Close provides a mock function with given fields: -func (_m *MockRepository) Close() error { - ret := _m.Called() - - if len(ret) == 0 { - panic("no return value specified for Close") - } - - var r0 error - if rf, ok := ret.Get(0).(func() error); ok { - r0 = rf() - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockRepository_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' -type MockRepository_Close_Call struct { - *mock.Call -} - -// Close is a helper method to define mock.On call -func (_e *MockRepository_Expecter) Close() *MockRepository_Close_Call { - return &MockRepository_Close_Call{Call: _e.mock.On("Close")} -} - -func (_c *MockRepository_Close_Call) Run(run func()) *MockRepository_Close_Call { - _c.Call.Run( - func(args mock.Arguments) { - run() - }, - ) - return _c -} - -func (_c *MockRepository_Close_Call) Return(_a0 error) *MockRepository_Close_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockRepository_Close_Call) RunAndReturn(run func() error) *MockRepository_Close_Call { - _c.Call.Return(run) - return _c -} - -// Introspect provides a mock function with given fields: ctx -func (_m *MockRepository) Introspect(ctx context.Context) (*Metadata, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Introspect") - } - - var r0 *Metadata - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) (*Metadata, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) *Metadata); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(*Metadata) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockRepository_Introspect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Introspect' -type MockRepository_Introspect_Call struct { - *mock.Call -} - -// Introspect is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockRepository_Expecter) Introspect(ctx interface{}) *MockRepository_Introspect_Call { - return &MockRepository_Introspect_Call{Call: _e.mock.On("Introspect", ctx)} -} - -func (_c *MockRepository_Introspect_Call) Run(run func(ctx context.Context)) *MockRepository_Introspect_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) - return _c -} - -func (_c *MockRepository_Introspect_Call) Return(_a0 *Metadata, _a1 error) *MockRepository_Introspect_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockRepository_Introspect_Call) RunAndReturn( - run func(context.Context) ( - *Metadata, - error, - ), -) *MockRepository_Introspect_Call { - _c.Call.Return(run) - return _c -} - -// ListDatabases provides a mock function with given fields: ctx -func (_m *MockRepository) ListDatabases(ctx context.Context) ([]string, error) { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for ListDatabases") - } - - var r0 []string - var r1 error - if rf, ok := ret.Get(0).(func(context.Context) ([]string, error)); ok { - return rf(ctx) - } - if rf, ok := ret.Get(0).(func(context.Context) []string); ok { - r0 = rf(ctx) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).([]string) - } - } - - if rf, ok := ret.Get(1).(func(context.Context) error); ok { - r1 = rf(ctx) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockRepository_ListDatabases_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDatabases' -type MockRepository_ListDatabases_Call struct { - *mock.Call -} - -// ListDatabases is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockRepository_Expecter) ListDatabases(ctx interface{}) *MockRepository_ListDatabases_Call { - return &MockRepository_ListDatabases_Call{Call: _e.mock.On("ListDatabases", ctx)} -} - -func (_c *MockRepository_ListDatabases_Call) Run(run func(ctx context.Context)) *MockRepository_ListDatabases_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) - return _c -} - -func (_c *MockRepository_ListDatabases_Call) Return(_a0 []string, _a1 error) *MockRepository_ListDatabases_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockRepository_ListDatabases_Call) RunAndReturn( - run func(context.Context) ( - []string, - error, - ), -) *MockRepository_ListDatabases_Call { - _c.Call.Return(run) - return _c -} - -// Ping provides a mock function with given fields: ctx -func (_m *MockRepository) Ping(ctx context.Context) error { - ret := _m.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Ping") - } - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = rf(ctx) - } else { - r0 = ret.Error(0) - } - - return r0 -} - -// MockRepository_Ping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ping' -type MockRepository_Ping_Call struct { - *mock.Call -} - -// Ping is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockRepository_Expecter) Ping(ctx interface{}) *MockRepository_Ping_Call { - return &MockRepository_Ping_Call{Call: _e.mock.On("Ping", ctx)} -} - -func (_c *MockRepository_Ping_Call) Run(run func(ctx context.Context)) *MockRepository_Ping_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context)) - }, - ) - return _c -} - -func (_c *MockRepository_Ping_Call) Return(_a0 error) *MockRepository_Ping_Call { - _c.Call.Return(_a0) - return _c -} - -func (_c *MockRepository_Ping_Call) RunAndReturn(run func(context.Context) error) *MockRepository_Ping_Call { - _c.Call.Return(run) - return _c -} - -// SampleTable provides a mock function with given fields: ctx, meta, params -func (_m *MockRepository) SampleTable(ctx context.Context, meta *TableMetadata, params SampleParameters) ( - Sample, - error, -) { - ret := _m.Called(ctx, meta, params) - - if len(ret) == 0 { - panic("no return value specified for SampleTable") - } - - var r0 Sample - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, *TableMetadata, SampleParameters) (Sample, error)); ok { - return rf(ctx, meta, params) - } - if rf, ok := ret.Get(0).(func(context.Context, *TableMetadata, SampleParameters) Sample); ok { - r0 = rf(ctx, meta, params) - } else { - r0 = ret.Get(0).(Sample) - } - - if rf, ok := ret.Get(1).(func(context.Context, *TableMetadata, SampleParameters) error); ok { - r1 = rf(ctx, meta, params) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -// MockRepository_SampleTable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SampleTable' -type MockRepository_SampleTable_Call struct { - *mock.Call -} - -// SampleTable is a helper method to define mock.On call -// - ctx context.Context -// - meta *TableMetadata -// - params SampleParameters -func (_e *MockRepository_Expecter) SampleTable( - ctx interface{}, - meta interface{}, - params interface{}, -) *MockRepository_SampleTable_Call { - return &MockRepository_SampleTable_Call{Call: _e.mock.On("SampleTable", ctx, meta, params)} -} - -func (_c *MockRepository_SampleTable_Call) Run( - run func( - ctx context.Context, - meta *TableMetadata, - params SampleParameters, - ), -) *MockRepository_SampleTable_Call { - _c.Call.Run( - func(args mock.Arguments) { - run(args[0].(context.Context), args[1].(*TableMetadata), args[2].(SampleParameters)) - }, - ) - return _c -} - -func (_c *MockRepository_SampleTable_Call) Return(_a0 Sample, _a1 error) *MockRepository_SampleTable_Call { - _c.Call.Return(_a0, _a1) - return _c -} - -func (_c *MockRepository_SampleTable_Call) RunAndReturn( - run func( - context.Context, - *TableMetadata, - SampleParameters, - ) (Sample, error), -) *MockRepository_SampleTable_Call { - _c.Call.Return(run) - return _c -} - -// NewMockRepository creates a new instance of Mocksql. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockRepository( - t interface { - mock.TestingT - Cleanup(func()) - }, -) *MockRepository { - mock := &MockRepository{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} diff --git a/discovery/mysql.go b/discovery/mysql.go index 7a7afc8..4def031 100644 --- a/discovery/mysql.go +++ b/discovery/mysql.go @@ -25,8 +25,8 @@ WHERE // MySqlRepository is a SQLRepository implementation for MySQL databases. type MySqlRepository struct { // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } // MySqlRepository implements SQLRepository @@ -44,33 +44,25 @@ func NewMySqlRepository(cfg RepoConfig) (*MySqlRepository, error) { // https://github.com/go-sql-driver/mysql#dsn-data-source-name cfg.Database, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeMysql, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeMysql, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &MySqlRepository{genericSqlRepo: sqlRepo}, nil + return &MySqlRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a MySQL-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *MySqlRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, MySqlDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, MySqlDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *MySqlRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *MySqlRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a MySQL-specific @@ -78,25 +70,24 @@ func (r *MySqlRepository) Introspect(ctx context.Context) (*Metadata, error) { // GenericRepository.SampleTableWithQuery for more details. func (r *MySqlRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // MySQL uses backticks to quote identifiers. - attrStr := meta.QuotedAttributeNamesString("`") + attrStr := params.Metadata.QuotedAttributeNamesString("`") // The generic select/limit/offset query and ? placeholders work fine with // MySQL. - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *MySqlRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *MySqlRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/mysql_test.go b/discovery/mysql_test.go index 9361a5f..52a12cf 100644 --- a/discovery/mysql_test.go +++ b/discovery/mysql_test.go @@ -24,6 +24,6 @@ func initMySqlRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &MySqlRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeMysql, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeMysql, "dbName", db), } } diff --git a/discovery/oracle.go b/discovery/oracle.go index a3dfcb1..e011ea2 100644 --- a/discovery/oracle.go +++ b/discovery/oracle.go @@ -39,8 +39,8 @@ ON // OracleRepository is a SQLRepository implementation for Oracle databases. type OracleRepository struct { // The majority of the OracleRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } // OracleRepository implements SQLRepository @@ -60,19 +60,11 @@ func NewOracleRepository(cfg RepoConfig) (*OracleRepository, error) { cfg.Port, oracleCfg.ServiceName, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeOracle, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeOracle, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &OracleRepository{genericSqlRepo: sqlRepo}, nil + return &OracleRepository{generic: generic}, nil } // ListDatabases is left unimplemented for Oracle, because Oracle doesn't have @@ -85,8 +77,8 @@ func (r *OracleRepository) ListDatabases(_ context.Context) ([]string, error) { // Introspect delegates introspection to GenericRepository, using an // Oracle-specific introspection query. See SQLRepository.Introspect and // GenericRepository.IntrospectWithQuery for more details. -func (r *OracleRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.IntrospectWithQuery(ctx, OracleIntrospectQuery) +func (r *OracleRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.IntrospectWithQuery(ctx, OracleIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using an Oracle-specific @@ -94,17 +86,16 @@ func (r *OracleRepository) Introspect(ctx context.Context) (*Metadata, error) { // GenericRepository.SampleTableWithQuery for more details. func (r *OracleRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Oracle uses double-quotes to quote identifiers. - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // Oracle uses :x for placeholders. query := fmt.Sprintf( "SELECT %s FROM %s.%s OFFSET :1 ROWS FETCH NEXT :2 ROWS ONLY", - attrStr, meta.Schema, meta.Name, + attrStr, params.Metadata.Schema, params.Metadata.Name, ) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.Offset, params.SampleSize) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping verifies the connection to Oracle database used by this Oracle @@ -113,13 +104,13 @@ func (r *OracleRepository) SampleTable( // Oracle being Oracle does not like this. Instead, we defer to the native // Ping method implemented by the Oracle DB driver. func (r *OracleRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.GetDb().PingContext(ctx) + return r.generic.GetDb().PingContext(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *OracleRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } // OracleConfig is a struct to hold Oracle-specific configuration. diff --git a/discovery/postgres.go b/discovery/postgres.go index df9603a..618b812 100644 --- a/discovery/postgres.go +++ b/discovery/postgres.go @@ -3,6 +3,7 @@ package discovery import ( "context" "fmt" + "strings" // Postgresql DB driver _ "github.com/lib/pq" @@ -26,8 +27,8 @@ WHERE // PostgresRepository is a SQLRepository implementation for Postgres databases. type PostgresRepository struct { // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } // PostgresRepository implements SQLRepository @@ -35,7 +36,7 @@ var _ SQLRepository = (*PostgresRepository)(nil) // NewPostgresRepository creates a new PostgresRepository. func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) + pgCfg, err := parsePostgresConfig(cfg) if err != nil { return nil, fmt.Errorf("error parsing postgres config: %w", err) } @@ -53,33 +54,25 @@ func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { database, pgCfg.ConnOptsStr, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &PostgresRepository{genericSqlRepo: sqlRepo}, nil + return &PostgresRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a Postgres-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *PostgresRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *PostgresRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *PostgresRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a @@ -87,26 +80,30 @@ func (r *PostgresRepository) Introspect(ctx context.Context) (*Metadata, error) // GenericRepository.SampleTableWithQuery for more details. func (r *PostgresRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Postgres uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // Postgres uses $x for placeholders - query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + query := fmt.Sprintf( + "SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", + attrStr, + params.Metadata.Schema, + params.Metadata.Name, + ) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *PostgresRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *PostgresRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } // PostgresConfig contains Postgres-specific configuration parameters. @@ -115,13 +112,80 @@ type PostgresConfig struct { ConnOptsStr string } -// ParsePostgresConfig parses the Postgres-specific configuration parameters +// parsePostgresConfig parses the Postgres-specific configuration parameters // from the given The Postgres connection options are built from the // config and stored in the ConnOptsStr field of the returned Postgres -func ParsePostgresConfig(cfg RepoConfig) (*PostgresConfig, error) { - connOptsStr, err := BuildConnOptsStr(cfg) +func parsePostgresConfig(cfg RepoConfig) (*PostgresConfig, error) { + connOptsStr, err := buildConnOptsStr(cfg) if err != nil { return nil, fmt.Errorf("error building connection options string: %w", err) } return &PostgresConfig{ConnOptsStr: connOptsStr}, nil } + +// buildConnOptsStr parses the repo config to produce a string in the format +// "?option=value&option2=value2". Example: +// +// buildConnOptsStr(RepoConfig{ +// Advanced: map[string]any{ +// "connection-string-args": []any{"sslmode=disable"}, +// }, +// }) +// +// returns ("?sslmode=disable", nil). +func buildConnOptsStr(cfg RepoConfig) (string, error) { + connOptsMap, err := mapFromConnOpts(cfg) + if err != nil { + return "", fmt.Errorf("connection options: %w", err) + } + connOptsStr := "" + for key, val := range connOptsMap { + // Don't add if the value is empty, since that would make the + // string malformed. + if val != "" { + if connOptsStr == "" { + connOptsStr += fmt.Sprintf("%s=%s", key, val) + } else { + // Need & for subsequent options + connOptsStr += fmt.Sprintf("&%s=%s", key, val) + } + } + } + // Only add ? if connection string is not empty + if connOptsStr != "" { + connOptsStr = "?" + connOptsStr + } + return connOptsStr, nil +} + +// mapFromConnOpts builds a map from the list of connection options given. Each +// option has the format 'option=value'. An error is returned if the config is +// malformed. +func mapFromConnOpts(cfg RepoConfig) (map[string]string, error) { + m := make(map[string]string) + connOptsInterface, ok := cfg.Advanced[configConnOpts] + if !ok { + return nil, nil + } + connOpts, ok := connOptsInterface.([]any) + if !ok { + return nil, fmt.Errorf("'%s' is not a list", configConnOpts) + } + for _, optInterface := range connOpts { + opt, ok := optInterface.(string) + if !ok { + return nil, fmt.Errorf("'%v' is not a string", optInterface) + } + splitOpt := strings.Split(opt, "=") + if len(splitOpt) != 2 { + return nil, fmt.Errorf( + "malformed '%s'. "+ + "Please follow the format 'option=value'", configConnOpts, + ) + } + key := splitOpt[0] + val := splitOpt[1] + m[key] = val + } + return m, nil +} diff --git a/discovery/postgres_test.go b/discovery/postgres_test.go index 1794a17..622bd01 100644 --- a/discovery/postgres_test.go +++ b/discovery/postgres_test.go @@ -19,11 +19,84 @@ func TestPostgresRepository_ListDatabases(t *testing.T) { require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } +func TestBuildConnOptionsSucc(t *testing.T) { + sampleRepoCfg := getSampleRepoConfig() + connOptsStr, err := buildConnOptsStr(sampleRepoCfg) + require.NoError(t, err) + require.Equal(t, connOptsStr, "?sslmode=disable") +} + +func TestBuildConnOptionsFail(t *testing.T) { + invalidRepoCfg := RepoConfig{ + Advanced: map[string]any{ + // Invalid: map instead of string + configConnOpts: []any{ + map[string]string{"sslmode": "disable"}, + }, + }, + } + connOptsStr, err := buildConnOptsStr(invalidRepoCfg) + require.Error(t, err) + require.Empty(t, connOptsStr) +} + +func TestMapConnOptionsSucc(t *testing.T) { + sampleRepoCfg := getSampleRepoConfig() + connOptsMap, err := mapFromConnOpts(sampleRepoCfg) + require.NoError(t, err) + require.EqualValues( + t, connOptsMap, map[string]string{ + "sslmode": "disable", + }, + ) +} + +// The mapping should only fail if the config is malformed, not if it is missing +func TestMapConnOptionsMissing(t *testing.T) { + sampleCfg := RepoConfig{} + optsMap, err := mapFromConnOpts(sampleCfg) + require.NoError(t, err) + require.Empty(t, optsMap) +} + +func TestMapConnOptionsMalformedMap(t *testing.T) { + sampleCfg := RepoConfig{ + Advanced: map[string]any{ + // Let's put a map instead of the required list + configConnOpts: map[string]any{ + "testKey": "testValue", + }, + }, + } + _, err := mapFromConnOpts(sampleCfg) + require.Error(t, err) +} + +func TestMapConnOptionsMalformedColon(t *testing.T) { + sampleCfg := RepoConfig{ + Advanced: map[string]any{ + // Let's use a colon instead of '=' to divide options + configConnOpts: []string{"sslmode:disable"}, + }, + } + _, err := mapFromConnOpts(sampleCfg) + require.Error(t, err) +} + func initPostgresRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *PostgresRepository) { ctx := context.Background() db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &PostgresRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypePostgres, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypePostgres, "dbName", db), + } +} + +// Returns a correct repo config +func getSampleRepoConfig() RepoConfig { + return RepoConfig{ + Advanced: map[string]any{ + configConnOpts: []any{"sslmode=disable"}, + }, } } diff --git a/discovery/redshift.go b/discovery/redshift.go index 57027d8..7586a90 100644 --- a/discovery/redshift.go +++ b/discovery/redshift.go @@ -15,8 +15,8 @@ const ( // RedshiftRepository is a SQLRepository implementation for Redshift databases. type RedshiftRepository struct { // The majority of the RedshiftRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } // RedshiftRepository implements SQLRepository @@ -24,7 +24,7 @@ var _ SQLRepository = (*RedshiftRepository)(nil) // NewRedshiftRepository creates a new RedshiftRepository. func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { - pgCfg, err := ParsePostgresConfig(cfg) + pgCfg, err := parsePostgresConfig(cfg) if err != nil { return nil, fmt.Errorf("unable to parse postgres config: %w", err) } @@ -42,19 +42,11 @@ func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { database, pgCfg.ConnOptsStr, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypePostgres, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &RedshiftRepository{genericSqlRepo: sqlRepo}, nil + return &RedshiftRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by @@ -62,14 +54,14 @@ func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *RedshiftRepository) ListDatabases(ctx context.Context) ([]string, error) { // Redshift and Postgres use the same query to list the server databases. - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *RedshiftRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *RedshiftRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a @@ -77,24 +69,23 @@ func (r *RedshiftRepository) Introspect(ctx context.Context) (*Metadata, error) // GenericRepository.SampleTableWithQuery for more details. func (r *RedshiftRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Redshift uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") + attrStr := params.Metadata.QuotedAttributeNamesString("\"") // Redshift uses $x for placeholders - query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize, params.Offset) + query := fmt.Sprintf("SELECT %s FROM %s.%s LIMIT $1 OFFSET $2", attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *RedshiftRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *RedshiftRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/redshift_test.go b/discovery/redshift_test.go index 07df14c..1f188a0 100644 --- a/discovery/redshift_test.go +++ b/discovery/redshift_test.go @@ -24,6 +24,6 @@ func initRedshiftRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmo db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &RedshiftRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeRedshift, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeRedshift, "dbName", db), } } diff --git a/discovery/registry.go b/discovery/registry.go index da18fc2..bc42c07 100644 --- a/discovery/registry.go +++ b/discovery/registry.go @@ -57,15 +57,21 @@ func (r *Registry) MustRegister(repoType string, constructor RepoConstructor) { } } +// Unregister removes a repository type from the registry. If the repository +// type is not registered, this method is a no-op. +func (r *Registry) Unregister(repoType string) { + delete(r.constructors, repoType) +} + // NewRepository is a factory method to return a concrete SQLRepository // implementation based on the specified type, e.g. MySQL, Postgres, SQL Server, // etc., which must be registered with the registry. If the repository type is // not registered, an error is returned. A new instance of the repository is // returned each time this method is called. -func (r *Registry) NewRepository(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - constructor, ok := r.constructors[cfg.Type] +func (r *Registry) NewRepository(ctx context.Context, repoType string, cfg RepoConfig) (SQLRepository, error) { + constructor, ok := r.constructors[repoType] if !ok { - return nil, errors.New("unsupported repo type " + cfg.Type) + return nil, errors.New("unsupported repo type " + repoType) } repo, err := constructor(ctx, cfg) if err != nil { @@ -86,10 +92,16 @@ func MustRegister(repoType string, constructor RepoConstructor) { DefaultRegistry.MustRegister(repoType, constructor) } +// Unregister is a convenience function that delegates to DefaultRegistry. See +// Registry.Unregister for more details. +func Unregister(repoType string) { + DefaultRegistry.Unregister(repoType) +} + // NewRepository is a convenience function that delegates to DefaultRegistry. // See Registry.NewRepository for more details. -func NewRepository(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - return DefaultRegistry.NewRepository(ctx, cfg) +func NewRepository(ctx context.Context, repoType string, cfg RepoConfig) (SQLRepository, error) { + return DefaultRegistry.NewRepository(ctx, repoType, cfg) } // init registers all out-of-the-box repository types and their respective diff --git a/discovery/registry_test.go b/discovery/registry_test.go index d459e36..094aa14 100644 --- a/discovery/registry_test.go +++ b/discovery/registry_test.go @@ -5,12 +5,9 @@ import ( "errors" "testing" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// TODO: refactor tests to use registry instance, not default registry -ccampo 2024-04-02 - func TestRegistry_Register_Successful(t *testing.T) { repoType := "repoType" constructor := func(context.Context, RepoConfig) (SQLRepository, error) { @@ -19,12 +16,12 @@ func TestRegistry_Register_Successful(t *testing.T) { reg := NewRegistry() err := reg.Register(repoType, constructor) require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) + require.Contains(t, reg.constructors, repoType) } func TestRegistry_MustRegister_NilConstructor(t *testing.T) { reg := NewRegistry() - assert.Panics(t, func() { reg.MustRegister("repoType", nil) }) + require.Panics(t, func() { reg.MustRegister("repoType", nil) }) } func TestRegistry_MustRegister_TwoCalls_Panics(t *testing.T) { @@ -34,14 +31,14 @@ func TestRegistry_MustRegister_TwoCalls_Panics(t *testing.T) { } reg := NewRegistry() reg.MustRegister(repoType, constructor) - assert.Contains(t, reg.constructors, repoType) - assert.Panics(t, func() { reg.MustRegister(repoType, constructor) }) + require.Contains(t, reg.constructors, repoType) + require.Panics(t, func() { reg.MustRegister(repoType, constructor) }) } func TestRegistry_NewRepository_IsSuccessful(t *testing.T) { repoType := "repoType" called := false - expectedRepo := dummyRepo{} + expectedRepo := (SQLRepository)(nil) constructor := func(context.Context, RepoConfig) (SQLRepository, error) { called = true return expectedRepo, nil @@ -49,13 +46,12 @@ func TestRegistry_NewRepository_IsSuccessful(t *testing.T) { reg := NewRegistry() err := reg.Register(repoType, constructor) require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) + require.Contains(t, reg.constructors, repoType) - cfg := RepoConfig{Type: repoType} - repo, err := reg.NewRepository(context.Background(), cfg) - assert.NoError(t, err) - assert.Equal(t, expectedRepo, repo) - assert.True(t, called, "Constructor was not called") + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.NoError(t, err) + require.Equal(t, expectedRepo, repo) + require.True(t, called, "Constructor was not called") } func TestRegistry_NewRepository_ConstructorError(t *testing.T) { @@ -69,45 +65,18 @@ func TestRegistry_NewRepository_ConstructorError(t *testing.T) { reg := NewRegistry() err := reg.Register(repoType, constructor) require.NoError(t, err) - assert.Contains(t, reg.constructors, repoType) + require.Contains(t, reg.constructors, repoType) - cfg := RepoConfig{Type: repoType} - repo, err := reg.NewRepository(context.Background(), cfg) - assert.ErrorIs(t, err, expectedErr) - assert.Nil(t, repo) - assert.True(t, called, "Constructor was not called") + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.ErrorIs(t, err, expectedErr) + require.Nil(t, repo) + require.True(t, called, "Constructor was not called") } func TestRegistry_NewRepository_UnsupportedRepoType(t *testing.T) { repoType := "repoType" - cfg := RepoConfig{Type: repoType} reg := NewRegistry() - repo, err := reg.NewRepository(context.Background(), cfg) - assert.Error(t, err) - assert.Nil(t, repo) -} - -type dummyRepo struct{} - -func (d dummyRepo) SampleTable(context.Context, *TableMetadata, SampleParameters) ( - Sample, - error, -) { - panic("not implemented") -} - -func (d dummyRepo) ListDatabases(context.Context) ([]string, error) { - panic("not implemented") -} - -func (d dummyRepo) Introspect(context.Context) (*Metadata, error) { - panic("not implemented") -} - -func (d dummyRepo) Ping(context.Context) error { - panic("not implemented") -} - -func (d dummyRepo) Close() error { - panic("not implemented") + repo, err := reg.NewRepository(context.Background(), repoType, RepoConfig{}) + require.Error(t, err) + require.Nil(t, repo) } diff --git a/discovery/sample_all_databases_test.go b/discovery/sample_all_databases_test.go deleted file mode 100644 index 0539d05..0000000 --- a/discovery/sample_all_databases_test.go +++ /dev/null @@ -1,226 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const mockRepoType = "mockRepo" - -func setup(t *testing.T) *MockRepository { - repo := NewMockRepository(t) - MustRegister( - mockRepoType, - func(ctx context.Context, cfg RepoConfig) (SQLRepository, error) { - return repo, nil - }, - ) - return repo -} - -func cleanup() { - delete(DefaultRegistry.constructors, mockRepoType) -} - -func TestSampleAllDatabases_Error(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - listDbErr := errors.New("error listing databases") - repo.On("ListDatabases", ctx).Return(nil, listDbErr) - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.Nil(t, samples) - require.ErrorIs(t, err, listDbErr) -} - -func TestSampleAllDatabases_Successful_TwoDatabases(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sample := Sample{ - Metadata: SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - }, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(sample, nil) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - // Two databases should be sampled, and our mock will return the sample for - // each sample call. This really just asserts that we've sampled the correct - // number of times. - require.ElementsMatch(t, samples, []Sample{sample, sample}) -} - -func TestSampleAllDatabases_IntrospectError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - introspectErr := errors.New("introspect error") - repo.On("Introspect", mock.Anything).Return(nil, introspectErr) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.Empty(t, samples) - require.NoError(t, err) -} - -func TestSampleAllDatabases_SampleError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sampleErr := errors.New("sample error") - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(Sample{}, sampleErr) - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - require.Empty(t, samples) -} - -func TestSampleAllDatabases_TwoDatabases_OneSampleError(t *testing.T) { - repo := setup(t) - t.Cleanup(cleanup) - - ctx := context.Background() - dbs := []string{"db1", "db2"} - repo.On("ListDatabases", ctx).Return(dbs, nil) - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Name: "test", - RepoType: mockRepoType, - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - repo.On("Introspect", mock.Anything).Return(&meta, nil) - sample := Sample{ - Metadata: SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - }, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(sample, nil).Once() - sampleErr := errors.New("sample error") - repo.On("SampleTable", mock.Anything, mock.Anything, mock.Anything). - Return(Sample{}, sampleErr).Once() - repo.On("Close").Return(nil) - - cfg := RepoConfig{Type: mockRepoType} - sampleParams := SampleParameters{SampleSize: 5} - samples, err := SampleAllDatabases(ctx, repo, cfg, sampleParams) - require.NoError(t, err) - // Because of a single sample error, we expect only one database was - // sampled. - require.ElementsMatch(t, samples, []Sample{sample}) -} diff --git a/discovery/sample_repository_test.go b/discovery/sample_repository_test.go deleted file mode 100644 index e1323e3..0000000 --- a/discovery/sample_repository_test.go +++ /dev/null @@ -1,181 +0,0 @@ -package discovery - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestSampleRepository(t *testing.T) { - ctx := context.Background() - repo := fakeRepo{} - params := SampleParameters{SampleSize: 2} - samples, err := SampleRepository(ctx, repo, params) - require.NoError(t, err) - - // Order is not important and is actually non-deterministic due to concurrency - assert.ElementsMatch(t, samples, []Sample{table1Sample, table2Sample}) -} - -func TestSampleRepository_PartialError(t *testing.T) { - ctx := context.Background() - repo := fakeRepo{includeForbiddenTables: true} - params := SampleParameters{SampleSize: 2} - samples, err := SampleRepository(ctx, repo, params) - require.ErrorContains(t, err, "forbidden table") - - // Order is not important and is actually non-deterministic due to concurrency - assert.ElementsMatch(t, samples, []Sample{table1Sample, table2Sample}) -} - -var repoMeta = Metadata{ - Name: "name", - RepoType: "repoType", - Database: "database", - Schemas: map[string]*SchemaMetadata{ - "schema1": { - Name: "", - Tables: map[string]*TableMetadata{ - "table1": { - Schema: "schema1", - Name: "table1", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "table1", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "table1", - Name: "name2", - DataType: "decimal", - }, - }, - }, - }, - }, - "schema2": { - Name: "", - Tables: map[string]*TableMetadata{ - "table2": { - Schema: "schema2", - Name: "table2", - Attributes: []*AttributeMetadata{ - { - Schema: "schema2", - Table: "table2", - Name: "name3", - DataType: "int", - }, - { - Schema: "schema2", - Table: "table2", - Name: "name4", - DataType: "timestamp", - }, - }, - }, - }, - }, - }, -} - -var schema1ForbiddenTable = TableMetadata{ - Schema: "schema1", - Name: "forbidden", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "forbidden", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "forbidden", - Name: "name2", - DataType: "decimal", - }, - }, -} - -var table1Sample = Sample{ - Metadata: SampleMetadata{ - Repo: "name", - Database: "database", - Schema: "schema1", - Table: "table1", - }, - Results: []SampleResult{ - { - "name1": "foo", - "name2": "bar", - }, - { - "name1": "baz", - "name2": "qux", - }, - }, -} - -var table2Sample = Sample{ - Metadata: SampleMetadata{ - Repo: "name", - Database: "database", - Schema: "schema2", - Table: "table2", - }, - Results: []SampleResult{ - { - "name3": "foo1", - "name4": "bar1", - }, - { - "name3": "baz1", - "name4": "qux1", - }, - }, -} - -type fakeRepo struct { - includeForbiddenTables bool -} - -func (f fakeRepo) Introspect(context.Context) (*Metadata, error) { - if f.includeForbiddenTables { - repoMeta.Schemas["schema1"].Tables["forbidden"] = &schema1ForbiddenTable - } - return &repoMeta, nil -} - -func (f fakeRepo) SampleTable(_ context.Context, meta *TableMetadata, _ SampleParameters) ( - Sample, - error, -) { - if meta.Name == "table1" { - return table1Sample, nil - } else if meta.Name == "table2" { - return table2Sample, nil - } else if meta.Name == "forbidden" { - return Sample{}, errors.New("forbidden table") - } else { - return Sample{}, errors.New("unrecognized table") - } -} - -func (f fakeRepo) ListDatabases(context.Context) ([]string, error) { - panic("not implemented") -} - -func (f fakeRepo) Ping(context.Context) error { - panic("not implemented") -} - -func (f fakeRepo) Close() error { - panic("not implemented") -} diff --git a/discovery/sampling.go b/discovery/sampling.go deleted file mode 100644 index 24b3a14..0000000 --- a/discovery/sampling.go +++ /dev/null @@ -1,216 +0,0 @@ -package discovery - -import ( - "context" - "fmt" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - "golang.org/x/sync/semaphore" -) - -// SampleParameters contains all parameters necessary to sample a table. -type SampleParameters struct { - SampleSize uint - Offset uint -} - -// Sample represents a sample of a database table. The Metadata field contains -// metadata about the sample itself. The actual results of the sample, which -// are represented by a set of database rows, are contained in the Results -// field. -type Sample struct { - Metadata SampleMetadata - Results []SampleResult -} - -// SampleMetadata contains the metadata associated with a given sample, such as -// repo name, database name, table, schema, and query (if applicable). This can -// be used for diagnostic and informational purposes when analyzing a -// particular sample. -type SampleMetadata struct { - Repo string - Database string - Schema string - Table string -} - -// SampleResult stores the results from a single database sample. It is -// equivalent to a database row, where the map key is the column name and the -// map value is the column value. -type SampleResult map[string]any - -// sampleAndErr is a "pair" type intended to be passed to a channel (see -// SampleRepository) -type sampleAndErr struct { - sample Sample - err error -} - -// samplesAndErr is a "pair" type intended to be passed to a channel (see -// SampleAllDatabases) -type samplesAndErr struct { - samples []Sample - err error -} - -// GetAttributeNamesAndValues splits a SampleResult map into two slices and -// returns them. The first slice contains all the keys of SampleResult, -// representing the table's attribute names, and the second slice is the map's -// corresponding values. -func (result SampleResult) GetAttributeNamesAndValues() ([]string, []string) { - names := make([]string, 0, len(result)) - vals := make([]string, 0, len(result)) - for name, val := range result { - names = append(names, name) - var v string - if b, ok := val.([]byte); ok { - v = string(b) - } else { - v = fmt.Sprint(val) - } - vals = append(vals, v) - } - return names, vals -} - -// SampleRepository is a helper function which will sample every table in a -// given repository and return them as a collection of Sample. First the -// repository is introspected by calling sql.Introspect to return the -// repository metadata (Metadata). Then, for each schema and table in the -// metadata, it calls sql.SampleTable in a new goroutine. Once all the -// sampling goroutines are finished, their results are collected and returned -// as a slice of Sample. -func SampleRepository(ctx context.Context, repo SQLRepository, params SampleParameters) ( - []Sample, - error, -) { - meta, err := repo.Introspect(ctx) - if err != nil { - return nil, fmt.Errorf("error introspecting repository: %w", err) - } - - // Fan out sample executions - out := make(chan sampleAndErr) - numTables := 0 - for _, schemaMeta := range meta.Schemas { - for _, tableMeta := range schemaMeta.Tables { - numTables++ - go func(meta *TableMetadata, params SampleParameters) { - sample, err := repo.SampleTable(ctx, meta, params) - out <- sampleAndErr{sample: sample, err: err} - }(tableMeta, params) - } - } - - var samples []Sample - var errs error - for i := 0; i < numTables; i++ { - res := <-out - if res.err != nil { - errs = multierror.Append(errs, res.err) - } else { - samples = append(samples, res.sample) - } - } - close(out) - - if errs != nil { - return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) - } - - return samples, nil -} - -// SampleAllDatabases uses the given repository to list all the databases on the -// server, and samples each one in parallel by calling SampleRepository for each -// database. The repository is intended to be configured to connect to the -// default database on the server, or at least some database which can be used -// to enumerate the full set of databases on the server. An error will be -// returned if the set of databases cannot be listed. If there is an error -// connecting to or sampling a database, the error will be logged and no samples -// will be returned for that database. Therefore, the returned slice of samples -// contains samples for only the databases which could be discovered and -// successfully sampled, and could potentially be empty if no databases were -// sampled. -func SampleAllDatabases( - ctx context.Context, - repo SQLRepository, - repoCfg RepoConfig, - sampleParams SampleParameters, -) ( - []Sample, - error, -) { - // We assume that this repository will be connected to the default database - // (or at least some database that can discover all the other databases), - // and we use that to discover all other databases. - dbs, err := repo.ListDatabases(ctx) - if err != nil { - return nil, fmt.Errorf("error listing databases: %w", err) - } - - // Sample each database on a separate goroutine, and send the samples - // to the 'out' channel. Each slice of samples will be aggregated below - // on the main goroutine and returned. - var wg sync.WaitGroup - out := make(chan samplesAndErr) - wg.Add(len(dbs)) - // Ensures that we avoid opening more than the specified number of - // connections. - var sema *semaphore.Weighted - if repoCfg.MaxOpenConns > 0 { - sema = semaphore.NewWeighted(int64(repoCfg.MaxOpenConns)) - } - for _, db := range dbs { - go func(db string, cfg RepoConfig) { - defer wg.Done() - if sema != nil { - _ = sema.Acquire(ctx, 1) - defer sema.Release(1) - } - cfg.Database = db - // Create a repository instance for this database. It will be used - // to connect and sample the database. - // TODO: this is ugly - there's gotta be a better way to do this -ccampo 2024-04-02 - repo, err := NewRepository(ctx, cfg) - if err != nil { - log.WithError(err).Errorf("error creating repository instance for database %s", db) - return - } - // Close this repository and free up unused resources since we don't - // need it any longer. - defer func() { _ = repo.Close() }() - s, err := SampleRepository(ctx, repo, sampleParams) - if err != nil && len(s) == 0 { - log.WithError(err).Errorf("error gathering repository data samples for database %s", db) - return - } - // Send the samples for this database to the 'out' channel. The - // samples for each database will be aggregated into a single slice - // on the main goroutine and returned. - out <- samplesAndErr{samples: s, err: err} - }(db, repoCfg) - } - - // Start a goroutine to close the 'out' channel once all the goroutines - // we launched above are done. This will allow the aggregation range loop - // below to terminate properly. Note that this must start after the wg.Add - // call. See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). - go func() { - wg.Wait() - close(out) - }() - - // Aggregate and return the results. - var ret []Sample - var errs error - for res := range out { - ret = append(ret, res.samples...) - if res.err != nil { - errs = multierror.Append(errs, res.err) - } - } - return ret, errs -} diff --git a/discovery/sampling_test.go b/discovery/sampling_test.go deleted file mode 100644 index f964a4f..0000000 --- a/discovery/sampling_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package discovery - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestSampleResult_GetAttributeNamesAndValues(t *testing.T) { - tests := []struct { - name string - result SampleResult - wantNames, wantVals []string - }{ - { - name: "string", - result: SampleResult{ - "foo": "fooVal", - }, - wantNames: []string{"foo"}, - wantVals: []string{"fooVal"}, - }, - { - name: "int", - result: SampleResult{ - "foo": 123, - }, - wantNames: []string{"foo"}, - wantVals: []string{"123"}, - }, - { - name: "float", - result: SampleResult{ - "foo": 12.3, - }, - wantNames: []string{"foo"}, - wantVals: []string{"12.3"}, - }, - { - name: "bytes", - result: SampleResult{ - "foo": []byte("fooVal"), - }, - wantNames: []string{"foo"}, - wantVals: []string{"fooVal"}, - }, - { - name: "bool", - result: SampleResult{ - "foo": false, - }, - wantNames: []string{"foo"}, - wantVals: []string{"false"}, - }, - { - name: "varied types", - result: SampleResult{ - "foo": "fooVal", - "bar": 123, - "baz": 12.3, - "qux": []byte("quxVal"), - "quxx": true, - }, - wantNames: []string{"foo", "bar", "baz", "qux", "quxx"}, - wantVals: []string{"fooVal", "123", "12.3", "quxVal", "true"}, - }, - } - for _, tt := range tests { - t.Run( - tt.name, func(t *testing.T) { - gotNames, gotVals := tt.result.GetAttributeNamesAndValues() - require.ElementsMatch(t, tt.wantNames, gotNames) - require.ElementsMatch(t, tt.wantVals, gotVals) - }, - ) - } -} diff --git a/discovery/snowflake.go b/discovery/snowflake.go index cf6d3a0..ba47af2 100644 --- a/discovery/snowflake.go +++ b/discovery/snowflake.go @@ -26,8 +26,8 @@ WHERE // SnowflakeRepository is a SQLRepository implementation for Snowflake databases. type SnowflakeRepository struct { // The majority of the SQLRepository functionality is delegated to - // a generic SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // a generic SQL repository instance. + generic *GenericRepository } // SnowflakeRepository implements SQLRepository @@ -53,55 +53,46 @@ func NewSnowflakeRepository(cfg RepoConfig) (*SnowflakeRepository, error) { snowflakeCfg.Role, snowflakeCfg.Warehouse, ) - sqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeSnowflake, - database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.ExcludePaths, - ) + generic, err := NewGenericRepository(RepoTypeSnowflake, database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &SnowflakeRepository{genericSqlRepo: sqlRepo}, nil + return &SnowflakeRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a Snowflake-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *SnowflakeRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, SnowflakeDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, SnowflakeDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *SnowflakeRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *SnowflakeRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository. See // SQLRepository.SampleTable and GenericRepository.SampleTable for more details. func (r *SnowflakeRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { - return r.genericSqlRepo.SampleTable(ctx, meta, params) + return r.generic.SampleTable(ctx, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *SnowflakeRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *SnowflakeRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } // SnowflakeConfig holds Snowflake-specific configuration parameters. diff --git a/discovery/snowflake_test.go b/discovery/snowflake_test.go index 6499a11..2cb3df6 100644 --- a/discovery/snowflake_test.go +++ b/discovery/snowflake_test.go @@ -24,6 +24,6 @@ func initSnowflakeRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlm db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &SnowflakeRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeSnowflake, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeSnowflake, "dbName", db), } } diff --git a/discovery/sqlrepository.go b/discovery/sqlrepository.go index 8f0eeb8..724633e 100644 --- a/discovery/sqlrepository.go +++ b/discovery/sqlrepository.go @@ -2,6 +2,8 @@ package discovery import ( "context" + + "github.com/gobwas/glob" ) // SQLRepository represents a Dmap data SQL repository, and provides functionality @@ -12,7 +14,7 @@ type SQLRepository interface { // Introspect will read and analyze the basic properties of the repository // and return as a Metadata instance. This includes all the repository's // databases, schemas, tables, columns, and attributes. - Introspect(ctx context.Context) (*Metadata, error) + Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) // SampleTable samples the table referenced by the TableMetadata meta // parameter and returns the sample as a slice of Sample. The parameters for // the sample, such as sample size, are passed via the params parameter (see @@ -22,7 +24,7 @@ type SQLRepository interface { // than the specified sample size, it is because the table in question had a // row count less than the sample size. Prefer small sample sizes to limit // impact on the database. - SampleTable(ctx context.Context, meta *TableMetadata, params SampleParameters) (Sample, error) + SampleTable(ctx context.Context, params SampleParameters) (Sample, error) // Ping is meant to be used as a general purpose connectivity test. It // should be invoked e.g. in the dry-run mode. Ping(ctx context.Context) error @@ -30,3 +32,45 @@ type SQLRepository interface { // invoked when the SQLRepository is no longer used. Close() error } + +// IntrospectParameters is a struct that holds the parameters for the Introspect +// method of the SQLRepository interface. +type IntrospectParameters struct { + // IncludePaths is a list of glob patterns that will be used to filter + // the tables that will be introspected. If a table name matches any of + // the patterns in this list, it will be included in the repository + // metadata. + IncludePaths []glob.Glob + // ExcludePaths is a list of glob patterns that will be used to filter + // the tables that will be introspected. If a table name matches any of + // the patterns in this list, it will be excluded from the repository + // metadata. + ExcludePaths []glob.Glob +} + +// Sample represents a sample of data from a database table. +type Sample struct { + // TablePath is the full path of the data repository table that was sampled. + // Each element corresponds to a component, in increasing order of + // granularity (e.g. [database, schema, table]). + TablePath []string + // Results is the set of sample results. Each SampleResult is equivalent to + // a database row, where the map key is the column name and the map value is + // the column value. + Results []SampleResult +} + +// SampleParameters contains all parameters necessary to sample a table. +type SampleParameters struct { + // Metadata is the metadata for the table to be sampled. + Metadata *TableMetadata + // SampleSize is the number of rows to sample from the table. + SampleSize uint + // Offset is the number of rows to skip before starting the sample. + Offset uint +} + +// SampleResult stores the results from a single database sample. It is +// equivalent to a database row, where the map key is the column name and the +// map value is the column value. +type SampleResult map[string]any diff --git a/discovery/sqlserver.go b/discovery/sqlserver.go index 15320e7..e08dc55 100644 --- a/discovery/sqlserver.go +++ b/discovery/sqlserver.go @@ -25,8 +25,8 @@ const ( // databases. type SqlServerRepository struct { // The majority of the SQLRepository functionality is delegated to a generic - // SQL repository instance (genericSqlRepo). - genericSqlRepo *GenericRepository + // SQL repository instance. + generic *GenericRepository } // SqlServerRepository implements SQLRepository @@ -45,33 +45,25 @@ func NewSqlServerRepository(cfg RepoConfig) (*SqlServerRepository, error) { if cfg.Database != "" { connStr = fmt.Sprintf(connStr+"?database=%s", cfg.Database) } - genericSqlRepo, err := NewGenericRepository( - cfg.Host, - RepoTypeSqlServer, - cfg.Database, - connStr, - cfg.MaxOpenConns, - cfg.IncludePaths, - cfg.IncludePaths, - ) + generic, err := NewGenericRepository(RepoTypeSqlServer, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &SqlServerRepository{genericSqlRepo: genericSqlRepo}, nil + return &SqlServerRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a SQL Server-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *SqlServerRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.genericSqlRepo.ListDatabasesWithQuery(ctx, SqlServerDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, SqlServerDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // SQLRepository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *SqlServerRepository) Introspect(ctx context.Context) (*Metadata, error) { - return r.genericSqlRepo.Introspect(ctx) +func (r *SqlServerRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { + return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a @@ -79,23 +71,22 @@ func (r *SqlServerRepository) Introspect(ctx context.Context) (*Metadata, error) // GenericRepository.SampleTableWithQuery for more details. func (r *SqlServerRepository) SampleTable( ctx context.Context, - meta *TableMetadata, params SampleParameters, ) (Sample, error) { // Sqlserver uses double-quotes to quote identifiers - attrStr := meta.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(SqlServerSampleQueryTemplate, attrStr, meta.Schema, meta.Name) - return r.genericSqlRepo.SampleTableWithQuery(ctx, meta, query, params.SampleSize) + attrStr := params.Metadata.QuotedAttributeNamesString("\"") + query := fmt.Sprintf(SqlServerSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See SQLRepository.Ping and // GenericRepository.Ping for more details. func (r *SqlServerRepository) Ping(ctx context.Context) error { - return r.genericSqlRepo.Ping(ctx) + return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See SQLRepository.Close and // GenericRepository.Close for more details. func (r *SqlServerRepository) Close() error { - return r.genericSqlRepo.Close() + return r.generic.Close() } diff --git a/discovery/sqlserver_test.go b/discovery/sqlserver_test.go index 16f2ae5..d89f944 100644 --- a/discovery/sqlserver_test.go +++ b/discovery/sqlserver_test.go @@ -24,6 +24,6 @@ func initSqlServerRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlm db, mock, err := sqlmock.New() require.NoError(t, err) return ctx, db, mock, &SqlServerRepository{ - genericSqlRepo: NewGenericRepositoryFromDB("repoName", RepoTypeSqlServer, "dbName", db), + generic: NewGenericRepositoryFromDB(RepoTypeSqlServer, "dbName", db), } } diff --git a/scan/classify.go b/scan/classify.go new file mode 100644 index 0000000..6ce39f9 --- /dev/null +++ b/scan/classify.go @@ -0,0 +1,61 @@ +package scan + +import ( + "context" + "fmt" + "maps" + "strings" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/discovery" +) + +// classifySamples uses the provided classifiers to classify the sample data +// passed via the "samples" parameter. It is mostly a helper function which +// loops through each repository.Sample, retrieves the attribute names and +// values of that sample, passes them to Classifier.Classify, and then +// aggregates the results. Please see the documentation for Classifier and its +// Classify method for more details. The returned slice represents all the +// unique classification results for a given sample set. +func classifySamples( + ctx context.Context, + samples []discovery.Sample, + classifier classification.Classifier, +) ([]Classification, error) { + uniqueResults := make(map[string]Classification) + for _, sample := range samples { + // Classify each sampled row and combine the results. + for _, sampleResult := range sample.Results { + res, err := classifier.Classify(ctx, sampleResult) + if err != nil { + return nil, fmt.Errorf("error classifying sample: %w", err) + } + for attr, labels := range res { + attrPath := append(sample.TablePath, attr) + key := pathKey(attrPath) + result, ok := uniqueResults[key] + if !ok { + uniqueResults[key] = Classification{ + AttributePath: attrPath, + Labels: labels, + } + } else { + // Merge the labels from the new result into the existing result. + maps.Copy(result.Labels, labels) + } + } + } + } + // Convert the map of unique results to a slice. + results := make([]Classification, 0, len(uniqueResults)) + for _, result := range uniqueResults { + results = append(results, result) + } + return results, nil +} + +func pathKey(path []string) string { + // U+2063 is an invisible separator. It is used here to ensure that the + // pathKey is unique and does not conflict with any of the path elements. + return strings.Join(path, "\u2063") +} diff --git a/scan/classify_test.go b/scan/classify_test.go new file mode 100644 index 0000000..2ddb900 --- /dev/null +++ b/scan/classify_test.go @@ -0,0 +1,159 @@ +package scan + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/discovery" + "github.com/cyralinc/dmap/scan/mocks" +) + +func Test_classifySamples_SingleSample(t *testing.T) { + ctx := context.Background() + sample := discovery.Sample{ + Path: []string{"db", "schema", "table"}, + Results: []discovery.SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + } + classifier := mocks.NewClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + + expected := []RepoScanResult{ + { + AttributePath: append(sample.Path, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(sample.Path, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(sample.Path, "credit_card_num"), + Labels: lblSet("CCN"), + }, + } + actual, err := classifySamples(ctx, []discovery.Sample{sample}, classifier) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func Test_classifySamples_MultipleSamples(t *testing.T) { + ctx := context.Background() + samples := []discovery.Sample{ + { + Path: []string{"db1", "schema1", "table1"}, + Results: []discovery.SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + }, + { + Path: []string{"db2", "schema2", "table2"}, + Results: []discovery.SampleResult{ + { + "fullname": "John Doe", + "dob": "2000-01-01", + "random": "foobarbaz", + }, + }, + }, + } + + classifier := mocks.NewClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( + classification.Result{ + "fullname": lblSet("FULL_NAME"), + "dob": lblSet("DOB"), + }, + nil, + ) + + expected := []RepoScanResult{ + { + AttributePath: append(samples[0].Path, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(samples[0].Path, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(samples[0].Path, "credit_card_num"), + Labels: lblSet("CCN"), + }, + { + AttributePath: append(samples[1].Path, "fullname"), + Labels: lblSet("FULL_NAME"), + }, + { + AttributePath: append(samples[1].Path, "dob"), + Labels: lblSet("DOB"), + }, + } + actual, err := classifySamples(ctx, samples, classifier) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func lblSet(labels ...string) classification.LabelSet { + set := make(classification.LabelSet) + for _, label := range labels { + set[label] = struct { + }{} + } + return set +} diff --git a/scan/doc.go b/scan/doc.go index f0557aa..d3b226c 100644 --- a/scan/doc.go +++ b/scan/doc.go @@ -1,4 +1,3 @@ -// Package scan provides a EnvironmentScanner interface which can be used for scanning -// cloud environments and performing data repository discovery. -// TODO: fix doc -ccampo 2024-04-02 +// Package scan provides an API to scan cloud environments for data +// repositories, and to scan those repositories for sensitive data. package scan diff --git a/scan/gen.go b/scan/gen.go index a6b6284..f1f00ab 100644 --- a/scan/gen.go +++ b/scan/gen.go @@ -3,3 +3,4 @@ package scan // Mock generation - see https://vektra.github.io/mockery/ //go:generate mockery --with-expecter --srcpkg=github.com/cyralinc/dmap/classification --name=Classifier --filename=mock_classifier.go +//go:generate mockery --with-expecter --srcpkg=github.com/cyralinc/dmap/discovery --name=SQLRepository --filename=mock_sql_repository.go diff --git a/scan/mocks/mock_sql_repository.go b/scan/mocks/mock_sql_repository.go new file mode 100644 index 0000000..706bbdb --- /dev/null +++ b/scan/mocks/mock_sql_repository.go @@ -0,0 +1,302 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + discovery "github.com/cyralinc/dmap/discovery" + mock "github.com/stretchr/testify/mock" +) + +// SQLRepository is an autogenerated mock type for the SQLRepository type +type SQLRepository struct { + mock.Mock +} + +type SQLRepository_Expecter struct { + mock *mock.Mock +} + +func (_m *SQLRepository) EXPECT() *SQLRepository_Expecter { + return &SQLRepository_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function with given fields: +func (_m *SQLRepository) Close() error { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Close") + } + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SQLRepository_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type SQLRepository_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +func (_e *SQLRepository_Expecter) Close() *SQLRepository_Close_Call { + return &SQLRepository_Close_Call{Call: _e.mock.On("Close")} +} + +func (_c *SQLRepository_Close_Call) Run(run func()) *SQLRepository_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *SQLRepository_Close_Call) Return(_a0 error) *SQLRepository_Close_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SQLRepository_Close_Call) RunAndReturn(run func() error) *SQLRepository_Close_Call { + _c.Call.Return(run) + return _c +} + +// Introspect provides a mock function with given fields: ctx, params +func (_m *SQLRepository) Introspect(ctx context.Context, params discovery.IntrospectParameters) (*discovery.Metadata, error) { + ret := _m.Called(ctx, params) + + if len(ret) == 0 { + panic("no return value specified for Introspect") + } + + var r0 *discovery.Metadata + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, discovery.IntrospectParameters) (*discovery.Metadata, error)); ok { + return rf(ctx, params) + } + if rf, ok := ret.Get(0).(func(context.Context, discovery.IntrospectParameters) *discovery.Metadata); ok { + r0 = rf(ctx, params) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*discovery.Metadata) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, discovery.IntrospectParameters) error); ok { + r1 = rf(ctx, params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SQLRepository_Introspect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Introspect' +type SQLRepository_Introspect_Call struct { + *mock.Call +} + +// Introspect is a helper method to define mock.On call +// - ctx context.Context +// - params discovery.IntrospectParameters +func (_e *SQLRepository_Expecter) Introspect(ctx interface{}, params interface{}) *SQLRepository_Introspect_Call { + return &SQLRepository_Introspect_Call{Call: _e.mock.On("Introspect", ctx, params)} +} + +func (_c *SQLRepository_Introspect_Call) Run(run func(ctx context.Context, params discovery.IntrospectParameters)) *SQLRepository_Introspect_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(discovery.IntrospectParameters)) + }) + return _c +} + +func (_c *SQLRepository_Introspect_Call) Return(_a0 *discovery.Metadata, _a1 error) *SQLRepository_Introspect_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *SQLRepository_Introspect_Call) RunAndReturn(run func(context.Context, discovery.IntrospectParameters) (*discovery.Metadata, error)) *SQLRepository_Introspect_Call { + _c.Call.Return(run) + return _c +} + +// ListDatabases provides a mock function with given fields: ctx +func (_m *SQLRepository) ListDatabases(ctx context.Context) ([]string, error) { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for ListDatabases") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) ([]string, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) []string); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SQLRepository_ListDatabases_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ListDatabases' +type SQLRepository_ListDatabases_Call struct { + *mock.Call +} + +// ListDatabases is a helper method to define mock.On call +// - ctx context.Context +func (_e *SQLRepository_Expecter) ListDatabases(ctx interface{}) *SQLRepository_ListDatabases_Call { + return &SQLRepository_ListDatabases_Call{Call: _e.mock.On("ListDatabases", ctx)} +} + +func (_c *SQLRepository_ListDatabases_Call) Run(run func(ctx context.Context)) *SQLRepository_ListDatabases_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *SQLRepository_ListDatabases_Call) Return(_a0 []string, _a1 error) *SQLRepository_ListDatabases_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *SQLRepository_ListDatabases_Call) RunAndReturn(run func(context.Context) ([]string, error)) *SQLRepository_ListDatabases_Call { + _c.Call.Return(run) + return _c +} + +// Ping provides a mock function with given fields: ctx +func (_m *SQLRepository) Ping(ctx context.Context) error { + ret := _m.Called(ctx) + + if len(ret) == 0 { + panic("no return value specified for Ping") + } + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SQLRepository_Ping_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Ping' +type SQLRepository_Ping_Call struct { + *mock.Call +} + +// Ping is a helper method to define mock.On call +// - ctx context.Context +func (_e *SQLRepository_Expecter) Ping(ctx interface{}) *SQLRepository_Ping_Call { + return &SQLRepository_Ping_Call{Call: _e.mock.On("Ping", ctx)} +} + +func (_c *SQLRepository_Ping_Call) Run(run func(ctx context.Context)) *SQLRepository_Ping_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *SQLRepository_Ping_Call) Return(_a0 error) *SQLRepository_Ping_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *SQLRepository_Ping_Call) RunAndReturn(run func(context.Context) error) *SQLRepository_Ping_Call { + _c.Call.Return(run) + return _c +} + +// SampleTable provides a mock function with given fields: ctx, params +func (_m *SQLRepository) SampleTable(ctx context.Context, params discovery.SampleParameters) (discovery.Sample, error) { + ret := _m.Called(ctx, params) + + if len(ret) == 0 { + panic("no return value specified for SampleTable") + } + + var r0 discovery.Sample + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, discovery.SampleParameters) (discovery.Sample, error)); ok { + return rf(ctx, params) + } + if rf, ok := ret.Get(0).(func(context.Context, discovery.SampleParameters) discovery.Sample); ok { + r0 = rf(ctx, params) + } else { + r0 = ret.Get(0).(discovery.Sample) + } + + if rf, ok := ret.Get(1).(func(context.Context, discovery.SampleParameters) error); ok { + r1 = rf(ctx, params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SQLRepository_SampleTable_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SampleTable' +type SQLRepository_SampleTable_Call struct { + *mock.Call +} + +// SampleTable is a helper method to define mock.On call +// - ctx context.Context +// - params discovery.SampleParameters +func (_e *SQLRepository_Expecter) SampleTable(ctx interface{}, params interface{}) *SQLRepository_SampleTable_Call { + return &SQLRepository_SampleTable_Call{Call: _e.mock.On("SampleTable", ctx, params)} +} + +func (_c *SQLRepository_SampleTable_Call) Run(run func(ctx context.Context, params discovery.SampleParameters)) *SQLRepository_SampleTable_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(discovery.SampleParameters)) + }) + return _c +} + +func (_c *SQLRepository_SampleTable_Call) Return(_a0 discovery.Sample, _a1 error) *SQLRepository_SampleTable_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *SQLRepository_SampleTable_Call) RunAndReturn(run func(context.Context, discovery.SampleParameters) (discovery.Sample, error)) *SQLRepository_SampleTable_Call { + _c.Call.Return(run) + return _c +} + +// NewSQLRepository creates a new instance of SQLRepository. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewSQLRepository(t interface { + mock.TestingT + Cleanup(func()) +}) *SQLRepository { + mock := &SQLRepository{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/scan/repo_scanner.go b/scan/repo_scanner.go index 87bcf83..23bbeb9 100644 --- a/scan/repo_scanner.go +++ b/scan/repo_scanner.go @@ -4,23 +4,26 @@ import ( "context" "fmt" + "github.com/gobwas/glob" log "github.com/sirupsen/logrus" "github.com/cyralinc/dmap/classification" "github.com/cyralinc/dmap/discovery" ) -// RepoScannerConfig is the configuration for the RepoScanner. -type RepoScannerConfig struct { - Dmap DmapConfig `embed:""` - Repo discovery.RepoConfig `embed:""` +// RepoScanResults is the result of a repository scan. +type RepoScanResults struct { + Labels []classification.Label `json:"labels"` + Classifications []Classification `json:"classifications"` } -// DmapConfig is the necessary configuration to connect to the Dmap API. -type DmapConfig struct { - ApiBaseUrl string `help:"Base URL of the Dmap API." default:"https://api.dmap.cyral.io"` - ClientID string `help:"API client ID to access the Dmap API."` - ClientSecret string `help:"API client secret to access the Dmap API."` //#nosec G101 -- false positive +type Classification struct { + // AttributePath is the full path of the data repository attribute + // (e.g. the column). Each element corresponds to a component, in increasing + // order of granularity (e.g. [database, schema, table, column]). + AttributePath []string `json:"attributePath"` + // Labels is the set of labels that the attribute was classified as. + Labels classification.LabelSet `json:"labels"` } // RepoScanner is a data discovery scanner that scans a data repository for @@ -28,171 +31,103 @@ type DmapConfig struct { // the configured external sources. It currently only supports SQL-based // repositories. type RepoScanner struct { - config RepoScannerConfig - repository discovery.SQLRepository + Config RepoScannerConfig + labels []classification.Label classifier classification.Classifier - publisher classification.Publisher -} - -// RepoScannerOption is a functional option type for the RepoScanner type. -type RepoScannerOption func(*RepoScanner) - -// WithSQLRepository is a functional option that sets the SQLRepository for the -// RepoScanner. -func WithSQLRepository(r discovery.SQLRepository) RepoScannerOption { - return func(s *RepoScanner) { s.repository = r } -} - -// WithClassifier is a functional option that sets the classifier for the -// RepoScanner. -func WithClassifier(c classification.Classifier) RepoScannerOption { - return func(s *RepoScanner) { s.classifier = c } } -// WithPublisher is a functional option that sets the publisher for the RepoScanner. -func WithPublisher(p classification.Publisher) RepoScannerOption { - return func(s *RepoScanner) { s.publisher = p } +// RepoScannerConfig is the configuration for the RepoScanner. +type RepoScannerConfig struct { + RepoType string + RepoConfig discovery.RepoConfig + Registry *discovery.Registry + IncludePaths, ExcludePaths []glob.Glob + SampleSize uint + Offset uint } // NewRepoScanner creates a new RepoScanner instance with the provided configuration. -func NewRepoScanner(ctx context.Context, cfg RepoScannerConfig, opts ...RepoScannerOption) (*RepoScanner, error) { - s := &RepoScanner{config: cfg} - // Apply options. - for _, opt := range opts { - opt(s) +func NewRepoScanner(cfg RepoScannerConfig) (*RepoScanner, error) { + if cfg.RepoType == "" { + return nil, fmt.Errorf("repository type not specified") } - if s.publisher == nil { - // Default to stdout publisher. - s.publisher = classification.NewStdOutPublisher() + if cfg.Registry == nil { + cfg.Registry = discovery.DefaultRegistry } - if s.classifier == nil { - // Create a new label classifier with the embedded labels. - lbls, err := classification.GetEmbeddedLabels() - if err != nil { - return nil, fmt.Errorf("error getting embedded labels: %w", err) - } - c, err := classification.NewLabelClassifier(lbls.ToSlice()...) - if err != nil { - return nil, fmt.Errorf("error creating new label classifier: %w", err) - } - s.classifier = c + // Create a new label classifier with the embedded labels. + lbls, err := classification.GetEmbeddedLabels() + if err != nil { + return nil, fmt.Errorf("error getting embedded labels: %w", err) } - if s.repository == nil { - // Get a repository instance from the default registry. - repo, err := discovery.NewRepository(ctx, s.config.Repo) - if err != nil { - return nil, fmt.Errorf("error connecting to database: %w", err) - } - s.repository = repo + c, err := classification.NewLabelClassifier(lbls...) + if err != nil { + return nil, fmt.Errorf("error creating new label classifier: %w", err) } - return s, nil + return &RepoScanner{Config: cfg, labels: lbls, classifier: c}, nil } // Scan performs the data repository scan. It introspects and samples the // repository, classifies the sampled data, and publishes the results to the // configured classification publisher. -func (s *RepoScanner) Scan(ctx context.Context) error { - sampleParams := discovery.SampleParameters{SampleSize: s.config.Repo.SampleSize} - var samples []discovery.Sample - // The name of the database to connect to has been left unspecified by - // the user, so we try to connect and sample all databases instead. Note - // that Oracle doesn't really have the concept of "databases", and thus - // the RepoScanner always scans the entire database, so only the single - // (default) repository instance is required in that case. - if s.config.Repo.Database == "" && s.config.Repo.Type != discovery.RepoTypeOracle { - var err error - samples, err = discovery.SampleAllDatabases( - ctx, - s.repository, - s.config.Repo, - sampleParams, - ) - if err != nil { - err = fmt.Errorf("error sampling databases: %w", err) - // If we didn't get any samples, just return the error. - if len(samples) == 0 { - return err - } - // There were error(s) during sampling, but we still got some - // samples. Just warn and continue. - log.WithError(err).Warn("error sampling databases") - } - } else { - // User specified a database (or this is an Oracle DB), therefore - // we already have a repository instance for it. Just use it to - // sample that database only. - var err error - samples, err = discovery.SampleRepository(ctx, s.repository, sampleParams) - if err != nil { - err = fmt.Errorf("error gathering repository data samples: %w", err) - // If we didn't get any samples, just return the error. - if len(samples) == 0 { - return err - } - // There were error(s) during sampling, but we still got some - // samples. Just warn and continue. - log.WithError(err).Warn("error gathering repository data samples") +func (s *RepoScanner) Scan(ctx context.Context) (*RepoScanResults, error) { + // Introspect and sample the data repository. + samples, err := s.sample(ctx) + if err != nil { + msg := "error sampling repository" + // If we didn't get any samples, just return the error. + if len(samples) == 0 { + return nil, fmt.Errorf("%s: %w", msg, err) } + // There were error(s) during sampling, but we still got some samples. + // Just warn and continue. + log.WithError(err).Warn(msg) } - - // Classify sampled data + // Classify the sampled data. classifications, err := classifySamples(ctx, samples, s.classifier) if err != nil { - return fmt.Errorf("error classifying samples: %w", err) + return nil, fmt.Errorf("error classifying samples: %w", err) } - - // Publish classifications if necessary - if len(classifications) == 0 { - log.Info("No discovered classifications") - } else if err := s.publisher.PublishClassifications(ctx, s.config.Repo.Host, classifications); err != nil { - return fmt.Errorf("error publishing classifications: %w", err) - } - - // Done! - return nil + return &RepoScanResults{ + Labels: s.labels, + Classifications: classifications, + }, nil } -// Cleanup performs cleanup operations for the RepoScanner. -func (s *RepoScanner) Cleanup() { - // Nil checks are prevent panics if deps are not yet initialized. - if s.repository != nil { - _ = s.repository.Close() +func (s *RepoScanner) sample(ctx context.Context) ([]discovery.Sample, error) { + // This closure is used to create a new repository instance for each + // database that is sampled. When there are multiple databases to sample, + // it is passed to sampleAllDbs to create the necessary repository instances + // for each database. When there is only a single database to sample, it is + // used directly below to create the repository instance for that database, + // which is passed to sampleDb to sample the database. + newRepo := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return s.Config.Registry.NewRepository(ctx, s.Config.RepoType, cfg) } -} - -// classifySamples uses the provided classifiers to classify the sample data -// passed via the "samples" parameter. It is mostly a helper function which -// loops through each repository.Sample, retrieves the attribute names and -// values of that sample, passes them to Classifier.Classify, and then -// aggregates the results. Please see the documentation for Classifier and its -// Classify method for more details. The returned slice represents all the -// unique classification results for a given sample set. -func classifySamples( - ctx context.Context, - samples []discovery.Sample, - classifier classification.Classifier, -) ([]classification.ClassifiedTable, error) { - tables := make([]classification.ClassifiedTable, 0, len(samples)) - for _, sample := range samples { - // Classify each sampled row and combine the results. - result := make(classification.Result) - for _, sampleResult := range sample.Results { - res, err := classifier.Classify(ctx, sampleResult) - if err != nil { - return nil, fmt.Errorf("error classifying sample: %w", err) - } - result.Merge(res) - } - if len(result) > 0 { - table := classification.ClassifiedTable{ - Repo: sample.Metadata.Repo, - Database: sample.Metadata.Database, - Schema: sample.Metadata.Schema, - Table: sample.Metadata.Table, - Classifications: result, - } - tables = append(tables, table) + introspectParams := discovery.IntrospectParameters{ + IncludePaths: s.Config.IncludePaths, + ExcludePaths: s.Config.ExcludePaths, + } + // Check if the user specified a single database, or told us to scan an + // Oracle DB. In that case, therefore we only need to sample that single + // database. Note that Oracle doesn't really have the concept of + // "databases", therefore a single repository instance will always scan the + // entire database. + if s.Config.RepoConfig.Database != "" || s.Config.RepoType == discovery.RepoTypeOracle { + repo, err := newRepo(ctx, s.Config.RepoConfig) + if err != nil { + return nil, fmt.Errorf("error creating repository: %w", err) } + defer func() { _ = repo.Close() }() + return sampleDb(ctx, repo, introspectParams, s.Config.SampleSize, s.Config.Offset) } - return tables, nil + // The name of the database to connect to has been left unspecified by the + // user, so we try to connect and sample all databases instead. + return sampleAllDbs( + ctx, + newRepo, + s.Config.RepoConfig, + introspectParams, + s.Config.SampleSize, + s.Config.Offset, + ) } diff --git a/scan/repo_scanner_test.go b/scan/repo_scanner_test.go index 95c4e50..040485b 100644 --- a/scan/repo_scanner_test.go +++ b/scan/repo_scanner_test.go @@ -1,247 +1 @@ package scan - -import ( - "context" - "testing" - - "github.com/cyralinc/dmap/classification" - "github.com/cyralinc/dmap/discovery" - "github.com/cyralinc/dmap/scan/mocks" - - "github.com/stretchr/testify/require" -) - -func Test_classifySamples_SingleTable(t *testing.T) { - ctx := context.Background() - meta := discovery.SampleMetadata{ - Repo: "repo", - Database: "db", - Schema: "schema", - Table: "table", - } - - sample := discovery.Sample{ - Metadata: meta, - Results: []discovery.SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - } - - classifier := mocks.NewClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - - expected := []classification.ClassifiedTable{ - { - Repo: meta.Repo, - Database: meta.Database, - Schema: meta.Schema, - Table: meta.Table, - Classifications: classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - }, - } - actual, err := classifySamples(ctx, []discovery.Sample{sample}, classifier) - require.NoError(t, err) - require.Len(t, actual, len(expected)) - for i := range actual { - requireClassifiedTableEqual(t, expected[i], actual[i]) - } -} - -func Test_classifySamples_MultipleTables(t *testing.T) { - ctx := context.Background() - meta1 := discovery.SampleMetadata{ - Repo: "repo1", - Database: "db1", - Schema: "schema1", - Table: "table1", - } - meta2 := discovery.SampleMetadata{ - Repo: "repo2", - Database: "db2", - Schema: "schema2", - Table: "table2", - } - - samples := []discovery.Sample{ - { - Metadata: meta1, - Results: []discovery.SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - }, - { - Metadata: meta2, - Results: []discovery.SampleResult{ - { - "fullname": "John Doe", - "dob": "2000-01-01", - "random": "foobarbaz", - }, - }, - }, - } - - classifier := mocks.NewClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( - classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( - classification.Result{ - "fullname": { - "FULL_NAME": {Name: "FULL_NAME"}, - }, - "dob": { - "DOB": {Name: "DOB"}, - }, - }, - nil, - ) - - expected := []classification.ClassifiedTable{ - { - Repo: meta1.Repo, - Database: meta1.Database, - Schema: meta1.Schema, - Table: meta1.Table, - Classifications: classification.Result{ - "age": { - "AGE": {Name: "AGE"}, - "CVV": {Name: "CVV"}, - }, - "social_sec_num": { - "SSN": {Name: "SSN"}, - }, - "credit_card_num": { - "CCN": {Name: "CCN"}, - }, - }, - }, - { - Repo: meta2.Repo, - Database: meta2.Database, - Schema: meta2.Schema, - Table: meta2.Table, - Classifications: classification.Result{ - "fullname": { - "FULL_NAME": {Name: "FULL_NAME"}, - }, - "dob": { - "DOB": {Name: "DOB"}, - }, - }, - }, - } - actual, err := classifySamples(ctx, samples, classifier) - require.NoError(t, err) - require.Len(t, actual, len(expected)) - for i := range actual { - requireClassifiedTableEqual(t, expected[i], actual[i]) - } -} - -func requireClassifiedTableEqual(t *testing.T, expected, actual classification.ClassifiedTable) { - require.Equal(t, expected.Repo, actual.Repo) - require.Equal(t, expected.Database, actual.Database) - require.Equal(t, expected.Schema, actual.Schema) - require.Equal(t, expected.Table, actual.Table) - requireResultEqual(t, expected.Classifications, actual.Classifications) -} - -func requireResultEqual(t *testing.T, want, got classification.Result) { - require.Len(t, got, len(want)) - for k, v := range want { - gotSet, ok := got[k] - require.Truef(t, ok, "missing attribute %s", k) - requireLabelSetEqual(t, v, gotSet) - } -} - -func requireLabelSetEqual(t *testing.T, want, got classification.LabelSet) { - require.Len(t, got, len(want)) - for k, v := range want { - gotLbl, ok := got[k] - require.Truef(t, ok, "missing label %s", k) - require.Equal(t, v.Name, gotLbl.Name) - } -} diff --git a/scan/sample.go b/scan/sample.go new file mode 100644 index 0000000..68b6718 --- /dev/null +++ b/scan/sample.go @@ -0,0 +1,192 @@ +package scan + +import ( + "context" + "fmt" + "sync" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" + + "github.com/cyralinc/dmap/discovery" +) + +// sampleAndErr is a "pair" type intended to be passed to a channel (see +// sampleDb) +type sampleAndErr struct { + sample discovery.Sample + err error +} + +// samplesAndErr is a "pair" type intended to be passed to a channel (see +// sampleAllDbs) +type samplesAndErr struct { + samples []discovery.Sample + err error +} + +// sampleAllDbs uses the given SQLRepository to list all the +// databases on the server, and samples each one in parallel by calling +// sampleDb for each database. The repository is intended to be +// configured to connect to the default database on the server, or at least some +// database which can be used to enumerate the full set of databases on the +// server. An error will be returned if the set of databases cannot be listed. +// If there is an error connecting to or sampling a database, the error will be +// logged and no samples will be returned for that database. Therefore, the +// returned slice of samples contains samples for only the databases which could +// be discovered and successfully sampled, and could potentially be empty if no +// databases were sampled. +func sampleAllDbs( + ctx context.Context, + ctor discovery.RepoConstructor, + cfg discovery.RepoConfig, + introspectParams discovery.IntrospectParameters, + sampleSize, offset uint, +) ( + []discovery.Sample, + error, +) { + // Create a repository instance that will be used to list all the databases + // on the server. + repo, err := ctor(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("error creating repository instance: %w", err) + } + defer func() { _ = repo.Close() }() + + // We assume that this repository will be connected to the default database + // (or at least some database that can discover all the other databases), + // and we use that to discover all other databases. + dbs, err := repo.ListDatabases(ctx) + if err != nil { + return nil, fmt.Errorf("error listing databases: %w", err) + } + + // Sample each database on a separate goroutine, and send the samples to + // the 'out' channel. Each slice of samples will be aggregated below on the + // main goroutine and returned. + var wg sync.WaitGroup + out := make(chan samplesAndErr) + wg.Add(len(dbs)) + // Ensures that we avoid opening more than the specified number of + // connections. + var sema *semaphore.Weighted + if cfg.MaxOpenConns > 0 { + sema = semaphore.NewWeighted(int64(cfg.MaxOpenConns)) + } + for _, db := range dbs { + go func(db string, cfg discovery.RepoConfig) { + defer wg.Done() + if sema != nil { + _ = sema.Acquire(ctx, 1) + defer sema.Release(1) + } + // Create a repository instance for this specific database. It will + // be used to connect to and sample the database. + cfg.Database = db + repo, err := ctor(ctx, cfg) + if err != nil { + log.WithError(err).Errorf("error creating repository instance for database %s", db) + return + } + defer func() { _ = repo.Close() }() + // Sample the database. + s, err := sampleDb(ctx, repo, introspectParams, sampleSize, offset) + if err != nil && len(s) == 0 { + log.WithError(err).Errorf("error gathering repository data samples for database %s", db) + return + } + // Send the samples for this database to the 'out' channel. The + // samples for each database will be aggregated into a single slice + // on the main goroutine and returned. + out <- samplesAndErr{samples: s, err: err} + }(db, cfg) + } + + // Start a goroutine to close the 'out' channel once all the goroutines + // we launched above are done. This will allow the aggregation range loop + // below to terminate properly. Note that this must start after the wg.Add + // call. See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). + go func() { + wg.Wait() + close(out) + }() + + // Aggregate and return the results. + var ret []discovery.Sample + var errs error + for res := range out { + ret = append(ret, res.samples...) + if res.err != nil { + errs = multierror.Append(errs, res.err) + } + } + return ret, errs +} + +// sampleDb is a helper function which will sample every table in a +// given repository and return them as a collection of Sample. First the +// repository is introspected by calling sql.Introspect to return the +// repository metadata (Metadata). Then, for each schema and table in the +// metadata, it calls sql.SampleTable in a new goroutine. Once all the +// sampling goroutines are finished, their results are collected and returned +// as a slice of Sample. +func sampleDb( + ctx context.Context, + repo discovery.SQLRepository, + introspectParams discovery.IntrospectParameters, + sampleSize, offset uint, +) ( + []discovery.Sample, + error, +) { + // Introspect the repository to get the metadata. + meta, err := repo.Introspect(ctx, introspectParams) + if err != nil { + return nil, fmt.Errorf("error introspecting repository: %w", err) + } + + // Fan out sample executions. + out := make(chan sampleAndErr) + numTables := 0 + for _, schemaMeta := range meta.Schemas { + for _, tableMeta := range schemaMeta.Tables { + numTables++ + go func(meta *discovery.TableMetadata) { + params := discovery.SampleParameters{ + Metadata: meta, + SampleSize: sampleSize, + Offset: offset, + } + sample, err := repo.SampleTable(ctx, params) + select { + case <-ctx.Done(): + return + case out <- sampleAndErr{sample: sample, err: err}: + } + }(tableMeta) + } + } + + // Aggregate and return the results. + var samples []discovery.Sample + var errs error + for i := 0; i < numTables; i++ { + select { + case <-ctx.Done(): + return samples, ctx.Err() + case res:= <-out: + if res.err != nil { + errs = multierror.Append(errs, res.err) + } else { + samples = append(samples, res.sample) + } + } + } + close(out) + if errs != nil { + return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) + } + return samples, nil +} diff --git a/scan/sample_test.go b/scan/sample_test.go new file mode 100644 index 0000000..216f18d --- /dev/null +++ b/scan/sample_test.go @@ -0,0 +1,208 @@ +package scan + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/cyralinc/dmap/discovery" + "github.com/cyralinc/dmap/scan/mocks" +) + +var ( + table1Sample = discovery.Sample{ + TablePath: []string{"database", "schema1", "table1"}, + Results: []discovery.SampleResult{ + { + "name1": "foo", + "name2": "bar", + }, + { + "name1": "baz", + "name2": "qux", + }, + }, + } + + table2Sample = discovery.Sample{ + TablePath: []string{"database", "schema2", "table2"}, + Results: []discovery.SampleResult{ + { + "name3": "foo1", + "name4": "bar1", + }, + { + "name3": "baz1", + "name4": "qux1", + }, + }, + } +) + +func Test_sampleDb_Success(t *testing.T) { + ctx := context.Background() + repo := mocks.NewSQLRepository(t) + meta := discovery.Metadata{ + Database: "database", + Schemas: map[string]*discovery.SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*discovery.TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*discovery.TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := discovery.SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := discovery.SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + samples, err := sampleDb(ctx, repo, discovery.IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Order is not important and is actually non-deterministic due to concurrency + expected := []discovery.Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} + +func Test_sampleDb_PartialError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + repo := mocks.NewSQLRepository(t) + meta := discovery.Metadata{ + Database: "database", + Schemas: map[string]*discovery.SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*discovery.TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + "forbidden": { + Schema: "schema1", + Name: "forbidden", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema1", + Table: "forbidden", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "forbidden", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*discovery.TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := discovery.SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := discovery.SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["forbidden"], + } + sampleParamsForbidden := discovery.SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + errForbidden := errors.New("forbidden table") + repo.EXPECT().SampleTable(ctx, sampleParamsForbidden).Return(discovery.Sample{}, errForbidden) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + + samples, err := sampleDb(ctx, repo, discovery.IntrospectParameters{}, 0, 0) + require.ErrorIs(t, err, errForbidden) + // Order is not important and is actually non-deterministic due to concurrency + expected := []discovery.Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} diff --git a/scan/samplealldatabases_test.go b/scan/samplealldatabases_test.go new file mode 100644 index 0000000..02c2701 --- /dev/null +++ b/scan/samplealldatabases_test.go @@ -0,0 +1,185 @@ +package scan + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/cyralinc/dmap/discovery" + "github.com/cyralinc/dmap/scan/mocks" +) + +func Test_sampleAllDbs_Error(t *testing.T) { + ctx := context.Background() + listDbErr := errors.New("error listing databases") + repo := mocks.NewSQLRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(nil, listDbErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return repo, nil + } + cfg := discovery.RepoConfig{} + samples, err := sampleAllDbs(ctx, ctor, cfg, discovery.IntrospectParameters{}, 0, 0) + require.Nil(t, samples) + require.ErrorIs(t, err, listDbErr) +} + +func Test_sampleAllDbs_Successful_TwoDatabases(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := discovery.Metadata{ + Database: "db", + Schemas: map[string]*discovery.SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*discovery.TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := discovery.Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []discovery.SampleResult{ + { + "attr": "foo", + }, + }, + } + repo := mocks.NewSQLRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, discovery.RepoConfig{}, discovery.IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Two databases should be sampled, and our mock will return the sample for + // each sample call. This really just asserts that we've sampled the correct + // number of times. + require.ElementsMatch(t, samples, []discovery.Sample{sample, sample}) +} + +func Test_sampleAllDbs_IntrospectError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + introspectErr := errors.New("introspect error") + repo := mocks.NewSQLRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(nil, introspectErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, discovery.RepoConfig{}, discovery.IntrospectParameters{}, 0, 0) + require.Empty(t, samples) + require.NoError(t, err) +} + +func Test_sampleAllDbs_SampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := discovery.Metadata{ + Database: "db", + Schemas: map[string]*discovery.SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*discovery.TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sampleErr := errors.New("sample error") + repo := mocks.NewSQLRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(discovery.Sample{}, sampleErr) + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, discovery.RepoConfig{}, discovery.IntrospectParameters{}, 0, 0) + require.NoError(t, err) + require.Empty(t, samples) +} + +func Test_sampleAllDbs_TwoDatabases_OneSampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := discovery.Metadata{ + Database: "db", + Schemas: map[string]*discovery.SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*discovery.TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*discovery.AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := discovery.Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []discovery.SampleResult{ + { + "attr": "foo", + }, + }, + } + sampleErr := errors.New("sample error") + repo := mocks.NewSQLRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil).Once() + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(discovery.Sample{}, sampleErr).Once() + repo.EXPECT().Close().Return(nil) + ctor := func(ctx context.Context, cfg discovery.RepoConfig) (discovery.SQLRepository, error) { + return repo, nil + } + samples, err := sampleAllDbs(ctx, ctor, discovery.RepoConfig{}, discovery.IntrospectParameters{}, 0, 0) + require.NoError(t, err) + // Because of a single sample error, we expect only one database was + // sampled. + require.ElementsMatch(t, samples, []discovery.Sample{sample}) +}