alibaba/kt-connect

View on GitHub
pkg/kt/service/dns/dnsserver.go

Summary

Maintainability
A
2 hrs
Test Coverage
package dns

import (
    "fmt"
    "github.com/alibaba/kt-connect/pkg/common"
    opt "github.com/alibaba/kt-connect/pkg/kt/command/options"
    "github.com/alibaba/kt-connect/pkg/kt/service/cluster"
    "github.com/alibaba/kt-connect/pkg/kt/util"
    "github.com/miekg/dns"
    "github.com/rs/zerolog/log"
    "net"
    "regexp"
    "strconv"
    "strings"
    "time"
)

type DnsServer struct {
    dnsAddresses []string
    extraDomains map[string]string
}

func SetupLocalDns(remoteDnsPort, localDnsPort int, dnsOrder []string) error {
    var res = make(chan error)
    go func() {
        upstreamDnsAddresses := getDnsAddresses(dnsOrder, GetNameServer(), remoteDnsPort)
        // domain-name -> ip
        extraDomains := getIngressDomains()
        log.Info().Msgf("Setup local DNS with upstream %v", upstreamDnsAddresses)
        HandleExtraDomainMapping(extraDomains, localDnsPort)
        res <-common.SetupDnsServer(&DnsServer{upstreamDnsAddresses, extraDomains}, localDnsPort, "udp")
    }()
    select {
    case err := <-res:
        return err
    case <-time.After(1 * time.Second):
        return nil
    }
}

func getIngressDomains() map[string]string {
    if opt.Get().Connect.IngressIp == "" {
        return map[string]string{}
    }
    if !util.IsValidIp(opt.Get().Connect.IngressIp) {
        log.Warn().Msgf("Ingress Ip '" + opt.Get().Connect.IngressIp + "' is invalid")
        return map[string]string{}
    }
    ingressDomains := make(map[string]string)
    if ingresses, err := cluster.Ins().GetAllIngressInNamespace(opt.Get().Global.Namespace); err != nil {
        log.Warn().Err(err).Msgf("Failed to found ingress instances")
    } else {
        for _, ingress := range ingresses.Items {
            for _, rule := range ingress.Spec.Rules {
                if rule.Host != "" {
                    log.Debug().Msgf("Find ingress domain " + rule.Host)
                    ingressDomains[rule.Host] = opt.Get().Connect.IngressIp
                }
            }
        }
    }
    return ingressDomains
}

func getDnsAddresses(dnsOrder []string, upstreamDns string, clusterDnsPort int) []string {
    upstreamPattern := fmt.Sprintf("^([cdptu]{3}:)?%s(:[0-9]+)?$", util.DnsOrderUpstream)
    var dnsAddresses []string
    for _, dnsAddr := range dnsOrder {
        if dnsAddr == util.DnsOrderCluster {
            dnsAddresses = append(dnsAddresses, fmt.Sprintf("tcp:%s:%d", common.Localhost, clusterDnsPort))
        } else if ok, err := regexp.MatchString(upstreamPattern, dnsAddr); err == nil && ok {
            upstreamParts := strings.Split(dnsAddr, ":")
            if upstreamDns != "" {
                switch strings.Count(dnsAddr, ":") {
                case 0:
                    dnsAddresses = append(dnsAddresses, fmt.Sprintf("udp:%s:%d", upstreamDns, common.StandardDnsPort))
                case 1:
                    if _, err = strconv.Atoi(upstreamParts[1]); err == nil {
                        dnsAddresses = append(dnsAddresses, fmt.Sprintf("udp:%s:%s", upstreamDns, upstreamParts[1]))
                    } else {
                        dnsAddresses = append(dnsAddresses, fmt.Sprintf("%s:%s:%d", upstreamParts[0], upstreamDns, common.StandardDnsPort))
                    }
                case 2:
                    dnsAddresses = append(dnsAddresses, fmt.Sprintf("%s:%s:%s", upstreamParts[0], upstreamDns, upstreamParts[2]))
                default:
                    log.Warn().Msgf("Skip invalid upstream dns server %s", dnsAddr)
                }
            }
        } else {
            switch strings.Count(dnsAddr, ":") {
            case 0:
                dnsAddresses = append(dnsAddresses, fmt.Sprintf("udp:%s:%d", dnsAddr, common.StandardDnsPort))
            case 1:
                if _, err = strconv.Atoi(strings.Split(dnsAddr, ":")[1]); err == nil {
                    dnsAddresses = append(dnsAddresses, fmt.Sprintf("udp:%s", dnsAddr))
                } else {
                    dnsAddresses = append(dnsAddresses, fmt.Sprintf("%s:%d", dnsAddr, common.StandardDnsPort))
                }
            case 2:
                dnsAddresses = append(dnsAddresses, dnsAddr)
            default:
                log.Warn().Msgf("Skip invalid dns server %s", dnsAddr)
            }
        }
    }
    return dnsAddresses
}

