ecadlabs/signatory

View on GitHub
pkg/vault/azure/auth/auth.go

Summary

Maintainability
A
2 hrs
Test Coverage
F
44%
package auth

import (
    "context"
    "crypto/rand"
    "crypto/rsa"
    "crypto/sha1"
    "crypto/x509"
    "encoding/base64"
    "encoding/hex"
    "encoding/json"
    "encoding/pem"
    "errors"
    "fmt"
    "io/ioutil"
    "net/http"
    "net/url"
    "os"
    "strings"
    "time"

    "github.com/golang-jwt/jwt/v5"
    "golang.org/x/crypto/pkcs12"
    "golang.org/x/oauth2"
)

const (
    envTenant                      = "AZURE_CLIENT_TENANT"
    envClientID                    = "AZURE_CLIENT_ID"
    envClientSecret                = "AZURE_CLIENT_SECRET"
    envClientPKCS12Certificate     = "AZURE_CLIENT_PKCS12_CERTIFICATE"
    envClientCertificate           = "AZURE_CLIENT_CERTIFICATE"
    envClientCertificateThumbprint = "AZURE_CLIENT_CERTIFICATE_THUMBPRINT"
    envPrivateKey                  = "AZURE_CLIENT_PRIVATE_KEY"
    envPrivateKeyPassword          = "AZURE_DECRYPT_PASSWORD"
)

const assertionTokenDuration = time.Hour * 24

const assertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

// Config is the configuration for using Azure authentication
type Config struct {
    Tenant                      string `yaml:"tenant_id" validate:"omitempty,uuid4"`
    ClientID                    string `yaml:"client_id" validate:"omitempty,uuid4"`
    ClientSecret                string `yaml:"client_secret"`
    ClientPKCS12Certificate     string `yaml:"client_pkcs12_certificate"`
    ClientCertificate           string `yaml:"client_certificate"`
    ClientCertificateThumbprint string `yaml:"client_certificate_thumbprint"`
    PrivateKey                  string `yaml:"client_private_key"`
    PrivateKeyPassword          string `yaml:"decrypt_password"`
}

func (c *Config) tokenURL() string {
    return fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", c.Tenant)
}

func (c *Config) parseEnv() *Config {
    res := *c
    if res.Tenant == "" {
        res.Tenant = os.Getenv(envTenant)
    }
    if res.ClientID == "" {
        res.ClientID = os.Getenv(envClientID)
    }
    if res.ClientSecret == "" {
        res.ClientSecret = os.Getenv(envClientSecret)
    }
    if res.ClientCertificate == "" {
        res.ClientCertificate = os.Getenv(envClientCertificate)
    }
    if res.ClientCertificateThumbprint == "" {
        res.ClientCertificateThumbprint = os.Getenv(envClientCertificateThumbprint)
    }
    if res.ClientPKCS12Certificate == "" {
        res.ClientPKCS12Certificate = os.Getenv(envClientPKCS12Certificate)
    }
    if res.PrivateKey == "" {
        res.PrivateKey = os.Getenv(envPrivateKey)
    }
    if res.PrivateKeyPassword == "" {
        res.PrivateKeyPassword = os.Getenv(envPrivateKeyPassword)
    }
    return &res
}

func parsePKCS12Certificate(name, password string) (pk interface{}, thumbprint []byte, err error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, nil, err
    }

    pk, cert, err := pkcs12.Decode(buf, password)
    if err != nil {
        return nil, nil, err
    }

    sum := sha1.Sum(cert.Raw)
    return pk, sum[:], nil
}

func parsePrivateKey(name, password string) (pk interface{}, err error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, err
    }

    block, _ := pem.Decode(buf)
    if block == nil {
        return nil, errors.New("failed to parse PEM block containing the private key")
    }

    if block.Type == "RSA PRIVATE KEY" {
        var pkdata []byte
        if x509.IsEncryptedPEMBlock(block) {
            // Is it used anymore?
            pkdata, err = x509.DecryptPEMBlock(block, []byte(password))
            if err != nil {
                return nil, err
            }
        } else {
            pkdata = block.Bytes
        }
        return x509.ParsePKCS1PrivateKey(pkdata)
    }

    return x509.ParsePKCS8PrivateKey(block.Bytes) // Unencrypted PKCS#8 only
}

func getThumbprint(name string) ([]byte, error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, err
    }

    block, _ := pem.Decode(buf)
    if block == nil {
        return nil, errors.New("failed to parse PEM block containing the certificate")
    }

    cert, err := x509.ParseCertificate(block.Bytes)
    if err != nil {
        return nil, err
    }

    sum := sha1.Sum(cert.Raw)
    return sum[:], nil
}

type jwtTokenSource struct {
    conf           *Config
    scopes         []string
    certThumbprint []byte
    key            *rsa.PrivateKey
    ctx            context.Context
}

