ajbeach2/sqsworker

View on GitHub
sqsworker.go

Summary

Maintainability
A
35 mins
Test Coverage
B
89%
package sqsworker

import (
    "context"
    "fmt"
    "github.com/aws/aws-sdk-go/aws"
    "github.com/aws/aws-sdk-go/aws/awserr"
    "github.com/aws/aws-sdk-go/aws/session"
    "github.com/aws/aws-sdk-go/service/sns"
    "github.com/aws/aws-sdk-go/service/sns/snsiface"
    "github.com/aws/aws-sdk-go/service/sqs"
    "github.com/aws/aws-sdk-go/service/sqs/sqsiface"
    "go.uber.org/zap"
    "os"
    "runtime"
    "sync"
)

// DefaultWorkers Number of worker goroutines to spawn, each runs the handler function
const DefaultWorkers = 1

// DefaultMaxNumberOfMessages amount of messages received by each SQS request
const DefaultMaxNumberOfMessages = 10

// DefaultVisibilityTimeout SQS visibility Timeout
const DefaultVisibilityTimeout = 60

// DefaultWaitTimeSeconds Long-polling interval for SQS
const DefaultWaitTimeSeconds = 20

// Handler interface for SQS consumers
type Processor interface {
    Process(context.Context, *sqs.Message, *sns.PublishInput) error
}

// Callback which is passed result from handler on success
type Callback func(*string, error)

// Worker encapsulates the SQS consumer
type Worker struct {
    QueueURL  string
    TopicArn  string
    Queue     sqsiface.SQSAPI
    Topic     snsiface.SNSAPI
    Session   *session.Session
    Consumers int
    Logger    *zap.Logger
    Processor Processor
    Callback  Callback
    Name      string
    done      chan error
}

// WorkerConfig settings for Worker to be passed in NewWorker Contstuctor
type WorkerConfig struct {
    QueueURL string
    TopicArn string
    // If the number of workers is 0, the number of workers defaults to runtime.NumCPU()
    Workers   int
    Processor Processor
    Callback  Callback
    Name      string
    Logger    *zap.Logger
}

func (w *Worker) logError(msg string, err error) {
    if w.Logger != nil {
        w.Logger.Error(err.Error(),
            zap.String("app", w.Name),
            zap.String("msg", msg),
            zap.Error(err),
        )
    }
}

func (w *Worker) logInfo(msg string) {
    if w.Logger != nil {
        w.Logger.Info(msg,
            zap.String("app", w.Name),
        )
    }
}

func (w *Worker) deleteMessage(m *sqs.DeleteMessageInput) error {
    _, err := w.Queue.DeleteMessage(m)
    if err != nil {
        return err
    }
    return nil
}

func (w *Worker) sendMessage(msg *sns.PublishInput) error {
    if w.TopicArn == "" {
        return nil
    }

    if msg.Message == nil {
        return nil
    }

    _, err := w.Topic.Publish(msg)
    return err
}

func (w *Worker) consumer(ctx context.Context, in chan *sqs.Message) {
    var msgString string
    deleteInput := &sqs.DeleteMessageInput{QueueUrl: &w.QueueURL}
    var sendInput *sns.PublishInput
    var err error
    for {
        select {
        case <-ctx.Done():
            return
        case msg := <-in:
            if w.Callback != nil || w.TopicArn != "" {
                sendInput = &sns.PublishInput{TopicArn: &w.TopicArn, Message: &msgString}
            }
            err = w.Processor.Process(ctx, msg, sendInput)
            if err == nil {
                err = w.sendMessage(sendInput)
                if err != nil {
                    w.logError("send message failed!", err)
                }
                deleteInput.ReceiptHandle = msg.ReceiptHandle
                err = w.deleteMessage(deleteInput)
                if err != nil {
                    w.logError("delete message failed!", err)
                }
            } else {
                w.logError("handler failed!", err)
            }

            if w.Callback != nil {
                w.Callback(sendInput.Message, err)
            }
        }
    }
}

