oauth2-proxy/oauth2-proxy

View on GitHub
pkg/logger/logger.go

Summary

Maintainability
A
2 hrs
Test Coverage
package logger

import (
    "bytes"
    "fmt"
    "io"
    "net/http"
    "net/url"
    "os"
    "runtime"
    "sync"
    "text/template"
    "time"

    middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware"
    requestutil "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/requests/util"
)

// AuthStatus defines the different types of auth logging that occur
type AuthStatus string

// Level indicates the log level for log messages
type Level int

const (
    // DefaultStandardLoggingFormat defines the default standard log format
    DefaultStandardLoggingFormat = "[{{.Timestamp}}] [{{.File}}] {{.Message}}"
    // DefaultAuthLoggingFormat defines the default auth log format
    DefaultAuthLoggingFormat = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] [{{.Status}}] {{.Message}}"
    // DefaultRequestLoggingFormat defines the default request log format
    DefaultRequestLoggingFormat = "{{.Client}} - {{.RequestID}} - {{.Username}} [{{.Timestamp}}] {{.Host}} {{.RequestMethod}} {{.Upstream}} {{.RequestURI}} {{.Protocol}} {{.UserAgent}} {{.StatusCode}} {{.ResponseSize}} {{.RequestDuration}}"

    // AuthSuccess indicates that an auth attempt has succeeded explicitly
    AuthSuccess AuthStatus = "AuthSuccess"
    // AuthFailure indicates that an auth attempt has failed explicitly
    AuthFailure AuthStatus = "AuthFailure"
    // AuthError indicates that an auth attempt has failed due to an error
    AuthError AuthStatus = "AuthError"

    // Llongfile flag to log full file name and line number: /a/b/c/d.go:23
    Llongfile = 1 << iota
    // Lshortfile flag to log final file name element and line number: d.go:23. overrides Llongfile
    Lshortfile
    // LUTC flag to log UTC datetime rather than the local time zone
    LUTC
    // LstdFlags flag for initial values for the logger
    LstdFlags = Lshortfile

    // DEFAULT is the default log level (effectively INFO)
    DEFAULT Level = iota
    // ERROR is for error-level logging
    ERROR
)

// These are the containers for all values that are available as variables in the logging formats.
// All values are pre-formatted strings so it is easy to use them in the format string.
type stdLogMessageData struct {
    Timestamp,
    File,
    Message string
}

type authLogMessageData struct {
    Client,
    Host,
    Protocol,
    RequestID,
    RequestMethod,
    Timestamp,
    UserAgent,
    Username,
    Status,
    Message string
}

type reqLogMessageData struct {
    Client,
    Host,
    Protocol,
    RequestID,
    RequestDuration,
    RequestMethod,
    RequestURI,
    ResponseSize,
    StatusCode,
    Timestamp,
    Upstream,
    UserAgent,
    Username string
}

// Returns the apparent "real client IP" as a string.
type GetClientFunc = func(r *http.Request) string

// A Logger represents an active logging object that generates lines of
// output to an io.Writer passed through a formatter. Each logging
// operation makes a single call to the Writer's Write method. A Logger
// can be used simultaneously from multiple goroutines; it guarantees to
// serialize access to the Writer.
type Logger struct {
    mu             sync.Mutex
    flag           int
    writer         io.Writer
    errWriter      io.Writer
    stdEnabled     bool
    authEnabled    bool
    reqEnabled     bool
    getClientFunc  GetClientFunc
    excludePaths   map[string]struct{}
    stdLogTemplate *template.Template
    authTemplate   *template.Template
    reqTemplate    *template.Template
}

// New creates a new Standarderr Logger.
func New(flag int) *Logger {
    return &Logger{
        writer:         os.Stdout,
        errWriter:      os.Stderr,
        flag:           flag,
        stdEnabled:     true,
        authEnabled:    true,
        reqEnabled:     true,
        getClientFunc:  func(r *http.Request) string { return r.RemoteAddr },
        excludePaths:   nil,
        stdLogTemplate: template.Must(template.New("std-log").Parse(DefaultStandardLoggingFormat)),
        authTemplate:   template.Must(template.New("auth-log").Parse(DefaultAuthLoggingFormat)),
        reqTemplate:    template.Must(template.New("req-log").Parse(DefaultRequestLoggingFormat)),
    }
}

