cyberark/secretless-broker

View on GitHub
internal/plugin/connectors/http/aws/connector.go

Summary

Maintainability
A
0 mins
Test Coverage
A
91%
package aws

import (
    gohttp "net/http"
    "strings"

    "github.com/cyberark/secretless-broker/pkg/secretless/log"
    "github.com/cyberark/secretless-broker/pkg/secretless/plugin/connector"
)

// Connector injects an HTTP request with AWS authorization headers.
type Connector struct {
    logger log.Logger
}

// Connect is the function that implements the http.Connector func
// signature. It has access to the client http.Request and the credentials (as a
// map), and is expected to decorate the request with Authorization headers.
//
// Connect uses the "accessKeyId", "secretAccessKey" and optional "accessToken"
// credentials to sign the Authorization header, following the AWS signature
// format.
func (c *Connector) Connect(
    req *gohttp.Request,
    credentialsByID connector.CredentialValuesByID,
) error {
    var err error

    // Extract metadata of a signed AWS request: date, region and service name.
    reqMeta, err := newRequestMetadata(req)
    if err != nil {
        return err
    }

    // No metadata means the original request was not signed. Don't sign this
    // request either.
    if reqMeta == nil {
        return nil
    }

    // Set AWS endpoint
    // NOTE: this must be done before signing the request, otherwise the modified request
    // will fail the integrity check.
    err = maybeSetAmzEndpoint(req, reqMeta)
    if err != nil {
        return err
    }

    // Use metadata and credentials to sign request
    c.logger.Debugf(
        "Signing for service=%s region=%s signedHeaders=%s",
        reqMeta.serviceName,
        reqMeta.region,
        strings.Join(reqMeta.signedHeaders, ","),
    )

    // Temporarily remove any headers that were not signed in the original request.
    unsignedHeaders := removeUnsignedHeaders(req, reqMeta)

    // Sign the request.
    err = signRequest(req, reqMeta, credentialsByID)
    if err != nil {
        return err
    }

    // Reinstate unsigned headers without clobbering the effects of signing.
    reinstateUnsignedHeaders(req, unsignedHeaders)

    return nil
}

func removeUnsignedHeaders(req *gohttp.Request, reqMeta *requestMetadata) map[string][]string {
    var signedHeadersMap = map[string]struct{}{}
    for _, key := range reqMeta.signedHeaders {
        signedHeadersMap[key] = struct{}{}
    }

    var unsignedHeaders = map[string][]string{}
    for key, value := range req.Header {
        if _, isSignedHeader := signedHeadersMap[key]; isSignedHeader {
            continue
        }

        unsignedHeaders[key] = value
        req.Header.Del(key)
    }

    return unsignedHeaders
}

func reinstateUnsignedHeaders(req *gohttp.Request, unsignedHeaders map[string][]string) {
    // Reserved meaning the headers already on the request. We shouldn't touch those because
    // they should only be the signed headers and those generated by signing
    var reservedHeaders = map[string]struct{}{}
    for key := range req.Header {
        reservedHeaders[key] = struct{}{}
    }

    for key, values := range unsignedHeaders {
        // Ignore reserved headers, we don't want to mess with those!
        if _, isReservedHeader := reservedHeaders[key]; isReservedHeader {
            continue
        }

        // Add back all the values for each non-reserved unsigned header
        for _, value := range values {
            req.Header.Add(key, value)
        }
    }
}