Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

issue-144 Using aws-lambda-go-api-proxy to enable echo Framework and Lambda with ALB #146

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 156 additions & 0 deletions core/requestalb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Package core provides utility methods that help convert proxy events
// into an http.Request and http.ResponseWriter
package core

import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"strings"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambdacontext"
)

const (
// ALBTgContextHeader is the custom header key used to store the
// ALB Target Group Request context. To access the Context properties use the
// RequestAccessorALB method of the RequestAccessorALB object.
ALBTgContextHeader = "X-Golambdaproxy-Albtargetgroup-Context"
)

// RequestAccessorALB objects give access to custom ALB Target Group properties
// in the request.
type RequestAccessorALB struct{}

// GetALBTargetGroupRequestContext extracts the ALB Target Group Responce context object from a
// request's custom header.
// Returns a populated events.ALBTargetGroupRequestContext object from
// the request.
func (r *RequestAccessorALB) GetALBTargetGroupRequestContext(req *http.Request) (events.ALBTargetGroupRequestContext, error) {
if req.Header.Get(ALBTgContextHeader) == "" {
return events.ALBTargetGroupRequestContext{}, errors.New("No context header in request")
}
context := events.ALBTargetGroupRequestContext{}
err := json.Unmarshal([]byte(req.Header.Get(ALBTgContextHeader)), &context)
if err != nil {
log.Println("Erorr while unmarshalling context")
log.Println(err)
return events.ALBTargetGroupRequestContext{}, err
}
return context, nil
}

// ProxyEventToHTTPRequest converts an ALB Target Group proxy event into a http.Request object.
// Returns the populated http request with additional two custom headers for ALB Tg Req context.
// To access these properties use the GetALBTargetGroupRequestContext method of the RequestAccessor object.
func (r *RequestAccessorALB) ProxyEventToHTTPRequest(req events.ALBTargetGroupRequest) (*http.Request, error) {
httpRequest, err := r.EventToRequest(req)
if err != nil {
log.Println(err)
return nil, err
}
return addToHeaderALB(httpRequest, req)
}

// EventToRequestWithContext converts an ALB Target Group proxy event and context into an http.Request object.
// Returns the populated http request with lambda context, ALBTargetGroupRequestContext as part of its context.
// Access those using GetRuntimeContextFromContextALB and GetRuntimeContextFromContext functions in this package.
func (r *RequestAccessorALB) EventToRequestWithContext(ctx context.Context, req events.ALBTargetGroupRequest) (*http.Request, error) {
httpRequest, err := r.EventToRequest(req)
if err != nil {
log.Println(err)
return nil, err
}
return addToContextALB(ctx, httpRequest, req), nil
}

// EventToRequest converts an ALB Target group proxy event into an http.Request object.
// Returns the populated request maintaining headers
func (r *RequestAccessorALB) EventToRequest(req events.ALBTargetGroupRequest) (*http.Request, error) {
decodedBody := []byte(req.Body)
if req.IsBase64Encoded {
base64Body, err := base64.StdEncoding.DecodeString(req.Body)
if err != nil {
return nil, err
}
decodedBody = base64Body
}

path := req.Path

if !strings.HasPrefix(path, "/") {
path = "/" + path
}

if len(req.QueryStringParameters) > 0 {
values := url.Values{}
for key, value := range req.QueryStringParameters {
values.Add(key, value)
}
path += "?" + values.Encode()
}

httpRequest, err := http.NewRequest(
strings.ToUpper(req.HTTPMethod),
path,
bytes.NewReader(decodedBody),
)

if err != nil {
fmt.Printf("Could not convert request %s:%s to http.Request\n", req.HTTPMethod, req.Path)
log.Println(err)
return nil, err
}

for headerKey, headerValue := range req.Headers {
for _, val := range strings.Split(headerValue, ",") {
httpRequest.Header.Add(headerKey, strings.Trim(val, " "))
}
}

for headerKey, headerValue := range req.MultiValueHeaders {
for _, arrVal := range headerValue {
for _, val := range strings.Split(arrVal, ",") {
httpRequest.Header.Add(headerKey, strings.Trim(val, " "))
}
}
}

httpRequest.RequestURI = httpRequest.URL.RequestURI()

return httpRequest, nil
}

func addToHeaderALB(req *http.Request, albTgRequest events.ALBTargetGroupRequest) (*http.Request, error) {
albTgContext, err := json.Marshal(albTgRequest.RequestContext)
if err != nil {
log.Println("Could not Marshal ALB Tg context for custom header")
return req, err
}
req.Header.Add(ALBTgContextHeader, string(albTgContext))
return req, nil
}

