Skip to content
This repository has been archived by the owner on Jul 23, 2024. It is now read-only.

Merge HTTP and GRPC traffic into one port #171

Closed
wants to merge 12 commits into from
1 change: 1 addition & 0 deletions authenticator/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ require (
github.com/google/uuid v1.1.1
github.com/hashicorp/vault/api v1.0.4
github.com/sirupsen/logrus v1.6.0
github.com/soheilhy/cmux v0.1.4
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.7.0
go.opencensus.io v0.22.4
Expand Down
1 change: 1 addition & 0 deletions authenticator/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykE
github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc=
github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s=
github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA=
github.com/soheilhy/cmux v0.1.4 h1:0HKaf1o97UwFjHH9o5XsHUOF+tqmdA7KEzXLpiyaw0E=
github.com/soheilhy/cmux v0.1.4/go.mod h1:IM3LyeVVIOuxMH7sFAkER9+bJ4dT7Ms6E4xg4kGIyLM=
github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA=
github.com/spf13/afero v1.1.2 h1:m8/z1t7/fwjysjQRYbP0RD+bUIF/8tJwPdEZsI83ACI=
Expand Down
16 changes: 6 additions & 10 deletions authenticator/server/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,33 @@ package api

import (
"crypto/tls"
"fmt"
"io/ioutil"
"net"
"net/http"
"strconv"
"time"

"contrib.go.opencensus.io/exporter/prometheus"
"github.com/cyralinc/approzium/authenticator/server/config"
log "github.com/sirupsen/logrus"
)

func Start(logger *log.Logger, config config.Config) error {
func Start(logger *log.Logger, listener net.Listener, config config.Config) error {

if err := loadEndpoints(logger, config); err != nil {
return err
}

serviceAddress := config.Host + ":" + strconv.Itoa(config.HTTPPort)
server := &http.Server{
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}

if config.DisableTLS {
server.Addr = serviceAddress
go func() {
logger.Fatal(server.ListenAndServe())
logger.Fatal(server.Serve(listener))
}()
logger.Infof("api starting on http://%s", serviceAddress)
logger.Infof("api server starting without TLS")
} else {
server.Addr = fmt.Sprintf(":%d", config.HTTPPort)
crt, err := ioutil.ReadFile(config.PathToTLSCert)
if err != nil {
return err
Expand All @@ -51,9 +47,9 @@ func Start(logger *log.Logger, config config.Config) error {
}

go func() {
logger.Fatal(server.ListenAndServeTLS("", ""))
logger.Fatal(server.ServeTLS(listener, "", ""))
}()
logger.Infof("api starting on https://%s", serviceAddress)
logger.Infof("api server starting with TLS")
}
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions authenticator/server/api/health_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (

func TestHealthChecker(t *testing.T) {
checker := newHealthChecker(testtools.TestLogger(), config.Config{
Host: "127.0.0.1",
GRPCPort: 6001,
Host: "127.0.0.1",
Port: 6001,
})
testWriter := &testtools.TestResponseWriter{}
checker.ServeHTTP(testWriter, nil)
Expand Down
33 changes: 21 additions & 12 deletions authenticator/server/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"io"
"net"
"strconv"
"strings"

"github.com/aws/aws-sdk-go/aws/arn"
Expand All @@ -30,6 +31,7 @@ import (
"github.com/cyralinc/approzium/authenticator/server/identity"
pb "github.com/cyralinc/approzium/authenticator/server/protos"
log "github.com/sirupsen/logrus"
"github.com/soheilhy/cmux"
"golang.org/x/crypto/pbkdf2"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -53,16 +55,29 @@ var maxIterations = uint32(15000 * 10)
// an error is returned from either, terminating the application. Both servers
// respond to CTRL+C shutdowns.
func Start(logger *log.Logger, config config.Config) error {
if err := api.Start(logger, config); err != nil {
serviceAddress := config.Host + ":" + strconv.Itoa(config.Port)
l, err := net.Listen("tcp", serviceAddress)
if err != nil {
return err
}

m := cmux.New(l)
grpcListener := m.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
// All the rest is assumed to be HTTP
httpListener := m.Match(cmux.Any())
if err := api.Start(logger, httpListener, config); err != nil {
return err
}
svr, err := buildServer(logger, config)
if err != nil {
return err
}
if err := startGrpc(logger, config, svr); err != nil {
if err := startGrpc(logger, grpcListener, config, svr); err != nil {
return err
}
go func() {
m.Serve()
}()
return nil
}

Expand All @@ -84,28 +99,22 @@ func buildServer(logger *log.Logger, config config.Config) (pb.AuthenticatorServ
return svr, nil
}

func startGrpc(logger *log.Logger, config config.Config, authenticatorServer pb.AuthenticatorServer) error {
serviceAddress := fmt.Sprintf("%s:%d", config.Host, config.GRPCPort)
lis, err := net.Listen("tcp", serviceAddress)
if err != nil {
return err
}

func startGrpc(logger *log.Logger, listener net.Listener, config config.Config, authenticatorServer pb.AuthenticatorServer) error {
var grpcServer *grpc.Server
if config.DisableTLS {
grpcServer = grpc.NewServer()
logger.Infof("grpc starting on http://%s", serviceAddress)
logger.Infof("grpc starting without TLS")
} else {
creds, err := credentials.NewServerTLSFromFile(config.PathToTLSCert, config.PathToTLSKey)
if err != nil {
return err
}
grpcServer = grpc.NewServer(grpc.Creds(creds))
logger.Infof("grpc starting on https://%s", serviceAddress)
logger.Infof("grpc starting with TLS")
}
pb.RegisterAuthenticatorServer(grpcServer, authenticatorServer)
go func() {
logger.Fatal(grpcServer.Serve(lis))
logger.Fatal(grpcServer.Serve(listener))
}()
return nil
}
Expand Down
11 changes: 9 additions & 2 deletions authenticator/server/authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package server

import (
"context"
"fmt"
"io/ioutil"
"net"
"net/http"
"os"
"reflect"
Expand Down Expand Up @@ -397,10 +399,15 @@ func TestMetrics(t *testing.T) {
// Start the API, which includes an endpoint for Prometheus to mine metrics.
config := config.Config{
Host: "127.0.0.1",
HTTPPort: 6000,
Port: 6000,
DisableTLS: true,
}
_ = api.Start(testtools.TestLogger(), config)
serviceAddress := fmt.Sprintf("%s:%d", config.Host, config.Port)
lis, err := net.Listen("tcp", serviceAddress)
if err != nil {
t.Fatal(err)
}
_ = api.Start(testtools.TestLogger(), lis, config)

// Make some calls to increment the metrics.
svr, err := buildServer(testtools.TestLogger(), config)
Expand Down
19 changes: 6 additions & 13 deletions authenticator/server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ import (
// Please see https://approzium.org/configuration for elaboration upon
// each parameter.
type Config struct {
Host string
HTTPPort int
GRPCPort int
Host string
Port int

DisableTLS bool
PathToTLSCert string
Expand Down Expand Up @@ -51,8 +50,7 @@ func Parse() (Config, error) {
os.Unsetenv("VAULT_ADDR")
config = Config{
Host: "127.0.0.1",
HTTPPort: 6000,
GRPCPort: 6001,
Port: 6000,
DisableTLS: true,
LogLevel: "debug",
LogFormat: "text",
Expand Down Expand Up @@ -106,10 +104,7 @@ func setConfigEnvVars() error {
if err := viper.BindEnv("Host", "APPROZIUM_HOST"); err != nil {
return err
}
if err := viper.BindEnv("HTTPPort", "APPROZIUM_HTTP_PORT"); err != nil {
return err
}
if err := viper.BindEnv("GRPCPort", "APPROZIUM_GRPC_PORT"); err != nil {
if err := viper.BindEnv("Port", "APPROZIUM_PORT"); err != nil {
return err
}

Expand Down Expand Up @@ -147,8 +142,7 @@ func setConfigEnvVars() error {

func setConfigDefaults() {
viper.SetDefault("Host", "127.0.0.1")
viper.SetDefault("HTTPPort", "6000")
viper.SetDefault("GRPCPort", "6001")
viper.SetDefault("Port", "6000")

viper.SetDefault("DisableTLS", false)

Expand All @@ -162,8 +156,7 @@ func setConfigFlags() {
// avoid redefining flags because it leads to panic
if pflag.Lookup("host") == nil {
pflag.String("host", "", "")
pflag.String("httpport", "", "")
pflag.String("grpcport", "", "")
pflag.String("port", "", "")

pflag.Bool("disabletls", false, "")
pflag.String("tlscertpath", "", "")
Expand Down
12 changes: 6 additions & 6 deletions authenticator/server/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

func TestParseConfig(t *testing.T) {
os.Unsetenv("APPROZIUM_HOST")
os.Unsetenv("APPROZIUM_HTTP_PORT")
os.Unsetenv("APPROZIUM_PORT")
os.Unsetenv("APPROZIUM_LOG_LEVEL")
os.Setenv("APPROZIUM_DISABLE_TLS", "true")
config, err := Parse()
Expand All @@ -17,15 +17,15 @@ func TestParseConfig(t *testing.T) {
if config.Host != "127.0.0.1" {
t.Fatalf("expected %s, received %s", "127.0.0.1", config.Host)
}
if config.HTTPPort != 6000 {
t.Fatalf("expected %d, received %d", 6000, config.HTTPPort)
if config.Port != 6000 {
t.Fatalf("expected %d, received %d", 6000, config.Port)
}
if config.LogLevel != "info" {
t.Fatalf("expected %s, received %s", "info", config.LogLevel)
}

os.Setenv("APPROZIUM_HOST", "0.0.0.0")
os.Setenv("APPROZIUM_HTTP_PORT", "6001")
os.Setenv("APPROZIUM_PORT", "6001")
os.Setenv("APPROZIUM_LOG_LEVEL", "debug")
os.Setenv("APPROZIUM_DISABLE_TLS", "true")
config, err = Parse()
Expand All @@ -35,8 +35,8 @@ func TestParseConfig(t *testing.T) {
if config.Host != "0.0.0.0" {
t.Fatalf("expected %s, received %s", "0.0.0.0", config.Host)
}
if config.HTTPPort != 6001 {
t.Fatalf("expected %d, received %d", 6001, config.HTTPPort)
if config.Port != 6001 {
t.Fatalf("expected %d, received %d", 6001, config.Port)
}
if config.LogLevel != "debug" {
t.Fatalf("expected %s, received %s", "debug", config.LogLevel)
Expand Down
5 changes: 2 additions & 3 deletions docs/docs/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ title: Configuration
| Name, shorthand | Environment variable | Default | Description |
|--------------------|----------------------------|------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| --host | APPROZIUM_HOST | 127.0.0.1 | Set to 0.0.0.0 to listen on all interfaces. |
| --httpport | APPROZIUM_HTTP_PORT | 6000 | Port for HTTP(S) API endpoints. |
| --grpcport | APPROZIUM_GRPC_PORT | 6001 | Port for authenticator endpoint for clients. |
| --secretsmanager | APPROZIUM_SECRETS_MANAGER | | Supported options are "vault" (Hashicorp Vault), "asm" (AWS Secrets Manager), and "local" (Local YAML file) |
| --port | APPROZIUM_PORT | 6000 | Port for HTTP(S) API endpoints and authentication requests. |
| --secretsmanager | APPROZIUM_SECRETS_MANAGER | | Supported options are "vault" (Hashicorp Vault), "asm" (AWS Secrets Manager), and "local" (Local YAML file) |
| --disabletls | APPROZIUM_DISABLE_TLS | false | When false, Approzium comes up as an `"https"` server. When `"true"` disables TLS, and plain "http" is used. Setting to `"true"` means the Approzium authentication server will send database connection information in plain text, making it vulnerable to [man-in-the-middle attacks](https://en.wikipedia.org/wiki/Man-in-the-middle_attack). **Do not set to `"true"` in production environments.** |
| --tlscertpath | APPROZIUM_PATH_TO_TLS_CERT | | The path to the TLS certificate the Approzium authentication server has been issued to prove its identity. Curious about how to generate a valid cert? See [this walkthrough](https://itnext.io/practical-guide-to-securing-grpc-connections-with-go-and-tls-part-1-f63058e9d6d1). This certificate would correspond to the `service.pem` generated in the walkthrough. However, ideally this would not be a certificate issued by your own [Certificate Authority (CA)](https://en.wikipedia.org/wiki/Certificate_authority), and instead it might be issued by your company's internal CA and/or a widely recognized one. However, even a self-created CA is better than none. |
| --tlskeypath | APPROZIUM_PATH_TO_TLS_KEY | | The path to the TLS key the Approzium authentication server can use to prove its identity. In the above walkthrough, this would correspond to the `service.key`. |
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/examples/asyncpg_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from approzium.asyncpg.pool import create_pool

auth = AuthClient(
"authenticator:6001",
"authenticator:6000",
# This is insecure, see https://approzium.org/configuration for proper use.
disable_tls=True,
)
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/examples/mysql_connector_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from approzium.mysql.connector.pooling import MySQLConnectionPool

auth = AuthClient(
"authenticator:6001",
"authenticator:6000",
# This is insecure, see https://approzium.org/configuration for proper use.
disable_tls=True,
)
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/examples/psycopg2_attribution_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
import approzium
from approzium.psycopg2 import connect

auth = approzium.AuthClient("authenticator:6001")
auth = approzium.AuthClient("authenticator:6000")
print(auth.attribution_info)
# {'authenticator_address': 'authenticator:6001',
# {'authenticator_address': 'authenticator:6000',
# 'iam_arn': 'arn:aws:iam::*******:user/****',
# 'authenticated': False,
# 'num_connections': 0
Expand All @@ -17,7 +17,7 @@
dsn = "host=dbmd5 dbname=db user=bob"
conn = connect(dsn)
print(auth.attribution_info)
# {'authenticator_address': 'authenticator:6001',
# {'authenticator_address': 'authenticator:6000',
# 'iam_arn': 'arn:aws:iam::*******:user/****',
# 'authenticated': True,
# 'num_connections': 1
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/examples/psycopg2_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from approzium.psycopg2.pool import ThreadedConnectionPool

auth = AuthClient(
"authenticator:6001",
"authenticator:6000",
# This is insecure, see https://approzium.org/configuration for proper use.
disable_tls=True,
)
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/examples/psycopg2_opentelemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import approzium.opentelemetry
from approzium.psycopg2 import connect

auth = approzium.AuthClient("authenticator:6001")
auth = approzium.AuthClient("authenticator:6000")
approzium.default_auth_client = auth

trace.set_tracer_provider(TracerProvider())
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/examples/pymysql_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from approzium.pymysql import connect

auth = AuthClient(
"authenticator:6001",
"authenticator:6000",
# This is insecure, see https://approzium.org/configuration for proper use.
disable_tls=True,
)
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def determine_authclients():
pytest.authclients = []
for host in authenticatorhosts:
base_aws_auth = approzium.AuthClient(
"%s:6001" % host,
"%s:6000" % host,
tls_config=approzium.TLSConfig(
trusted_certs=environ.get("TEST_CERT_DIR") + "/approzium.pem",
client_cert=environ.get("TEST_CERT_DIR") + "/client.pem",
Expand All @@ -28,7 +28,7 @@ def determine_authclients():
if environ.get("TEST_ASSUMABLE_ARN"):
for host in authenticatorhosts:
role_aws_auth = approzium.AuthClient(
"%s:6001" % host,
"%s:6000" % host,
tls_config=approzium.TLSConfig(
trusted_certs=environ.get("TEST_CERT_DIR") + "/approzium.pem",
client_cert=environ.get("TEST_CERT_DIR") + "/client.pem",
Expand Down