Skip to content

Commit

Permalink
More refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
ccampo133 committed Apr 3, 2024
1 parent b5409e7 commit 476d814
Show file tree
Hide file tree
Showing 48 changed files with 1,849 additions and 2,200 deletions.
35 changes: 8 additions & 27 deletions classification/classification.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ package classification

import (
"context"
"maps"
"encoding/json"

"golang.org/x/exp/maps"
)

// Classifier is an interface that represents a data classifier. A classifier
Expand All @@ -22,35 +24,14 @@ type Classifier interface {
Classify(ctx context.Context, input map[string]any) (Result, error)
}

// ClassifiedTable represents a database table that has been classified. The
// classifications are stored in the Classifications field, which is a map of
// attribute names (i.e. columns) to the set of labels that attributes were
// classified as.
type ClassifiedTable struct {
Repo string `json:"repo"`
Database string `json:"database"`
Schema string `json:"schema"`
Table string `json:"table"`
Classifications Result `json:"classifications"`
}

// Result represents the classifications for a set of data attributes. The key
// is the attribute (i.e. column) name and the value is the set of labels
// that attribute was classified as.
type Result map[string]LabelSet

// Merge combines the given other Result into this Result (the receiver). If
// an attribute from other is already present in this Result, the existing
// labels for that attribute are merged with the labels from other, otherwise
// labels from other for the attribute are simply added to this Result.
func (c Result) Merge(other Result) {
if c == nil {
return
}
for attr, labelSet := range other {
if _, ok := c[attr]; !ok {
c[attr] = make(LabelSet)
}
maps.Copy(c[attr], labelSet)
}
// LabelSet is a set of unique labels.
type LabelSet map[string]struct{}

func (l LabelSet) MarshalJSON() ([]byte, error) {
return json.Marshal(maps.Keys(l))
}
59 changes: 0 additions & 59 deletions classification/classification_test.go

This file was deleted.

50 changes: 20 additions & 30 deletions classification/label.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/rego"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"gopkg.in/yaml.v3"
)