var std = New(LstdFlags)

func (l *Logger) formatLogMessage(calldepth int, message string) []byte {
    now := time.Now()
    file := "???:0"

    if l.flag&(Lshortfile|Llongfile) != 0 {
        file = l.GetFileLineString(calldepth + 1)
    }

    var logBuff = new(bytes.Buffer)
    err := l.stdLogTemplate.Execute(logBuff, stdLogMessageData{
        Timestamp: FormatTimestamp(now),
        File:      file,
        Message:   message,
    })
    if err != nil {
        panic(err)
    }

    _, err = logBuff.Write([]byte("\n"))
    if err != nil {
        panic(err)
    }

    return logBuff.Bytes()
}

// Output a standard log template with a simple message to default output channel.
// Write a final newline at the end of every message.
func (l *Logger) Output(lvl Level, calldepth int, message string) {
    l.mu.Lock()
    defer l.mu.Unlock()
    if !l.stdEnabled {
        return
    }
    msg := l.formatLogMessage(calldepth+1, message)

    var err error
    switch lvl {
    case ERROR:
        _, err = l.errWriter.Write(msg)
    default:
        _, err = l.writer.Write(msg)
    }
    if err != nil {
        panic(err)
    }
}

// PrintAuthf writes auth info to the logger. Requires an http.Request to
// log request details. Remaining arguments are handled in the manner of
// fmt.Sprintf. Writes a final newline to the end of every message.
func (l *Logger) PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) {
    if !l.authEnabled {
        return
    }

    now := time.Now()

    if username == "" {
        username = "-"
    }

    client := l.getClientFunc(req)

    l.mu.Lock()
    defer l.mu.Unlock()

    scope := middlewareapi.GetRequestScope(req)
    err := l.authTemplate.Execute(l.writer, authLogMessageData{
        Client:        client,
        Host:          requestutil.GetRequestHost(req),
        Protocol:      req.Proto,
        RequestID:     scope.RequestID,
        RequestMethod: req.Method,
        Timestamp:     FormatTimestamp(now),
        UserAgent:     fmt.Sprintf("%q", req.UserAgent()),
        Username:      username,
        Status:        string(status),
        Message:       fmt.Sprintf(format, a...),
    })
    if err != nil {
        panic(err)
    }

    _, err = l.writer.Write([]byte("\n"))
    if err != nil {
        panic(err)
    }
}

// PrintReq writes request details to the Logger using the http.Request,
// url, and timestamp of the request.  Writes a final newline to the end
// of every message.
func (l *Logger) PrintReq(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) {
    if !l.reqEnabled {
        return
    }

    if _, ok := l.excludePaths[url.Path]; ok {
        return
    }

    duration := float64(time.Since(ts)) / float64(time.Second)

    if username == "" {
        username = "-"
    }

    if upstream == "" {
        upstream = "-"
    }

    if url.User != nil && username == "-" {
        if name := url.User.Username(); name != "" {
            username = name
        }
    }

    client := l.getClientFunc(req)

    l.mu.Lock()
    defer l.mu.Unlock()

    scope := middlewareapi.GetRequestScope(req)
    err := l.reqTemplate.Execute(l.writer, reqLogMessageData{
        Client:          client,
        Host:            requestutil.GetRequestHost(req),
        Protocol:        req.Proto,
        RequestID:       scope.RequestID,
        RequestDuration: fmt.Sprintf("%0.3f", duration),
        RequestMethod:   req.Method,
        RequestURI:      fmt.Sprintf("%q", url.RequestURI()),
        ResponseSize:    fmt.Sprintf("%d", size),
        StatusCode:      fmt.Sprintf("%d", status),
        Timestamp:       FormatTimestamp(ts),
        Upstream:        upstream,
        UserAgent:       fmt.Sprintf("%q", req.UserAgent()),
        Username:        username,
    })
    if err != nil {
        panic(err)
    }

    _, err = l.writer.Write([]byte("\n"))
    if err != nil {
        panic(err)
    }
}

