0x4b53/amqp-rpc

View on GitHub
testing.go

Summary

Maintainability
A
0 mins
Test Coverage
A
91%
package amqprpc

import (
    "context"
    "encoding/json"
    "fmt"
    "net/http"
    "net/url"
    "sync"
    "time"

    amqp "github.com/rabbitmq/amqp091-go"
)

const (
    testURL                  = "amqp://guest:guest@localhost:5672/"
    serverAPITestURL         = "http://guest:guest@localhost:15672/api"
    defaultTestQueue         = "test.queue"
    testClientConnectionName = "test-client"
    testServerConnectionName = "test-server"
)

// MockAcknowledger is a mocked amqp.Acknowledger, useful for tests.
type MockAcknowledger struct {
    Acks    int
    Nacks   int
    Rejects int
    OnAckFn func() error
}

// Ack increases Acks.
func (ma *MockAcknowledger) Ack(_ uint64, _ bool) error {
    ma.Acks++

    if ma.OnAckFn != nil {
        return ma.OnAckFn()
    }

    return nil
}

// Nack increases Nacks.
func (ma *MockAcknowledger) Nack(_ uint64, _, _ bool) error {
    ma.Nacks++
    return nil
}

// Reject increases Rejects.
func (ma *MockAcknowledger) Reject(_ uint64, _ bool) error {
    ma.Rejects++
    return nil
}

// startAndWait will start s by running ListenAndServe, it will then block
// until the server is started.
func startAndWait(s *Server) func() {
    started := make(chan struct{})
    once := sync.Once{}

    s.OnStarted(func(_, _ *amqp.Connection, _, _ *amqp.Channel) {
        once.Do(func() {
            close(started)
        })
    })

    done := make(chan struct{})

    go func() {
        s.ListenAndServe()
        close(done)
    }()

    <-started

    return func() {
        s.Stop()
        <-done
    }
}

func deleteQueue(name string) {
    queueURL := fmt.Sprintf("%s/queues/%s/%s", serverAPITestURL, url.PathEscape("/"), url.PathEscape(name))

    req, err := http.NewRequest(http.MethodDelete, queueURL, http.NoBody)
    if err != nil {
        panic(err)
    }

    resp, err := http.DefaultClient.Do(req)
    if err != nil {
        panic(err)
    }

    _ = resp.Body.Close()
}

func closeConnections(names ...string) {
    var (
        connectionsURL = fmt.Sprintf("%s/connections", serverAPITestURL)
        connections    []map[string]interface{}
    )

    // It takes a while (0.5s - 4s) for the management plugin to discover the
    // connections so we loop until we've found some.
    for i := 0; i < 20; i++ {
        resp, err := http.Get(connectionsURL) //nolint:gosec // This URL is fine!
        if err != nil {
            panic(err)
        }

        err = json.NewDecoder(resp.Body).Decode(&connections)
        if err != nil {
            panic(err)
        }

        _ = resp.Body.Close()

        if len(connections) == 0 {
            time.Sleep(time.Duration(i*100) * time.Millisecond)
            continue
        }

        break
    }

    for _, conn := range connections {
        // Should we close this connection?
        shouldRemove := false

        for _, name := range names {
            if conn["user_provided_name"] == name {
                shouldRemove = true
                break
            }
        }

        if !shouldRemove {
            continue
        }

        connName, ok := conn["name"].(string)
        if !ok {
            panic("name is not a string")
        }

        connectionURL := fmt.Sprintf("%s/connections/%s", serverAPITestURL, url.PathEscape(connName))

        req, err := http.NewRequest(http.MethodDelete, connectionURL, http.NoBody)
        if err != nil {
            panic(err)
        }

        resp, err := http.DefaultClient.Do(req)
        if err != nil {
            panic(err)
        }

        _ = resp.Body.Close()

        fmt.Println("closed", conn["user_provided_name"])
    }
}

func testServer() *Server {
    server := NewServer(testURL).
        WithDebugLogger(testLogFunc("debug")).
        WithErrorLogger(testLogFunc("ERROR")).
        WithDialConfig(
            amqp.Config{
                Properties: amqp.Table{
                    "connection_name": testServerConnectionName,
                },
            },
        )

    server.Bind(DirectBinding(defaultTestQueue, func(_ context.Context, rw *ResponseWriter, d amqp.Delivery) {
        fmt.Fprintf(rw, "Got message: %s", d.Body)
    }))

    return server
}

func testClient() *Client {
    return NewClient(testURL).
        WithDebugLogger(testLogFunc("debug")).
        WithErrorLogger(testLogFunc("ERROR")).
        WithDialConfig(
            amqp.Config{
                Properties: amqp.Table{
                    "connection_name": testClientConnectionName,
                },
            },
        )
}

func initTest() (server *Server, client *Client, start, stop func()) {
    deleteQueue(defaultTestQueue) // Ensure queue is clean from the start.

    server = testServer()
    client = testClient()

    var stopServer func()

    stop = func() {
        client.Stop()
        stopServer()
    }

    start = func() {
        stopServer = startAndWait(server)

        // Ensure the client is started.
        _, err := client.Send(
            NewRequest().
                WithRoutingKey(defaultTestQueue).
                WithResponse(false),
        )
        if err != nil {
            panic(err)
        }

        stop = func() {
            stopServer()
            client.Stop()
        }
    }

    return
}

func testLogFunc(prefix string) LogFunc {
    startTime := time.Now()

    return func(format string, args ...interface{}) {
        format = fmt.Sprintf("[%s] %s: %s\n", prefix, time.Since(startTime).String(), format)
        fmt.Printf(format, args...)
    }
}