diff --git a/.travis.yml b/.travis.yml index a7c98b7c4..6770f8622 100644 --- a/.travis.yml +++ b/.travis.yml @@ -3,3 +3,13 @@ language: go go: - 1.6.2 - 1.7beta1 + +before_script: + - echo $HOSTNAME + - mkdir -p $GOPATH/bin + - wget https://releases.hashicorp.com/consul/0.6.4/consul_0.6.4_linux_amd64.zip + - wget https://releases.hashicorp.com/vault/0.5.3/vault_0.5.3_linux_amd64.zip + - unzip -d $GOPATH/bin consul_0.6.4_linux_amd64.zip + - unzip -d $GOPATH/bin vault_0.5.3_linux_amd64.zip + - vault --version + - consul --version diff --git a/Makefile b/Makefile index 7b1947e7c..d7474cdcb 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ build: test: $(GO) test -i ./... - $(GO) test ./... + $(GO) test -test.timeout 5s `go list ./... | grep -v '/vendor/'` gofmt: gofmt -w `find . -type f -name '*.go' | grep -v vendor` diff --git a/cert/consul_source.go b/cert/consul_source.go new file mode 100644 index 000000000..680d11aee --- /dev/null +++ b/cert/consul_source.go @@ -0,0 +1,124 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "log" + "net/url" + "path" + "reflect" + "time" + + "github.com/hashicorp/consul/api" +) + +// ConsulSource implements a certificate source which loads +// TLS and client authentication certificates from the consul KV store. +// The CertURL/ClientCAURL must point to the base path of the certificates. +// The TLS certificates are updated automatically when the KV store +// changes. +type ConsulSource struct { + CertURL string + ClientCAURL string +} + +const kvURLPrefix = "/v1/kv/" + +func parseConsulURL(consulURL, stripPrefix string) (client *api.Client, key string, err error) { + u, err := url.Parse(consulURL) + if err != nil { + return nil, "", err + } + var token string + if len(u.Query()["token"]) > 0 { + token = u.Query()["token"][0] + } + client, err = api.NewClient(&api.Config{Address: u.Host, Scheme: u.Scheme, Token: token}) + if err != nil { + return nil, "", err + } + key = u.RequestURI()[len(stripPrefix):] + return client, key, nil +} + +func (s ConsulSource) LoadClientCAs() (*x509.CertPool, error) { + if s.ClientCAURL == "" { + return nil, nil + } + + client, key, err := parseConsulURL(s.ClientCAURL, kvURLPrefix) + if err != nil { + return nil, err + } + + load := func(key string) (map[string][]byte, error) { + pemBlocks, _, err := getCerts(client, key, 0) + return pemBlocks, err + } + return newCertPool(key, load) +} + +func (s ConsulSource) Certificates() chan []tls.Certificate { + if s.CertURL == "" { + return nil + } + + client, key, err := parseConsulURL(s.CertURL, kvURLPrefix) + if err != nil { + log.Printf("[ERROR] cert: Failed to create consul client. %s", err) + } + + pemBlocksCh := make(chan map[string][]byte, 1) + go watchKV(client, key, pemBlocksCh) + + ch := make(chan []tls.Certificate, 1) + go func() { + for pemBlocks := range pemBlocksCh { + certs, err := loadCertificates(pemBlocks) + if err != nil { + log.Printf("[ERROR] cert: Failed to load certificates. %s", err) + continue + } + ch <- certs + } + }() + return ch +} + +// watchKV monitors a key in the KV store for changes. +func watchKV(client *api.Client, key string, pemBlocks chan map[string][]byte) { + var lastIndex uint64 + var lastValue map[string][]byte + + for { + value, index, err := getCerts(client, key, lastIndex) + if err != nil { + log.Printf("[WARN] cert: Error fetching certificates from %s. %v", key, err) + time.Sleep(time.Second) + continue + } + + if !reflect.DeepEqual(value, lastValue) || index != lastIndex { + log.Printf("[INFO] cert: Certificate index changed to #%d", index) + pemBlocks <- value + lastValue, lastIndex = value, index + } + } +} + +func getCerts(client *api.Client, key string, waitIndex uint64) (pemBlocks map[string][]byte, lastIndex uint64, err error) { + q := &api.QueryOptions{RequireConsistent: true, WaitIndex: waitIndex} + kvpairs, meta, err := client.KV().List(key, q) + if err != nil { + return nil, 0, fmt.Errorf("consul: list: %s", err) + } + if len(kvpairs) == 0 { + return nil, meta.LastIndex, nil + } + pemBlocks = map[string][]byte{} + for _, kvpair := range kvpairs { + pemBlocks[path.Base(kvpair.Key)] = kvpair.Value + } + return pemBlocks, meta.LastIndex, nil +} diff --git a/cert/file_source.go b/cert/file_source.go new file mode 100644 index 000000000..9dd939c09 --- /dev/null +++ b/cert/file_source.go @@ -0,0 +1,53 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "io/ioutil" + "log" +) + +// FileSource implements a certificate source for one +// TLS and one client authentication certificate. +// The certificates are loaded during startup and are cached +// in memory until the program exits. +// It exists to support the legacy configuration only. The +// PathSource should be used instead. +type FileSource struct { + CertFile string + KeyFile string + ClientAuthFile string +} + +func (s FileSource) LoadClientCAs() (*x509.CertPool, error) { + return newCertPool(s.ClientAuthFile, func(path string) (map[string][]byte, error) { + if s.ClientAuthFile == "" { + return nil, nil + } + pemBlock, err := ioutil.ReadFile(path) + return map[string][]byte{path: pemBlock}, err + }) +} + +func (s FileSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + ch <- []tls.Certificate{loadX509KeyPair(s.CertFile, s.KeyFile)} + close(ch) + return ch +} + +func loadX509KeyPair(certFile, keyFile string) tls.Certificate { + if certFile == "" { + log.Fatalf("[FATAL] cert: CertFile is required") + } + + if keyFile == "" { + keyFile = certFile + } + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalf("[FATAL] cert: Error loading certificate. %s", err) + } + return cert +} diff --git a/cert/http_source.go b/cert/http_source.go new file mode 100644 index 000000000..5586662d9 --- /dev/null +++ b/cert/http_source.go @@ -0,0 +1,31 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "time" +) + +// HTTPSource implements a certificate source which loads +// TLS and client authentication certificates from an HTTP/HTTPS server. +// The CertURL/ClientCAURL must point to a text file in the directory +// of the certificates. The text file contains all files that should +// be loaded from this directory - one filename per line. +// The TLS certificates are updated automatically when Refresh +// is not zero. Refresh cannot be less than one second to prevent +// busy loops. +type HTTPSource struct { + CertURL string + ClientCAURL string + Refresh time.Duration +} + +func (s HTTPSource) LoadClientCAs() (*x509.CertPool, error) { + return newCertPool(s.ClientCAURL, loadURL) +} + +func (s HTTPSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + go watch(ch, s.Refresh, s.CertURL, loadURL) + return ch +} diff --git a/cert/load.go b/cert/load.go new file mode 100644 index 000000000..ee9b25c32 --- /dev/null +++ b/cert/load.go @@ -0,0 +1,197 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "path" + "path/filepath" + "sort" + "strings" +) + +const MaxSize = 1 << 20 // 1MB + +func loadURL(listURL string) (pemBlocks map[string][]byte, err error) { + if listURL == "" { + return nil, nil + } + + baseURL, err := base(listURL) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + fetch := func(url string) (buf []byte, err error) { + resp, err := http.Get(url) + if err != nil { + return nil, err + } + defer resp.Body.Close() + return ioutil.ReadAll(resp.Body) + } + + // fetch the file with the list of filenames + list, err := fetch(listURL) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + // fetch the individual files + pemBlocks = map[string][]byte{} + for _, p := range strings.Split(string(list), "\n") { + if p == "" { + continue + } + + path := baseURL + p + + buf, err := fetch(path) + if err != nil { + return nil, fmt.Errorf("cert: %s", err) + } + + pemBlocks[path] = buf + } + + return pemBlocks, nil +} + +func loadPath(root string) (pemBlocks map[string][]byte, err error) { + if root == "" { + return nil, nil + } + + pemBlocks = map[string][]byte{} + err = filepath.Walk(root, func(path string, info os.FileInfo, err error) error { + // check if the root directory exists + if _, ok := err.(*os.PathError); ok && path == root { + return nil + } + + if err != nil { + return err + } + + if info.IsDir() || filepath.Ext(info.Name()) != ".pem" || strings.HasPrefix(info.Name(), ".") { + return nil + } + + if info.Size() > MaxSize { + log.Print("[WARN] cert: File too large %s", info.Name) + return nil + } + + buf, err := ioutil.ReadFile(path) + if err != nil { + return fmt.Errorf("cert: %s", err) + } + + pemBlocks[path] = buf + return nil + }) + + if err != nil { + return nil, err + } + + return pemBlocks, nil +} + +func loadCertificates(pemBlocks map[string][]byte) ([]tls.Certificate, error) { + var n []string + x := map[string]tls.Certificate{} + + for name := range pemBlocks { + var certFile, keyFile string + switch { + case strings.HasSuffix(name, "-cert.pem"): + certFile, keyFile = name, replaceSuffix(name, "-cert.pem", "-key.pem") + case strings.HasSuffix(name, "-key.pem"): + certFile, keyFile = replaceSuffix(name, "-key.pem", "-cert.pem"), name + case strings.HasSuffix(name, ".pem"): + certFile, keyFile = name, name + default: + continue + } + + if _, exists := x[certFile]; exists { + continue + } + + cert, key := pemBlocks[certFile], pemBlocks[keyFile] + if cert == nil || key == nil { + return nil, fmt.Errorf("cert: cannot load certificate %s", name) + } + + c, err := tls.X509KeyPair(cert, key) + if err != nil { + return nil, fmt.Errorf("cert: invalid certificate %s. %s", name, err) + } + + x[certFile] = c + n = append(n, certFile) + } + + // append certificates in alphabetical order of the + // cert filenames. This determines which certificate + // becomes the default certificate (the first one) + sort.Strings(n) + var certs []tls.Certificate + for _, certFile := range n { + certs = append(certs, x[certFile]) + } + + return certs, nil +} + +// base returns the rawurl with the last element of the path +// removed. http://foo.com/x/y becomes http://foo.com/x +func base(rawurl string) (string, error) { + if rawurl == "" { + return "", nil + } + u, err := url.Parse(rawurl) + if err != nil { + return "", err + } + if u.Path != "/" { + u.Path = path.Dir(u.Path) + } + return u.String(), nil +} + +// replaceSuffix replaces oldSuffix with newSuffix in s. +// It is only valid when s has oldSuffix and oldSuffix is not empty. +func replaceSuffix(s string, oldSuffix, newSuffix string) string { + return s[:len(s)-len(oldSuffix)] + newSuffix +} + +// newCertPool creates a new x509.CertPool by loading the +// PEM blocks from loadFn(path) and adding them to a CertPool. +func newCertPool(path string, loadFn func(path string) (pemBlocks map[string][]byte, err error)) (*x509.CertPool, error) { + pemBlocks, err := loadFn(path) + if err != nil { + return nil, err + } + + if len(pemBlocks) == 0 { + return nil, nil + } + + x := x509.NewCertPool() + for name, pemBlock := range pemBlocks { + if !x.AppendCertsFromPEM(pemBlock) { + log.Printf("[WARN] cert: Could not add client CA certificate from %s", name) + continue + } + } + + log.Printf("[INFO] cert: Load client CA certs from %s", path) + return x, nil +} diff --git a/cert/load_test.go b/cert/load_test.go new file mode 100644 index 000000000..05147f0bd --- /dev/null +++ b/cert/load_test.go @@ -0,0 +1,36 @@ +package cert + +import "testing" + +func TestBase(t *testing.T) { + tests := []struct { + in, out, err string + }{ + {"", "", ""}, + {"http://foo.com/x/y", "http://foo.com/x", ""}, + {"http://foo.com/x/y?p=q", "http://foo.com/x?p=q", ""}, + } + + for i, tt := range tests { + u, err := base(tt.in) + if err != nil { + if got, want := err.Error(), tt.err; got != want { + t.Errorf("%d: got %v want %v", i, got, want) + continue + } + } + if tt.err != "" { + t.Errorf("%d: got nil want %v", i, tt.err) + continue + } + if got, want := u, tt.out; got != want { + t.Errorf("%d: got %v want %v", i, got, want) + } + } +} + +func TestReplaceSuffix(t *testing.T) { + if got, want := replaceSuffix("ab", "b", "c"), "ac"; got != want { + t.Errorf("got %q want %q", got, want) + } +} diff --git a/cert/path_source.go b/cert/path_source.go new file mode 100644 index 000000000..82afaac09 --- /dev/null +++ b/cert/path_source.go @@ -0,0 +1,39 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "path/filepath" + "time" +) + +const ( + DefaultCertPath = "cert" + DefaultClientCAPath = "clientca" +) + +type PathSource struct { + Path string + CertPath string + ClientCAPath string + Refresh time.Duration +} + +func (s PathSource) LoadClientCAs() (*x509.CertPool, error) { + path := makePath(s.Path, s.ClientCAPath, DefaultClientCAPath) + return newCertPool(path, loadPath) +} + +func (s PathSource) Certificates() chan []tls.Certificate { + path := makePath(s.Path, s.CertPath, DefaultCertPath) + ch := make(chan []tls.Certificate, 1) + go watch(ch, s.Refresh, path, loadPath) + return ch +} + +func makePath(parent, child, defaultChild string) string { + if child == "" { + return filepath.Join(parent, defaultChild) + } + return filepath.Join(parent, child) +} diff --git a/cert/source.go b/cert/source.go new file mode 100644 index 000000000..f3e8b8cb6 --- /dev/null +++ b/cert/source.go @@ -0,0 +1,51 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" +) + +// Source provides the interface for dynamic certificate sources. +// +// Certificates() loads certificates for TLS connections. +// The first certificate is used as the default certificate +// if the client does not support SNI or no matching certificate +// could be found. TLS certificates can be updated at runtime. +// +// LoadClientCAs() provides certificates for client certificate +// authentication. +type Source interface { + Certificates() chan []tls.Certificate + LoadClientCAs() (*x509.CertPool, error) +} + +// TLSConfig creates a tls.Config which sets the +// GetCertificate field to a certificate store +// which uses the given source to update the +// the certificates on demand. +// +// It also sets the ClientCAs field if +// src.LoadClientCAs returns a non-nil value +// and sets ClientAuth to RequireAndVerifyClientCert. +func TLSConfig(src Source) (*tls.Config, error) { + clientCAs, err := src.LoadClientCAs() + if err != nil { + return nil, err + } + + store := NewStore() + x := &tls.Config{GetCertificate: store.GetCertificate} + + if clientCAs != nil { + x.ClientCAs = clientCAs + x.ClientAuth = tls.RequireAndVerifyClientCert + } + + go func() { + for certs := range src.Certificates() { + store.SetCertificates(certs) + } + }() + + return x, nil +} diff --git a/cert/source_test.go b/cert/source_test.go new file mode 100644 index 000000000..947a35003 --- /dev/null +++ b/cert/source_test.go @@ -0,0 +1,335 @@ +package cert + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io/ioutil" + "log" + "math/big" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + consulapi "github.com/hashicorp/consul/api" + vaultapi "github.com/hashicorp/vault/api" +) + +type StaticSource struct { + cert tls.Certificate +} + +func (s StaticSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + ch <- []tls.Certificate{s.cert} + close(ch) + return ch +} + +func (s StaticSource) LoadClientCAs() (*x509.CertPool, error) { + return nil, nil +} + +func TestStaticSource(t *testing.T) { + certPEM, keyPEM := makeCert("localhost", time.Minute) + cert, err := tls.X509KeyPair(certPEM, keyPEM) + if err != nil { + t.Fatalf("X509KeyPair: got %s want nil", err) + } + testSource(t, StaticSource{cert}, makeCertPool(certPEM), 0) +} + +func TestFileSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + certFile, keyFile := saveCert(dir, "localhost", certPEM, keyPEM) + testSource(t, FileSource{CertFile: certFile, KeyFile: keyFile}, makeCertPool(certPEM), 0) +} + +func TestPathSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + saveCert(dir, "localhost", certPEM, keyPEM) + testSource(t, PathSource{CertPath: dir}, makeCertPool(certPEM), 0) +} + +func TestHTTPSource(t *testing.T) { + dir := tempDir() + defer os.RemoveAll(dir) + certPEM, keyPEM := makeCert("localhost", time.Minute) + certFile, keyFile := saveCert(dir, "localhost", certPEM, keyPEM) + listFile := filepath.Base(certFile) + "\n" + filepath.Base(keyFile) + "\n" + writeFile(filepath.Join(dir, "list"), []byte(listFile)) + + srv := httptest.NewServer(http.FileServer(http.Dir(dir))) + defer srv.Close() + + testSource(t, HTTPSource{CertURL: srv.URL + "/list"}, makeCertPool(certPEM), 50*time.Millisecond) +} + +func TestConsulSource(t *testing.T) { + const ( + certURL = "http://localhost:8500/v1/kv/fabio/test/consul-server" + dataDir = "/tmp/fabio-consul-source-test" + ) + + // run a consul server if it isn't already running + _, err := http.Get("http://localhost:8500/v1/status/leader") + if err != nil { + t.Log("Starting consul server") + consul := exec.Command("consul", "agent", "-server", "-bootstrap", "-data-dir", dataDir) + if err := consul.Start(); err != nil { + t.Fatalf("Failed to start consul server. %s", err) + } + defer func() { + consul.Process.Kill() + os.RemoveAll(dataDir) + }() + + isUp := func() bool { + resp, err := http.Get("http://localhost:8500/v1/status/leader") + return err == nil && resp.StatusCode == 200 + } + if !waitFor(time.Second, isUp) { + t.Fatal("Timeout waiting for consul server") + } + // give consul time to figure out that it is the only member + time.Sleep(3 * time.Second) + } else { + t.Log("Using existing consul server") + } + + client, key, err := parseConsulURL(certURL, kvURLPrefix) + if err != nil { + t.Fatalf("Failed to create consul client: %s", err) + } + defer func() { client.KV().DeleteTree(key, &consulapi.WriteOptions{}) }() + + write := func(name string, value []byte) { + p := &consulapi.KVPair{Key: key + "/" + name, Value: value} + _, err := client.KV().Put(p, &consulapi.WriteOptions{}) + if err != nil { + t.Fatalf("Failed to write %q to consul: %s", p.Key, err) + } + } + + certPEM, keyPEM := makeCert("localhost", time.Minute) + write("localhost-cert.pem", certPEM) + write("localhost-key.pem", keyPEM) + + testSource(t, ConsulSource{CertURL: certURL}, makeCertPool(certPEM), 50*time.Millisecond) +} + +func TestVaultSource(t *testing.T) { + const ( + addr = "127.0.0.1:58421" + rootToken = "token" + certPath = "secret/fabio/cert" + ) + + // run a vault server in dev mode + t.Log("Starting vault server") + vault := exec.Command("vault", "server", "-dev", "-dev-root-token-id="+rootToken, "-dev-listen-address="+addr) + if err := vault.Start(); err != nil { + t.Fatalf("Failed to start vault server. %s", err) + } + defer vault.Process.Kill() + + // create a vault client for that server + c, err := vaultapi.NewClient(&vaultapi.Config{Address: "http://" + addr}) + if err != nil { + t.Fatalf("NewClient failed: %s", err) + } + c.SetToken(rootToken) + + isUp := func() bool { + ok, err := c.Sys().InitStatus() + return err == nil && ok + } + if !waitFor(time.Second, isUp) { + t.Fatal("Timeout waiting for vault server") + } + + // create a renewable token since the vault source + // will renew the token on every request + tok, err := c.Auth().Token().Create(&vaultapi.TokenCreateRequest{NoParent: true, TTL: "1h"}) + if err != nil { + t.Fatalf("Token.Create failed: %s", err) + } + + // create a cert and store it in vault + certPEM, keyPEM := makeCert("localhost", time.Minute) + data := map[string]interface{}{"cert": string(certPEM), "key": string(keyPEM)} + if _, err := c.Logical().Write(certPath+"/localhost", data); err != nil { + t.Fatalf("logical.Write failed: %s", err) + } + + testSource(t, VaultSource{Addr: "http://" + addr, token: tok.Auth.ClientToken, CertPath: certPath}, makeCertPool(certPEM), 50*time.Millisecond) +} + +// testSource runs an integration test by making an HTTPS request +// to https://localhost/ expecting that the source provides a valid +// certificate for "localhost". rootCAs is expected to contain a +// valid root certificate or the server certificate itself so that +// the HTTPS client can validate the certificate presented by the +// server. +func testSource(t *testing.T, source Source, rootCAs *x509.CertPool, sleep time.Duration) { + srvConfig, err := TLSConfig(source) + if err != nil { + t.Fatalf("TLSConfig: got %q want nil", err) + } + + // give the source some time to initialize if necessary + time.Sleep(sleep) + + // create the https server and start it + // it will be listening on 127.0.0.1 + srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "OK") + })) + srv.TLS = srvConfig + srv.StartTLS() + defer srv.Close() + + // create an http client that will accept the root CAs + // otherwise the HTTPS client will not verify the + // certificate presented by the server. + client := http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: rootCAs, + }, + }, + } + + call := func(host string) (statusCode int, body string, err error) { + // for the certificate validation to work we need to put a hostname + // which resolves to 127.0.0.1 in the URL. Can't fake the hostname via + // the Host header. + resp, err := client.Get(strings.Replace(srv.URL, "127.0.0.1", host, 1)) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + return 0, "", err + } + + return resp.StatusCode, string(data), nil + } + + // disable log output for the next call to prevent + // confusing log messages since they are expected + // http: TLS handshake error from 127.0.0.1:55044: remote error: bad certificate + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stderr) + + // make a call for which certificate validation fails. + // localhost.org is external but resolves to 127.0.0.1 + _, _, err = call("localhost.org") + if got, want := err, "x509: certificate is valid for localhost, not localhost.org"; got == nil || !strings.Contains(got.Error(), want) { + t.Fatalf("got %q want %q", got, want) + } + + // now make the call that should succeed + statusCode, body, err := call("localhost") + if err != nil { + t.Fatalf("got %v want nil", err) + } + if got, want := statusCode, 200; got != want { + t.Fatalf("got %v want %v", got, want) + } + if got, want := body, "OK"; got != want { + t.Fatalf("got %v want %v", got, want) + } +} + +func tempDir() string { + dir, err := ioutil.TempDir("", "fabio") + if err != nil { + log.Fatal(err) + } + return dir +} + +func writeFile(filename string, data []byte) { + if err := ioutil.WriteFile(filename, data, 0644); err != nil { + log.Fatal(err) + } +} + +func makeCertPool(x ...[]byte) *x509.CertPool { + p := x509.NewCertPool() + for _, b := range x { + p.AppendCertsFromPEM(b) + } + return p +} + +func saveCert(dir, host string, certPEM, keyPEM []byte) (certFile, keyFile string) { + certFile, keyFile = filepath.Join(dir, host+"-cert.pem"), filepath.Join(dir, host+"-key.pem") + writeFile(certFile, certPEM) + writeFile(keyFile, keyPEM) + return certFile, keyFile +} + +// makeCert creates a self-signed RSA certificate. +// taken from crypto/tls/generate_cert.go +func makeCert(host string, validFor time.Duration) (certPEM, keyPEM []byte) { + const bits = 1024 + priv, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + log.Fatalf("Failed to generate private key: %s", err) + } + + template := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Fabio Co"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(validFor), + IsCA: true, + DNSNames: []string{host}, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + log.Fatalf("Failed to create certificate: %s", err) + } + + var cert, key bytes.Buffer + pem.Encode(&cert, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}) + pem.Encode(&key, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + return cert.Bytes(), key.Bytes() +} + +func waitFor(timeout time.Duration, up func() bool) bool { + until := time.Now().Add(timeout) + for { + if time.Now().After(until) { + return false + } + if up() { + return true + } + time.Sleep(100 * time.Millisecond) + } +} diff --git a/cert/store.go b/cert/store.go new file mode 100644 index 000000000..106e67d23 --- /dev/null +++ b/cert/store.go @@ -0,0 +1,100 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "log" + "strings" + "sync/atomic" +) + +// Store provides a dynamic certificate store which can be updated at +// runtime and is safe for concurrent use. +type Store struct { + cfg atomic.Value +} + +// NewStore creates an empty certificate store. +func NewStore() *Store { + s := new(Store) + s.cfg.Store(config{}) + return s +} + +// SetCertificates replaces the certificates of the store. +func (s *Store) SetCertificates(certs []tls.Certificate) { + cfg := config{Certificates: certs} + cfg.BuildNameToCertificate() + s.cfg.Store(cfg) + var names []string + for name := range cfg.NameToCertificate { + names = append(names, name) + } + log.Printf("[INFO] cert: Store has certificates for [%q]", strings.Join(names, ",")) +} + +// GetCertificate returns a matching certificate for the given clientHello if possible +// or the first certificate from the store. +func (s *Store) GetCertificate(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) { + return getCertificate(s.cfg.Load().(config), clientHello) +} + +func getCertificate(cfg config, clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) { + if len(cfg.Certificates) == 0 { + return nil, errors.New("cert: no certificates configured") + } + + if len(cfg.Certificates) == 1 || cfg.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &cfg.Certificates[0], nil + } + + name := strings.ToLower(clientHello.ServerName) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + if cert, ok := cfg.NameToCertificate[name]; ok { + return cert, nil + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := cfg.NameToCertificate[candidate]; ok { + return cert, nil + } + } + + // If nothing matches, return the first certificate. + return &cfg.Certificates[0], nil +} + +type config struct { + Certificates []tls.Certificate + NameToCertificate map[string]*tls.Certificate +} + +// BuildNameToCertificate parses Certificates and builds NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +func (c *config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*tls.Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} diff --git a/cert/vault_source.go b/cert/vault_source.go new file mode 100644 index 000000000..882222b09 --- /dev/null +++ b/cert/vault_source.go @@ -0,0 +1,111 @@ +package cert + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "log" + "time" + + "github.com/hashicorp/vault/api" +) + +// VaultSource implements a certificate source which loads +// TLS and client authorization certificates from a Vault server. +// The Vault token should be set through the VAULT_TOKEN environment +// variable. +// +// The TLS certificates are updated automatically when Refresh +// is not zero. Refresh cannot be less than one second to prevent +// busy loops. +type VaultSource struct { + Addr string + CertPath string + ClientCAPath string + Refresh time.Duration + + token string +} + +func (s VaultSource) client() (*api.Client, error) { + c, err := api.NewClient(&api.Config{Address: s.Addr}) + if err != nil { + return nil, err + } + c.SetToken(s.token) + return c, nil +} + +func (s VaultSource) LoadClientCAs() (*x509.CertPool, error) { + return newCertPool(s.ClientCAPath, s.load) +} + +func (s VaultSource) Certificates() chan []tls.Certificate { + ch := make(chan []tls.Certificate, 1) + go watch(ch, s.Refresh, s.CertPath, s.load) + return ch +} + +func (s VaultSource) load(path string) (pemBlocks map[string][]byte, err error) { + pemBlocks = map[string][]byte{} + + // get will read a key=value pair from the secret + // and store it as -{cert,key}.pem so that + // they are recognized by the post-processing function + // which assembles the certificates. + // The value can be stored either as string or []byte. + get := func(name, typ string, secret *api.Secret) { + v := secret.Data[typ] + if v == nil { + return + } + + var b []byte + switch v.(type) { + case string: + b = []byte(v.(string)) + case []byte: + b = v.([]byte) + default: + log.Printf("[WARN] cert: key %s has type %T", name, v) + return + } + + pemBlocks[name+"-"+typ+".pem"] = []byte(b) + } + + c, err := s.client() + if err != nil { + return nil, fmt.Errorf("vault: client: %s", err) + } + + // renew token + _, err = c.Auth().Token().RenewSelf(3600) + if err != nil { + return nil, fmt.Errorf("vault: renew-self: %s", err) + } + + // get the subkeys under 'path'. + // Each subkey refers to a certificate. + certs, err := c.Logical().List(path) + if err != nil { + return nil, fmt.Errorf("vault: list: %s", err) + } + if certs == nil { + return nil, nil + } + + for _, s := range certs.Data["keys"].([]interface{}) { + name := s.(string) + p := path + "/" + name + secret, err := c.Logical().Read(p) + if err != nil { + log.Printf("[WARN] cert: Failed to read %s from Vault: %s", p, err) + continue + } + get(name, "cert", secret) + get(name, "key", secret) + } + + return pemBlocks, nil +} diff --git a/cert/watch.go b/cert/watch.go new file mode 100644 index 000000000..f6cb8b1aa --- /dev/null +++ b/cert/watch.go @@ -0,0 +1,46 @@ +package cert + +import ( + "crypto/tls" + "log" + "reflect" + "time" +) + +// watch monitors the result of the loadFn function for changes. +func watch(ch chan []tls.Certificate, refresh time.Duration, path string, loadFn func(path string) (map[string][]byte, error)) { + once := refresh <= 0 + + // do not refresh more often than once a second to prevent busy loops + if refresh < time.Second { + refresh = time.Second + } + + var last map[string][]byte + for { + next, err := loadFn(path) + if err != nil { + log.Printf("[ERROR] cert: Cannot load certificates from %s. %s", path, err) + time.Sleep(refresh) + continue + } + + if reflect.DeepEqual(next, last) { + time.Sleep(refresh) + continue + } + + certs, err := loadCertificates(next) + if err != nil { + log.Printf("[ERROR] cert: Cannot make certificates: %s", err) + continue + } + + ch <- certs + last = next + + if once { + return + } + } +} diff --git a/config/config.go b/config/config.go index 0ddd6321b..d3d796178 100644 --- a/config/config.go +++ b/config/config.go @@ -1,24 +1,36 @@ package config -import "time" +import ( + "net/http" + "time" +) type Config struct { - Proxy Proxy - Registry Registry - Listen []Listen - Metrics Metrics - UI UI - Runtime Runtime + Proxy Proxy + Registry Registry + Listen []Listen + CertSources map[string]CertSource + Metrics Metrics + UI UI + Runtime Runtime +} + +type CertSource struct { + Name string + Type string + CertPath string + KeyPath string + ClientCAPath string + Refresh time.Duration + Header http.Header } type Listen struct { - Addr string - KeyFile string - CertFile string - ClientAuthFile string - TLS bool - ReadTimeout time.Duration - WriteTimeout time.Duration + Addr string + Scheme string + ReadTimeout time.Duration + WriteTimeout time.Duration + CertSource CertSource } type UI struct { @@ -43,6 +55,7 @@ type Proxy struct { TLSHeader string TLSHeaderValue string ListenerAddr string + CertSources string } type Runtime struct { diff --git a/config/default.go b/config/default.go index fcbf173bd..ca0a3def8 100644 --- a/config/default.go +++ b/config/default.go @@ -41,4 +41,5 @@ var Default = &Config{ Prefix: "default", Interval: 30 * time.Second, }, + CertSources: map[string]CertSource{}, } diff --git a/config/load.go b/config/load.go index bd1df5007..5e716ab73 100644 --- a/config/load.go +++ b/config/load.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "net/http" "os" "runtime" "strings" @@ -67,6 +68,7 @@ func load(p *properties.Properties) (cfg *Config, err error) { f.StringVar(&cfg.Proxy.TLSHeader, "proxy.header.tls", Default.Proxy.TLSHeader, "header for TLS connections") f.StringVar(&cfg.Proxy.TLSHeaderValue, "proxy.header.tls.value", Default.Proxy.TLSHeaderValue, "value for TLS connection header") f.StringVar(&cfg.Proxy.ListenerAddr, "proxy.addr", Default.Proxy.ListenerAddr, "listener config") + f.StringVar(&cfg.Proxy.CertSources, "proxy.cs", Default.Proxy.CertSources, "certificate sources") f.DurationVar(&cfg.Proxy.ReadTimeout, "proxy.readtimeout", Default.Proxy.ReadTimeout, "read timeout for incoming requests") f.DurationVar(&cfg.Proxy.WriteTimeout, "proxy.writetimeout", Default.Proxy.WriteTimeout, "write timeout for outgoing responses") f.StringVar(&cfg.Metrics.Target, "metrics.target", Default.Metrics.Target, "metrics backend") @@ -95,9 +97,18 @@ func load(p *properties.Properties) (cfg *Config, err error) { var proxyTimeout time.Duration f.DurationVar(&proxyTimeout, "proxy.timeout", time.Duration(0), "deprecated") + // filter out -test flags + var args []string + for _, a := range os.Args[1:] { + if strings.HasPrefix(a, "-test.") { + continue + } + args = append(args, a) + } + // parse configuration prefixes := []string{"FABIO_", ""} - if err := f.ParseFlags(os.Args[1:], os.Environ(), prefixes, p); err != nil { + if err := f.ParseFlags(args, os.Environ(), prefixes, p); err != nil { return nil, err } @@ -108,7 +119,12 @@ func load(p *properties.Properties) (cfg *Config, err error) { cfg.Registry.Consul.Scheme, cfg.Registry.Consul.Addr = parseScheme(cfg.Registry.Consul.Addr) - cfg.Listen, err = parseListen(cfg.Proxy.ListenerAddr, cfg.Proxy.ReadTimeout, cfg.Proxy.WriteTimeout) + cfg.CertSources, err = parseCertSources(cfg.Proxy.CertSources) + if err != nil { + return nil, err + } + + cfg.Listen, err = parseListeners(cfg.Proxy.ListenerAddr, cfg.CertSources, cfg.Proxy.ReadTimeout, cfg.Proxy.WriteTimeout) if err != nil { return nil, err } @@ -141,6 +157,8 @@ func load(p *properties.Properties) (cfg *Config, err error) { return cfg, nil } +// parseScheme splits a url into scheme and address and defaults +// to "http" if no scheme was given. func parseScheme(s string) (scheme, addr string) { s = strings.ToLower(s) if strings.HasPrefix(s, "https://") { @@ -152,33 +170,178 @@ func parseScheme(s string) (scheme, addr string) { return "http", s } -func parseListen(addrs string, readTimeout, writeTimeout time.Duration) ([]Listen, error) { - listen := []Listen{} - for _, addr := range strings.Split(addrs, ",") { - addr = strings.TrimSpace(addr) - if addr == "" { +// parseKV converts a "key1=val1;key2=val2;..." string into a map. +func parseKV(cfg string) map[string]string { + m := map[string]string{} + for _, s := range strings.Split(cfg, ";") { + p := strings.SplitN(s, "=", 2) + if len(p) == 1 { + m[p[0]] = "" + } else { + m[p[0]] = p[1] + } + } + return m +} + +func parseListeners(cfgs string, cs map[string]CertSource, readTimeout, writeTimeout time.Duration) (listen []Listen, err error) { + for _, cfg := range strings.Split(cfgs, ",") { + cfg = strings.TrimSpace(cfg) + if cfg == "" { continue } - var l Listen - p := strings.Split(addr, ";") - switch len(p) { - case 1: - l.Addr = p[0] - case 2: - l.Addr, l.CertFile, l.KeyFile, l.TLS = p[0], p[1], p[1], true - case 3: - l.Addr, l.CertFile, l.KeyFile, l.TLS = p[0], p[1], p[2], true - case 4: - l.Addr, l.CertFile, l.KeyFile, l.ClientAuthFile, l.TLS = p[0], p[1], p[2], p[3], true - default: - return nil, fmt.Errorf("invalid address %s", addr) + l, err := parseListen(cfg, cs, readTimeout, writeTimeout) + if err != nil { + return nil, err } - l.ReadTimeout = readTimeout - l.WriteTimeout = writeTimeout + listen = append(listen, l) } - return listen, nil + return +} + +func parseListen(cfg string, cs map[string]CertSource, readTimeout, writeTimeout time.Duration) (l Listen, err error) { + if cfg == "" { + return Listen{}, nil + } + + opts := strings.Split(cfg, ";") + if len(opts) > 1 && !strings.Contains(opts[1], "=") { + return parseLegacyListen(cfg, readTimeout, writeTimeout) + } + + l = Listen{ + Addr: opts[0], + Scheme: "http", + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + } + + for k, v := range parseKV(cfg) { + switch k { + case "rt": // read timeout + d, err := time.ParseDuration(v) + if err != nil { + return Listen{}, err + } + l.ReadTimeout = d + case "wt": // write timeout + d, err := time.ParseDuration(v) + if err != nil { + return Listen{}, err + } + l.WriteTimeout = d + case "cs": // cert source + c, ok := cs[v] + if !ok { + return Listen{}, fmt.Errorf("unknown certificate source %s", v) + } + l.CertSource = c + l.Scheme = "https" + } + } + return +} + +func parseLegacyListen(cfg string, readTimeout, writeTimeout time.Duration) (l Listen, err error) { + opts := strings.Split(cfg, ";") + + l = Listen{ + Addr: opts[0], + Scheme: "http", + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + } + + if len(opts) > 1 { + l.Scheme = "https" + l.CertSource.Type = "file" + l.CertSource.CertPath = opts[1] + } + if len(opts) > 2 { + l.CertSource.KeyPath = opts[2] + } + if len(opts) > 3 { + l.CertSource.ClientCAPath = opts[3] + } + if len(opts) > 4 { + return Listen{}, fmt.Errorf("invalid listener configuration") + } + + log.Printf("[WARN] proxy.addr legacy configuration for certificates is deprecated. Use cs=path configuration") + return l, nil +} + +func parseCertSources(cfgs string) (cs map[string]CertSource, err error) { + cs = map[string]CertSource{} + for _, cfg := range strings.Split(cfgs, ",") { + cfg = strings.TrimSpace(cfg) + if cfg == "" { + continue + } + + src, err := parseCertSource(cfg) + if err != nil { + return nil, err + } + cs[src.Name] = src + } + return +} + +func parseCertSource(cfg string) (c CertSource, err error) { + if cfg == "" { + return CertSource{}, nil + } + + c.Refresh = 3 * time.Second + + for k, v := range parseKV(cfg) { + switch k { + case "cs": + c.Name = v + case "type": + c.Type = v + case "cert": + c.CertPath = v + case "key": + c.KeyPath = v + case "clientca": + c.ClientCAPath = v + case "refresh": + d, err := time.ParseDuration(v) + if err != nil { + return CertSource{}, err + } + c.Refresh = d + case "hdr": + p := strings.SplitN(v, ": ", 2) + if len(p) != 2 { + return CertSource{}, fmt.Errorf("invalid header %s", v) + } + if c.Header == nil { + c.Header = http.Header{} + } + c.Header.Set(p[0], p[1]) + } + } + if c.Name == "" { + return CertSource{}, fmt.Errorf("missing 'cs' in %s", cfg) + } + if c.Type == "" { + return CertSource{}, fmt.Errorf("missing 'type' in %s", cfg) + } + if c.CertPath == "" { + return CertSource{}, fmt.Errorf("missing 'cert' in %s", cfg) + } + if c.Type != "file" && c.Type != "path" && c.Type != "http" && c.Type != "consul" && c.Type != "vault" { + return CertSource{}, fmt.Errorf("unknown cert source type %s", c.Type) + } + if c.Type == "file" { + c.Refresh = 0 + } + return } type tags []string diff --git a/config/load_test.go b/config/load_test.go index 73efe9dae..75d076de4 100644 --- a/config/load_test.go +++ b/config/load_test.go @@ -1,6 +1,7 @@ package config import ( + "net/http" "reflect" "testing" "time" @@ -11,6 +12,7 @@ import ( func TestFromProperties(t *testing.T) { in := ` +proxy.cs = cs=name;type=path;cert=foo;clientca=bar;refresh=99s;hdr=a: b proxy.addr = :1234 proxy.localip = 4.4.4.4 proxy.strategy = rr @@ -50,6 +52,16 @@ ui.color = fonzy ui.title = fabfab ` out := &Config{ + CertSources: map[string]CertSource{ + "name": CertSource{ + Name: "name", + Type: "path", + CertPath: "foo", + ClientCAPath: "bar", + Refresh: 99 * time.Second, + Header: http.Header{"A": []string{"b"}}, + }, + }, Proxy: Proxy{ MaxConn: 666, LocalIP: "4.4.4.4", @@ -66,6 +78,7 @@ ui.title = fabfab TLSHeader: "tls", TLSHeaderValue: "tls-true", ListenerAddr: ":1234", + CertSources: "cs=name;type=path;cert=foo;clientca=bar;refresh=99s;hdr=a: b", }, Registry: Registry{ Backend: "something", @@ -92,6 +105,7 @@ ui.title = fabfab Listen: []Listen{ { Addr: ":1234", + Scheme: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 10 * time.Second, }, @@ -149,59 +163,65 @@ func TestParseScheme(t *testing.T) { } } -func TestParseAddr(t *testing.T) { +func TestParseListen(t *testing.T) { + cs := map[string]CertSource{ + "name": CertSource{Type: "foo"}, + } + tests := []struct { in string - out []Listen + out Listen err string }{ { "", - []Listen{}, + Listen{}, "", }, { ":123", - []Listen{ - {Addr: ":123"}, - }, + Listen{Addr: ":123", Scheme: "http"}, "", }, { - ":123;cert.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "cert.pem", TLS: true}, - }, + ":123;rt=5s;wt=5s", + Listen{Addr: ":123", Scheme: "http", ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second}, "", }, { - ":123;cert.pem;key.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "key.pem", TLS: true}, + ":123;pathA;pathB;pathC", + Listen{ + Addr: ":123", + Scheme: "https", + CertSource: CertSource{ + Type: "file", + CertPath: "pathA", + KeyPath: "pathB", + ClientCAPath: "pathC", + }, }, "", }, { - ":123;cert.pem;key.pem;client.pem", - []Listen{ - {Addr: ":123", CertFile: "cert.pem", KeyFile: "key.pem", ClientAuthFile: "client.pem", TLS: true}, + ":123;cs=name", + Listen{ + Addr: ":123", + Scheme: "https", + CertSource: CertSource{ + Type: "foo", + }, }, "", }, - { - ":123;cert.pem;key.pem;client.pem;", - nil, - "invalid address :123;cert.pem;key.pem;client.pem;", - }, } for i, tt := range tests { - l, err := parseListen(tt.in, time.Duration(0), time.Duration(0)) + l, err := parseListen(tt.in, cs, time.Duration(0), time.Duration(0)) if got, want := err, tt.err; (got != nil || want != "") && got.Error() != want { - t.Errorf("%d: got %v want %v", i, got, want) + t.Errorf("%d: got %+v want %+v", i, got, want) } if got, want := l, tt.out; !reflect.DeepEqual(got, want) { - t.Errorf("%d: got %v want %v", i, got, want) + t.Errorf("%d: got %+v want %+v", i, got, want) } } } diff --git a/fabio.properties b/fabio.properties index 636f20571..ac8719b71 100644 --- a/fabio.properties +++ b/fabio.properties @@ -1,33 +1,172 @@ -# proxy.addr configures the HTTP and HTTPS listeners as a comma separated list. +# proxy.cs configures one or more certificate sources. # -# To configure an HTTP listener provide [host]:port. -# To configure an HTTPS listener provide [host]:port;certFile;keyFile;clientAuthFile. -# certFile and keyFile contain the public/private key pair for that listener -# in PEM format. If certFile contains both the public and private key then -# keyFile can be omittted. -# clientAuthFile contains the root CAs for client certificate validation. -# When clientAuthFile is provided the TLS configuration is set to -# RequireAndVerifyClientCert. +# Each certificate source is configured with a list of +# key/value options. Each source must have a unique +# name which can then be referred to in a listener +# configuration. # -# Configure a single HTTP listener on port 9999: +# cs=;type=;opt=arg;opt[=arg];... # +# All certificates need to be provided in PEM format. +# +# The following types of certificate sources are available: +# +# File +# +# The file certificate source supports one certificate which is loaded at +# startup and is cached until the service exits. +# +# The 'cert' option contains the path to the certificate file. The 'key' +# option contains the path to the private key file. If the certificate file +# contains both the certificate and the private key the 'key' option can be +# omitted. The 'clientca' option contains the path to one or more client +# authentication certificates. +# +# cs=;type=file;cert=p/a-cert.pem;key=p/a-key.pem;clientca=p/clientAuth.pem +# +# Path +# +# The path certificate source loads certificates from a directory in +# alphabetical order and refreshes them periodically. +# +# The 'cert' option provides the path to the TLS certificates and the +# 'clientca' option provides the path to the certificates for client +# authentication. +# +# TLS certificates are stored either in one or two files: +# +# www.example.com.pem or www.example.com-{cert,key}.pem +# +# TLS certificates are loaded in alphabetical order and the first certificate +# is the default for clients which do not support SNI. +# +# The 'refresh' option can be set to specify the refresh interval for the TLS +# certificates. Client authentication certificates cannot be refreshed since +# Go does not provide a mechanism for that yet. +# +# The default refresh interval is 3 seconds and cannot be lower than 1 second +# to prevent busy loops. To load the certificates only once and disable +# automatic refreshing set 'refresh' to zero. +# +# cs=;type=path;cert=path/to/certs;clientca=path/to/clientcas;refresh=3s +# +# HTTP +# +# The http certificate source loads certificates from an HTTP/HTTPS server. +# +# The 'cert' option provides a URL to a text file which contains all files +# that should be loaded from this directory. The filenames follow the same +# rules as for the path source. The text file can be generated with: +# +# ls -1 *.pem > list +# +# The 'clientca' option provides a URL for the client authentication +# certificates analogous to the 'cert' option. +# +# Authentication credentials can be provided in the URL as request parameter, +# as basic authentication parameters or through a header. +# +# The 'refresh' option can be set to specify the refresh interval for the TLS +# certificates. Client authentication certificates cannot be refreshed since +# Go does not provide a mechanism for that yet. +# +# The default refresh interval is 3 seconds and cannot be lower than 1 second +# to prevent busy loops. To load the certificates only once and disable +# automatic refreshing set 'refresh' to zero. +# +# cs=;type=http;cert=https://host.com/path/to/cert/list&token=123 +# cs=;type=http;cert=https://user:pass@host.com/path/to/cert/list +# cs=;type=http;cert=https://host.com/path/to/cert/list;hdr=Authorization: Bearer 1234 +# +# Consul +# +# The consul certificate source loads certificates from consul. +# +# The 'cert' option provides a KV store URL where the the TLS certificates are +# stored. +# +# The 'clientca' option provides a URL to a path in the KV store where the the +# client authentication certificates are stored. +# +# The filenames follow the same rules as for the path source. +# +# The TLS certificates are updated automatically whenever the KV store +# changes. The client authentication certificates cannot be updated +# automatically since Go does not provide a mechanism for that yet. +# +# cs=;type=consul;cert=http://localhost:8500/v1/kv/path/to/cert&token=123 +# +# Vault +# +# The Vault certificate store uses HashiCorp Vault as the certificate +# store. +# +# The 'cert' option provides the path to the TLS certificates and the +# 'clientca' option provides the path to the certificates for client +# authentication. +# +# The token must be provided in the VAULT_TOKEN environment variable. +# +# cs=;type=vault;cert=https://host:port/secret/fabio/certs +# +# Examples: +# +# # file based certificate source +# proxy.cs = cs=some-name;type=file;cert=p/a-cert.pem;key=p/a-key.pem +# +# # path based certificate source +# proxy.cs = cs=some-name;type=path;path=path/to/certs +# +# # HTTP certificate source +# proxy.cs = cs=some-name;type=http;cert=https://user:pass@host:port/path/to/certs +# +# # Consul certificate source +# proxy.cs = cs=some-name;type=consul;cert=https://host:port/v1/kv/path/to/certs?token=abc123 +# +# # Vault certificate source +# proxy.cs = cs=some-name;type=vault;cert=https://host:port/secret/fabio/certs +# +# # Multiple certificate sources +# proxy.cs = cs=srcA;type=path;path=path/to/certs,\ +# cs=srcB;type=http;cert=https://user:pass@host:port/path/to/certs +# +# The default is +# +# proxy.cs = + + +# proxy.addr configures the HTTP and HTTPS listeners. +# +# Each listener is configured with and address and a +# list of optional arguments in the form of +# +# [host]:port;opt=arg;opt[=arg];... +# +# General options: +# +# read timeout: rt= +# write timeout: wt= +# +# HTTPS listeners require a certificate source which is +# configured by setting the 'cs' option to the name of +# a certificate source. +# +# Examples: +# +# # HTTP listener on port 9999 # proxy.addr = :9999 # -# Configure both an HTTP and HTTPS listener: +# # HTTP listener on IPv4 with read timeout +# proxy.addr = 1.2.3.4:9999;rt=3s # -# proxy.addr = :9999,:443;path/to/cert.pem;path/to/key.pem;path/to/clientauth.pem +# # HTTP listener on IPv6 with write timeout +# proxy.addr = [2001:DB8::A/32]:9999;wt=5s # -# Configure multiple HTTP and HTTPS listeners on IPv4 and IPv6: +# # Multiple listeners +# proxy.addr = 1.2.3.4:9999;rt=3s,[2001:DB8::A/32]:9999;wt=5s # -# proxy.addr = \ -# 1.2.3.4:9999, \ -# 5.6.7.8:9999, \ -# [2001:DB8::A/32]:9999, \ -# [2001:DB8::B/32]:9999, \ -# 1.2.3.4:443;path/to/certA.pem;path/to/keyA.pem, \ -# 5.6.7.8:443;path/to/certB.pem;path/to/keyB.pem, \ -# [2001:DB8::A/32]:443;path/to/certA.pem;path/to/keyA.pem, \ -# [2001:DB8::B/32]:443;path/to/certB.pem;path/to/keyB.pem +# # HTTPS listener on port 443 with certificate source +# proxy.addr = :443;cs=some-name # # The default is # diff --git a/listen.go b/listen.go index 527e70f6b..6c8b67e61 100644 --- a/listen.go +++ b/listen.go @@ -2,9 +2,7 @@ package main import ( "crypto/tls" - "crypto/x509" - "errors" - "io/ioutil" + "fmt" "log" "net" "net/http" @@ -13,6 +11,7 @@ import ( "time" "github.com/armon/go-proxyproto" + "github.com/eBay/fabio/cert" "github.com/eBay/fabio/config" "github.com/eBay/fabio/exit" "github.com/eBay/fabio/proxy" @@ -45,15 +44,29 @@ func startListeners(listen []config.Listen, wait time.Duration, h http.Handler) } func listenAndServe(l config.Listen, h http.Handler) { - srv, err := newServer(l, h) - if err != nil { - log.Fatal("[FATAL] ", err) + srv := &http.Server{ + Handler: h, + Addr: l.Addr, + ReadTimeout: l.ReadTimeout, + WriteTimeout: l.WriteTimeout, + } + + if l.Scheme == "https" { + src, err := makeCertSource(l.CertSource) + if err != nil { + log.Fatal("[FATAL] ", err) + } + + srv.TLSConfig, err = cert.TLSConfig(src) + if err != nil { + log.Fatal("[FATAL] ", err) + } } if srv.TLSConfig != nil { - log.Printf("[INFO] HTTPS proxy listening on %s with certificate %s", l.Addr, l.CertFile) + log.Printf("[INFO] HTTPS proxy listening on %s", l.Addr) if srv.TLSConfig.ClientAuth == tls.RequireAndVerifyClientCert { - log.Printf("[INFO] Client certificate authentication enabled on %s with certificates from %s", l.Addr, l.ClientAuthFile) + log.Printf("[INFO] Client certificate authentication enabled on %s", l.Addr) } } else { log.Printf("[INFO] HTTP proxy listening on %s", l.Addr) @@ -64,41 +77,46 @@ func listenAndServe(l config.Listen, h http.Handler) { } } -var tlsLoadX509KeyPair = tls.LoadX509KeyPair - -func newServer(l config.Listen, h http.Handler) (*http.Server, error) { - srv := &http.Server{ - Addr: l.Addr, - Handler: h, - ReadTimeout: l.ReadTimeout, - WriteTimeout: l.WriteTimeout, - } - - if l.CertFile != "" { - cert, err := tlsLoadX509KeyPair(l.CertFile, l.KeyFile) - if err != nil { - return nil, err - } - - srv.TLSConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - } - - if l.ClientAuthFile != "" { - pemBlock, err := ioutil.ReadFile(l.ClientAuthFile) - if err != nil { - return nil, err - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(pemBlock) { - return nil, errors.New("failed to add client auth certs") - } - srv.TLSConfig.ClientCAs = pool - srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert - } +func makeCertSource(cfg config.CertSource) (cert.Source, error) { + switch cfg.Type { + case "file": + return cert.FileSource{ + CertFile: cfg.CertPath, + KeyFile: cfg.KeyPath, + ClientAuthFile: cfg.ClientCAPath, + }, nil + + case "path": + return cert.PathSource{ + CertPath: cfg.CertPath, + ClientCAPath: cfg.ClientCAPath, + Refresh: cfg.Refresh, + }, nil + + case "http": + return cert.HTTPSource{ + CertURL: cfg.CertPath, + ClientCAURL: cfg.ClientCAPath, + Refresh: cfg.Refresh, + }, nil + + case "consul": + return cert.ConsulSource{ + CertURL: cfg.CertPath, + ClientCAURL: cfg.ClientCAPath, + }, nil + + case "vault": + return cert.VaultSource{ + // TODO(fs): configure Addr but not token + CertPath: cfg.CertPath, + ClientCAPath: cfg.ClientCAPath, + Refresh: cfg.Refresh, + }, nil + + default: + return nil, fmt.Errorf("invalid certificate source %q", cfg.Type) } - - return srv, nil } func serve(srv *http.Server) error { diff --git a/listen_test.go b/listen_test.go index 270dff35b..6ca76099c 100644 --- a/listen_test.go +++ b/listen_test.go @@ -1,10 +1,8 @@ package main import ( - "crypto/tls" "net/http" "net/http/httptest" - "reflect" "sync" "testing" "time" @@ -14,48 +12,6 @@ import ( "github.com/eBay/fabio/route" ) -func TestNewServer(t *testing.T) { - h := http.DefaultServeMux - cert := tls.Certificate{} - tlsLoadX509KeyPair = func(string, string) (tls.Certificate, error) { - return cert, nil - } - defer func() { tlsLoadX509KeyPair = tls.LoadX509KeyPair }() - - tests := []struct { - in config.Listen - out *http.Server - err string - }{ - { - config.Listen{Addr: ":123"}, - &http.Server{Addr: ":123", Handler: h}, - "", - }, - { - config.Listen{Addr: ":123", CertFile: "cert.pem"}, - &http.Server{ - Addr: ":123", - Handler: h, - TLSConfig: &tls.Config{ - Certificates: []tls.Certificate{cert}, - }, - }, - "", - }, - } - - for i, tt := range tests { - srv, err := newServer(tt.in, h) - if got, want := err, tt.err; (got != nil || want != "") && got.Error() != want { - t.Errorf("%d: got %v want %v", i, got, want) - } - if got, want := srv, tt.out; !reflect.DeepEqual(got, want) { - t.Errorf("%d: got %v want %v", i, got, want) - } - } -} - func TestGracefulShutdown(t *testing.T) { req := func(url string) int { resp, err := http.Get(url)