main.go
// SPDX-FileCopyrightText: 2021 Comcast Cable Communications Management, LLC
// SPDX-License-Identifier: Apache-2.0
package main
//
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"runtime"
"syscall"
"time"
"go.uber.org/zap"
"github.com/gorilla/mux"
"github.com/spf13/pflag"
"github.com/spf13/viper"
"github.com/xmidt-org/ancla"
"github.com/xmidt-org/bascule/basculehelper"
"github.com/xmidt-org/candlelight"
"github.com/xmidt-org/httpaux/recovery"
"github.com/xmidt-org/sallust"
"github.com/xmidt-org/webpa-common/v2/adapter"
// nolint:staticcheck
"github.com/xmidt-org/webpa-common/v2/concurrent"
// nolint:staticcheck
"github.com/xmidt-org/webpa-common/v2/server"
"github.com/xmidt-org/webpa-common/v2/service/servicecfg"
"go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
)
const (
applicationName = "caduceus"
defaultKeyID = "current"
tracingConfigKey = "tracing"
)
var (
GitCommit = "undefined"
Version = "undefined"
BuildTime = "undefined"
)
// httpClientTimeout contains timeouts for an HTTP client and its requests.
type httpClientTimeout struct {
// ClientTimeout is HTTP Client Timeout.
ClientTimeout time.Duration
// NetDialerTimeout is the net dialer timeout
NetDialerTimeout time.Duration
}
// caduceus is the driver function for Caduceus. It performs everything main() would do,
// except for obtaining the command-line arguments (which are passed to it).
func caduceus(arguments []string) int {
beginCaduceus := time.Now()
var (
f = pflag.NewFlagSet(applicationName, pflag.ContinueOnError)
v = viper.New()
logger, metricsRegistry, webPA, err = server.Initialize(applicationName, arguments, f, v, Metrics, AnclaHelperMetrics, basculehelper.AuthCapabilitiesMetrics, basculehelper.AuthValidationMetrics)
)
if parseErr, done := printVersion(f, arguments); done {
// if we're done, we're exiting no matter what
if parseErr != nil {
friendlyError := fmt.Sprintf("failed to parse arguments. detailed error: %s", parseErr)
logger.Error(friendlyError)
os.Exit(1)
}
os.Exit(0)
}
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to initialize Viper environment: %s\n", err)
return 1
}
logger.Info("viper environment successfully initialized", zap.Any("configurationFile", v.ConfigFileUsed()))
caduceusConfig := new(CaduceusConfig)
err = v.Unmarshal(caduceusConfig)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to unmarshal configuration data into struct: %s\n", err)
return 1
}
tracing, err := loadTracing(v, applicationName)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to build tracing component: %v \n", err)
return 1
}
logger.Info("tracing status", zap.Bool("enabled", !tracing.IsNoop()))
var tr http.RoundTripper = &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: caduceusConfig.Sender.DisableClientHostnameValidation},
MaxIdleConnsPerHost: caduceusConfig.Sender.NumWorkersPerSender,
ResponseHeaderTimeout: caduceusConfig.Sender.ResponseHeaderTimeout,
IdleConnTimeout: caduceusConfig.Sender.IdleConnTimeout,
}
tr = otelhttp.NewTransport(tr,
otelhttp.WithPropagators(tracing.Propagator()),
otelhttp.WithTracerProvider(tracing.TracerProvider()),
)
caduceusSenderWrapper, err := SenderWrapperFactory{
NumWorkersPerSender: caduceusConfig.Sender.NumWorkersPerSender,
QueueSizePerSender: caduceusConfig.Sender.QueueSizePerSender,
CutOffPeriod: caduceusConfig.Sender.CutOffPeriod,
Linger: caduceusConfig.Sender.Linger,
DeliveryRetries: caduceusConfig.Sender.DeliveryRetries,
DeliveryInterval: caduceusConfig.Sender.DeliveryInterval,
MetricsRegistry: metricsRegistry,
Logger: logger,
Sender: doerFunc((&http.Client{
Transport: tr,
Timeout: caduceusConfig.Sender.ClientTimeout,
}).Do),
CustomPIDs: caduceusConfig.Sender.CustomPIDs,
DisablePartnerIDs: caduceusConfig.Sender.DisablePartnerIDs,
}.New()
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to initialize new caduceus sender wrapper: %s\n", err)
return 1
}
serverWrapper := &ServerHandler{
Logger: logger,
caduceusHandler: &CaduceusHandler{
senderWrapper: caduceusSenderWrapper,
Logger: logger,
},
errorRequests: metricsRegistry.NewCounter(ErrorRequestBodyCounter),
emptyRequests: metricsRegistry.NewCounter(EmptyRequestBodyCounter),
invalidCount: metricsRegistry.NewCounter(DropsDueToInvalidPayload),
incomingQueueDepthMetric: metricsRegistry.NewGauge(IncomingQueueDepth),
modifiedWRPCount: metricsRegistry.NewCounter(ModifiedWRPCounter),
maxOutstanding: 0,
// 0 is for the unused `buckets` argument in xmetrics.Registry.NewHistogram
incomingQueueLatency: metricsRegistry.NewHistogram(IncomingQueueLatencyHistogram, 0),
now: time.Now,
}
caduceusConfig.Webhook.Logger = logger
caduceusConfig.Listener.Measures = NewHelperMeasures(metricsRegistry)
argusClientTimeout, err := newArgusClientTimeout(v)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to parse argus client timeout config values: %v \n", err)
return 1
}
caduceusConfig.Webhook.BasicClientConfig.HTTPClient = newHTTPClient(argusClientTimeout, tracing)
svc, err := ancla.NewService(caduceusConfig.Webhook, getLogger)
if err != nil {
fmt.Fprintf(os.Stderr, "Webhook service initialization error: %v\n", err)
return 1
}
stopWatches, err := svc.StartListener(caduceusConfig.Listener, setLoggerInContext(), caduceusSenderWrapper)
if err != nil {
fmt.Fprintf(os.Stderr, "Webhook service start listener error: %v\n", err)
return 1
}
logger.Info("Webhook service enabled")
rootRouter := mux.NewRouter()
rootRouter.Use(
recovery.Middleware(
recovery.WithStatusCode(555), // a wacky status code that will show up in metrics
// TODO: should probably customize things a bit
),
)
otelMuxOptions := []otelmux.Option{
otelmux.WithPropagators(tracing.Propagator()),
otelmux.WithTracerProvider(tracing.TracerProvider()),
}
rootRouter.Use(otelmux.Middleware("primary", otelMuxOptions...), candlelight.EchoFirstTraceNodeInfo(tracing.Propagator(), false))
primaryHandler, err := NewPrimaryHandler(logger, v, metricsRegistry, serverWrapper, svc, rootRouter, v.GetBool("previousVersionSupport"))
if err != nil {
fmt.Fprintf(os.Stderr, "Handler creation error: %v\n", err)
return 1
}
_, runnable, done := webPA.Prepare(logger, nil, metricsRegistry, primaryHandler)
waitGroup, shutdown, err := concurrent.Execute(runnable)
if err != nil {
fmt.Fprintf(os.Stderr, "Unable to start device manager: %s\n", err)
return 1
}
//
// Now, initialize the service discovery infrastructure
//
if !v.IsSet("service") {
logger.Info("no service discovery configured")
} else {
var log = &adapter.Logger{
Logger: logger,
}
e, err := servicecfg.NewEnvironment(log, v.Sub("service"))
if err != nil {
logger.Error("Unable to initialize service discovery environment", zap.Error(err))
return 4
}
defer e.Close()
logger.Info("service discovery environment successfully initialized", zap.Any("configurationFile", v.ConfigFileUsed()))
e.Register()
}
logger.Info("Caduceus is up and running!", zap.Any("elapsedTime", time.Since(beginCaduceus)))
signals := make(chan os.Signal, 10)
signal.Notify(signals, syscall.SIGTERM, os.Interrupt)
for exit := false; !exit; {
select {
case s := <-signals:
logger.Error("exiting due to signal", zap.Any("signal", s))
exit = true
case <-done:
logger.Error("one or more servers exited")
exit = true
}
}
close(shutdown)
waitGroup.Wait()
// shutdown the sender wrapper gently so that all queued messages get serviced
caduceusSenderWrapper.Shutdown(true)
stopWatches()
return 0
}
func setLoggerInContext() func(context.Context, *zap.Logger) context.Context {
return func(parent context.Context, logger *zap.Logger) context.Context {
return sallust.With(parent, logger)
}
}
func loadTracing(v *viper.Viper, appName string) (candlelight.Tracing, error) {
var traceConfig candlelight.Config
err := v.UnmarshalKey(tracingConfigKey, &traceConfig)
if err != nil {
return candlelight.Tracing{}, err
}
traceConfig.ApplicationName = appName
tracing, err := candlelight.New(traceConfig)
return tracing, err
}
func newArgusClientTimeout(v *viper.Viper) (httpClientTimeout, error) {
var timeouts httpClientTimeout
err := v.UnmarshalKey("argusClientTimeout", &timeouts)
if err != nil {
return httpClientTimeout{}, err
}
if timeouts.ClientTimeout == 0 {
timeouts.ClientTimeout = time.Second * 50
}
if timeouts.NetDialerTimeout == 0 {
timeouts.NetDialerTimeout = time.Second * 5
}
return timeouts, nil
}
func newHTTPClient(timeouts httpClientTimeout, tracing candlelight.Tracing) *http.Client {
var transport http.RoundTripper = &http.Transport{
Dial: (&net.Dialer{
Timeout: timeouts.NetDialerTimeout,
}).Dial,
}
transport = otelhttp.NewTransport(transport,
otelhttp.WithPropagators(tracing.Propagator()),
otelhttp.WithTracerProvider(tracing.TracerProvider()),
)
return &http.Client{
Timeout: timeouts.ClientTimeout,
Transport: transport,
}
}
func printVersion(f *pflag.FlagSet, arguments []string) (error, bool) {
printVer := f.BoolP("version", "v", false, "displays the version number")
if err := f.Parse(arguments); err != nil {
return err, true
}
if *printVer {
printVersionInfo(os.Stdout)
return nil, true
}
return nil, false
}
func printVersionInfo(writer io.Writer) {
fmt.Fprintf(writer, "%s:\n", applicationName)
fmt.Fprintf(writer, " version: \t%s\n", Version)
fmt.Fprintf(writer, " go version: \t%s\n", runtime.Version())
fmt.Fprintf(writer, " built time: \t%s\n", BuildTime)
fmt.Fprintf(writer, " git commit: \t%s\n", GitCommit)
fmt.Fprintf(writer, " os/arch: \t%s/%s\n", runtime.GOOS, runtime.GOARCH)
}
func main() {
os.Exit(caduceus(os.Args))
}