func addToContextALB(ctx context.Context, req *http.Request, albTgRequest events.ALBTargetGroupRequest) *http.Request {
lc, _ := lambdacontext.FromContext(ctx)
rc := requestContextALB{lambdaContext: lc, gatewayProxyContext: albTgRequest.RequestContext}
ctx = context.WithValue(ctx, ctxKey{}, rc)
return req.WithContext(ctx)
}

func GetRuntimeContextFromContextALB(ctx context.Context) (*lambdacontext.LambdaContext, bool) {
v, ok := ctx.Value(ctxKey{}).(requestContextALB)
return v.lambdaContext, ok
}

type requestContextALB struct {
lambdaContext *lambdacontext.LambdaContext
gatewayProxyContext events.ALBTargetGroupRequestContext
}
239 changes: 239 additions & 0 deletions core/requestalb_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
package core_test

import (
"context"
"encoding/base64"
"encoding/json"
"io/ioutil"
"math/rand"
"strings"

"github.com/aws/aws-lambda-go/events"
"github.com/aws/aws-lambda-go/lambdacontext"
"github.com/awslabs/aws-lambda-go-api-proxy/core"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("RequestAccessorALB tests", func() {
Context("event conversion", func() {
accessor := core.RequestAccessorALB{}
basicRequest := getProxyRequestALB("/hello", "GET")
It("Correctly converts a basic event", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basicRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("/hello").To(Equal(httpReq.RequestURI))
Expect("GET").To(Equal(httpReq.Method))
})

basicRequest = getProxyRequestALB("/hello", "get")
It("Converts method to uppercase", func() {
// calling old method to verify reverse compatibility
httpReq, err := accessor.ProxyEventToHTTPRequest(basicRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("/hello").To(Equal(httpReq.RequestURI))
Expect("GET").To(Equal(httpReq.Method))
})

binaryBody := make([]byte, 256)
_, err := rand.Read(binaryBody)
if err != nil {
Fail("Could not generate random binary body")
}

encodedBody := base64.StdEncoding.EncodeToString(binaryBody)

binaryRequest := getProxyRequestALB("/hello", "POST")
binaryRequest.Body = encodedBody
binaryRequest.IsBase64Encoded = true

It("Decodes a base64 encoded body", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), binaryRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("/hello").To(Equal(httpReq.RequestURI))
Expect("POST").To(Equal(httpReq.Method))

bodyBytes, err := ioutil.ReadAll(httpReq.Body)

Expect(err).To(BeNil())
Expect(len(binaryBody)).To(Equal(len(bodyBytes)))
Expect(binaryBody).To(Equal(bodyBytes))
})

mqsRequest := getProxyRequestALB("/hello", "GET")
mqsRequest.QueryStringParameters = map[string]string{
"hello": "1",
"world": "2",
}
It("Populates multiple value query string correctly", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), mqsRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect(httpReq.RequestURI).To(ContainSubstring("hello=1"))
Expect(httpReq.RequestURI).To(ContainSubstring("world=2"))
Expect("GET").To(Equal(httpReq.Method))

query := httpReq.URL.Query()
Expect(2).To(Equal(len(query)))
Expect(query["hello"]).ToNot(BeNil())
Expect(query["world"]).ToNot(BeNil())
Expect(1).To(Equal(len(query["hello"])))
Expect(1).To(Equal(len(query["world"])))
Expect("1").To(Equal(query["hello"][0]))
Expect("2").To(Equal(query["world"][0]))
})

// Support `QueryStringParameters` for backward compatibility.
// https://github.com/awslabs/aws-lambda-go-api-proxy/issues/37
qsRequest := getProxyRequestALB("/hello", "GET")
qsRequest.QueryStringParameters = map[string]string{
"hello": "1",
"world": "2",
}
It("Populates query string correctly", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), qsRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect(httpReq.RequestURI).To(ContainSubstring("hello=1"))
Expect(httpReq.RequestURI).To(ContainSubstring("world=2"))
Expect("GET").To(Equal(httpReq.Method))

query := httpReq.URL.Query()
Expect(2).To(Equal(len(query)))
Expect(query["hello"]).ToNot(BeNil())
Expect(query["world"]).ToNot(BeNil())
Expect(1).To(Equal(len(query["hello"])))
Expect(1).To(Equal(len(query["world"])))
Expect("1").To(Equal(query["hello"][0]))
Expect("2").To(Equal(query["world"][0]))
})

mvhRequest := getProxyRequestALB("/hello", "GET")
mvhRequest.Headers = map[string]string{
"hello": "1",
"world": "2,3",
}
mvhRequest.MultiValueHeaders = map[string][]string{
"hello world": []string{"4", "5", "6"},
}

