alibaba/kt-connect

View on GitHub
pkg/kt/service/dns/dns_darwin.go

Summary

Maintainability
A
0 mins
Test Coverage
package dns

import (
    "fmt"
    "github.com/alibaba/kt-connect/pkg/common"
    opt "github.com/alibaba/kt-connect/pkg/kt/command/options"
    "github.com/alibaba/kt-connect/pkg/kt/service/cluster"
    "github.com/alibaba/kt-connect/pkg/kt/util"
    "github.com/rs/zerolog/log"
    "io/ioutil"
    "os"
    "os/signal"
    "strconv"
    "strings"
    "syscall"
)

const (
    resolverDir = "/etc/resolver"
    ktResolverPrefix = "kt."
    resolverComment  = "# Generated by KtConnect"
)

// SetNameServer set dns server records
func SetNameServer(dnsServer string) error {
    dnsSignal := make(chan error)
    if err := util.CreateDirIfNotExist(resolverDir); err != nil {
        log.Error().Err(err).Msgf("Failed to create resolver dir")
        return err
    }
    go func() {
        var nsList []string
        namespaces, err := cluster.Ins().GetAllNamespaces()
        if err != nil {
            log.Info().Msgf("Cannot list all namespaces, set dns for '%s' only", opt.Get().Global.Namespace)
            nsList = append(nsList, opt.Get().Global.Namespace)
        } else {
            for _, ns := range namespaces.Items {
                nsList = append(nsList, ns.Name)
            }
        }

        preferredDnsInfo := strings.Split(dnsServer, ":")
        dnsIp := preferredDnsInfo[0]
        dnsPort := strconv.Itoa(common.StandardDnsPort)
        if len(preferredDnsInfo) > 1 {
            dnsPort = preferredDnsInfo[1]
        }

        createResolverFile("local", opt.Get().Connect.ClusterDomain, dnsIp, dnsPort)
        createResolverFile("svc.local", "svc", dnsIp, dnsPort)
        for _, ns := range nsList {
            createResolverFile(fmt.Sprintf("%s.local", ns), ns, dnsIp, dnsPort)
        }
        dnsSignal <- nil

        defer RestoreNameServer()
        sigCh := make(chan os.Signal, 1)
        signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM)
        <-sigCh
    }()
    return <-dnsSignal
}

// HandleExtraDomainMapping handle extra domain change
func HandleExtraDomainMapping(extraDomains map[string]string, localDnsPort int) {
    for _, suffix := range getAllDomainSuffixes(extraDomains) {
        createResolverFile(fmt.Sprintf("%s.local", suffix), suffix, common.Localhost, fmt.Sprintf("%d", localDnsPort))
    }
    for _, suffix := range strings.Split(opt.Get().Connect.IncludeDomains, ",") {
        if len(suffix) > 0 {
            createResolverFile(fmt.Sprintf("%s.local", suffix), suffix, common.Localhost, fmt.Sprintf("%d", localDnsPort))
        }
    }
}

// RestoreNameServer remove the nameservers added by ktctl
func RestoreNameServer() {
    rd, _ := ioutil.ReadDir(resolverDir)
    for _, f := range rd {
        if !f.IsDir() && strings.HasPrefix(f.Name(), ktResolverPrefix) {
            if err := os.Remove(fmt.Sprintf("%s/%s", resolverDir, f.Name())); err != nil {
                log.Warn().Err(err).Msgf("Failed to remove resolver file %s", f.Name())
            }
        }
    }
}

func createResolverFile(postfix, domain, dnsIp, dnsPort string) {
    resolverFile := fmt.Sprintf("%s/%s%s", resolverDir, ktResolverPrefix, postfix)
    if _, err := os.Stat(resolverFile); err == nil {
        _ = os.Remove(resolverFile)
    }
    resolverContent := fmt.Sprintf("%s\ndomain %s\nnameserver %s\nport %s\n",
        resolverComment, domain, dnsIp, dnsPort)
    if err := ioutil.WriteFile(resolverFile, []byte(resolverContent), 0644); err != nil {
        log.Warn().Err(err).Msgf("Failed to create resolver file of %s", domain)
    }
}

func getAllDomainSuffixes(extraDomains map[string]string) []string {
    var suffixes []string
    for domain, _ := range extraDomains {
        i := strings.LastIndex(domain, ".")
        if i < 0 {
            continue
        }
        suffix := domain[i+1:]
        if !util.Contains(suffixes, suffix) {
            suffixes = append(suffixes, suffix)
        }
    }
    return suffixes
}