Skip to content

Commit

Permalink
Merge pull request #9 from sendgrid/default-client
Browse files Browse the repository at this point in the history
Allow for setting a custom HTTP client
  • Loading branch information
thinkingserious authored Jul 28, 2016
2 parents 2eefcc3 + 67fe53f commit 6741dbc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 6 deletions.
32 changes: 26 additions & 6 deletions rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// Method contains the supported HTTP verbs.
type Method string

// Supported HTTP verbs.
const (
Get Method = "GET"
Post Method = "POST"
Expand All @@ -28,6 +29,16 @@ type Request struct {
Body []byte
}

// DefaultClient is used if no custom HTTP client is defined
var DefaultClient = &Client{HTTPClient: http.DefaultClient}

// Client allows modification of client headers, redirect policy
// and other settings
// See https://golang.org/pkg/net/http
type Client struct {
HTTPClient *http.Client
}

// Response holds the response from an API call.
type Response struct {
StatusCode int // e.g. 200
Expand Down Expand Up @@ -59,11 +70,7 @@ func BuildRequestObject(request Request) (*http.Request, error) {

// MakeRequest makes the API call.
func MakeRequest(req *http.Request) (*http.Response, error) {
var Client = &http.Client{
Transport: http.DefaultTransport,
}
res, err := Client.Do(req)
return res, err
return DefaultClient.HTTPClient.Do(req)
}

// BuildResponse builds the response struct.
Expand All @@ -83,6 +90,19 @@ func BuildResponse(res *http.Response) (*Response, error) {

// API is the main interface to the API.
func API(request Request) (*Response, error) {
return DefaultClient.API(request)
}

// The following functions enable the ability to define a
// custom HTTP Client

// MakeRequest makes the API call.
func (c *Client) MakeRequest(req *http.Request) (*http.Response, error) {
return c.HTTPClient.Do(req)
}

// API is the main interface to the API.
func (c *Client) API(request Request) (*Response, error) {
// Add any query parameters to the URL.
if len(request.QueryParams) != 0 {
request.BaseURL = AddQueryParameters(request.BaseURL, request.QueryParams)
Expand All @@ -95,7 +115,7 @@ func API(request Request) (*Response, error) {
}

// Build the HTTP client and make the request.
res, err := MakeRequest(req)
res, err := c.MakeRequest(req)
if err != nil {
return nil, err
}
Expand Down
28 changes: 28 additions & 0 deletions rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)

func TestBuildURL(t *testing.T) {
Expand Down Expand Up @@ -47,6 +49,7 @@ func TestBuildResponse(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
baseURL := fakeServer.URL
method := Get
request := Request{
Expand Down Expand Up @@ -74,6 +77,7 @@ func TestRest(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
host := fakeServer.URL
endpoint := "/test_endpoint"
baseURL := host + endpoint
Expand Down Expand Up @@ -105,3 +109,27 @@ func TestRest(t *testing.T) {
t.Errorf("Rest failed to make a valid API request. Returned error: %v", e)
}
}

func TestCustomHTTPClient(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 20)
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
host := fakeServer.URL
endpoint := "/test_endpoint"
baseURL := host + endpoint
method := Get
request := Request{
Method: method,
BaseURL: baseURL,
}
customClient := &Client{&http.Client{Timeout: time.Millisecond * 10}}
_, err := customClient.API(request)
if err == nil {
t.Error("A timeout did not trigger as expected")
}
if strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") == false {
t.Error("We did not receive the Timeout error")
}
}

0 comments on commit 6741dbc

Please sign in to comment.