Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Apr 3, 2024
1 parent b5409e7 commit 46cf074
Show file tree
Hide file tree
Showing 60 changed files with 2,114 additions and 2,464 deletions.
16 changes: 8 additions & 8 deletions aws/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/cyralinc/dmap/scan"
)

// AWSScanner is an implementation of the EnvironmentScanner interface for the AWS cloud
// AWSScanner is an implementation of the Scanner interface for the AWS cloud
// provider. It supports scanning data repositories from multiple AWS regions,
// including RDS clusters and instances, Redshift clusters and DynamoDB tables.
type AWSScanner struct {
Expand All @@ -23,8 +23,8 @@ type AWSScanner struct {
awsClientConstructor awsClientConstructor
}

// AWSScanner implements scan.EnvironmentScanner
var _ scan.EnvironmentScanner = (*AWSScanner)(nil)
// AWSScanner implements scan.Scanner
var _ scan.Scanner = (*AWSScanner)(nil)

// NewAWSScanner creates a new instance of AWSScanner based on the ScannerConfig.
// If AssumeRoleConfig is specified, the AWSScanner will assume this IAM Role
Expand Down Expand Up @@ -57,7 +57,7 @@ func NewAWSScanner(
// Scan performs a scan across all the AWS regions configured and return a scan
// results, containing a list of data repositories that includes: RDS clusters
// and instances, Redshift clusters and DynamoDB tables.
func (s *AWSScanner) Scan(ctx context.Context) (*scan.EnvironmentScanResults, error) {
func (s *AWSScanner) Scan(ctx context.Context) (*scan.ScanResults, error) {
responseChan := make(chan scanResponse)
var wg sync.WaitGroup
wg.Add(len(s.scannerConfig.Regions))
Expand Down Expand Up @@ -100,18 +100,18 @@ func (s *AWSScanner) Scan(ctx context.Context) (*scan.EnvironmentScanResults, er
select {
case <-ctx.Done():
scanErrors = append(scanErrors, ctx.Err())
return &scan.EnvironmentScanResults{
return &scan.ScanResults{
Repositories: repositories,
}, &scan.EnvironmentScanError{Errs: scanErrors}
}, &scan.ScanError{Errs: scanErrors}

case response, ok := <-responseChan:
if !ok {
// Channel closed, all scans finished.
var scanErr error
if len(scanErrors) > 0 {
scanErr = &scan.EnvironmentScanError{Errs: scanErrors}
scanErr = &scan.ScanError{Errs: scanErrors}
}
return &scan.EnvironmentScanResults{
return &scan.ScanResults{
Repositories: repositories,
}, scanErr

Expand Down
4 changes: 2 additions & 2 deletions aws/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (s *AWSScannerTestSuite) TestScan() {
ctx := context.Background()
results, err := awsScanner.Scan(ctx)

expectedResults := &scan.EnvironmentScanResults{
expectedResults := &scan.ScanResults{
Repositories: map[string]scan.Repository{
*s.dummyRDSClusters[0].DBClusterArn: {
Id: *s.dummyRDSClusters[0].DBClusterArn,
Expand Down Expand Up @@ -359,7 +359,7 @@ func (s *AWSScannerTestSuite) TestScan_WithErrors() {
ctx := context.Background()
results, err := awsScanner.Scan(ctx)

expectedResults := &scan.EnvironmentScanResults{
expectedResults := &scan.ScanResults{
Repositories: nil,
}

Expand Down
35 changes: 8 additions & 27 deletions classification/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ package classification

import (
"context"
"maps"
"encoding/json"

"golang.org/x/exp/maps"
)

// Classifier is an interface that represents a data classifier. A classifier
Expand All @@ -22,35 +24,14 @@ type Classifier interface {
Classify(ctx context.Context, input map[string]any) (Result, error)
}

// ClassifiedTable represents a database table that has been classified. The
// classifications are stored in the Classifications field, which is a map of
// attribute names (i.e. columns) to the set of labels that attributes were
// classified as.
type ClassifiedTable struct {
Repo string `json:"repo"`
Database string `json:"database"`
Schema string `json:"schema"`
Table string `json:"table"`
Classifications Result `json:"classifications"`
}

// Result represents the classifications for a set of data attributes. The key
// is the attribute (i.e. column) name and the value is the set of labels
// that attribute was classified as.
type Result map[string]LabelSet

// Merge combines the given other Result into this Result (the receiver). If
// an attribute from other is already present in this Result, the existing
// labels for that attribute are merged with the labels from other, otherwise
// labels from other for the attribute are simply added to this Result.
func (c Result) Merge(other Result) {
if c == nil {
return
}
for attr, labelSet := range other {
if _, ok := c[attr]; !ok {
c[attr] = make(LabelSet)
}
maps.Copy(c[attr], labelSet)
}
// LabelSet is a set of unique labels.
type LabelSet map[string]struct{}

func (l LabelSet) MarshalJSON() ([]byte, error) {
return json.Marshal(maps.Keys(l))
}
59 changes: 0 additions & 59 deletions classification/classification_test.go

This file was deleted.

50 changes: 20 additions & 30 deletions classification/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"
)

Expand All @@ -20,18 +20,23 @@ var (

// Label represents a data classification label.
type Label struct {
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Tags []string `yaml:"tags" json:"tags"`
ClassificationRule *rego.Rego `yaml:"-" json:"-"`
// Name is the name of the label.
Name string `yaml:"name" json:"name"`
// Description is a brief description of the label.
Description string `yaml:"description" json:"description"`
// Tags are a list of arbitrary tags associated with the label.
Tags []string `yaml:"tags" json:"tags"`
// ClassificationRule is the compiled Rego classification rule used to
// classify data.
ClassificationRule *ast.Module `yaml:"-" json:"-"`
}

// NewLabel creates a new Label with the given name, description, classification
// rule, and tags. The classification rule is expected to be the raw Rego code
// that will be used to classify data. If the classification rule is invalid, an
// error is returned.
func NewLabel(name, description, classificationRule string, tags []string) (Label, error) {
rule, err := ruleRego(classificationRule)
func NewLabel(name, description, classificationRule string, tags ...string) (Label, error) {
rule, err := parseRego(classificationRule)
if err != nil {
return Label{}, fmt.Errorf("error preparing classification rule for label %s: %w", name, err)
}
Expand All @@ -43,25 +48,12 @@ func NewLabel(name, description, classificationRule string, tags []string) (Labe
}, nil
}

// LabelSet is a set of unique labels. The key is the label name and the value
// is the label itself.
type LabelSet map[string]Label

// ToSlice returns the labels in the set as a slice.
func (s LabelSet) ToSlice() []Label {
var labels []Label
for _, label := range s {
labels = append(labels, label)
}
return labels
}

// GetEmbeddedLabels returns the predefined embedded labels and their
// classification rules. The labels are read from the embedded labels.yaml file
// and the classification rules are read from the embedded Rego files.
func GetEmbeddedLabels() (LabelSet, error) {
func GetEmbeddedLabels() ([]Label, error) {
labels := struct {
Labels LabelSet `yaml:"labels"`
Labels map[string]Label `yaml:"labels"`
}{}
if err := yaml.Unmarshal([]byte(labelsYaml), &labels); err != nil {
return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err)
Expand All @@ -72,24 +64,22 @@ func GetEmbeddedLabels() (LabelSet, error) {
if err != nil {
return nil, fmt.Errorf("error reading rego file %s: %w", fname, err)
}
rule, err := ruleRego(string(b))
rule, err := parseRego(string(b))
if err != nil {
return nil, fmt.Errorf("error preparing classification rule for label %s: %w", lbl.Name, err)
}
lbl.Name = name
lbl.ClassificationRule = rule
labels.Labels[name] = lbl
}
return labels.Labels, nil
return maps.Values(labels.Labels), nil
}

func ruleRego(code string) (*rego.Rego, error) {
func parseRego(code string) (*ast.Module, error) {
log.Tracef("classifier module code: '%s'", code)
moduleName := "classifier"
compiledRego, err := ast.CompileModules(map[string]string{moduleName: code})
module, err := ast.ParseModule("classifier", code)
if err != nil {
return nil, fmt.Errorf("error compiling rego code: %w", err)
return nil, fmt.Errorf("error parsing rego code: %w", err)
}
regoQuery := compiledRego.Modules[moduleName].Package.Path.String() + ".output"
return rego.New(rego.Query(regoQuery), rego.Compiler(compiledRego)), nil
return module, nil
}
44 changes: 24 additions & 20 deletions classification/label_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,63 @@ import (
// LabelClassifier is a Classifier implementation that uses a set of labels and
// their classification rules to classify data.
type LabelClassifier struct {
labels LabelSet
queries map[string]*rego.Rego
}

// LabelClassifier implements Classifier
var _ Classifier = (*LabelClassifier)(nil)

// NewLabelClassifier creates a new LabelClassifier with the provided labels and
// classification rules.
// NewLabelClassifier creates a new LabelClassifier with the provided labels.
//
func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) {
if len(labels) == 0 {
return nil, fmt.Errorf("labels cannot be empty")
}
l := make(LabelSet, len(labels))
queries := make(map[string]*rego.Rego, len(labels))
for _, lbl := range labels {
l[lbl.Name] = lbl
queries[lbl.Name] = rego.New(
// We only care about the 'output' variable.
rego.Query(lbl.ClassificationRule.Package.Path.String() + ".output"),
rego.ParsedModule(lbl.ClassificationRule),
)
}
return &LabelClassifier{labels: l}, nil
return &LabelClassifier{queries: queries}, nil
}

// Classify performs the classification of the provided attributes using the
// classifier's labels and classification rules. It returns a Result, which is
// a map of attribute names to the set of labels that the attribute was
// classified as.
// classifier's labels and their corresponding classification rules. It returns
// a Result, which is a map of attribute names to the set of labels that the
// attribute was classified as.
func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (Result, error) {
result := make(Result, len(c.labels))
for _, lbl := range c.labels {
output, err := evalQuery(ctx, lbl.ClassificationRule, input)
result := make(Result, len(c.queries))
for lbl, query := range c.queries {
output, err := evalQuery(ctx, query, input)
if err != nil {
return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl.Name, err)
return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl, err)
}
log.Debugf("classification results for label %s: %v", lbl.Name, output)
log.Debugf("classification results for label %s: %v", lbl, output)
for attrName, classified := range output {
if classified {
attrLabels, ok := result[attrName]
if !ok {
attrLabels = make(LabelSet)
result[attrName] = attrLabels
}
attrLabels[lbl.Name] = lbl
// Add the label to the set of labels for the attribute.
attrLabels[lbl] = struct{}{}
}
}
}
return result, nil
}

// evalQuery evaluates the provided prepared Rego query with the given
// attributes as input, and returns the classification results. The output is a
// evalQuery evaluates the provided Rego query with the given attributes as input, and returns the classification results. The output is a
// map of attribute names to boolean values, where the boolean indicates whether
// the attribute is classified as belonging to the label.
func evalQuery(ctx context.Context, rule *rego.Rego, input map[string]any) (map[string]bool, error) {
q, err := rule.PrepareForEval(ctx)
func evalQuery(ctx context.Context, query *rego.Rego, input map[string]any) (map[string]bool, error) {
q, err := query.PrepareForEval(ctx)
if err != nil {
return nil, fmt.Errorf("error preparing rule for evaluation: %w", err)
return nil, fmt.Errorf("error preparing query for evaluation: %w", err)
}
// Evaluate the prepared Rego query. This performs the actual classification
// logic.
Expand Down
Loading

0 comments on commit 46cf074

Please sign in to comment.