// GetFileLineString will find the caller file and line number
// taking in to account the calldepth to iterate up the stack
// to find the non-logging call location.
func (l *Logger) GetFileLineString(calldepth int) string {
    var file string
    var line int
    var ok bool

    _, file, line, ok = runtime.Caller(calldepth)
    if !ok {
        file = "???"
        line = 0
    }

    if l.flag&Lshortfile != 0 {
        short := file
        for i := len(file) - 1; i > 0; i-- {
            if file[i] == '/' {
                short = file[i+1:]
                break
            }
        }
        file = short
    }

    return fmt.Sprintf("%s:%d", file, line)
}

// FormatTimestamp returns a formatted timestamp.
func (l *Logger) FormatTimestamp(ts time.Time) string {
    if l.flag&LUTC != 0 {
        ts = ts.UTC()
    }

    return ts.Format("2006/01/02 15:04:05")
}

// Flags returns the output flags for the logger.
func (l *Logger) Flags() int {
    l.mu.Lock()
    defer l.mu.Unlock()
    return l.flag
}

// SetFlags sets the output flags for the logger.
func (l *Logger) SetFlags(flag int) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.flag = flag
}

// SetStandardEnabled enables or disables standard logging.
func (l *Logger) SetStandardEnabled(e bool) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.stdEnabled = e
}

// SetErrToInfo enables or disables error logging to error writer instead of the default.
func (l *Logger) SetErrToInfo(e bool) {
    l.mu.Lock()
    defer l.mu.Unlock()
    if e {
        l.errWriter = l.writer
    } else {
        l.errWriter = os.Stderr
    }
}

// SetAuthEnabled enables or disables auth logging.
func (l *Logger) SetAuthEnabled(e bool) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.authEnabled = e
}

// SetReqEnabled enabled or disables request logging.
func (l *Logger) SetReqEnabled(e bool) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.reqEnabled = e
}

// SetGetClientFunc sets the function which determines the apparent "real client IP".
func (l *Logger) SetGetClientFunc(f GetClientFunc) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.getClientFunc = f
}

// SetExcludePaths sets the paths to exclude from logging.
func (l *Logger) SetExcludePaths(s []string) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.excludePaths = make(map[string]struct{})
    for _, p := range s {
        l.excludePaths[p] = struct{}{}
    }
}

// SetStandardTemplate sets the template for standard logging.
func (l *Logger) SetStandardTemplate(t string) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.stdLogTemplate = template.Must(template.New("std-log").Parse(t))
}

// SetAuthTemplate sets the template for auth logging.
func (l *Logger) SetAuthTemplate(t string) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.authTemplate = template.Must(template.New("auth-log").Parse(t))
}

// SetReqTemplate sets the template for request logging.
func (l *Logger) SetReqTemplate(t string) {
    l.mu.Lock()
    defer l.mu.Unlock()
    l.reqTemplate = template.Must(template.New("req-log").Parse(t))
}

// These functions utilize the standard logger.

// FormatTimestamp returns a formatted timestamp for the standard logger.
func FormatTimestamp(ts time.Time) string {
    return std.FormatTimestamp(ts)
}

// Flags returns the output flags for the standard logger.
func Flags() int {
    return std.Flags()
}

// SetFlags sets the output flags for the standard logger.
func SetFlags(flag int) {
    std.SetFlags(flag)
}

// SetOutput sets the output destination for the standard logger's default channel.
func SetOutput(w io.Writer) {
    std.mu.Lock()
    defer std.mu.Unlock()
    std.writer = w
}

// SetErrOutput sets the output destination for the standard logger's error channel.
func SetErrOutput(w io.Writer) {
    std.mu.Lock()
    defer std.mu.Unlock()
    std.errWriter = w
}

