moleculer-go/moleculer

View on GitHub
transit/tcp/udp.go

Summary

Maintainability
A
1 hr
Test Coverage
package tcp

import (
    "fmt"
    "net"
    "strconv"
    "strings"
    "time"

    "golang.org/x/net/ipv4"

    "github.com/moleculer-go/moleculer"
    log "github.com/sirupsen/logrus"
)

type UdpServerEntry struct {
    conn             *net.UDPConn
    packetConn       *ipv4.PacketConn
    discoveryTargets []string
}

type OnUdpMessage func(nodeID, ip string, port int)

type UdpServer struct {
    state            State
    opts             UdpServerOptions
    discoveryCounter int
    logger           *log.Entry
    discoverTimer    *time.Ticker
    servers          []*UdpServerEntry
    onUdpMessage     OnUdpMessage
    registry         moleculer.Registry
}

type UdpServerOptions struct {
    Port           int
    Multicast      string
    MulticastTTL   int
    BindAddress    string
    BroadcastAddrs []string
    DiscoverPeriod time.Duration
    MaxDiscovery   int
    Discovery      bool
    Namespace      string
    NodeID         string
}

func NewUdpServer(opts UdpServerOptions, registry moleculer.Registry, onUdpMessage OnUdpMessage, logger *log.Entry) *UdpServer {
    return &UdpServer{
        opts:         opts,
        onUdpMessage: onUdpMessage,
        registry:     registry,
        logger:       logger,
    }
}

func (u *UdpServer) startServer(ip string, port int, multicast string, multicastTTL int, discoveryTargets []string) error {
    udpAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", ip, port))
    if err != nil {
        u.logger.Warnf("Unable to resolve UDP address: %s\n", err)
        return err
    }

    udpConn, err := net.ListenUDP("udp4", udpAddr)
    if err != nil {
        u.logger.Warnf("Unable to listen on UDP address: %s\n", err)
        return err
    }
    var packetConn *ipv4.PacketConn
    if multicast != "" {
        u.logger.Infof("UDP Multicast Server is listening on %s:%d. Membership:  %s \n", ip, port, multicast)
        packetConn, err = u.joinMulticastGroup(multicast, udpConn, multicastTTL, ip, port)
        if err != nil {
            u.logger.Error("Error joining multicast group:", err)
            return err
        }

    } else {
        u.logger.Infof("UDP Broadcast Server is listening on %s:%d\n", ip, port)
    }

    u.servers = append(u.servers, &UdpServerEntry{udpConn, packetConn, discoveryTargets})
    return nil
}

func (u *UdpServer) joinMulticastGroup(multicast string, udpConn *net.UDPConn, multicastTTL int, ip string, port int) (*ipv4.PacketConn, error) {
    groupAddr, err := net.ResolveUDPAddr("udp4", multicast+":"+strconv.Itoa(port))
    if err != nil {
        u.logger.Warnf("Unable to resolve multicast address: %s\n", err)
        return nil, err
    }

    packetConn := ipv4.NewPacketConn(udpConn)

    interfaces, err := net.Interfaces()
    if err != nil {
        u.logger.Warnf("Unable to get network interfaces: %s\n", err)
        return nil, err
    }

    for _, iface := range interfaces {
        if err := packetConn.JoinGroup(&iface, groupAddr); err != nil {
            u.logger.Tracef("Unable to join multicast group on interface %s: %s\n", iface.Name, err)
        }
    }

    if err := packetConn.SetMulticastTTL(multicastTTL); err != nil {
        u.logger.Tracef("Unable to set multicast TTL: %s\n", err)
        return nil, err
    }

    u.logger.Infof("UDP Multicast Server is listening on %s:%d. Membership: %s\n", ip, port, multicast)
    return packetConn, nil
}

func (u *UdpServer) getAllIPs() []string {
    ips := []string{}
    interfaces, err := net.Interfaces()
    if err != nil {
        u.logger.Error("Error getting interfaces:", err)
        return ips
    }

    for _, i := range interfaces {
        addrs, err := i.Addrs()
        if err != nil {
            u.logger.Error("Error getting addresses for interface:", err)
            continue
        }

        for _, addr := range addrs {
            var ip net.IP
            switch v := addr.(type) {
            case *net.IPNet:
                ip = v.IP
            case *net.IPAddr:
                ip = v.IP
            }

            // Check if the IP is a valid IPv4 address
            if ip == nil || ip.To4() == nil {
                continue
            }
            ips = append(ips, ip.String())
            u.logger.Debugf("Interface: %v, IP Address: %v", i.Name, ip.String())
        }
    }
    return ips
}