func (w *Worker) producer(ctx context.Context, out chan *sqs.Message) {
    params := &sqs.ReceiveMessageInput{
        QueueUrl:            aws.String(w.QueueURL),
        MaxNumberOfMessages: aws.Int64(DefaultMaxNumberOfMessages),
        VisibilityTimeout:   aws.Int64(DefaultVisibilityTimeout),
        WaitTimeSeconds:     aws.Int64(DefaultWaitTimeSeconds),
    }

    for {
        select {
        case <-ctx.Done():
            return
        default:
            req, resp := w.Queue.ReceiveMessageRequest(params)
            err := req.Send()
            if err != nil {
                w.logError("receive messages failed!", err)
            } else {
                messages := resp.Messages
                if len(messages) > 0 {
                    for _, message := range messages {
                        out <- message
                    }
                }
            }
        }
    }
}

// Close function will send a signal to all workers to exit
func (w *Worker) Close() {
    close(w.done)
}

// Run does the main consumer/producer loop
func (w *Worker) Run() {
    ctx, cancel := context.WithCancel(context.Background())
    messages := make(chan *sqs.Message, w.Consumers)

    w.logInfo(fmt.Sprint("Staring producer"))
    go func() {
        w.producer(ctx, messages)
        close(messages)
    }()

    go func() {
        <-w.done
        cancel()
    }()

    w.logInfo(fmt.Sprint("Staring consumer with ", w.Consumers, " consumers"))
    // Consume messages
    var wg sync.WaitGroup
    for x := 0; x < w.Consumers; x++ {
        wg.Add(1)
        go func() {
            defer wg.Done()
            w.consumer(ctx, messages)
        }()
    }
    wg.Wait()
}

// CreateQueue Create queue by name.
func CreateQueue(name string, sqsc sqsiface.SQSAPI) (string, error) {
    result, err := sqsc.CreateQueue(&sqs.CreateQueueInput{
        QueueName: aws.String(name),
    })
    if err != nil {
        return "", err
    }

    return *result.QueueUrl, nil
}

// GetOrCreateQueue an SQS Queue by name.
func GetOrCreateQueue(name string, sqsc sqsiface.SQSAPI) (string, error) {
    queueOut, err := sqsc.GetQueueUrl(&sqs.GetQueueUrlInput{
        QueueName: aws.String(name),
    })

    if aerr, ok := err.(awserr.Error); ok {
        switch aerr.Code() {
        case sqs.ErrCodeQueueDoesNotExist:
            return CreateQueue(name, sqsc)
        }
    }

    return *queueOut.QueueUrl, err
}

// GetOrCreateTopic Create SNS topic by name.
func GetOrCreateTopic(name string, snsc snsiface.SNSAPI) (string, error) {
    if name == "" {
        return "", nil
    }

    snsOut, err := snsc.CreateTopic(&sns.CreateTopicInput{
        Name: aws.String(name),
    })

    return *snsOut.TopicArn, err
}

// NewWorker constructor for SQS Worker
func NewWorker(sess *session.Session, wc WorkerConfig) *Worker {
    var logger *zap.Logger
    workers := runtime.NumCPU()
    var queueURL, topicARN = wc.QueueURL, wc.TopicArn

    if wc.Workers != 0 {
        workers = wc.Workers
    }

    if wc.Logger == nil {
        logger, _ = zap.NewProduction()
    } else {
        logger = wc.Logger
    }

    if queueURL == "" {
        queueURL = os.Getenv("QUEUE_URL")
    }

    if topicARN == "" {
        topicARN = os.Getenv("TOPIC_ARN")
    }

    return &Worker{
        queueURL,
        topicARN,
        sqs.New(sess),
        sns.New(sess),
        sess,
        workers,
        logger,
        wc.Processor,
        wc.Callback,
        wc.Name,
        make(chan error),
    }
}