Expand All @@ -20,18 +20,23 @@ var (

// Label represents a data classification label.
type Label struct {
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Tags []string `yaml:"tags" json:"tags"`
ClassificationRule *rego.Rego `yaml:"-" json:"-"`
// Name is the name of the label.
Name string `yaml:"name" json:"name"`
// Description is a brief description of the label.
Description string `yaml:"description" json:"description"`
// Tags are a list of arbitrary tags associated with the label.
Tags []string `yaml:"tags" json:"tags"`
// ClassificationRule is the compiled Rego classification rule used to
// classify data.
ClassificationRule *ast.Module `yaml:"-" json:"-"`
}

// NewLabel creates a new Label with the given name, description, classification
// rule, and tags. The classification rule is expected to be the raw Rego code
// that will be used to classify data. If the classification rule is invalid, an
// error is returned.
func NewLabel(name, description, classificationRule string, tags []string) (Label, error) {
rule, err := ruleRego(classificationRule)
func NewLabel(name, description, classificationRule string, tags ...string) (Label, error) {
rule, err := parseRego(classificationRule)
if err != nil {
return Label{}, fmt.Errorf("error preparing classification rule for label %s: %w", name, err)
}
Expand All @@ -43,25 +48,12 @@ func NewLabel(name, description, classificationRule string, tags []string) (Labe
}, nil
}

// LabelSet is a set of unique labels. The key is the label name and the value
// is the label itself.
type LabelSet map[string]Label

// ToSlice returns the labels in the set as a slice.
func (s LabelSet) ToSlice() []Label {
var labels []Label
for _, label := range s {
labels = append(labels, label)
}
return labels
}

// GetEmbeddedLabels returns the predefined embedded labels and their
// classification rules. The labels are read from the embedded labels.yaml file
// and the classification rules are read from the embedded Rego files.
func GetEmbeddedLabels() (LabelSet, error) {
func GetEmbeddedLabels() ([]Label, error) {
labels := struct {
Labels LabelSet `yaml:"labels"`
Labels map[string]Label `yaml:"labels"`
}{}
if err := yaml.Unmarshal([]byte(labelsYaml), &labels); err != nil {
return nil, fmt.Errorf("error unmarshalling labels.yaml: %w", err)
Expand All @@ -72,24 +64,22 @@ func GetEmbeddedLabels() (LabelSet, error) {
if err != nil {
return nil, fmt.Errorf("error reading rego file %s: %w", fname, err)
}
rule, err := ruleRego(string(b))
rule, err := parseRego(string(b))
if err != nil {
return nil, fmt.Errorf("error preparing classification rule for label %s: %w", lbl.Name, err)
}
lbl.Name = name
lbl.ClassificationRule = rule
labels.Labels[name] = lbl
}
return labels.Labels, nil
return maps.Values(labels.Labels), nil
}

func ruleRego(code string) (*rego.Rego, error) {
func parseRego(code string) (*ast.Module, error) {
log.Tracef("classifier module code: '%s'", code)
moduleName := "classifier"
compiledRego, err := ast.CompileModules(map[string]string{moduleName: code})
module, err := ast.ParseModule("classifier", code)
if err != nil {
return nil, fmt.Errorf("error compiling rego code: %w", err)
return nil, fmt.Errorf("error parsing rego code: %w", err)
}
regoQuery := compiledRego.Modules[moduleName].Package.Path.String() + ".output"
return rego.New(rego.Query(regoQuery), rego.Compiler(compiledRego)), nil
return module, nil
}
44 changes: 24 additions & 20 deletions classification/label_classifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,63 @@ import (
// LabelClassifier is a Classifier implementation that uses a set of labels and
// their classification rules to classify data.
type LabelClassifier struct {
labels LabelSet
queries map[string]*rego.Rego
}

// LabelClassifier implements Classifier
var _ Classifier = (*LabelClassifier)(nil)

// NewLabelClassifier creates a new LabelClassifier with the provided labels and
// classification rules.
// NewLabelClassifier creates a new LabelClassifier with the provided labels.
//
func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) {
if len(labels) == 0 {
return nil, fmt.Errorf("labels cannot be empty")
}
l := make(LabelSet, len(labels))
queries := make(map[string]*rego.Rego, len(labels))
for _, lbl := range labels {
l[lbl.Name] = lbl
queries[lbl.Name] = rego.New(
// We only care about the 'output' variable.
rego.Query(lbl.ClassificationRule.Package.Path.String() + ".output"),
rego.ParsedModule(lbl.ClassificationRule),
)
}
return &LabelClassifier{labels: l}, nil
return &LabelClassifier{queries: queries}, nil
}

// Classify performs the classification of the provided attributes using the
// classifier's labels and classification rules. It returns a Result, which is
// a map of attribute names to the set of labels that the attribute was
// classified as.
// classifier's labels and their corresponding classification rules. It returns
// a Result, which is a map of attribute names to the set of labels that the
// attribute was classified as.
func (c *LabelClassifier) Classify(ctx context.Context, input map[string]any) (Result, error) {
result := make(Result, len(c.labels))
for _, lbl := range c.labels {
output, err := evalQuery(ctx, lbl.ClassificationRule, input)
result := make(Result, len(c.queries))
for lbl, query := range c.queries {
output, err := evalQuery(ctx, query, input)
if err != nil {
return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl.Name, err)
return nil, fmt.Errorf("error evaluating query for label %s: %w", lbl, err)
}
log.Debugf("classification results for label %s: %v", lbl.Name, output)
log.Debugf("classification results for label %s: %v", lbl, output)
for attrName, classified := range output {
if classified {
attrLabels, ok := result[attrName]
if !ok {
attrLabels = make(LabelSet)
result[attrName] = attrLabels
}
attrLabels[lbl.Name] = lbl
// Add the label to the set of labels for the attribute.
attrLabels[lbl] = struct{}{}
}
}
}
return result, nil
}

// evalQuery evaluates the provided prepared Rego query with the given
// attributes as input, and returns the classification results. The output is a
// evalQuery evaluates the provided Rego query with the given attributes as input, and returns the classification results. The output is a
// map of attribute names to boolean values, where the boolean indicates whether
// the attribute is classified as belonging to the label.
func evalQuery(ctx context.Context, rule *rego.Rego, input map[string]any) (map[string]bool, error) {
q, err := rule.PrepareForEval(ctx)
func evalQuery(ctx context.Context, query *rego.Rego, input map[string]any) (map[string]bool, error) {
q, err := query.PrepareForEval(ctx)
if err != nil {
return nil, fmt.Errorf("error preparing rule for evaluation: %w", err)
return nil, fmt.Errorf("error preparing query for evaluation: %w", err)
}
// Evaluate the prepared Rego query. This performs the actual classification
// logic.
Expand Down
27 changes: 14 additions & 13 deletions classification/label_classifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ import (
)

func TestNewLabelClassifier_Success(t *testing.T) {
classifier, err := NewLabelClassifier(Label{Name: "foo"})
lbl, err := NewLabel("foo", "test label", "package foo\noutput = true")
require.NoError(t, err)
classifier, err := NewLabelClassifier(lbl)
require.NoError(t, err)
require.NotNil(t, classifier)
}
Expand Down Expand Up @@ -40,7 +42,7 @@ func TestLabelClassifier_Classify(t *testing.T) {
input: map[string]any{"age": "42"},
want: Result{
"age": {
"AGE": Label{Name: "AGE"},
"AGE": {},
},
},
},
Expand All @@ -53,7 +55,7 @@ func TestLabelClassifier_Classify(t *testing.T) {
},
want: Result{
"age": {
"AGE": Label{Name: "AGE"},
"AGE": {},
},
},
},
Expand All @@ -63,7 +65,7 @@ func TestLabelClassifier_Classify(t *testing.T) {
input: map[string]any{"age": "42"},
want: Result{
"age": {
"AGE": Label{Name: "AGE"},
"AGE": {},
},
},
},
Expand All @@ -76,10 +78,10 @@ func TestLabelClassifier_Classify(t *testing.T) {
},
want: Result{
"age": {
"AGE": Label{Name: "AGE"},
"AGE": {},
},
"ccn": {
"CCN": Label{Name: "CCN"},
"CCN": {},
},
},
},
Expand All @@ -92,11 +94,11 @@ func TestLabelClassifier_Classify(t *testing.T) {
},
want: Result{
"age": {
"AGE": Label{Name: "AGE"},
"CVV": Label{Name: "CVV"},
"AGE": {},
"CVV": {},
},
"cvv": {
"CVV": Label{Name: "CVV"},
"CVV": {},
},
},
},
Expand Down Expand Up @@ -127,10 +129,9 @@ func requireResultEqual(t *testing.T, want, got Result) {

func requireLabelSetEqual(t *testing.T, want, got LabelSet) {
require.Len(t, got, len(want))
for k, v := range want {
gotLbl, ok := got[k]
for k := range want {
_, ok := got[k]
require.Truef(t, ok, "missing label %s", k)
require.Equal(t, v.Name, gotLbl.Name)
}
}

Expand All @@ -149,7 +150,7 @@ func newTestLabel(t *testing.T, lblName string) Label {
fin, err := regoFs.ReadFile(fname)
require.NoError(t, err)
classifierCode := string(fin)
lbl, err := NewLabel(lblName, "test label", classifierCode, nil)
lbl, err := NewLabel(lblName, "test label", classifierCode)
require.NoError(t, err)
return lbl
}
Loading

0 comments on commit 476d814

Please sign in to comment.