ory-am/hydra

View on GitHub
test/mock-client/main.go

Summary

Maintainability
A
3 hrs
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package main

import (
    "bytes"
    "context"
    "encoding/json"
    "flag"
    "fmt"
    "io"
    "log"
    "net/http"
    "net/http/cookiejar"
    "net/url"
    "os"
    "strings"
    "time"

    hydra "github.com/ory/hydra-client-go/v2"

    "golang.org/x/oauth2"

    "github.com/ory/hydra/v2/x"
    "github.com/ory/x/cmdx"
    "github.com/ory/x/urlx"
)

var hydraURL = urlx.ParseOrPanic(os.Getenv("HYDRA_ADMIN_URL"))
var sdk = hydra.NewAPIClient(hydra.NewConfiguration())

func init() {
    sdk.GetConfig().Servers = hydra.ServerConfigurations{{URL: hydraURL.String()}}
}

type oauth2token struct {
    IDToken      string    `json:"id_token"`
    AccessToken  string    `json:"access_token"`
    TokenType    string    `json:"token_type,omitempty"`
    RefreshToken string    `json:"refresh_token,omitempty"`
    Expiry       time.Time `json:"expiry,omitempty"`
}

var printToken, printCookie bool

func init() {
    flag.BoolVar(&printToken, "print-token", false, "")
    flag.BoolVar(&printCookie, "print-cookie", false, "")
}

func main() {
    flag.Parse()
    conf := oauth2.Config{
        ClientID:     os.Getenv("OAUTH2_CLIENT_ID"),
        ClientSecret: os.Getenv("OAUTH2_CLIENT_SECRET"),
        Endpoint: oauth2.Endpoint{
            AuthURL:  strings.TrimRight(os.Getenv("HYDRA_URL"), "/") + "/oauth2/auth",
            TokenURL: strings.TrimRight(os.Getenv("HYDRA_URL"), "/") + "/oauth2/token",
        },
        Scopes:      strings.Split(os.Getenv("OAUTH2_SCOPE"), ","),
        RedirectURL: os.Getenv("REDIRECT_URL"),
    }
    au := conf.AuthCodeURL("some-stupid-state-foo") + os.Getenv("OAUTH2_EXTRA")
    c, err := cookiejar.New(&cookiejar.Options{})
    if err != nil {
        log.Fatalf("Unable to create cookie jar: %s", err)
    }

    u, _ := url.Parse("http://127.0.0.1")
    if os.Getenv("AUTH_COOKIE") != "" {
        c.SetCookies(u, []*http.Cookie{{Name: "oauth2_authentication_session", Value: os.Getenv("AUTH_COOKIE")}})
    }

    resp, err := (&http.Client{
        Jar: c,
        // Hack to fix cookie across domains
        CheckRedirect: func(req *http.Request, via []*http.Request) error {
            if len(via) > 0 && req.Header.Get("cookie") == "" {
                req.Header.Set("Cookie", via[len(via)-1].Header.Get("Cookie"))
            }

            return nil
        },
    }).Get(au)
    cmdx.CheckResponse(err, http.StatusOK, resp)
    defer resp.Body.Close()

    out, err := io.ReadAll(resp.Body)
    if err != nil {
        log.Fatalf("Unable to read body: %s", err)
    }

    for _, c := range c.Cookies(u) {
        if c.Name == "oauth2_authentication_session" {
            if printCookie {
                fmt.Print(c.Value)
            }
        }
    }

    var token oauth2token
    if err := json.Unmarshal(out, &token); err != nil {
        log.Fatalf("Unable transform to token: %s", err)
    }

    checkTokenResponse(token)
    for i := 0; i <= 5; i++ {
        token = refreshToken(token)
        checkTokenResponse(token)
    }

    newToken := refreshToken(token)
    if printToken {
        fmt.Printf("%s", newToken.AccessToken)
    }

    // refreshing the same token twice does not work
    resp, err = refreshTokenRequest(token)
    cmdx.CheckResponse(err, http.StatusBadRequest, resp)
    defer resp.Body.Close()
}

