Skip to content

Commit

Permalink
Issue #199: Refactor config loader test
Browse files Browse the repository at this point in the history
The config loader test was not testing the full
roundtrip of the command line arguments, environment
variables and config files from files and URLs.
Various smaller tests were only testing individual parts
of the config loader. This made updating and refactoring
the config loading code more vulnerable to regressions.

This refactor replaces all smaller tests with a single
test suite that tests all aspects of the config loading
in the same way as it is used by the main() function.
New tests can be added easily and consistently.

At the same time, variables which store temporary
values for the config loader have been removed
from the main config structure as they are
redundant and confusing since no part of the
application needs access to the raw argument
values.
  • Loading branch information
magiconair committed Dec 6, 2016
1 parent e259825 commit 2a2f96d
Show file tree
Hide file tree
Showing 5 changed files with 829 additions and 432 deletions.
19 changes: 6 additions & 13 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,12 @@ import (
)

type Config struct {
Proxy Proxy
Registry Registry
Listen []Listen
CertSources map[string]CertSource
Metrics Metrics
UI UI
Runtime Runtime

ListenerValue []string
CertSourcesValue []map[string]string
Proxy Proxy
Registry Registry
Listen []Listen
Metrics Metrics
UI UI
Runtime Runtime
}

type CertSource struct {
Expand Down Expand Up @@ -54,14 +50,11 @@ type Proxy struct {
DialTimeout time.Duration
ResponseHeaderTimeout time.Duration
KeepAliveTimeout time.Duration
ReadTimeout time.Duration
WriteTimeout time.Duration
FlushInterval time.Duration
LocalIP string
ClientIPHeader string
TLSHeader string
TLSHeaderValue string
GZIPContentTypesValue string
GZIPContentTypes *regexp.Regexp
}

Expand Down
15 changes: 12 additions & 3 deletions config/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,18 @@ import (
"time"
)

var Default = &Config{
ListenerValue: []string{":9999"},
var defaultValues = struct {
ListenerValue []string
CertSourcesValue []map[string]string
ReadTimeout time.Duration
WriteTimeout time.Duration
GZIPContentTypesValue string
}{
ListenerValue: []string{":9999"},
CertSourcesValue: []map[string]string{},
}

var defaultConfig = &Config{
Proxy: Proxy{
MaxConn: 10000,
Strategy: "rnd",
Expand Down Expand Up @@ -45,5 +55,4 @@ var Default = &Config{
Interval: 30 * time.Second,
CirconusAPIApp: "fabio",
},
CertSources: map[string]CertSource{},
}
260 changes: 140 additions & 120 deletions config/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"log"
"net/http"
"os"
"regexp"
"runtime"
"strings"
Expand All @@ -15,145 +14,159 @@ import (
"github.com/magiconair/properties"
)

func Load() (cfg *Config, err error) {
var path string
for i, arg := range os.Args {
if arg == "-v" || arg == "--version" {
return nil, nil
func Load(args, environ []string) (cfg *Config, err error) {
var props *properties.Properties

cmdline, path, version, err := parse(args)
switch {
case err != nil:
return nil, err
case version:
return nil, nil
case path != "":
switch {
case strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://"):
props, err = properties.LoadURL(path)
case path != "":
props, err = properties.LoadFile(path, properties.UTF8)
}
path, err = parseCfg(os.Args, i)
if err != nil {
return nil, err
}
if path != "" {
break
}
}
p, err := loadProperties(path)
if err != nil {
return nil, err
}
return load(p)
envprefix := []string{"FABIO_", ""}
return load(cmdline, environ, envprefix, props)
}

var errInvalidConfig = errors.New("invalid or missing path to config file")

func parseCfg(args []string, i int) (path string, err error) {
if len(args) == 0 || i >= len(args) || !strings.HasPrefix(args[i], "-cfg") {
return "", nil
}
arg := args[i]
if arg == "-cfg" {
if i >= len(args)-1 {
return "", errInvalidConfig
}
return args[i+1], nil
// parse extracts the version and config file flags from the command
// line arguments and returns the individual parts. Test flags are
// ignored.
func parse(args []string) (cmdline []string, path string, version bool, err error) {
if len(args) < 1 {
panic("missing exec name")
}

if !strings.HasPrefix(arg, "-cfg=") {
return "", errInvalidConfig
}
// always copy the name of the executable
cmdline = args[:1]

path = arg[len("-cfg="):]
switch {
case path == "":
return "", errInvalidConfig
case path[0] == '\'':
path = strings.Trim(path, "'")
case path[0] == '"':
path = strings.Trim(path, "\"")
}
if path == "" {
return "", errInvalidConfig
}
return path, nil
}
// parse rest of the arguments
for i := 1; i < len(args); i++ {
arg := args[i]

func loadProperties(path string) (p *properties.Properties, err error) {
if path == "" {
return properties.NewProperties(), nil
}
if strings.HasPrefix(path, "http://") || strings.HasPrefix(path, "https://") {
return properties.LoadURL(path)
switch {
// version flag
case arg == "-v" || arg == "--version":
return nil, "", true, nil

// config file without '='
case arg == "-cfg" || arg == "--cfg":
if i >= len(args)-1 {
return nil, "", false, errInvalidConfig
}
path = args[i+1]
i++

// config file with '='. needs unquoting
case strings.HasPrefix(arg, "-cfg=") || strings.HasPrefix(arg, "--cfg="):
if strings.HasPrefix(arg, "-cfg=") {
path = arg[len("-cfg="):]
} else {
path = arg[len("--cfg="):]
}
switch {
case path == "":
return nil, "", false, errInvalidConfig
case path[0] == '\'':
path = strings.Trim(path, "'")
case path[0] == '"':
path = strings.Trim(path, "\"")
}
if path == "" {
return nil, "", false, errInvalidConfig
}

// ignore test flags
case strings.HasPrefix(arg, "-test."):
continue

default:
cmdline = append(cmdline, arg)
}
}
return properties.LoadFile(path, properties.UTF8)
return cmdline, path, false, nil
}

func load(p *properties.Properties) (cfg *Config, err error) {
func load(cmdline, environ, envprefix []string, props *properties.Properties) (cfg *Config, err error) {
cfg = &Config{}

f := NewFlagSet(os.Args[0], flag.ExitOnError)
f := NewFlagSet(cmdline[0], flag.ExitOnError)

// dummy values which were parsed earlier
f.String("cfg", "", "Path or URL to config file")
f.Bool("v", false, "Show version")
f.Bool("version", false, "Show version")

// config values
f.IntVar(&cfg.Proxy.MaxConn, "proxy.maxconn", Default.Proxy.MaxConn, "maximum number of cached connections")
f.StringVar(&cfg.Proxy.Strategy, "proxy.strategy", Default.Proxy.Strategy, "load balancing strategy")
f.StringVar(&cfg.Proxy.Matcher, "proxy.matcher", Default.Proxy.Matcher, "path matching algorithm")
f.IntVar(&cfg.Proxy.NoRouteStatus, "proxy.noroutestatus", Default.Proxy.NoRouteStatus, "status code for invalid route")
f.DurationVar(&cfg.Proxy.ShutdownWait, "proxy.shutdownwait", Default.Proxy.ShutdownWait, "time for graceful shutdown")
f.DurationVar(&cfg.Proxy.DialTimeout, "proxy.dialtimeout", Default.Proxy.DialTimeout, "connection timeout for backend connections")
f.DurationVar(&cfg.Proxy.ResponseHeaderTimeout, "proxy.responseheadertimeout", Default.Proxy.ResponseHeaderTimeout, "response header timeout")
f.DurationVar(&cfg.Proxy.KeepAliveTimeout, "proxy.keepalivetimeout", Default.Proxy.KeepAliveTimeout, "keep-alive timeout")
f.StringVar(&cfg.Proxy.LocalIP, "proxy.localip", Default.Proxy.LocalIP, "fabio address in Forward headers")
f.StringVar(&cfg.Proxy.ClientIPHeader, "proxy.header.clientip", Default.Proxy.ClientIPHeader, "header for the request ip")
f.StringVar(&cfg.Proxy.TLSHeader, "proxy.header.tls", Default.Proxy.TLSHeader, "header for TLS connections")
f.StringVar(&cfg.Proxy.TLSHeaderValue, "proxy.header.tls.value", Default.Proxy.TLSHeaderValue, "value for TLS connection header")
f.StringVar(&cfg.Proxy.GZIPContentTypesValue, "proxy.gzip.contenttype", Default.Proxy.GZIPContentTypesValue, "regexp of content types to compress")
f.StringSliceVar(&cfg.ListenerValue, "proxy.addr", Default.ListenerValue, "listener config")
f.KVSliceVar(&cfg.CertSourcesValue, "proxy.cs", Default.CertSourcesValue, "certificate sources")
f.DurationVar(&cfg.Proxy.ReadTimeout, "proxy.readtimeout", Default.Proxy.ReadTimeout, "read timeout for incoming requests")
f.DurationVar(&cfg.Proxy.WriteTimeout, "proxy.writetimeout", Default.Proxy.WriteTimeout, "write timeout for outgoing responses")
f.DurationVar(&cfg.Proxy.FlushInterval, "proxy.flushinterval", Default.Proxy.FlushInterval, "flush interval for streaming responses")
f.StringVar(&cfg.Metrics.Target, "metrics.target", Default.Metrics.Target, "metrics backend")
f.StringVar(&cfg.Metrics.Prefix, "metrics.prefix", Default.Metrics.Prefix, "prefix for reported metrics")
f.StringVar(&cfg.Metrics.Names, "metrics.names", Default.Metrics.Names, "route metric name template")
f.DurationVar(&cfg.Metrics.Interval, "metrics.interval", Default.Metrics.Interval, "metrics reporting interval")
f.StringVar(&cfg.Metrics.GraphiteAddr, "metrics.graphite.addr", Default.Metrics.GraphiteAddr, "graphite server address")
f.StringVar(&cfg.Metrics.StatsDAddr, "metrics.statsd.addr", Default.Metrics.StatsDAddr, "statsd server address")
f.StringVar(&cfg.Metrics.CirconusAPIKey, "metrics.circonus.apikey", Default.Metrics.CirconusAPIKey, "Circonus API token key")
f.StringVar(&cfg.Metrics.CirconusAPIApp, "metrics.circonus.apiapp", Default.Metrics.CirconusAPIApp, "Circonus API token app")
f.StringVar(&cfg.Metrics.CirconusAPIURL, "metrics.circonus.apiurl", Default.Metrics.CirconusAPIURL, "Circonus API URL")
f.StringVar(&cfg.Metrics.CirconusBrokerID, "metrics.circonus.brokerid", Default.Metrics.CirconusBrokerID, "Circonus Broker ID")
f.StringVar(&cfg.Metrics.CirconusCheckID, "metrics.circonus.checkid", Default.Metrics.CirconusCheckID, "Circonus Check ID")
f.StringVar(&cfg.Registry.Backend, "registry.backend", Default.Registry.Backend, "registry backend")
f.StringVar(&cfg.Registry.File.Path, "registry.file.path", Default.Registry.File.Path, "path to file based routing table")
f.StringVar(&cfg.Registry.Static.Routes, "registry.static.routes", Default.Registry.Static.Routes, "static routes")
f.StringVar(&cfg.Registry.Consul.Addr, "registry.consul.addr", Default.Registry.Consul.Addr, "address of the consul agent")
f.StringVar(&cfg.Registry.Consul.Token, "registry.consul.token", Default.Registry.Consul.Token, "token for consul agent")
f.StringVar(&cfg.Registry.Consul.KVPath, "registry.consul.kvpath", Default.Registry.Consul.KVPath, "consul KV path for manual overrides")
f.StringVar(&cfg.Registry.Consul.TagPrefix, "registry.consul.tagprefix", Default.Registry.Consul.TagPrefix, "prefix for consul tags")
f.BoolVar(&cfg.Registry.Consul.Register, "registry.consul.register.enabled", Default.Registry.Consul.Register, "register fabio in consul")
f.StringVar(&cfg.Registry.Consul.ServiceAddr, "registry.consul.register.addr", Default.Registry.Consul.ServiceAddr, "service registration address")
f.StringVar(&cfg.Registry.Consul.ServiceName, "registry.consul.register.name", Default.Registry.Consul.ServiceName, "service registration name")
f.StringSliceVar(&cfg.Registry.Consul.ServiceTags, "registry.consul.register.tags", Default.Registry.Consul.ServiceTags, "service registration tags")
f.StringSliceVar(&cfg.Registry.Consul.ServiceStatus, "registry.consul.service.status", Default.Registry.Consul.ServiceStatus, "valid service status values")
f.DurationVar(&cfg.Registry.Consul.CheckInterval, "registry.consul.register.checkInterval", Default.Registry.Consul.CheckInterval, "service check interval")
f.DurationVar(&cfg.Registry.Consul.CheckTimeout, "registry.consul.register.checkTimeout", Default.Registry.Consul.CheckTimeout, "service check timeout")
f.IntVar(&cfg.Runtime.GOGC, "runtime.gogc", Default.Runtime.GOGC, "sets runtime.GOGC")
f.IntVar(&cfg.Runtime.GOMAXPROCS, "runtime.gomaxprocs", Default.Runtime.GOMAXPROCS, "sets runtime.GOMAXPROCS")
f.StringVar(&cfg.UI.Addr, "ui.addr", Default.UI.Addr, "address the UI/API is listening on")
f.StringVar(&cfg.UI.Color, "ui.color", Default.UI.Color, "background color of the UI")
f.StringVar(&cfg.UI.Title, "ui.title", Default.UI.Title, "optional title for the UI")
var listenerValue []string
var certSourcesValue []map[string]string
var readTimeout, writeTimeout time.Duration
var gzipContentTypesValue string

f.IntVar(&cfg.Proxy.MaxConn, "proxy.maxconn", defaultConfig.Proxy.MaxConn, "maximum number of cached connections")
f.StringVar(&cfg.Proxy.Strategy, "proxy.strategy", defaultConfig.Proxy.Strategy, "load balancing strategy")
f.StringVar(&cfg.Proxy.Matcher, "proxy.matcher", defaultConfig.Proxy.Matcher, "path matching algorithm")
f.IntVar(&cfg.Proxy.NoRouteStatus, "proxy.noroutestatus", defaultConfig.Proxy.NoRouteStatus, "status code for invalid route")
f.DurationVar(&cfg.Proxy.ShutdownWait, "proxy.shutdownwait", defaultConfig.Proxy.ShutdownWait, "time for graceful shutdown")
f.DurationVar(&cfg.Proxy.DialTimeout, "proxy.dialtimeout", defaultConfig.Proxy.DialTimeout, "connection timeout for backend connections")
f.DurationVar(&cfg.Proxy.ResponseHeaderTimeout, "proxy.responseheadertimeout", defaultConfig.Proxy.ResponseHeaderTimeout, "response header timeout")
f.DurationVar(&cfg.Proxy.KeepAliveTimeout, "proxy.keepalivetimeout", defaultConfig.Proxy.KeepAliveTimeout, "keep-alive timeout")
f.StringVar(&cfg.Proxy.LocalIP, "proxy.localip", defaultConfig.Proxy.LocalIP, "fabio address in Forward headers")
f.StringVar(&cfg.Proxy.ClientIPHeader, "proxy.header.clientip", defaultConfig.Proxy.ClientIPHeader, "header for the request ip")
f.StringVar(&cfg.Proxy.TLSHeader, "proxy.header.tls", defaultConfig.Proxy.TLSHeader, "header for TLS connections")
f.StringVar(&cfg.Proxy.TLSHeaderValue, "proxy.header.tls.value", defaultConfig.Proxy.TLSHeaderValue, "value for TLS connection header")
f.StringVar(&gzipContentTypesValue, "proxy.gzip.contenttype", defaultValues.GZIPContentTypesValue, "regexp of content types to compress")
f.StringSliceVar(&listenerValue, "proxy.addr", defaultValues.ListenerValue, "listener config")
f.KVSliceVar(&certSourcesValue, "proxy.cs", defaultValues.CertSourcesValue, "certificate sources")
f.DurationVar(&readTimeout, "proxy.readtimeout", defaultValues.ReadTimeout, "read timeout for incoming requests")
f.DurationVar(&writeTimeout, "proxy.writetimeout", defaultValues.WriteTimeout, "write timeout for outgoing responses")
f.DurationVar(&cfg.Proxy.FlushInterval, "proxy.flushinterval", defaultConfig.Proxy.FlushInterval, "flush interval for streaming responses")
f.StringVar(&cfg.Metrics.Target, "metrics.target", defaultConfig.Metrics.Target, "metrics backend")
f.StringVar(&cfg.Metrics.Prefix, "metrics.prefix", defaultConfig.Metrics.Prefix, "prefix for reported metrics")
f.StringVar(&cfg.Metrics.Names, "metrics.names", defaultConfig.Metrics.Names, "route metric name template")
f.DurationVar(&cfg.Metrics.Interval, "metrics.interval", defaultConfig.Metrics.Interval, "metrics reporting interval")
f.StringVar(&cfg.Metrics.GraphiteAddr, "metrics.graphite.addr", defaultConfig.Metrics.GraphiteAddr, "graphite server address")
f.StringVar(&cfg.Metrics.StatsDAddr, "metrics.statsd.addr", defaultConfig.Metrics.StatsDAddr, "statsd server address")
f.StringVar(&cfg.Metrics.CirconusAPIKey, "metrics.circonus.apikey", defaultConfig.Metrics.CirconusAPIKey, "Circonus API token key")
f.StringVar(&cfg.Metrics.CirconusAPIApp, "metrics.circonus.apiapp", defaultConfig.Metrics.CirconusAPIApp, "Circonus API token app")
f.StringVar(&cfg.Metrics.CirconusAPIURL, "metrics.circonus.apiurl", defaultConfig.Metrics.CirconusAPIURL, "Circonus API URL")
f.StringVar(&cfg.Metrics.CirconusBrokerID, "metrics.circonus.brokerid", defaultConfig.Metrics.CirconusBrokerID, "Circonus Broker ID")
f.StringVar(&cfg.Metrics.CirconusCheckID, "metrics.circonus.checkid", defaultConfig.Metrics.CirconusCheckID, "Circonus Check ID")
f.StringVar(&cfg.Registry.Backend, "registry.backend", defaultConfig.Registry.Backend, "registry backend")
f.StringVar(&cfg.Registry.File.Path, "registry.file.path", defaultConfig.Registry.File.Path, "path to file based routing table")
f.StringVar(&cfg.Registry.Static.Routes, "registry.static.routes", defaultConfig.Registry.Static.Routes, "static routes")
f.StringVar(&cfg.Registry.Consul.Addr, "registry.consul.addr", defaultConfig.Registry.Consul.Addr, "address of the consul agent")
f.StringVar(&cfg.Registry.Consul.Token, "registry.consul.token", defaultConfig.Registry.Consul.Token, "token for consul agent")
f.StringVar(&cfg.Registry.Consul.KVPath, "registry.consul.kvpath", defaultConfig.Registry.Consul.KVPath, "consul KV path for manual overrides")
f.StringVar(&cfg.Registry.Consul.TagPrefix, "registry.consul.tagprefix", defaultConfig.Registry.Consul.TagPrefix, "prefix for consul tags")
f.BoolVar(&cfg.Registry.Consul.Register, "registry.consul.register.enabled", defaultConfig.Registry.Consul.Register, "register fabio in consul")
f.StringVar(&cfg.Registry.Consul.ServiceAddr, "registry.consul.register.addr", defaultConfig.Registry.Consul.ServiceAddr, "service registration address")
f.StringVar(&cfg.Registry.Consul.ServiceName, "registry.consul.register.name", defaultConfig.Registry.Consul.ServiceName, "service registration name")
f.StringSliceVar(&cfg.Registry.Consul.ServiceTags, "registry.consul.register.tags", defaultConfig.Registry.Consul.ServiceTags, "service registration tags")
f.StringSliceVar(&cfg.Registry.Consul.ServiceStatus, "registry.consul.service.status", defaultConfig.Registry.Consul.ServiceStatus, "valid service status values")
f.DurationVar(&cfg.Registry.Consul.CheckInterval, "registry.consul.register.checkInterval", defaultConfig.Registry.Consul.CheckInterval, "service check interval")
f.DurationVar(&cfg.Registry.Consul.CheckTimeout, "registry.consul.register.checkTimeout", defaultConfig.Registry.Consul.CheckTimeout, "service check timeout")
f.IntVar(&cfg.Runtime.GOGC, "runtime.gogc", defaultConfig.Runtime.GOGC, "sets runtime.GOGC")
f.IntVar(&cfg.Runtime.GOMAXPROCS, "runtime.gomaxprocs", defaultConfig.Runtime.GOMAXPROCS, "sets runtime.GOMAXPROCS")
f.StringVar(&cfg.UI.Addr, "ui.addr", defaultConfig.UI.Addr, "address the UI/API is listening on")
f.StringVar(&cfg.UI.Color, "ui.color", defaultConfig.UI.Color, "background color of the UI")
f.StringVar(&cfg.UI.Title, "ui.title", defaultConfig.UI.Title, "optional title for the UI")

var awsApiGWCertCN string
f.StringVar(&awsApiGWCertCN, "aws.apigw.cert.cn", "", "deprecated. use caupgcn=<CN> for cert source")

// filter out -test flags
var args []string
for _, a := range os.Args[1:] {
if strings.HasPrefix(a, "-test.") {
continue
}
args = append(args, a)
}

// parse configuration
prefixes := []string{"FABIO_", ""}
if err := f.ParseFlags(args, os.Environ(), prefixes, p); err != nil {
if err := f.ParseFlags(cmdline[1:], environ, envprefix, props); err != nil {
return nil, err
}

Expand All @@ -164,18 +177,18 @@ func load(p *properties.Properties) (cfg *Config, err error) {

cfg.Registry.Consul.Scheme, cfg.Registry.Consul.Addr = parseScheme(cfg.Registry.Consul.Addr)

cfg.CertSources, err = parseCertSources(cfg.CertSourcesValue)
certSources, err := parseCertSources(certSourcesValue)
if err != nil {
return nil, err
}

cfg.Listen, err = parseListeners(cfg.ListenerValue, cfg.CertSources, cfg.Proxy.ReadTimeout, cfg.Proxy.WriteTimeout)
cfg.Listen, err = parseListeners(listenerValue, certSources, readTimeout, writeTimeout)
if err != nil {
return nil, err
}

if cfg.Proxy.GZIPContentTypesValue != "" {
cfg.Proxy.GZIPContentTypes, err = regexp.Compile(cfg.Proxy.GZIPContentTypesValue)
if gzipContentTypesValue != "" {
cfg.Proxy.GZIPContentTypes, err = regexp.Compile(gzipContentTypesValue)
if err != nil {
return nil, fmt.Errorf("invalid expression for content types: %s", err)
}
Expand All @@ -195,13 +208,20 @@ func load(p *properties.Properties) (cfg *Config, err error) {
// to "http" if no scheme was given.
func parseScheme(s string) (scheme, addr string) {
s = strings.ToLower(s)
if strings.HasPrefix(s, "https://") {
return "https", s[len("https://"):]
switch {
case strings.HasPrefix(s, "https://"):
scheme, addr = "https", s[len("https://"):]
case strings.HasPrefix(s, "http://"):
scheme, addr = "http", s[len("http://"):]
default:
scheme, addr = "http", s
}
if strings.HasPrefix(s, "http://") {
return "http", s[len("http://"):]

// strip off anything after a final slash
if n := strings.Index(addr, "/"); n >= 0 {
addr = addr[:n]
}
return "http", s
return
}

func parseListeners(cfgs []string, cs map[string]CertSource, readTimeout, writeTimeout time.Duration) (listen []Listen, err error) {
Expand Down
Loading

0 comments on commit 2a2f96d

Please sign in to comment.