func (u *UdpServer) Start() error {
    if u.opts.Multicast != "" {
        if u.opts.BindAddress != "" {
            u.logger.Debug("Multicast + BindAddress options specified - Binding to a specific interface:", u.opts.BindAddress)
            // Bind only one interface
            return u.startServer(u.opts.BindAddress, u.opts.Port, u.opts.Multicast, u.opts.MulticastTTL, []string{u.opts.Multicast})
        }
        //list all interfaces and the ip addresses of each interface
        u.logger.Debug("Multicast option specified - listing all interfaces and the ip addresses of each interface")
        ips := u.getAllIPs()
        for _, ip := range ips {
            u.logger.Debug("Starting UDP server on IP:", ip, "Port:", u.opts.Port, "Multicast:", u.opts.Multicast, "MulticastTTL:", u.opts.MulticastTTL)
            err := u.startServer(ip, u.opts.Port, u.opts.Multicast, u.opts.MulticastTTL, []string{u.opts.Multicast})
            if err != nil {
                u.logger.Error("Error starting server on IP:", ip, err)
            }
        }
    } else if len(u.opts.BroadcastAddrs) > 0 {
        u.logger.Debug("Starting UDP server on IP (BindAddress):", u.opts.BindAddress, "Port:", u.opts.Port, " BroadcastAddrs option specified - Broadcasting to the specified addresses - BroadcastAddrs:", u.opts.BroadcastAddrs)
        return u.startServer(u.opts.BindAddress, u.opts.Port, "", 0, u.opts.BroadcastAddrs)
    } else {
        broadcastAddrss := u.getBroadcastAddresses()
        u.logger.Debug("Starting UDP server on IP (BindAddress):", u.opts.BindAddress, "Port:", u.opts.Port, "No Multicast or BroadcastAddrs options specified - Broadcasting to all interfaces - broadcastAddrss: ", broadcastAddrss)
        return u.startServer(u.opts.BindAddress, u.opts.Port, "", 0, broadcastAddrss)
    }

    u.state = STARTED

    go u.firstDiscoveryMessage()

    for _, server := range u.servers {
        go u.handleIncomingMessagesForServer(server)
    }

    u.startDiscovering()

    return nil
}

func (u *UdpServer) getBroadcastAddresses() []string {
    list := []string{}
    interfaces, err := net.Interfaces()
    if err != nil {
        u.logger.Error("Error getting network interfaces:", err)
        return list
    }
    for _, iface := range interfaces {
        addrs, err := iface.Addrs()
        if err != nil {
            u.logger.Error("Error getting addresses for interface:", iface.Name, err)
            continue
        }
        for _, addr := range addrs {
            if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
                if ipnet.IP.To4() != nil {
                    // Calculate the broadcast address by inverting the netmask and OR'ing it with the IP address
                    ip := ipnet.IP.To4()
                    mask := ipnet.Mask
                    broadcast := net.IPv4(0, 0, 0, 0)
                    for i := range ip {
                        broadcast[i] = ip[i] | ^mask[i]
                    }
                    list = append(list, broadcast.String())
                }
            }
        }
    }
    return list
}

func (u *UdpServer) firstDiscoveryMessage() {
    //wait for 1 second before sending the first discovery message
    time.Sleep(time.Second)
    u.BroadcastDiscoveryMessage()
}

func (u *UdpServer) handleIncomingMessagesForServer(server *UdpServerEntry) {
    buffer := make([]byte, 2048)
    for u.state == STARTED {
        n, addr, err := server.conn.ReadFromUDP(buffer)
        if err != nil {
            u.logger.Errorln("Error reading from UDP:", err)
            continue
        }
        message := string(buffer[:n])
        u.logger.Debugf("Received message from %s: %s", addr.String(), message)

        // Parse the message
        parts := strings.Split(message, "|")
        if len(parts) != 3 {
            u.logger.Debugf("Malformed UDP packet received: %s", message)
            continue
        }

        namespace := parts[0]

        if namespace != u.opts.Namespace {
            u.logger.Debugf("Message received for a different namespace: %s", namespace)
            continue
        }

        nodeID := parts[1]
        port, err := strconv.Atoi(parts[2])
        if err != nil {
            u.logger.Debugf("Error parsing port number: %s", err)
            continue
        }

        u.onUdpMessage(nodeID, addr.IP.String(), port)
    }
    u.logger.Debug("handleIncomingMessagesForServer() stopped")
}

func (u *UdpServer) startDiscovering() {
    if !u.opts.Discovery {
        u.logger.Info("UDP Discovery is disabled.")
        return
    }
    if u.discoverTimer != nil {
        u.logger.Warn("Discovery already started.")
        return
    }
    u.discoverTimer = time.NewTicker(u.opts.DiscoverPeriod * time.Second)
    go func() {
        for range u.discoverTimer.C {
            u.BroadcastDiscoveryMessage()
            if u.opts.MaxDiscovery > 0 && u.discoveryCounter >= u.opts.MaxDiscovery {
                u.logger.Info("Discovery limit reached, stopping UDP discovery")
                u.StopDiscovering()
            }
        }
    }()
}

func (u *UdpServer) BroadcastDiscoveryMessage() {
    node := u.registry.GetLocalNode()
    message := fmt.Sprintf("%s|%s|%d", u.opts.Namespace, node.GetID(), node.GetPort())
    u.logger.Debug("Broadcasting discovery message:", message)
    u.discoveryCounter++
    for _, server := range u.servers {
        for _, target := range server.discoveryTargets {
            destAddr, err := net.ResolveUDPAddr("udp4", fmt.Sprintf("%s:%d", target, u.opts.Port))
            if err != nil {
                u.logger.Error("Error resolving UDP address:", err)
                continue
            }
            if _, err := server.conn.WriteToUDP([]byte(message), destAddr); err != nil {
                u.logger.Error("Error broadcasting discovery message to:", destAddr, " error:", err)
            } else {
                u.logger.Debug("Discovery message sent to:", destAddr)
            }
        }
    }
}

func (u *UdpServer) StopDiscovering() {
    u.logger.Info("Stopping UDP Discovery.")
    if u.discoverTimer != nil {
        u.discoverTimer.Stop()
        u.discoverTimer = nil
    }
}

func (u *UdpServer) Close() {
    u.logger.Info("Closing UDP Server.")
    u.state = STOPPED
    u.StopDiscovering()
    for _, server := range u.servers {
        if server.conn != nil {
            server.conn.Close()
        }
    }
}