diff --git a/viper.go b/viper.go index 7a49a0a9f..8a619ea8d 100644 --- a/viper.go +++ b/viper.go @@ -605,15 +605,35 @@ func (v *Viper) GetSizeInBytes(key string) uint { // Takes a single key and unmarshals it into a Struct func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) } func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error { - return mapstructure.Decode(v.Get(key), rawVal) + config := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: rawVal, + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + } + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + return decoder.Decode(v.Get(key)) } // Unmarshals the config into a Struct. Make sure that the tags // on the fields of the structure are properly set. func Unmarshal(rawVal interface{}) error { return v.Unmarshal(rawVal) } func (v *Viper) Unmarshal(rawVal interface{}) error { - err := mapstructure.WeakDecode(v.AllSettings(), rawVal) + config := &mapstructure.DecoderConfig{ + Metadata: nil, + Result: rawVal, + WeaklyTypedInput: true, + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + err = decoder.Decode(v.AllSettings()) if err != nil { return err } @@ -631,6 +651,7 @@ func weakDecodeExact(input, output interface{}) error { Metadata: nil, Result: output, WeaklyTypedInput: true, + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), } decoder, err := mapstructure.NewDecoder(config) diff --git a/viper_test.go b/viper_test.go index 858caff2f..b32c06abd 100644 --- a/viper_test.go +++ b/viper_test.go @@ -453,10 +453,12 @@ func TestRecursiveAliases(t *testing.T) { func TestUnmarshal(t *testing.T) { SetDefault("port", 1313) Set("name", "Steve") + Set("duration", "10s") type config struct { - Port int - Name string + Port int + Name string + Duration time.Duration } var C config @@ -466,14 +468,15 @@ func TestUnmarshal(t *testing.T) { t.Fatalf("unable to decode into struct, %v", err) } - assert.Equal(t, &C, &config{Name: "Steve", Port: 1313}) + assert.Equal(t, &config{Name: "Steve", Port: 1313, Duration: 10 * time.Second}, &C) Set("port", 1234) + Set("duration", "20m") err = Unmarshal(&C) if err != nil { t.Fatalf("unable to decode into struct, %v", err) } - assert.Equal(t, &C, &config{Name: "Steve", Port: 1234}) + assert.Equal(t, &config{Name: "Steve", Port: 1234, Duration: 20 * time.Minute}, &C) } func TestBindPFlags(t *testing.T) {