// ServeDNS query DNS record
func (s *DnsServer) ServeDNS(w dns.ResponseWriter, req *dns.Msg) {
    msg := (&dns.Msg{}).SetReply(req)
    msg.Authoritative = true
    msg.Answer = query(req, s.dnsAddresses, s.extraDomains)
    if err := w.WriteMsg(msg); err != nil {
        log.Warn().Err(err).Msgf("Failed to reply dns request")
    }
}

func query(req *dns.Msg, dnsAddresses []string, extraDomains map[string]string) []dns.RR {
    domain := req.Question[0].Name
    qtype := req.Question[0].Qtype

    answer := common.ReadCache(domain, qtype, int64(opt.Get().Connect.DnsCacheTtl))
    if answer != nil {
        log.Debug().Msgf("Found domain %s (%d) in cache", domain, qtype)
        return answer
    }

    for host, ip := range extraDomains {
        if wildcardMatch(host, domain) {
            return []dns.RR{toARecord(domain, ip)}
        }
    }

    for _, dnsAddr := range dnsAddresses {
        dnsParts := strings.SplitN(dnsAddr, ":", 3)
        protocol := dnsParts[0]
        ip := dnsParts[1]
        port, err := strconv.Atoi(dnsParts[2])
        if ip == "" || err != nil || (protocol != "tcp" && protocol != "udp") {
            // skip invalid dns address
            continue
        }
        res, err := common.NsLookup(domain, qtype, protocol, fmt.Sprintf("%s:%d", ip, port))
        if res != nil && len(res.Answer) > 0 {
            // only record none-empty result of cluster dns
            log.Debug().Msgf("Found domain %s (%d) in dns (%s:%d)", domain, qtype, ip, port)
            common.WriteCache(domain, qtype, res.Answer, time.Now().Unix())
            return res.Answer
        } else if err != nil && !common.IsDomainNotExist(err) {
            // usually io timeout error
            log.Warn().Err(err).Msgf("Failed to lookup %s (%d) in dns (%s:%d)", domain, qtype, ip, port)
        }
    }
    log.Debug().Msgf("Empty answer for domain lookup %s (%d)", domain, qtype)
    common.WriteCache(domain, qtype, []dns.RR{}, time.Now().Unix()-int64(opt.Get().Connect.DnsCacheTtl)/2)
    return []dns.RR{}
}

func wildcardMatch(pattenDomain, targetDomain string) bool {
    if !strings.HasSuffix(pattenDomain, ".") {
        pattenDomain = pattenDomain + "."
    }
    if strings.Contains(pattenDomain, "*") {
        ok, err := regexp.MatchString("^" + strings.ReplaceAll(strings.ReplaceAll(pattenDomain, ".", "\\."), "*", ".*") + "$", targetDomain)
        return ok && err == nil
    } else {
        return pattenDomain == targetDomain
    }
}

func toARecord(domain, ip string) dns.RR {
    return &dns.A {
        Hdr: dns.RR_Header {
            Name: domain,
            Rrtype: dns.TypeA,
            Class: dns.ClassINET,
            Ttl: 5,
            Rdlength: 4,
        },
        A: net.ParseIP(ip),
    }
}