package auth

import (


const (
    envTenant                      = "AZURE_CLIENT_TENANT"
    envClientID                    = "AZURE_CLIENT_ID"
    envClientSecret                = "AZURE_CLIENT_SECRET"
    envClientPKCS12Certificate     = "AZURE_CLIENT_PKCS12_CERTIFICATE"
    envClientCertificate           = "AZURE_CLIENT_CERTIFICATE"
    envClientCertificateThumbprint = "AZURE_CLIENT_CERTIFICATE_THUMBPRINT"
    envPrivateKey                  = "AZURE_CLIENT_PRIVATE_KEY"
    envPrivateKeyPassword          = "AZURE_DECRYPT_PASSWORD"

const assertionTokenDuration = time.Hour * 24

const assertionType = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"

// Config is the configuration for using Azure authentication
type Config struct {
    Tenant                      string `yaml:"tenant_id" validate:"omitempty,uuid4"`
    ClientID                    string `yaml:"client_id" validate:"omitempty,uuid4"`
    ClientSecret                string `yaml:"client_secret"`
    ClientPKCS12Certificate     string `yaml:"client_pkcs12_certificate"`
    ClientCertificate           string `yaml:"client_certificate"`
    ClientCertificateThumbprint string `yaml:"client_certificate_thumbprint"`
    PrivateKey                  string `yaml:"client_private_key"`
    PrivateKeyPassword          string `yaml:"decrypt_password"`

func (c *Config) tokenURL() string {
    return fmt.Sprintf("", c.Tenant)

func (c *Config) parseEnv() *Config {
    res := *c
    if res.Tenant == "" {
        res.Tenant = os.Getenv(envTenant)
    if res.ClientID == "" {
        res.ClientID = os.Getenv(envClientID)
    if res.ClientSecret == "" {
        res.ClientSecret = os.Getenv(envClientSecret)
    if res.ClientCertificate == "" {
        res.ClientCertificate = os.Getenv(envClientCertificate)
    if res.ClientCertificateThumbprint == "" {
        res.ClientCertificateThumbprint = os.Getenv(envClientCertificateThumbprint)
    if res.ClientPKCS12Certificate == "" {
        res.ClientPKCS12Certificate = os.Getenv(envClientPKCS12Certificate)
    if res.PrivateKey == "" {
        res.PrivateKey = os.Getenv(envPrivateKey)
    if res.PrivateKeyPassword == "" {
        res.PrivateKeyPassword = os.Getenv(envPrivateKeyPassword)
    return &res

func parsePKCS12Certificate(name, password string) (pk interface{}, thumbprint []byte, err error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, nil, err

    pk, cert, err := pkcs12.Decode(buf, password)
    if err != nil {
        return nil, nil, err

    sum := sha1.Sum(cert.Raw)
    return pk, sum[:], nil

func parsePrivateKey(name, password string) (pk interface{}, err error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, err

    block, _ := pem.Decode(buf)
    if block == nil {
        return nil, errors.New("failed to parse PEM block containing the private key")

    if block.Type == "RSA PRIVATE KEY" {
        var pkdata []byte
        if x509.IsEncryptedPEMBlock(block) {
            // Is it used anymore?
            pkdata, err = x509.DecryptPEMBlock(block, []byte(password))
            if err != nil {
                return nil, err
        } else {
            pkdata = block.Bytes
        return x509.ParsePKCS1PrivateKey(pkdata)

    return x509.ParsePKCS8PrivateKey(block.Bytes) // Unencrypted PKCS#8 only

func getThumbprint(name string) ([]byte, error) {
    buf, err := ioutil.ReadFile(name)
    if err != nil {
        return nil, err

    block, _ := pem.Decode(buf)
    if block == nil {
        return nil, errors.New("failed to parse PEM block containing the certificate")

    cert, err := x509.ParseCertificate(block.Bytes)
    if err != nil {
        return nil, err

    sum := sha1.Sum(cert.Raw)
    return sum[:], nil

type jwtTokenSource struct {
    conf           *Config
    scopes         []string
    certThumbprint []byte
    key            *rsa.PrivateKey
    ctx            context.Context

func (c *Config) jwtTokenSource(ctx context.Context, scopes []string) (oauth2.TokenSource, error) {
    var (
        pk         interface{}
        thumbprint []byte
        err        error

    if c.ClientPKCS12Certificate != "" {
        pk, thumbprint, err = parsePKCS12Certificate(c.ClientPKCS12Certificate, c.PrivateKeyPassword)
        if err != nil {
            return nil, err

    } else {
        pk, err = parsePrivateKey(c.PrivateKey, c.PrivateKeyPassword)
        if err != nil {
            return nil, err

        if c.ClientCertificate != "" {
            thumbprint, err = getThumbprint(c.ClientCertificate)
            if err != nil {
                return nil, err

        } else {
            thumbprint, err = hex.DecodeString(c.ClientCertificateThumbprint)
            if err != nil || len(thumbprint) != sha1.Size {
                thumbprint, err = base64.URLEncoding.DecodeString(c.ClientCertificateThumbprint)
                if err != nil || len(thumbprint) != sha1.Size {
                    return nil, errors.New("failed to decode thumbprint string")

    key, ok := pk.(*rsa.PrivateKey)
    if !ok {
        return nil, fmt.Errorf("not a RSA key: %T", pk)

    return &jwtTokenSource{
        conf:           c,
        scopes:         scopes,
        certThumbprint: thumbprint,
        key:            key,
        ctx:            ctx,
    }, nil

func fetchToken(ctx context.Context, url string, v url.Values) (*oauth2.Token, error) {
    client := oauth2.NewClient(ctx, nil)
    resp, err := client.PostForm(url, v)
    if err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
    defer resp.Body.Close()

    body, err := ioutil.ReadAll(resp.Body)
    if err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)

    if resp.StatusCode/100 != 2 {
        return nil, &oauth2.RetrieveError{
            Response: resp,
            Body:     body,

    var res struct {
        TokenType   string `json:"token_type"`
        ExpiresIn   int64  `json:"expires_in"`
        AccessToken string `json:"access_token"`

    if err := json.Unmarshal(body, &res); err != nil {
        return nil, fmt.Errorf("auth: cannot fetch token: %w", err)

    token := oauth2.Token{
        AccessToken: res.AccessToken,
        TokenType:   res.TokenType,
    if res.ExpiresIn > 0 {
        token.Expiry = time.Now().Add(time.Duration(res.ExpiresIn) * time.Second)

    var (
        p      jwt.Parser
        claims jwt.Claims
    if _, _, err := p.ParseUnverified(res.AccessToken, claims); err == nil {
        exp, err := claims.GetExpirationTime()
        if err != nil {
            return nil, fmt.Errorf("auth: cannot fetch expiration time: %v", err)
        } else {
            token.Expiry = exp.Time
    return &token, nil

func (j *jwtTokenSource) Token() (*oauth2.Token, error) {
    jti := make([]byte, 20)
    _, err := rand.Read(jti)
    if err != nil {
        return nil, fmt.Errorf("auth: %w", err)

    now := time.Now()

    claims := jwt.MapClaims{
        "aud": j.conf.tokenURL(),
        "iss": j.conf.ClientID,
        "sub": j.conf.ClientID,
        "jti": base64.RawURLEncoding.EncodeToString(jti),
        "nbf": now.Unix(),
        "exp": now.Add(assertionTokenDuration).Unix(),

    assertionToken := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)

    kid := base64.RawURLEncoding.EncodeToString(j.certThumbprint)
    assertionToken.Header["kid"] = kid
    assertionToken.Header["x5t"] = kid

    assertion, err := assertionToken.SignedString(j.key)
    if err != nil {
        return nil, fmt.Errorf("auth: %w", err)

    v := url.Values{
        "client_id":             []string{j.conf.ClientID},
        "scope":                 []string{strings.Join(j.scopes, " ")},
        "client_assertion_type": []string{assertionType},
        "client_assertion":      []string{assertion},
        "grant_type":            []string{"client_credentials"},

    return fetchToken(j.ctx, j.conf.tokenURL(), v)

type clientSecretTokenSource struct {
    conf   *Config
    scopes []string
    ctx    context.Context

func (c *clientSecretTokenSource) Token() (*oauth2.Token, error) {
    v := url.Values{
        "client_id":     []string{c.conf.ClientID},
        "scope":         []string{strings.Join(c.scopes, " ")},
        "grant_type":    []string{"client_credentials"},
        "client_secret": []string{c.conf.ClientSecret},

    return fetchToken(c.ctx, c.conf.tokenURL(), v)

// TokenSource returns new token source using the configuration.
func (c *Config) TokenSource(ctx context.Context, scopes []string) (ts oauth2.TokenSource, err error) {
    if c.ClientSecret != "" {
        ts = &clientSecretTokenSource{
            conf:   c,
            scopes: scopes,
            ctx:    ctx,
    } else if ts, err = c.jwtTokenSource(ctx, scopes); err != nil {
        return nil, fmt.Errorf("auth: %w", err)

    return oauth2.ReuseTokenSource(nil, ts), nil

// Client returns an HTTP client wrapping the context's
// HTTP transport and adding Authorization headers with tokens
// obtained from c.
func (c *Config) Client(ctx context.Context, scopes []string) (*http.Client, error) {
    ts, err := c.TokenSource(ctx, scopes)
    if err != nil {
        return nil, err
    return oauth2.NewClient(ctx, ts), nil