diff --git a/internal/app/app.go b/internal/app/app.go index 61eef724..0fb081a7 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -1,81 +1,103 @@ package app import ( - "context" + "log" + "net/http" "os" "os/signal" - "sync" "syscall" + "time" runtime "github.com/banzaicloud/logrus-runtime-formatter" "github.com/metal-toolbox/flasher/internal/model" + "github.com/pkg/errors" "github.com/sirupsen/logrus" + "github.com/spf13/viper" + + // nolint:gosec // pprof path is only exposed over localhost + _ "net/http/pprof" +) + +var ( + ErrAppInit = errors.New("error initializing app") +) + +const ( + ProfilingEndpoint = "localhost:9091" ) // Config holds configuration data when running mctl // App holds attributes for the mtl application type App struct { - // Sync waitgroup to wait for running go routines on termination. - SyncWG *sync.WaitGroup + // Viper loads configuration parameters. + v *viper.Viper // Flasher configuration. - Config *model.Config - // TermCh is the channel to terminate the app based on a signal - TermCh chan os.Signal + Config *Configuration // Logger is the app logger Logger *logrus.Logger + // Kind is the type of application - worker + Kind model.AppKind } // New returns returns a new instance of the flasher app -func New(ctx context.Context, appKind model.AppKind, inventorySourceKind, cfgFile string, loglevel int) (*App, error) { - // load configuration - cfg := &model.Config{ - AppKind: appKind, - InventorySource: inventorySourceKind, - } - - if err := cfg.Load(cfgFile); err != nil { - return nil, err +func New(appKind model.AppKind, storeKind model.StoreKind, cfgFile, loglevel string, profiling bool) (*App, <-chan os.Signal, error) { + if appKind != model.AppKindWorker { + return nil, nil, errors.Wrap(ErrAppInit, "invalid app kind: "+string(appKind)) } app := &App{ - Config: cfg, - SyncWG: &sync.WaitGroup{}, + v: viper.New(), + Kind: appKind, + Config: &Configuration{}, Logger: logrus.New(), - TermCh: make(chan os.Signal), } - runtimeFormatter := &runtime.Formatter{ - ChildFormatter: &logrus.JSONFormatter{}, - File: true, - Line: true, - BaseNameOnly: true, + if err := app.LoadConfiguration(cfgFile, storeKind); err != nil { + return nil, nil, err } - // set log level, format - switch loglevel { + switch model.LogLevel(loglevel) { case model.LogLevelDebug: app.Logger.Level = logrus.DebugLevel - - // set runtime formatter options - runtimeFormatter.BaseNameOnly = true - runtimeFormatter.File = true - runtimeFormatter.Line = true - case model.LogLevelTrace: app.Logger.Level = logrus.TraceLevel - - // set runtime formatter options - runtimeFormatter.File = true - runtimeFormatter.Line = true - runtimeFormatter.Package = true default: app.Logger.Level = logrus.InfoLevel } + runtimeFormatter := &runtime.Formatter{ + ChildFormatter: &logrus.JSONFormatter{}, + File: true, + Line: true, + BaseNameOnly: true, + } + app.Logger.SetFormatter(runtimeFormatter) + termCh := make(chan os.Signal, 1) + // register for SIGINT, SIGTERM - signal.Notify(app.TermCh, syscall.SIGINT, syscall.SIGTERM) + signal.Notify(termCh, syscall.SIGINT, syscall.SIGTERM) + + if profiling { + enableProfilingEndpoint() + } + + return app, termCh, nil +} + +// enableProfilingEndpoint enables the profiling endpoint +func enableProfilingEndpoint() { + go func() { + server := &http.Server{ + Addr: "", + ReadHeaderTimeout: 2 * time.Second, // nolint:gomnd // time duration value is clear as is. + } + + if err := server.ListenAndServe(); err != nil { + log.Println(err) + } + }() - return app, nil + log.Println("profiling enabled: " + ProfilingEndpoint + "/debug/pprof") } diff --git a/internal/app/config.go b/internal/app/config.go new file mode 100644 index 00000000..cda5da6c --- /dev/null +++ b/internal/app/config.go @@ -0,0 +1,302 @@ +package app + +import ( + "net/url" + "os" + "strings" + "time" + + "github.com/jeremywohl/flatten" + "github.com/metal-toolbox/flasher/internal/model" + "github.com/mitchellh/mapstructure" + "github.com/pkg/errors" + "go.hollow.sh/toolbox/events" +) + +const ( + WorkerConcurrency = 1 +) + +var ( + ErrConfig = errors.New("configuration error") +) + +// Config holds application configuration read from a YAML or set by env variables. +// +// nolint:govet // prefer readability over field alignment optimization for this case. +type Configuration struct { + // LogLevel is the app verbose logging level. + // one of - info, debug, trace + LogLevel string `mapstructure:"log_level"` + + // AppKind is the application kind - worker / client + AppKind model.AppKind `mapstructure:"app_kind"` + + // Worker configuration + Concurrency int `mapstructure:"concurrency"` + + // FacilityCode limits this flasher to events in a facility. + FacilityCode string `mapstructure:"facility_code"` + + // The inventory source - one of serverservice OR Yaml + InventorySource string `mapstructure:"inventory_source"` + + StoreKind model.StoreKind `mapstructure:"store_kind"` + + // ServerserviceOptions defines the serverservice client configuration parameters + // + // This parameter is required when StoreKind is set to serverservice. + ServerserviceOptions *ServerserviceOptions `mapstructure:"serverservice"` + + // EventsBrokerKind indicates the kind of event broker configuration to enable, + // + // Supported parameter value - nats + EventsBorkerKind string `mapstructure:"events_broker_kind"` + + // NatsOptions defines the NATs events broker configuration parameters. + // + // This parameter is required when EventsBrokerKind is set to nats. + NatsOptions *events.NatsOptions `mapstructure:"nats"` +} + +// ServerserviceOptions defines configuration for the Serverservice client. +// https://github.com/metal-toolbox/hollow-serverservice +type ServerserviceOptions struct { + EndpointURL *url.URL + FacilityCode string `mapstructure:"facility_code"` + Endpoint string `mapstructure:"endpoint"` + OidcIssuerEndpoint string `mapstructure:"oidc_issuer_endpoint"` + OidcAudienceEndpoint string `mapstructure:"oidc_audience_endpoint"` + OidcClientSecret string `mapstructure:"oidc_client_secret"` + OidcClientID string `mapstructure:"oidc_client_id"` + OutofbandFirmwareNS string `mapstructure:"outofband_firmware_ns"` + AssetStateAttributeNS string `mapstructure:"device_state_attribute_ns"` + AssetStateAttributeKey string `mapstructure:"device_state_attribute_key"` + OidcClientScopes []string `mapstructure:"oidc_client_scopes"` + DeviceStates []string `mapstructure:"device_states"` + DisableOAuth bool `mapstructure:"disable_oauth"` +} + +// LoadConfiguration loads application configuration +// +// Reads in the cfgFile when available and overrides from environment variables. +func (a *App) LoadConfiguration(cfgFile string, storeKind model.StoreKind) error { + a.v.SetConfigType("yaml") + a.v.SetEnvPrefix(model.AppName) + a.v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + a.v.AutomaticEnv() + + // these are initialized here so viper can read in configuration from env vars + // once https://github.com/spf13/viper/pull/1429 is merged, this can go. + a.Config.ServerserviceOptions = &ServerserviceOptions{} + a.Config.NatsOptions = &events.NatsOptions{ + Stream: &events.NatsStreamOptions{}, + Consumer: &events.NatsConsumerOptions{}, + } + + if cfgFile != "" { + fh, err := os.Open(cfgFile) + if err != nil { + return errors.Wrap(ErrConfig, err.Error()) + } + + if err = a.v.ReadConfig(fh); err != nil { + return errors.Wrap(ErrConfig, "ReadConfig error:"+err.Error()) + } + } + + a.v.SetDefault("log.level", "info") + + if err := a.envBindVars(); err != nil { + return errors.Wrap(ErrConfig, "env var bind error:"+err.Error()) + } + + if err := a.v.Unmarshal(a.Config); err != nil { + return errors.Wrap(ErrConfig, "Unmarshal error: "+err.Error()) + } + + a.envVarAppOverrides() + + if a.Config.EventsBorkerKind == "nats" { + if err := a.envVarNatsOverrides(); err != nil { + return errors.Wrap(ErrConfig, "nats env overrides error:"+err.Error()) + } + } + + if storeKind == model.InventoryStoreServerservice { + if err := a.envVarServerserviceOverrides(); err != nil { + return errors.Wrap(ErrConfig, "serverservice env overrides error:"+err.Error()) + } + } + + return nil +} + +func (a *App) envVarAppOverrides() { + if a.v.GetString("log.level") != "" { + a.Config.LogLevel = a.v.GetString("log.level") + } +} + +// envBindVars binds environment variables to the struct +// without a configuration file being unmarshalled, +// this is a workaround for a viper bug, +// +// This can be replaced by the solution in https://github.com/spf13/viper/pull/1429 +// once that PR is merged. +func (a *App) envBindVars() error { + envKeysMap := map[string]interface{}{} + if err := mapstructure.Decode(a.Config, &envKeysMap); err != nil { + return err + } + + // Flatten nested conf map + flat, err := flatten.Flatten(envKeysMap, "", flatten.DotStyle) + if err != nil { + return errors.Wrap(err, "Unable to flatten config") + } + + for k := range flat { + if err := a.v.BindEnv(k); err != nil { + return errors.Wrap(ErrConfig, "env var bind error: "+err.Error()) + } + } + + return nil +} + +// NATs streaming configuration +var ( + defaultNatsConnectTimeout = 100 * time.Millisecond +) + +// nolint:gocyclo // nats env config load is cyclomatic +func (a *App) envVarNatsOverrides() error { + if a.Config.NatsOptions == nil { + a.Config.NatsOptions = &events.NatsOptions{} + } + + if a.v.GetString("nats.url") != "" { + a.Config.NatsOptions.URL = a.v.GetString("nats.url") + } + + if a.Config.NatsOptions.URL == "" { + return errors.New("missing parameter: nats.url") + } + + if a.v.GetString("nats.stream.user") != "" { + a.Config.NatsOptions.StreamUser = a.v.GetString("nats.stream.user") + } + + if a.v.GetString("nats.stream.pass") != "" { + a.Config.NatsOptions.StreamPass = a.v.GetString("nats.stream.pass") + } + + if a.v.GetString("nats.creds.file") != "" { + a.Config.NatsOptions.CredsFile = a.v.GetString("nats.creds.file") + } + + if a.v.GetString("nats.stream.name") != "" { + if a.Config.NatsOptions.Stream == nil { + a.Config.NatsOptions.Stream = &events.NatsStreamOptions{} + } + + a.Config.NatsOptions.Stream.Name = a.v.GetString("nats.stream.name") + } + + if a.Config.NatsOptions.Stream.Name == "" { + return errors.New("A stream name is required") + } + + if a.v.GetString("nats.consumer.name") != "" { + if a.Config.NatsOptions.Consumer == nil { + a.Config.NatsOptions.Consumer = &events.NatsConsumerOptions{} + } + + a.Config.NatsOptions.Consumer.Name = a.v.GetString("nats.consumer.name") + } + + if a.Config.NatsOptions.ConnectTimeout == 0 { + a.Config.NatsOptions.ConnectTimeout = defaultNatsConnectTimeout + } + + return nil +} + +// Server service configuration options + +// nolint:gocyclo // parameter validation is cyclomatic +func (a *App) envVarServerserviceOverrides() error { + if a.Config.ServerserviceOptions == nil { + a.Config.ServerserviceOptions = &ServerserviceOptions{} + } + + if a.v.GetString("serverservice.endpoint") != "" { + a.Config.ServerserviceOptions.Endpoint = a.v.GetString("serverservice.endpoint") + } + + if a.v.GetString("serverservice.facility.code") != "" { + a.Config.ServerserviceOptions.FacilityCode = a.v.GetString("serverservice.facility.code") + } + + if a.Config.ServerserviceOptions.FacilityCode == "" { + return errors.New("serverservice facility code not defined") + } + + endpointURL, err := url.Parse(a.Config.ServerserviceOptions.Endpoint) + if err != nil { + return errors.New("serverservice endpoint URL error: " + err.Error()) + } + + a.Config.ServerserviceOptions.EndpointURL = endpointURL + + if a.v.GetString("serverservice.disable.oauth") != "" { + a.Config.ServerserviceOptions.DisableOAuth = a.v.GetBool("serverservice.disable.oauth") + } + + if a.Config.ServerserviceOptions.DisableOAuth { + return nil + } + + if a.v.GetString("serverservice.oidc.issuer.endpoint") != "" { + a.Config.ServerserviceOptions.OidcIssuerEndpoint = a.v.GetString("serverservice.oidc.issuer.endpoint") + } + + if a.Config.ServerserviceOptions.OidcIssuerEndpoint == "" { + return errors.New("serverservice oidc.issuer.endpoint not defined") + } + + if a.v.GetString("serverservice.oidc.audience.endpoint") != "" { + a.Config.ServerserviceOptions.OidcAudienceEndpoint = a.v.GetString("serverservice.oidc.audience.endpoint") + } + + if a.Config.ServerserviceOptions.OidcAudienceEndpoint == "" { + return errors.New("serverservice oidc.audience.endpoint not defined") + } + + if a.v.GetString("serverservice.oidc.client.secret") != "" { + a.Config.ServerserviceOptions.OidcClientSecret = a.v.GetString("serverservice.oidc.client.secret") + } + + if a.Config.ServerserviceOptions.OidcClientSecret == "" { + return errors.New("serverservice.oidc.client.secret not defined") + } + + if a.v.GetString("serverservice.oidc.client.id") != "" { + a.Config.ServerserviceOptions.OidcClientID = a.v.GetString("serverservice.oidc.client.id") + } + + if a.Config.ServerserviceOptions.OidcClientID == "" { + return errors.New("serverservice.oidc.client.id not defined") + } + + if a.v.GetString("serverservice.oidc.client.scopes") != "" { + a.Config.ServerserviceOptions.OidcClientScopes = a.v.GetStringSlice("serverservice.oidc.client.scopes") + } + + if len(a.Config.ServerserviceOptions.OidcClientScopes) == 0 { + return errors.New("serverservice oidc.client.scopes not defined") + } + + return nil +}