It("Populates multiple value headers correctly", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), mvhRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("GET").To(Equal(httpReq.Method))

headers := httpReq.Header
Expect(3).To(Equal(len(headers)))

for k, value := range headers {
if mvhRequest.Headers[strings.ToLower(k)] != "" {
Expect(strings.Join(value, ",")).To(Equal(mvhRequest.Headers[strings.ToLower(k)]))
} else {
Expect(strings.Join(value, ",")).To(Equal(strings.Join(mvhRequest.MultiValueHeaders[strings.ToLower(k)], ",")))
}
}
})

svhRequest := getProxyRequestALB("/hello", "GET")
svhRequest.Headers = map[string]string{
"hello": "1",
"world": "2",
}
It("Populates single value headers correctly", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), svhRequest)
Expect(err).To(BeNil())
Expect("/hello").To(Equal(httpReq.URL.Path))
Expect("GET").To(Equal(httpReq.Method))

headers := httpReq.Header
Expect(2).To(Equal(len(headers)))

for k, value := range headers {
Expect(value[0]).To(Equal(svhRequest.Headers[strings.ToLower(k)]))
}
})

basePathRequest := getProxyRequestALB("/orders", "GET")

It("Stips the base path correct", func() {
httpReq, err := accessor.EventToRequestWithContext(context.Background(), basePathRequest)

Expect(err).To(BeNil())
Expect("/orders").To(Equal(httpReq.URL.Path))
Expect("/orders").To(Equal(httpReq.RequestURI))
})

contextRequest := getProxyRequestALB("/orders", "GET")
contextRequest.RequestContext = getRequestContextALB()

It("Populates context header correctly", func() {
// calling old method to verify reverse compatibility
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
Expect(err).To(BeNil())
Expect(1).To(Equal(len(httpReq.Header)))
Expect(httpReq.Header.Get(core.ALBTgContextHeader)).ToNot(BeNil())
})
})

Context("Retrieves ALB Target Group context", func() {
It("Returns a correctly unmarshalled object", func() {
contextRequest := getProxyRequestALB("/orders", "GET")
contextRequest.RequestContext = getRequestContextALB()

accessor := core.RequestAccessorALB{}
httpReq, err := accessor.ProxyEventToHTTPRequest(contextRequest)
Expect(err).To(BeNil())
ctx := httpReq.Header[core.ALBTgContextHeader][0]
var parsedCtx events.ALBTargetGroupRequestContext
json.Unmarshal([]byte(ctx), &parsedCtx)
Expect("foo").To(Equal(parsedCtx.ELB.TargetGroupArn))

headerContext, err := accessor.GetALBTargetGroupRequestContext(httpReq)
Expect(err).To(BeNil())
Expect("foo").To(Equal(headerContext.ELB.TargetGroupArn))

httpReq, err = accessor.EventToRequestWithContext(context.Background(), contextRequest)
Expect(err).To(BeNil())
Expect("/orders").To(Equal(httpReq.RequestURI))
runtimeContext, ok := core.GetRuntimeContextFromContextALB(httpReq.Context())
Expect(ok).To(BeTrue())
Expect(runtimeContext).To(BeNil())

lambdaContext := lambdacontext.NewContext(context.Background(), &lambdacontext.LambdaContext{AwsRequestID: "abc123"})
httpReq, err = accessor.EventToRequestWithContext(lambdaContext, contextRequest)
Expect(err).To(BeNil())
Expect("/orders").To(Equal(httpReq.RequestURI))

headerContext, err = accessor.GetALBTargetGroupRequestContext(httpReq)
// should fail as new context method doesn't populate headers
Expect(err).ToNot(BeNil())
Expect("").To(Equal(headerContext.ELB.TargetGroupArn))
runtimeContext, ok = core.GetRuntimeContextFromContextALB(httpReq.Context())
Expect(ok).To(BeTrue())
Expect(runtimeContext).ToNot(BeNil())
Expect("abc123").To(Equal(runtimeContext.AwsRequestID))
})
})
})

func getProxyRequestALB(path string, method string) events.ALBTargetGroupRequest {
return events.ALBTargetGroupRequest{
RequestContext: events.ALBTargetGroupRequestContext{},
Path: path,
HTTPMethod: method,
Headers: map[string]string{},
}
}

func getRequestContextALB() events.ALBTargetGroupRequestContext {
return events.ALBTargetGroupRequestContext{
ELB: events.ELBContext{
TargetGroupArn: "foo",
},
}
}
Loading