// SetStandardEnabled enables or disables standard logging for the
// standard logger.
func SetStandardEnabled(e bool) {
    std.SetStandardEnabled(e)
}

// SetErrToInfo enables or disables error logging to output writer instead of
// error writer.
func SetErrToInfo(e bool) {
    std.SetErrToInfo(e)
}

// SetAuthEnabled enables or disables auth logging for the standard
// logger.
func SetAuthEnabled(e bool) {
    std.SetAuthEnabled(e)
}

// SetReqEnabled enables or disables request logging for the
// standard logger.
func SetReqEnabled(e bool) {
    std.SetReqEnabled(e)
}

// SetGetClientFunc sets the function which determines the apparent IP address
// set by a reverse proxy for the standard logger.
func SetGetClientFunc(f GetClientFunc) {
    std.SetGetClientFunc(f)
}

// SetExcludePaths sets the path to exclude from logging, eg: health checks
func SetExcludePaths(s []string) {
    std.SetExcludePaths(s)
}

// SetStandardTemplate sets the template for standard logging for
// the standard logger.
func SetStandardTemplate(t string) {
    std.SetStandardTemplate(t)
}

// SetAuthTemplate sets the template for auth logging for the
// standard logger.
func SetAuthTemplate(t string) {
    std.SetAuthTemplate(t)
}

// SetReqTemplate sets the template for request logging for the
// standard logger.
func SetReqTemplate(t string) {
    std.SetReqTemplate(t)
}

// Print calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Print.
func Print(v ...interface{}) {
    std.Output(DEFAULT, 2, fmt.Sprint(v...))
}

// Printf calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Printf.
func Printf(format string, v ...interface{}) {
    std.Output(DEFAULT, 2, fmt.Sprintf(format, v...))
}

// Println calls Output to print to the standard logger.
// Arguments are handled in the manner of fmt.Println.
func Println(v ...interface{}) {
    std.Output(DEFAULT, 2, fmt.Sprintln(v...))
}

// Error calls OutputErr to print to the standard logger's error channel.
// Arguments are handled in the manner of fmt.Print.
func Error(v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprint(v...))
}

// Errorf calls OutputErr to print to the standard logger's error channel.
// Arguments are handled in the manner of fmt.Printf.
func Errorf(format string, v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprintf(format, v...))
}

// Errorln calls OutputErr to print to the standard logger's error channel.
// Arguments are handled in the manner of fmt.Println.
func Errorln(v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprintln(v...))
}

// Fatal is equivalent to Print() followed by a call to os.Exit(1).
func Fatal(v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprint(v...))
    os.Exit(1)
}

// Fatalf is equivalent to Printf() followed by a call to os.Exit(1).
func Fatalf(format string, v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprintf(format, v...))
    os.Exit(1)
}

// Fatalln is equivalent to Println() followed by a call to os.Exit(1).
func Fatalln(v ...interface{}) {
    std.Output(ERROR, 2, fmt.Sprintln(v...))
    os.Exit(1)
}

// Panic is equivalent to Print() followed by a call to panic().
func Panic(v ...interface{}) {
    s := fmt.Sprint(v...)
    std.Output(ERROR, 2, s)
    panic(s)
}

// Panicf is equivalent to Printf() followed by a call to panic().
func Panicf(format string, v ...interface{}) {
    s := fmt.Sprintf(format, v...)
    std.Output(ERROR, 2, s)
    panic(s)
}

// Panicln is equivalent to Println() followed by a call to panic().
func Panicln(v ...interface{}) {
    s := fmt.Sprintln(v...)
    std.Output(ERROR, 2, s)
    panic(s)
}

// PrintAuthf writes authentication details to the standard logger.
// Arguments are handled in the manner of fmt.Printf.
func PrintAuthf(username string, req *http.Request, status AuthStatus, format string, a ...interface{}) {
    std.PrintAuthf(username, req, status, format, a...)
}

// PrintReq writes request details to the standard logger.
func PrintReq(username, upstream string, req *http.Request, url url.URL, ts time.Time, status int, size int) {
    std.PrintReq(username, upstream, req, url, ts, status, size)
}