func refreshTokenRequest(token oauth2token) (*http.Response, error) {
    req, err := http.NewRequest("POST", strings.TrimRight(os.Getenv("HYDRA_URL"), "/")+"/oauth2/token", bytes.NewBufferString(url.Values{
        "refresh_token": {token.RefreshToken},
        "grant_type":    {"refresh_token"},
    }.Encode()))
    cmdx.Must(err, "%s", err)
    req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
    req.SetBasicAuth(os.Getenv("OAUTH2_CLIENT_ID"), os.Getenv("OAUTH2_CLIENT_SECRET"))
    return http.DefaultClient.Do(req)
}

func refreshToken(token oauth2token) (result oauth2token) {
    resp, err := refreshTokenRequest(token)
    cmdx.CheckResponse(err, http.StatusOK, resp)
    defer resp.Body.Close()

    err = json.NewDecoder(resp.Body).Decode(&result)
    cmdx.Must(err, "Unable to decode refresh token: %s", err)
    return result
}

func checkTokenResponse(token oauth2token) {
    if token.RefreshToken == "" {
        log.Fatalf("Expected a refresh token but none received: %+v", token)
    }

    // This value oscillates between bar and rab, depending on whether authorization was remembered or not. Check
    // mock-lcp which sets the value
    expectedValue := "bar"
    if strings.Contains(os.Getenv("OAUTH2_EXTRA"), "prompt=none") {
        expectedValue = "rab"
    }

    if os.Getenv("OAUTH2_ACCESS_TOKEN_STRATEGY") == "jwt" {
        parts := strings.Split(token.AccessToken, ".")

        if len(parts) != 3 {
            log.Fatalf("JWT Access Token does not seem to have three parts: %d - %+v - %v", len(parts), token, parts)
        }

        payload, err := x.DecodeSegment(parts[1])
        if err != nil {
            log.Fatalf("Unable to decode id token segment: %s", err)
        }

        var claims map[string]interface{}
        if err := json.Unmarshal(payload, &claims); err != nil {
            log.Fatalf("Unable to unmarshal id token body: %s", err)
        }

        if fmt.Sprintf("%s", claims["sub"]) != "the-subject" {
            log.Fatalf("Expected subject from access token to be %s but got %s", "the-subject", claims["sub"])
        }

        ext := claims["ext"].(map[string]interface{})
        if ext["foo"] != expectedValue {
            log.Fatalf("Expected extra field \"foo\" from access token to be \"%s\" but got %s", expectedValue, ext["foo"])
        }
    }

    intro, resp, err := sdk.OAuth2Api.IntrospectOAuth2Token(context.Background()).Token(token.AccessToken).Execute()
    defer resp.Body.Close()
    if err != nil {
        log.Fatalf("Unable to introspect OAuth2 token: %s", err)
    }

    if !intro.Active {
        log.Fatalf("Expected token to be active: %s", token.AccessToken)
    }

    if *intro.Sub != "the-subject" {
        log.Fatalf("Expected subject from access token to be %s but got %s", "the-subject", *intro.Sub)
    }

    if intro.Ext["foo"] != expectedValue {
        log.Fatalf("Expected extra field \"foo\" from access token to be \"%s\" but got %s", expectedValue, intro.Ext["foo"])
    }

    idt := token.IDToken
    if len(idt) == 0 {
        log.Fatalf("ID Token does not seem to be set: %+v", token)
    }

    parts := strings.Split(idt, ".")
    if len(parts) != 3 {
        log.Fatalf("ID Token does not seem to have three parts: %d - %+v - %v", len(parts), token, parts)
    }

    payload, err := x.DecodeSegment(parts[1])
    if err != nil {
        log.Fatalf("Unable to decode id token segment: %s", err)
    }

    var claims map[string]interface{}
    if err := json.Unmarshal(payload, &claims); err != nil {
        log.Fatalf("Unable to unmarshal id token body: %s", err)
    }

    if fmt.Sprintf("%s", claims["sub"]) != "the-subject" {
        log.Fatalf("Expected subject from id token to be %s but got %s", "the-subject", claims["sub"])
    }

    if fmt.Sprintf("%s", claims["baz"]) != expectedValue {
        log.Fatalf("Expected extra field \"baz\" from access token to be \"%s\" but got \"%s\"", expectedValue, claims["baz"])
    }
}