diff --git a/statefun-sdk-go/v3/pkg/statefun/handler.go b/statefun-sdk-go/v3/pkg/statefun/handler.go index 410955f5c..77f64227f 100644 --- a/statefun-sdk-go/v3/pkg/statefun/handler.go +++ b/statefun-sdk-go/v3/pkg/statefun/handler.go @@ -24,8 +24,9 @@ import ( "net/http" "sync" - "github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol" "google.golang.org/protobuf/proto" + + "github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol" ) // StatefulFunctions is a registry for multiple StatefulFunction's. A RequestReplyHandler @@ -72,6 +73,10 @@ type handler struct { } func (h *handler) WithSpec(spec StatefulFunctionSpec) error { + if spec.FunctionType == nil { + return fmt.Errorf("function type is required") + } + log.Printf("registering Stateful Function %v\n", spec.FunctionType) if _, exists := h.module[spec.FunctionType]; exists { err := fmt.Errorf("failed to register Stateful Function %s, there is already a spec registered under that type", spec.FunctionType) diff --git a/statefun-sdk-go/v3/pkg/statefun/handler_test.go b/statefun-sdk-go/v3/pkg/statefun/handler_test.go index e315c4c5d..a6708df0a 100644 --- a/statefun-sdk-go/v3/pkg/statefun/handler_test.go +++ b/statefun-sdk-go/v3/pkg/statefun/handler_test.go @@ -4,9 +4,10 @@ import ( "context" "testing" - "github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/proto" + + "github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol" ) // helper to create a protocol Address from an Address @@ -87,3 +88,20 @@ func TestStatefunHandler_WithCaller_ContextCallerIsCorrect(t *testing.T) { err := invokeStatefulFunction(context.Background(), &target, &caller, nil, StatefulFunctionPointer(statefulFunction)) assert.Nil(t, err) } + +func TestStatefulFunctionsBuilder_FunctionTypeRequired(t *testing.T) { + caller := Address{FunctionType: TypeNameFrom("namespace/function2"), Id: "2"} + + statefulFunction := func(ctx Context, message Message) error { + assert.Equal(t, caller.String(), ctx.Caller().String()) + return nil + } + + builder := StatefulFunctionsBuilder() + err := builder.WithSpec(StatefulFunctionSpec{ + Function: StatefulFunctionPointer(statefulFunction), + }) + + assert.Error(t, err) + assert.Contains(t, err.Error(), "function type is required") +}