pkg/stomp/broker.go
package stomp
import (
"errors"
"fmt"
"log"
"net"
"strconv"
"strings"
"sync"
"github.com/cenkalti/backoff"
"github.com/go-co-op/gocron"
"github.com/google/uuid"
)
// Session handles the STOMP client session on connection
type Session struct {
conn net.Conn
sessionID string
loginFunc LoginFunc
wgSessions *sync.WaitGroup
hbSendIntervalMsec int
hbRecvIntervalMsec int
hbJob *gocron.Job
}
// NewSession creates a new session object & maintains the session state internally
func NewSession(conn net.Conn, loginFunc LoginFunc, wg *sync.WaitGroup,
heartbeatSendIntervalMsec, heartbeatReceiveIntervalMsec int,
) *Session {
return &Session{
conn: conn,
loginFunc: loginFunc,
sessionID: uuid.NewString(),
wgSessions: wg,
hbSendIntervalMsec: heartbeatSendIntervalMsec,
hbRecvIntervalMsec: heartbeatReceiveIntervalMsec,
}
}
// LoginFunc represents the user-defined authentication function
type LoginFunc func(login, passcode string) error
// Start begins the STOMP session with the Client
func (sess *Session) Start() {
defer sess.cleanup()
for raw := range frameScanner(sess.conn) {
frame, err := NewFrameFromBytes(raw)
if err != nil {
_ = sess.sendError(err, fmt.Sprint("Frame serialization error:"+frame.String()))
return
}
if err = frame.Validate(ClientFrame); err != nil {
_ = sess.sendError(err, fmt.Sprint("Frame validation error:"+frame.String()))
return
}
if err = sess.stateMachine(frame); err != nil {
log.Println(err)
return
}
}
}
func (sess *Session) cleanup() {
_ = sess.conn.Close()
sess.wgSessions.Done()
if sess.hbJob != nil {
sched.RemoveByReference(sess.hbJob)
}
}
// sendError is the helper function to send the ERROR frames
func (sess *Session) sendError(err error, payload string) error {
return sess.send(CmdError, map[Header]string{
HdrKeyContentType: "text/plain",
HdrKeyContentLength: strconv.Itoa(len(payload)),
HdrKeyMessage: err.Error(),
}, []byte(payload))
}
// stateMachine is the brain of the protocol
func (sess *Session) stateMachine(frame *Frame) error {
switch frame.command {
case CmdConnect, CmdStomp:
if err := sess.handleConnect(frame); err != nil {
return err
}
case CmdSend:
// If the message is part of an ongoing transaction
if txID := frame.getHeader(HdrKeyTransaction); txID != "" {
if err := bufferTxMessage(txID, frame); err != nil {
return err
}
return nil
}
// Not part of transaction
if err := publish(frame, ""); err != nil {
return err
}
case CmdSubscribe:
ack := HdrValAckAuto
if ackStr := frame.getHeader(HdrKeyAck); ackStr == "" {
ack = AckMode(ackStr)
}
if err := addSubscription(frame.getHeader(HdrKeyDestination), frame.getHeader(HdrKeyID), ack, sess); err != nil {
return err
}
case CmdUnsubscribe:
if err := removeSubscription(frame.getHeader(HdrKeyID)); err != nil {
return err
}
case CmdAck:
if err := processAck(frame.headers[HdrKeyID]); err != nil {
return err
}
case CmdNack:
// if err := processNack(frame.headers[HdrKeyID]); err != nil {
// return err
// }
case CmdBegin:
if err := startTx(frame.getHeader(HdrKeyTransaction)); err != nil {
return err
}
case CmdCommit:
txID := frame.getHeader(HdrKeyTransaction)
// Pick each message from TX buffer
if err := foreachTx(txID, func(frameTx *Frame) error {
// Send the message to each subscriber
if err := publish(frameTx, txID); err != nil {
return err
}
return nil
}); err != nil {
return err
}
case CmdAbort:
txID := frame.getHeader(HdrKeyTransaction)
if err := dropTx(txID); err != nil {
return err
}
case CmdDisconnect:
_ = cleanupSubscriptions(sess.sessionID)
_ = sess.send(CmdReceipt, map[Header]string{HdrKeyReceiptID: frame.getHeader(HdrKeyReceipt)}, nil)
_ = sess.conn.Close()
}
return nil
}
func (sess *Session) sendMessage(dest, subsID string, ackNum uint32, txID string, headers map[Header]string,
body []byte,
) error {
h := map[Header]string{
HdrKeyDestination: dest,
HdrKeyMessageID: uuid.NewString(),
HdrKeySubscription: subsID,
}
h[HdrKeyAck] = fmtAckNum(dest, subsID, ackNum)
if txID != "" {
h[HdrKeyTransaction] = txID
}
for k, v := range headers {
h[Header(strings.ToLower(string(k)))] = v
}
return sess.send(CmdMessage, h, body)
}
func (sess *Session) sendRaw(body []byte) error {
sendIt := func() error {
if _, err := sess.conn.Write(body); err != nil {
log.Println(err)
return err
}
return nil
}
b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3)
if err := backoff.Retry(sendIt, b); err != nil {
return err
}
return nil
}
func (sess *Session) send(cmd Command, headers map[Header]string, body []byte) error {
f := NewFrame(cmd, headers, body)
// Make this check optional later
if err := f.Validate(ServerFrame); err != nil {
return err
}
sendIt := func() error {
if _, err := sess.conn.Write(f.Serialize()); err != nil {
log.Println(err)
return err
}
return nil
}
// Retry sending on error
b := backoff.WithMaxRetries(backoff.NewExponentialBackOff(), 3)
if err := backoff.Retry(sendIt, b); err != nil {
return err
}
return nil
}
// handleConnect responds to the CONNECT message from client
func (sess *Session) handleConnect(f *Frame) error {
// Authentication
if sess.loginFunc != nil {
login, passcode := f.getHeader(HdrKeyLogin), f.getHeader(HdrKeyPassCode)
if err := sess.loginFunc(login, passcode); err != nil {
_ = sess.sendError(errors.New("login failed"), "Authentication failed:\n"+err.Error())
return errorMsg(errBrokerStateMachine, "Login error: "+err.Error())
}
}
// Version negotiation
ver := ""
for _, v := range strings.Split(f.getHeader(HdrKeyAcceptVersion), ",") {
if v == "1.2" {
ver = "1.2"
break
}
}
if ver == "" {
// Send version ERROR
return errorMsg(errBrokerStateMachine, "Invalid client version received: "+f.getHeader(HdrKeyVersion))
}
// Heartbeat negotiation
if hbVal := f.getHeader(HdrKeyHeartBeat); hbVal != "" {
if err := sess.negotiateHeartbeats(hbVal); err != nil {
return errorMsg(errBrokerStateMachine, "Heartbeat negotiation: "+err.Error())
}
}
// Respond with CONNECTED
if err := sess.send(CmdConnected, map[Header]string{
HdrKeyVersion: ver,
HdrKeySession: sess.sessionID,
HdrKeyServer: "go-proto-stomp/" + releaseVersion,
HdrKeyHeartBeat: fmt.Sprintf("%d,%d", sess.hbSendIntervalMsec, sess.hbRecvIntervalMsec),
}, nil); err != nil {
return err
}
return nil
}
func (sess *Session) negotiateHeartbeats(hbVal string) error {
clientSendInterval, clientRecvInterval, err := parseHbVal(hbVal)
if err != nil {
return err
}
// Send-HB negotiation
if clientSendInterval == 0 || sess.hbRecvIntervalMsec == 0 {
sess.hbRecvIntervalMsec = 0
} else if clientSendInterval > sess.hbRecvIntervalMsec {
sess.hbRecvIntervalMsec = clientSendInterval
}
// Receive-HB negotiation
if clientRecvInterval == 0 || sess.hbSendIntervalMsec == 0 {
sess.hbSendIntervalMsec = 0
return nil // no heartbeats to be sent
} else if clientRecvInterval > sess.hbSendIntervalMsec {
sess.hbSendIntervalMsec = clientRecvInterval
}
// Schedule sending heartbeats by hbSendIntervalMsec
sess.hbJob, err = sched.Every(sess.hbSendIntervalMsec).Milliseconds().Tag(sess.sessionID).Do(
func() {
_ = sess.sendRaw([]byte("\n"))
})
if err != nil {
return errorMsg(errBrokerStateMachine, "Heartbeat setup error: "+err.Error())
}
sched.StartAsync()
return nil
}
// Broker lists the methods supported by the STOMP brokers
type Broker interface {
// ListenAndServe is a blocking method that keeps accepting the client connections and handles the STOMP messages.
ListenAndServe()
// Shutdown should be called to bring down the underlying server gracefully.
Shutdown()
}
// BrokerOpts is passed as an argument to StartBroker
type BrokerOpts struct {
// Transport refers to the underlying protocol for STOMP.
// Choices: TransportTCP, TransportWebsocket. Default: TransportTCP
Transport Transport
// Host is the name of the host or IP to bind the server to. Default: localhost
Host string
// Port is the port number for the server to listen on. Default: 61613 (DefaultPort)
Port string
// LoginFunc is a user defined function for authenticating the user. Default: nil
// It is of the form `func(login, passcode string) error`
LoginFunc LoginFunc
// HeartbeatSendIntervalMsec is the interval in milliseconds by which the broker can send heartbeats.
// The broker will negotiate using this value with the client. Default: 0 (no heartbeats)
// It will not send the heartbeats by an interval any smaller than this value.
HeartbeatSendIntervalMsec int
// HeartbeatReceiveIntervalMsec is the interval in milliseconds by which the broker can receive heartbeats.
// The broker will negotiate using this value with the client. Default: 0 (no heartbeats)
// This is to tell the client that the broker cannot receive heartbeats by any shorter interval than this value.
HeartbeatReceiveIntervalMsec int
}
// StartBroker is the entry point for the STOMP broker.
func StartBroker(opts *BrokerOpts) (Broker, error) {
var broker Broker
var err error
// Set default values
if opts.Host == "" {
opts.Host = "localhost"
}
if opts.Port == "" {
opts.Port = DefaultPort
}
if opts.Transport == "" {
opts.Transport = TransportTCP
}
if opts.HeartbeatSendIntervalMsec < 0 {
opts.HeartbeatSendIntervalMsec = 0
}
if opts.HeartbeatReceiveIntervalMsec < 0 {
opts.HeartbeatReceiveIntervalMsec = 0
}
switch opts.Transport {
case TransportTCP:
var tcp *tcpBroker
if tcp, err = startTcpBroker(opts); err != nil {
return nil, err
}
broker = tcp
case TransportWebsocket:
var wss *wssBroker
if wss, err = startWebsocketBroker(opts); err != nil {
return nil, err
}
broker = wss
}
return broker, nil
}