Skip to content

Commit

Permalink
Add namespace_map param, fix incorrect case statements, fix nesting r…
Browse files Browse the repository at this point in the history
…ecords unnecessarily
  • Loading branch information
dorner committed Nov 23, 2023
1 parent a4c9460 commit d54de3b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 9 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,22 @@ protoc --avro_out=. --avro_opt=emit_only=Foo,Bar:. *.proto

This will generate only `Foo.avsc` and `Bar.avsc` files.

You can also change the namespaces being mapped:

```bash
protoc --avro_out=. --avro_opt=namespace_map=foo:bar,baz:spam *.proto
```

...will change the output namespace for `foo` to `bar` and `baz` to `spam`.

---

To Do List:

* Add tests
* Map is currently outputting as Array<Map> due to how Protobuf handles [maps](https://protobuf.com/docs/descriptors#map-fields) (as repeated entries). Need to fix.
* Need to decide on how to truly differentiate between optional and required fields (technically all fields are optional on Protobuf, but maybe we should use the actual `optional` keyword and only have those be optional in Avro?)

---

This project supported by [Flipp](https://corp.flipp.com/).
8 changes: 8 additions & 0 deletions avro/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,23 @@ func BasicFieldTypeFromProto(proto *descriptorpb.FieldDescriptorProto) Type {
case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
return Bare("double")
case descriptorpb.FieldDescriptorProto_TYPE_INT64:
return Bare("long")
case descriptorpb.FieldDescriptorProto_TYPE_UINT64:
return Bare("long")
case descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
return Bare("long")
case descriptorpb.FieldDescriptorProto_TYPE_SINT64:
return Bare("long")
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
return Bare("long")
case descriptorpb.FieldDescriptorProto_TYPE_INT32:
return Bare("int")
case descriptorpb.FieldDescriptorProto_TYPE_UINT32:
return Bare("int")
case descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
return Bare("int")
case descriptorpb.FieldDescriptorProto_TYPE_SINT32:
return Bare("int")
case descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
return Bare("int")
case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
Expand Down
3 changes: 2 additions & 1 deletion avro/record.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ func (t Record) GetNamespace() string {
}

func (t Record) ToJSON(types *TypeRepo) (any, error) {
types.SeenType(t)
jsonMap := orderedmap.New()
jsonMap.Set("type", "record")
jsonMap.Set("name", t.Name)
jsonMap.Set("namespace", t.Namespace)
jsonMap.Set("namespace", types.MappedNamespace(t.Namespace))
fields := make([]any, len(t.Fields))
for i, field := range t.Fields {
fieldJson, err := field.ToJSON(types)
Expand Down
3 changes: 3 additions & 0 deletions avro/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ func DefaultValue(t Type) any {
case Bare("boolean"):
return false
case Bare("int"):
return 0
case Bare("long"):
return 0
case Bare("float"):
return 0.0
case Bare("double"):
return 0.0
}

switch typedT := t.(type) {
case Record:
return map[string]any{}
case Map:
return map[string]any{}
case Array:
Expand Down
25 changes: 20 additions & 5 deletions avro/typeRepo.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package avro

import (
"fmt"
"strings"
)

type TypeRepo struct {
Types map[string]NamedType
seenTypes map[string]bool // go "set"
NamespaceMap map[string]string
}

func NewTypeRepo() *TypeRepo {
return &TypeRepo{Types: make(map[string]NamedType)}
func NewTypeRepo(namespaceMap map[string]string) *TypeRepo {
return &TypeRepo{
Types: make(map[string]NamedType),
NamespaceMap: namespaceMap,
}
}

func (r *TypeRepo) AddType(t NamedType) {
Expand All @@ -27,16 +32,19 @@ func (r *TypeRepo) GetTypeByBareName(name string) Type {
return nil
}

func (r *TypeRepo) SeenType(t NamedType) {
r.seenTypes[FullName(t)] = true
}

func (r *TypeRepo) GetType(name string) (Type, error) {
if r.seenTypes[name] {
return Bare(name[1:]), nil
return Bare(r.MappedNamespace(name[1:])), nil
}
t, ok := r.Types[name]
if !ok {
// r.LogTypes()
return nil, fmt.Errorf("type %s not found", name)
}
r.seenTypes[FullName(t)] = true
r.SeenType(t)
return t, nil
}

Expand All @@ -52,3 +60,10 @@ func (r *TypeRepo) LogTypes() {
LogObj(keys)
}

func (r *TypeRepo) MappedNamespace(namespace string) string {
out := namespace
for k, v := range r.NamespaceMap {
out = strings.Replace(out, k, v, -1)
}
return out
}
28 changes: 25 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ import (
"strings"
)

type Params struct {
EmitOnly []string
NamespaceMap map[string]string
}

var params Params
var typeRepo *avro.TypeRepo

func readRequest() (*pluginpb.CodeGeneratorRequest, error) {
Expand Down Expand Up @@ -104,19 +110,34 @@ func processAll(fileProto *descriptorpb.FileDescriptorProto) {
}
}

func writeResponse(req *pluginpb.CodeGeneratorRequest) {
func parseParams(req *pluginpb.CodeGeneratorRequest) Params {
var recordsToEmit []string
namespaceMap := map[string]string{}
param := req.GetParameter()
if len(param) > 0 {
paramTokens := strings.Split(param, " ")
for _, token := range paramTokens {
paramStrings := strings.Split(token, "=")
if len(paramStrings) == 2 && paramStrings[0] == "emit_only" {
recordsToEmit = strings.Split(paramStrings[1], ",")
} else if len(paramStrings) == 2 && paramStrings[0] == "namespace_map" {
namespaces := strings.Split(paramStrings[1], ",")
for _, namespaceMapToken := range namespaces {
namespaceTokens := strings.Split(namespaceMapToken, ":")
namespaceMap[namespaceTokens[0]] = namespaceTokens[1]
}
}
}
}
response := generateResponse(recordsToEmit)
return Params{
EmitOnly: recordsToEmit,
NamespaceMap: namespaceMap,
}

}

func writeResponse(req *pluginpb.CodeGeneratorRequest) {
response := generateResponse(params.EmitOnly)
out, err := proto.Marshal(response)
if err != nil {
log.Fatalf("%s", fmt.Errorf("error marshalling response: %w", err))
Expand All @@ -128,11 +149,12 @@ func writeResponse(req *pluginpb.CodeGeneratorRequest) {
}

func main() {
typeRepo = avro.NewTypeRepo()
req, err := readRequest()
if err != nil {
log.Fatalf("%s", fmt.Errorf("error reading request: %w", err))
}
params = parseParams(req)
typeRepo = avro.NewTypeRepo(params.NamespaceMap)

for _, file := range req.ProtoFile {
if !slices.Contains(req.FileToGenerate, *file.Name) {
Expand Down

0 comments on commit d54de3b

Please sign in to comment.