Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Apr 2, 2024
1 parent 05827ac commit e066b04
Show file tree
Hide file tree
Showing 51 changed files with 989 additions and 968 deletions.
16 changes: 8 additions & 8 deletions aws/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
"github.com/cyralinc/dmap/scan"
)

// AWSScanner is an implementation of the Scanner interface for the AWS cloud
// AWSScanner is an implementation of the EnvironmentScanner interface for the AWS cloud
// provider. It supports scanning data repositories from multiple AWS regions,
// including RDS clusters and instances, Redshift clusters and DynamoDB tables.
type AWSScanner struct {
Expand All @@ -23,8 +23,8 @@ type AWSScanner struct {
awsClientConstructor awsClientConstructor
}

// AWSScanner implements scan.Scanner
var _ scan.Scanner = (*AWSScanner)(nil)
// AWSScanner implements scan.EnvironmentScanner
var _ scan.EnvironmentScanner = (*AWSScanner)(nil)

// NewAWSScanner creates a new instance of AWSScanner based on the ScannerConfig.
// If AssumeRoleConfig is specified, the AWSScanner will assume this IAM Role
Expand Down Expand Up @@ -57,7 +57,7 @@ func NewAWSScanner(
// Scan performs a scan across all the AWS regions configured and return a scan
// results, containing a list of data repositories that includes: RDS clusters
// and instances, Redshift clusters and DynamoDB tables.
func (s *AWSScanner) Scan(ctx context.Context) (*scan.ScanResults, error) {
func (s *AWSScanner) Scan(ctx context.Context) (*scan.EnvironmentScanResults, error) {
responseChan := make(chan scanResponse)
var wg sync.WaitGroup
wg.Add(len(s.scannerConfig.Regions))
Expand Down Expand Up @@ -100,18 +100,18 @@ func (s *AWSScanner) Scan(ctx context.Context) (*scan.ScanResults, error) {
select {
case <-ctx.Done():
scanErrors = append(scanErrors, ctx.Err())
return &scan.ScanResults{
return &scan.EnvironmentScanResults{
Repositories: repositories,
}, &scan.ScanError{Errs: scanErrors}
}, &scan.EnvironmentScanError{Errs: scanErrors}

case response, ok := <-responseChan:
if !ok {
// Channel closed, all scans finished.
var scanErr error
if len(scanErrors) > 0 {
scanErr = &scan.ScanError{Errs: scanErrors}
scanErr = &scan.EnvironmentScanError{Errs: scanErrors}
}
return &scan.ScanResults{
return &scan.EnvironmentScanResults{
Repositories: repositories,
}, scanErr

Expand Down
4 changes: 2 additions & 2 deletions aws/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func (s *AWSScannerTestSuite) TestScan() {
ctx := context.Background()
results, err := awsScanner.Scan(ctx)

expectedResults := &scan.ScanResults{
expectedResults := &scan.EnvironmentScanResults{
Repositories: map[string]scan.Repository{
*s.dummyRDSClusters[0].DBClusterArn: {
Id: *s.dummyRDSClusters[0].DBClusterArn,
Expand Down Expand Up @@ -359,7 +359,7 @@ func (s *AWSScannerTestSuite) TestScan_WithErrors() {
ctx := context.Background()
results, err := awsScanner.Scan(ctx)

expectedResults := &scan.ScanResults{
expectedResults := &scan.EnvironmentScanResults{
Repositories: nil,
}

Expand Down
40 changes: 0 additions & 40 deletions classification/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ package classification

import (
"context"
"fmt"
"maps"

"github.com/cyralinc/dmap/discovery/repository"
)

// Classifier is an interface that represents a data classifier. A classifier
Expand Down Expand Up @@ -57,40 +54,3 @@ func (c Result) Merge(other Result) {
maps.Copy(c[attr], labelSet)
}
}

// ClassifySamples uses the provided classifiers to classify the sample data
// passed via the "samples" parameter. It is mostly a helper function which
// loops through each repository.Sample, retrieves the attribute names and
// values of that sample, passes them to Classifier.Classify, and then
// aggregates the results. Please see the documentation for Classifier and its
// Classify method for more details. The returned slice represents all the
// unique classification results for a given sample set.
func ClassifySamples(
ctx context.Context,
samples []sql.Sample,
classifier Classifier,
) ([]ClassifiedTable, error) {
tables := make([]ClassifiedTable, 0, len(samples))
for _, sample := range samples {
// Classify each sampled row and combine the results.
result := make(Result)
for _, sampleResult := range sample.Results {
res, err := classifier.Classify(ctx, sampleResult)
if err != nil {
return nil, fmt.Errorf("error classifying sample: %w", err)
}
result.Merge(res)
}
if len(result) > 0 {
table := ClassifiedTable{
Repo: sample.Metadata.Repo,
Database: sample.Metadata.Database,
Schema: sample.Metadata.Schema,
Table: sample.Metadata.Table,
Classifications: result,
}
tables = append(tables, table)
}
}
return tables, nil
}
220 changes: 0 additions & 220 deletions classification/classification_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
package classification

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/cyralinc/dmap/discovery/repository"
)

func TestMerge_WhenCalledOnNilReceiver_ShouldNotPanic(t *testing.T) {
Expand Down Expand Up @@ -60,220 +57,3 @@ func TestMerge_WhenCalledWithExistingAttributes_ShouldOverwrite(t *testing.T) {
result.Merge(other)
require.Equal(t, expected, result)
}

func TestClassifySamples_SingleTable(t *testing.T) {
ctx := context.Background()
meta := sql.SampleMetadata{
Repo: "repo",
Database: "db",
Schema: "schema",
Table: "table",
}

sample := sql.Sample{
Metadata: meta,
Results: []sql.SampleResult{
{
"age": "52",
"social_sec_num": "512-23-4258",
"credit_card_num": "4111111111111111",
},
{
"age": "101",
"social_sec_num": "foobarbaz",
"credit_card_num": "4111111111111111",
},
},
}

classifier := NewMockClassifier(t)
// Need to explicitly convert it to a map because Mockery isn't smart enough
// to infer the type.
classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return(
Result{
"age": {
"AGE": {Name: "AGE"},
},
"social_sec_num": {
"SSN": {Name: "SSN"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
nil,
)
classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return(
Result{
"age": {
"AGE": {Name: "AGE"},
"CVV": {Name: "CVV"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
nil,
)

expected := []ClassifiedTable{
{
Repo: meta.Repo,
Database: meta.Database,
Schema: meta.Schema,
Table: meta.Table,
Classifications: Result{
"age": {
"AGE": {Name: "AGE"},
"CVV": {Name: "CVV"},
},
"social_sec_num": {
"SSN": {Name: "SSN"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
},
}
actual, err := ClassifySamples(ctx, []sql.Sample{sample}, classifier)
require.NoError(t, err)
require.Len(t, actual, len(expected))
for i := range actual {
requireClassifiedTableEqual(t, expected[i], actual[i])
}
}

func TestClassifySamples_MultipleTables(t *testing.T) {
ctx := context.Background()
meta1 := sql.SampleMetadata{
Repo: "repo1",
Database: "db1",
Schema: "schema1",
Table: "table1",
}
meta2 := sql.SampleMetadata{
Repo: "repo2",
Database: "db2",
Schema: "schema2",
Table: "table2",
}

samples := []sql.Sample{
{
Metadata: meta1,
Results: []sql.SampleResult{
{
"age": "52",
"social_sec_num": "512-23-4258",
"credit_card_num": "4111111111111111",
},
{
"age": "101",
"social_sec_num": "foobarbaz",
"credit_card_num": "4111111111111111",
},
},
},
{
Metadata: meta2,
Results: []sql.SampleResult{
{
"fullname": "John Doe",
"dob": "2000-01-01",
"random": "foobarbaz",
},
},
},
}

classifier := NewMockClassifier(t)
// Need to explicitly convert it to a map because Mockery isn't smart enough
// to infer the type.
classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return(
Result{
"age": {
"AGE": {Name: "AGE"},
},
"social_sec_num": {
"SSN": {Name: "SSN"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
nil,
)
classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return(
Result{
"age": {
"AGE": {Name: "AGE"},
"CVV": {Name: "CVV"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
nil,
)
classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return(
Result{
"fullname": {
"FULL_NAME": {Name: "FULL_NAME"},
},
"dob": {
"DOB": {Name: "DOB"},
},
},
nil,
)

expected := []ClassifiedTable{
{
Repo: meta1.Repo,
Database: meta1.Database,
Schema: meta1.Schema,
Table: meta1.Table,
Classifications: Result{
"age": {
"AGE": {Name: "AGE"},
"CVV": {Name: "CVV"},
},
"social_sec_num": {
"SSN": {Name: "SSN"},
},
"credit_card_num": {
"CCN": {Name: "CCN"},
},
},
},
{
Repo: meta2.Repo,
Database: meta2.Database,
Schema: meta2.Schema,
Table: meta2.Table,
Classifications: Result{
"fullname": {
"FULL_NAME": {Name: "FULL_NAME"},
},
"dob": {
"DOB": {Name: "DOB"},
},
},
},
}
actual, err := ClassifySamples(ctx, samples, classifier)
require.NoError(t, err)
require.Len(t, actual, len(expected))
for i := range actual {
requireClassifiedTableEqual(t, expected[i], actual[i])
}
}

func requireClassifiedTableEqual(t *testing.T, expected, actual ClassifiedTable) {
require.Equal(t, expected.Repo, actual.Repo)
require.Equal(t, expected.Database, actual.Database)
require.Equal(t, expected.Schema, actual.Schema)
require.Equal(t, expected.Table, actual.Table)
requireResultEqual(t, expected.Classifications, actual.Classifications)
}
5 changes: 0 additions & 5 deletions classification/gen.go

This file was deleted.

2 changes: 1 addition & 1 deletion classification/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func GetEmbeddedLabels() (LabelSet, error) {
return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err)
}
for name, lbl := range labels.Labels {
fname := "rego/" + strings.ReplaceAll(strings.ToLower(name), " ", "_") + ".rego"
fname := "labels/" + strings.ReplaceAll(strings.ToLower(name), " ", "_") + ".rego"
b, err := regoFs.ReadFile(fname)
if err != nil {
return nil, fmt.Errorf("error reading rego file %s: %w", fname, err)
Expand Down
2 changes: 1 addition & 1 deletion classification/label_classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ func newTestLabelClassifier(t *testing.T, lblNames ...string) *LabelClassifier {
}

func newTestLabel(t *testing.T, lblName string) Label {
fname := "rego/" + strings.ReplaceAll(strings.ToLower(lblName), " ", "_") + ".rego"
fname := "labels/" + strings.ReplaceAll(strings.ToLower(lblName), " ", "_") + ".rego"
fin, err := regoFs.ReadFile(fname)
require.NoError(t, err)
classifierCode := string(fin)
Expand Down
Loading

0 comments on commit e066b04

Please sign in to comment.