Skip to content

Commit

Permalink
fix: Fix http basic flag.
Browse files Browse the repository at this point in the history
It seems viper function RegisterAlias() does not works as expected, the GetStringSlice() function does not returns anything for any flag (the aliased as well as the alias).
So the fix consist of not using RegisterAlias() and simply handle both flags.
Also the PR move some bootstrap code from cmd/container.go and cmd/root.go inside package cmd/internal/http_basic.go.
Now, all code related to http basic feature is located there.
  • Loading branch information
gfyrag committed Jun 27, 2022
1 parent 8ccf42a commit 8a7d762
Show file tree
Hide file tree
Showing 6 changed files with 231 additions and 30 deletions.
14 changes: 3 additions & 11 deletions cmd/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/numary/go-libs/sharedpublish"
"github.com/numary/go-libs/sharedpublish/sharedpublishhttp"
"github.com/numary/go-libs/sharedpublish/sharedpublishkafka"
"github.com/numary/ledger/cmd/internal"
"github.com/numary/ledger/pkg/analytics"
"github.com/numary/ledger/pkg/api"
"github.com/numary/ledger/pkg/api/middlewares"
Expand Down Expand Up @@ -248,17 +249,8 @@ func NewContainer(v *viper.Viper, userOptions ...fx.Option) *fx.App {
res := make([]gin.HandlerFunc, 0)

methods := make([]sharedauth.Method, 0)
if basicAuth := v.GetStringSlice(authBasicCredentialsFlag); len(basicAuth) > 0 &&
(!v.IsSet(authBasicEnabledFlag) || v.GetBool(authBasicEnabledFlag)) { // Keep compatibility, we disable the feature only if the flag is explicitely set to false
credentials := sharedauth.Credentials{}
for _, kv := range basicAuth {
parts := strings.SplitN(kv, ":", 2)
credentials[parts[0]] = sharedauth.Credential{
Password: parts[1],
Scopes: routes.AllScopes,
}
}
methods = append(methods, sharedauth.NewHTTPBasicMethod(credentials))
if httpBasicMethod := internal.HTTPBasicAuthMethod(v); httpBasicMethod != nil {
methods = append(methods, httpBasicMethod)
}
if v.GetBool(authBearerEnabledFlag) {
methods = append(methods, sharedauth.NewHttpBearerMethod(
Expand Down
3 changes: 2 additions & 1 deletion cmd/doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"
"text/tabwriter"

"github.com/numary/ledger/cmd/internal"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
Expand Down Expand Up @@ -43,7 +44,7 @@ func NewDocFlagCommand() *cobra.Command {
panic(err)
}
for _, key := range allKeys {
asEnvVar := strings.ToUpper(replacer.Replace(key))
asEnvVar := strings.ToUpper(internal.EnvVarReplacer.Replace(key))
flag := cmd.Parent().Parent().PersistentFlags().Lookup(key)
if flag == nil {
continue
Expand Down
19 changes: 19 additions & 0 deletions cmd/internal/env.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package internal

import (
"strings"

"github.com/spf13/viper"
)

const (
envPrefix = "numary"
)

var EnvVarReplacer = strings.NewReplacer(".", "_", "-", "_")

func BindEnv(v *viper.Viper) {
v.SetEnvPrefix(envPrefix)
v.SetEnvKeyReplacer(EnvVarReplacer)
v.AutomaticEnv()
}
42 changes: 42 additions & 0 deletions cmd/internal/http_basic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package internal

import (
"strings"

"github.com/numary/go-libs/sharedauth"
"github.com/numary/ledger/pkg/api/routes"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)

const (
serverHttpBasicAuthFlag = "server.http.basic_auth"
authBasicEnabledFlag = "auth-basic-enabled"
authBasicCredentialsFlag = "auth-basic-credentials"
)

func HTTPBasicAuthMethod(v *viper.Viper) sharedauth.Method {
basicAuth := v.GetStringSlice(serverHttpBasicAuthFlag)
if len(basicAuth) == 0 {
basicAuth = v.GetStringSlice(authBasicCredentialsFlag)
}
if len(basicAuth) > 0 &&
(!v.IsSet(authBasicEnabledFlag) || v.GetBool(authBasicEnabledFlag)) { // Keep compatibility, we disable the feature only if the flag is explicitely set to false
credentials := sharedauth.Credentials{}
for _, kv := range basicAuth {
parts := strings.SplitN(kv, ":", 2)
credentials[parts[0]] = sharedauth.Credential{
Password: parts[1],
Scopes: routes.AllScopes,
}
}
return sharedauth.NewHTTPBasicMethod(credentials)
}
return nil
}

func InitHTTPBasicFlags(cmd *cobra.Command) {
cmd.PersistentFlags().Bool(authBasicEnabledFlag, false, "Enable basic auth")
cmd.PersistentFlags().StringSlice(authBasicCredentialsFlag, []string{}, "HTTP basic auth credentials (<username>:<password>)")
cmd.PersistentFlags().String(serverHttpBasicAuthFlag, "", "Http basic auth")
}
155 changes: 155 additions & 0 deletions cmd/internal/http_basic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
package internal

import (
"fmt"
"os"
"reflect"
"strings"
"testing"

"github.com/numary/go-libs/sharedauth"
"github.com/spf13/cobra"
"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)

func withPrefix(flag string) string {
return strings.ToUpper(fmt.Sprintf("%s_%s", envPrefix, EnvVarReplacer.Replace(flag)))
}

func setEnvVar(key, value string) func() {
prefixedFlag := withPrefix(key)
oldEnv := os.Getenv(prefixedFlag)
os.Setenv(prefixedFlag, value)
return func() {
os.Setenv(prefixedFlag, oldEnv)
}
}

func TestViperEnvBinding(t *testing.T) {

type testCase struct {
name string
key string
envValue string
viperMethod interface{}
expectedValue interface{}
}

for _, testCase := range []testCase{
{
name: "using deprecated credentials flag",
key: serverHttpBasicAuthFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetString,
expectedValue: "foo:bar",
},
{
name: "using credentials flag",
key: authBasicCredentialsFlag,
envValue: "foo:bar",
viperMethod: (*viper.Viper).GetStringSlice,
expectedValue: []string{"foo:bar"},
},
{
name: "using http basic enabled flags",
key: authBasicEnabledFlag,
envValue: "true",
viperMethod: (*viper.Viper).GetBool,
expectedValue: true,
},
} {
t.Run(testCase.name, func(t *testing.T) {
v := viper.GetViper()
cmd := &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
ret := reflect.ValueOf(testCase.viperMethod).Call([]reflect.Value{
reflect.ValueOf(v),
reflect.ValueOf(testCase.key),
})
require.Len(t, ret, 1)

rValue := ret[0].Interface()
require.Equal(t, testCase.expectedValue, rValue)
},
}
InitHTTPBasicFlags(cmd)
BindEnv(v)

restoreEnvVar := setEnvVar(testCase.key, testCase.envValue)
defer restoreEnvVar()

require.NoError(t, v.BindPFlags(cmd.PersistentFlags()))

require.NoError(t, cmd.Execute())
})
}
}

func TestHTTPBasicAuthMethod(t *testing.T) {

type testCase struct {
name string
args []string
expectedBasicAuthMethod bool
}

for _, testCase := range []testCase{
{
name: "no flag defined",
args: []string{},
expectedBasicAuthMethod: false,
},
{
name: "with latest credentials flag",
args: []string{
fmt.Sprintf("--%s=%s", authBasicCredentialsFlag, "foo:bar"),
},
expectedBasicAuthMethod: true,
},
{
name: "with deprecated credentials flag",
args: []string{
fmt.Sprintf("--%s=%s", serverHttpBasicAuthFlag, "foo:bar"),
},
expectedBasicAuthMethod: true,
},
{
name: "with enabled flag set to false",
args: []string{
fmt.Sprintf("--%s=%s", serverHttpBasicAuthFlag, "foo:bar"),
fmt.Sprintf("--%s=false", authBasicEnabledFlag),
},
expectedBasicAuthMethod: false,
},
{
name: "with enabled flag set to true",
args: []string{
fmt.Sprintf("--%s=%s", serverHttpBasicAuthFlag, "foo:bar"),
fmt.Sprintf("--%s=true", authBasicEnabledFlag),
},
expectedBasicAuthMethod: true,
},
} {
t.Run(testCase.name, func(t *testing.T) {
var method sharedauth.Method
cmd := &cobra.Command{
RunE: func(cmd *cobra.Command, args []string) error {
method = HTTPBasicAuthMethod(viper.GetViper())
return nil
},
}
InitHTTPBasicFlags(cmd)
require.NoError(t, viper.BindPFlags(cmd.PersistentFlags()))

cmd.SetArgs(testCase.args)

require.NoError(t, cmd.Execute())
if testCase.expectedBasicAuthMethod {
require.NotNil(t, method)
} else {
require.Nil(t, method)
}
})
}
}
28 changes: 10 additions & 18 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import (
"fmt"
"os"
"path"
"strings"
"time"

"github.com/numary/ledger/cmd/internal"
"github.com/numary/ledger/pkg/redis"
"github.com/pkg/errors"
"github.com/spf13/cobra"
Expand All @@ -23,7 +23,7 @@ const (
serverHttpBindAddressFlag = "server.http.bind_address"
uiHttpBindAddressFlag = "ui.http.bind_address"
// Deprecated
serverHttpBasicAuthFlag = "server.http.basic_auth"

lockStrategyFlag = "lock-strategy"
lockStrategyRedisUrlFlag = "lock-strategy-redis-url"
lockStrategyRedisDurationFlag = "lock-strategy-redis-duration"
Expand Down Expand Up @@ -54,13 +54,12 @@ const (
publisherKafkaTLSEnabled = "publisher-kafka-tls-enabled"
publisherTopicMappingFlag = "publisher-topic-mapping"
publisherHttpEnabledFlag = "publisher-http-enabled"
authBasicEnabledFlag = "auth-basic-enabled"
authBasicCredentialsFlag = "auth-basic-credentials"
authBearerEnabledFlag = "auth-bearer-enabled"
authBearerIntrospectUrlFlag = "auth-bearer-introspect-url"
authBearerAudienceFlag = "auth-bearer-audience"
authBearerAudiencesWildcardFlag = "auth-bearer-audiences-wildcard"
authBearerUseScopesFlag = "auth-bearer-use-scopes"

authBearerEnabledFlag = "auth-bearer-enabled"
authBearerIntrospectUrlFlag = "auth-bearer-introspect-url"
authBearerAudienceFlag = "auth-bearer-audience"
authBearerAudiencesWildcardFlag = "auth-bearer-audiences-wildcard"
authBearerUseScopesFlag = "auth-bearer-use-scopes"

segmentEnabledFlag = "segment-enabled"
segmentWriteKey = "segment-write-key"
Expand All @@ -73,8 +72,6 @@ var (
BuildDate = "-"
Commit = "-"
DefaultSegmentWriteKey = ""

replacer = strings.NewReplacer(".", "_", "-", "_")
)

func NewRootCommand() *cobra.Command {
Expand Down Expand Up @@ -134,7 +131,6 @@ func NewRootCommand() *cobra.Command {
root.PersistentFlags().Bool(storageCacheFlag, true, "Storage cache")
root.PersistentFlags().String(serverHttpBindAddressFlag, "localhost:3068", "API bind address")
root.PersistentFlags().String(uiHttpBindAddressFlag, "localhost:3068", "UI bind address")
root.PersistentFlags().String(serverHttpBasicAuthFlag, "", "Http basic auth")
root.PersistentFlags().Bool(otelTracesFlag, false, "Enable OpenTelemetry traces support")
root.PersistentFlags().Bool(otelTracesBatchFlag, false, "Use OpenTelemetry batching")
root.PersistentFlags().String(otelTracesExporterFlag, "stdout", "OpenTelemetry traces exporter")
Expand Down Expand Up @@ -165,8 +161,6 @@ func NewRootCommand() *cobra.Command {
root.PersistentFlags().String(publisherKafkaSASLMechanism, "", "SASL authentication mechanism")
root.PersistentFlags().Int(publisherKafkaSASLScramSHASize, 512, "SASL SCRAM SHA size")
root.PersistentFlags().Bool(publisherKafkaTLSEnabled, false, "Enable TLS to connect on kafka")
root.PersistentFlags().Bool(authBasicEnabledFlag, false, "Enable basic auth")
root.PersistentFlags().StringSlice(authBasicCredentialsFlag, []string{}, "HTTP basic auth credentials (<username>:<password>)")
root.PersistentFlags().Bool(authBearerEnabledFlag, false, "Enable bearer auth")
root.PersistentFlags().String(authBearerIntrospectUrlFlag, "", "OAuth2 introspect URL")
root.PersistentFlags().StringSlice(authBearerAudienceFlag, []string{}, "Allowed audiences")
Expand All @@ -177,7 +171,7 @@ func NewRootCommand() *cobra.Command {
root.PersistentFlags().String(segmentWriteKey, DefaultSegmentWriteKey, "Segment write key")
root.PersistentFlags().Duration(segmentHeartbeatInterval, 24*time.Hour, "Segment heartbeat interval")

viper.RegisterAlias(serverHttpBasicAuthFlag, authBasicCredentialsFlag)
internal.InitHTTPBasicFlags(root)

if err = viper.BindPFlags(root.PersistentFlags()); err != nil {
panic(err)
Expand All @@ -191,9 +185,7 @@ func NewRootCommand() *cobra.Command {
fmt.Printf("loading config file: %s\n", err)
}

viper.SetEnvPrefix("numary")
viper.SetEnvKeyReplacer(replacer)
viper.AutomaticEnv()
internal.BindEnv(viper.GetViper())

return root
}
Expand Down

0 comments on commit 8a7d762

Please sign in to comment.