func (c *Config) jwtTokenSource(ctx context.Context, scopes []string) (oauth2.TokenSource, error) {
    var (
        pk         interface{}
        thumbprint []byte
        err        error
    )

    if c.ClientPKCS12Certificate != "" {
        pk, thumbprint, err = parsePKCS12Certificate(c.ClientPKCS12Certificate, c.PrivateKeyPassword)
        if err != nil {
            return nil, err
        }

    } else {
        pk, err = parsePrivateKey(c.PrivateKey, c.PrivateKeyPassword)
        if err != nil {
            return nil, err
        }

        if c.ClientCertificate != "" {
            thumbprint, err = getThumbprint(c.ClientCertificate)
            if err != nil {
                return nil, err
            }

        } else {
            thumbprint, err = hex.DecodeString(c.ClientCertificateThumbprint)
            if err != nil || len(thumbprint) != sha1.Size {
                thumbprint, err = base64.URLEncoding.DecodeString(c.ClientCertificateThumbprint)
                if err != nil || len(thumbprint) != sha1.Size {
                    return nil, errors.New("failed to decode thumbprint string")
                }
            }
        }
    }

    key, ok := pk.(*rsa.PrivateKey)
    if !ok {
        return nil, fmt.Errorf("not a RSA key: %T", pk)
    }

    return &jwtTokenSource{
        conf:           c,
        scopes:         scopes,
        certThumbprint: thumbprint,
        key:            key,
        ctx:            ctx,
    }, nil
}

func fetchToken(ctx context.Context, url string, v url.Values) (*oauth2.Token, error) {
    client := oauth2.NewClient(ctx, nil)
    resp, err := client.PostForm(url, v)
    if err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
    }
    defer resp.Body.Close()

    body, err := ioutil.ReadAll(resp.Body)
    if err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
    }

    if resp.StatusCode/100 != 2 {
        return nil, &oauth2.RetrieveError{
            Response: resp,
            Body:     body,
        }
    }

    var res struct {
        TokenType   string `json:"token_type"`
        ExpiresIn   int64  `json:"expires_in"`
        AccessToken string `json:"access_token"`
    }

    if err := json.Unmarshal(body, &res); err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
    }

    token := oauth2.Token{
        AccessToken: res.AccessToken,
        TokenType:   res.TokenType,
    }
    if res.ExpiresIn > 0 {
        token.Expiry = time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)
    }

    var (
        p      jwt.Parser
        claims jwt.Claims
    )
    if _, _, err := p.ParseUnverified(res.AccessToken, claims); err == nil {
        exp, err := claims.GetExpirationTime()
        if err != nil {
            return nil, fmt.Errorf("auth: cannot fetch expiration time: %v", err)
        } else {
            token.Expiry = exp.Time
        }
    }
    return &token, nil
}

// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-client-creds-grant-flow
func (j *jwtTokenSource) Token() (*oauth2.Token, error) {
    jti := make([]byte, 20)
    _, err := rand.Read(jti)
    if err != nil {
        return nil, fmt.Errorf("auth: %w", err)
    }

    now := time.Now()

    claims := jwt.MapClaims{
        "aud": j.conf.tokenURL(),
        "iss": j.conf.ClientID,
        "sub": j.conf.ClientID,
        "jti": base64.RawURLEncoding.EncodeToString(jti),
        "nbf": now.Unix(),
        "exp": now.Add(assertionTokenDuration).Unix(),
    }

    assertionToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)

    kid := base64.RawURLEncoding.EncodeToString(j.certThumbprint)
    assertionToken.Header["kid"] = kid
    assertionToken.Header["x5t"] = kid

    assertion, err := assertionToken.SignedString(j.key)
    if err != nil {
        return nil, fmt.Errorf("auth: %w", err)
    }

    v := url.Values{
        "client_id":             []string{j.conf.ClientID},
        "scope":                 []string{strings.Join(j.scopes, " ")},
        "client_assertion_type": []string{assertionType},
        "client_assertion":      []string{assertion},
        "grant_type":            []string{"client_credentials"},
    }

    return fetchToken(j.ctx, j.conf.tokenURL(), v)
}

type clientSecretTokenSource struct {
    conf   *Config
    scopes []string
    ctx    context.Context
}

func (c *clientSecretTokenSource) Token() (*oauth2.Token, error) {
    v := url.Values{
        "client_id":     []string{c.conf.ClientID},
        "scope":         []string{strings.Join(c.scopes, " ")},
        "grant_type":    []string{"client_credentials"},
        "client_secret": []string{c.conf.ClientSecret},
    }

    return fetchToken(c.ctx, c.conf.tokenURL(), v)
}

// TokenSource returns new token source using the configuration.
func (c *Config) TokenSource(ctx context.Context, scopes []string) (ts oauth2.TokenSource, err error) {
    if c.ClientSecret != "" {
        ts = &clientSecretTokenSource{
            conf:   c,
            scopes: scopes,
            ctx:    ctx,
        }
    } else if ts, err = c.jwtTokenSource(ctx, scopes); err != nil {
        return nil, fmt.Errorf("auth: %w", err)
    }

    return oauth2.ReuseTokenSource(nil, ts), nil
}

// Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens
// obtained from c.
func (c *Config) Client(ctx context.Context, scopes []string) (*http.Client, error) {
    ts, err := c.TokenSource(ctx, scopes)
    if err != nil {
        return nil, err
    }
    return oauth2.NewClient(ctx, ts), nil
}