ory-am/hydra

View on GitHub
x/oauth2cors/cors_test.go

Summary

Maintainability
F
3 days
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package oauth2cors_test

import (
    "bytes"
    "context"
    "fmt"
    "io"
    "net/http"
    "net/http/httptest"
    "net/url"
    "testing"
    "time"

    "github.com/ory/hydra/v2/driver"
    "github.com/ory/x/contextx"

    "github.com/ory/hydra/v2/x"

    "github.com/stretchr/testify/assert"
    "github.com/stretchr/testify/require"

    "github.com/ory/fosite"
    "github.com/ory/hydra/v2/client"
    "github.com/ory/hydra/v2/internal"
    "github.com/ory/hydra/v2/oauth2"
)

func TestOAuth2AwareCORSMiddleware(t *testing.T) {
    ctx := context.Background()
    r := internal.NewRegistryMemory(t, internal.NewConfigurationWithDefaults(), &contextx.Default{})
    token, signature, _ := r.OAuth2HMACStrategy().GenerateAccessToken(ctx, nil)

    for k, tc := range []struct {
        prep         func(*testing.T, driver.Registry)
        d            string
        mw           func(http.Handler) http.Handler
        code         int
        header       http.Header
        expectHeader http.Header
        method       string
        body         io.Reader
    }{
        {
            d:            "should ignore when disabled",
            prep:         func(t *testing.T, r driver.Registry) {},
            code:         http.StatusNotImplemented,
            header:       http.Header{},
            expectHeader: http.Header{},
        },
        {
            d: "should reject when basic auth but client does not exist and cors enabled",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo", "bar"))}},
            expectHeader: http.Header{"Vary": {"Origin"}},
        },
        {
            d: "should reject when post auth client exists but origin not allowed",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
            expectHeader: http.Header{"Vary": {"Origin"}},
            method:       http.MethodPost,
            body:         bytes.NewBufferString(url.Values{"client_id": []string{"foo-2"}}.Encode()),
        },
        {
            d: "should accept when post auth client exists and origin allowed",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Content-Type": {"application/x-www-form-urlencoded"}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
            method:       http.MethodPost,
            body:         bytes.NewBufferString(url.Values{"client_id": {"foo-3"}}.Encode()),
        },
        {
            d: "should reject when basic auth client exists but origin not allowed",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-2", Secret: "bar", AllowedCORSOrigins: []string{"http://not-foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-2", "bar"))}},
            expectHeader: http.Header{"Vary": {"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin allowed",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin allowed",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-3", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-3", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed per client",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*.foobar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and wildcard origin is allowed per client",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-4", Secret: "bar", AllowedCORSOrigins: []string{"http://*"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-4", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin (with full wildcard) is allowed globally",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"*"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-5", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"*"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-5", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"*"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin (with partial wildcard) is allowed globally",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://*.foobar.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-6", Secret: "bar", AllowedCORSOrigins: []string{"http://barbar.com"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-6", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept when basic auth client exists and origin (with full wildcard) allowed per client",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-7", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-7", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should succeed on pre-flight request when token introspection fails",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer 1234"}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
            method:       "OPTIONS",
        },
        {
            d: "should fail when token introspection fails",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer 1234"}},
            expectHeader: http.Header{"Vary": {"Origin"}},
        },
        {
            d: "should work when token introspection returns a session",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://not-test-domain.com"})
                sess := oauth2.NewSession("foo-9")
                sess.SetExpiresAt(fosite.AccessToken, time.Now().Add(time.Hour))
                ar := fosite.NewAccessRequest(sess)
                cl := &client.Client{ID: "foo-9", Secret: "bar", AllowedCORSOrigins: []string{"http://foobar.com"}}
                ar.Client = cl

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, cl)
                _ = r.OAuth2Storage().CreateAccessTokenSession(ctx, signature, ar)
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foobar.com"}, "Authorization": {"Bearer " + token}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept any allowed specified origin protocol",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-11", Secret: "bar", AllowedCORSOrigins: []string{"*"}})
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://*", "https://*"})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://foo.foobar.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-11", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://foo.foobar.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept client origin when basic auth client exists and origin is set at the client as well as the server",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://**.example.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-12", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://myapp.example.biz"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-12", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://myapp.example.biz"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
        {
            d: "should accept server origin when basic auth client exists and origin is set at the client as well as the server",
            prep: func(t *testing.T, r driver.Registry) {
                r.Config().MustSet(ctx, "serve.public.cors.enabled", true)
                r.Config().MustSet(ctx, "serve.public.cors.allowed_origins", []string{"http://**.example.com"})

                // Ignore unique violations
                _ = r.ClientManager().CreateClient(ctx, &client.Client{ID: "foo-13", Secret: "bar", AllowedCORSOrigins: []string{"http://myapp.example.biz"}})
            },
            code:         http.StatusNotImplemented,
            header:       http.Header{"Origin": {"http://client-app.example.com"}, "Authorization": {fmt.Sprintf("Basic %s", x.BasicAuth("foo-13", "bar"))}},
            expectHeader: http.Header{"Access-Control-Allow-Credentials": []string{"true"}, "Access-Control-Allow-Origin": []string{"http://client-app.example.com"}, "Access-Control-Expose-Headers": []string{"Cache-Control, Expires, Last-Modified, Pragma, Content-Length, Content-Language, Content-Type"}, "Vary": []string{"Origin"}},
        },
    } {
        t.Run(fmt.Sprintf("case=%d/description=%s", k, tc.d), func(t *testing.T) {
            r.WithConfig(internal.NewConfigurationWithDefaults())

            if tc.prep != nil {
                tc.prep(t, r)
            }

            method := "GET"
            if tc.method != "" {
                method = tc.method
            }
            req, err := http.NewRequest(method, "http://foobar.com/", tc.body)
            require.NoError(t, err)
            for k := range tc.header {
                req.Header.Set(k, tc.header.Get(k))
            }

            res := httptest.NewRecorder()
            r.OAuth2AwareMiddleware()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                w.WriteHeader(http.StatusNotImplemented)
            })).ServeHTTP(res, req)
            require.NoError(t, err)
            assert.EqualValues(t, tc.code, res.Code)
            assert.EqualValues(t, tc.expectHeader, res.Header())
        })
    }
}