owasp-amass/resolve

View on GitHub
resolvers.go

Summary

Maintainability
A
25 mins
Test Coverage
// Copyright © by Jeff Foley 2017-2024. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
// SPDX-License-Identifier: Apache-2.0

package resolve

import (
    "context"
    "errors"
    "io"
    "log"
    "net"
    "runtime"
    "sync"
    "time"

    "github.com/caffix/queue"
    "github.com/miekg/dns"
    "go.uber.org/ratelimit"
)

// Resolvers is a pool of DNS resolvers managed for brute forcing using random selection.
type Resolvers struct {
    sync.Mutex
    done      chan struct{}
    log       *log.Logger
    conns     *connections
    pool      selector
    rmap      map[string]struct{}
    wildcards map[string]*wildcard
    queue     queue.Queue
    resps     queue.Queue
    qps       int
    maxSet    bool
    rate      ratelimit.Limiter
    servRates *RateTracker
    detector  *resolver
    timeout   time.Duration
    options   *ThresholdOptions
}

type resolver struct {
    done    chan struct{}
    pool    *Resolvers
    queue   queue.Queue
    xchgs   *xchgMgr
    address *net.UDPAddr
    qps     int
    rate    ratelimit.Limiter
    stats   *stats
}

func (r *Resolvers) initializeResolver(qps int, addr string) *resolver {
    if _, _, err := net.SplitHostPort(addr); err != nil {
        // Add the default port number to the IP address
        addr = net.JoinHostPort(addr, "53")
    }

    var res *resolver
    if uaddr, err := net.ResolveUDPAddr("udp", addr); err == nil {
        res = &resolver{
            done:    make(chan struct{}, 1),
            pool:    r,
            queue:   queue.NewQueue(),
            xchgs:   newXchgMgr(r.timeout),
            address: uaddr,
            qps:     qps,
            rate:    ratelimit.New(qps),
            stats:   new(stats),
        }
        go res.processRequests()
    }
    return res
}

func (r *resolver) stop() {
    select {
    case <-r.done:
        return
    default:
    }
    // Send the signal to shutdown and close the connection
    close(r.done)
    // Drain the xchgs of all messages and allow callers to return
    for _, req := range r.xchgs.removeAll() {
        req.errNoResponse()
        req.release()
    }
}

// NewResolvers initializes a Resolvers.
func NewResolvers() *Resolvers {
    responses := queue.NewQueue()
    r := &Resolvers{
        done:      make(chan struct{}, 1),
        log:       log.New(io.Discard, "", 0),
        conns:     newConnections(runtime.NumCPU(), responses),
        pool:      newRandomSelector(),
        rmap:      make(map[string]struct{}),
        wildcards: make(map[string]*wildcard),
        queue:     queue.NewQueue(),
        resps:     responses,
        timeout:   DefaultTimeout,
        options:   new(ThresholdOptions),
    }

    go r.timeouts()
    go r.enforceMaxQPS()
    go r.thresholdChecks()
    go r.processResponses()
    return r
}

// Len returns the number of resolvers that have been added to the pool.
func (r *Resolvers) Len() int {
    return r.pool.Len()
}

// SetLogger assigns a new logger to the resolver pool.
func (r *Resolvers) SetLogger(l *log.Logger) {
    r.log = l
}

func (r *Resolvers) SetRateTracker(rt *RateTracker) {
    r.servRates = rt
}

// SetTimeout updates the amount of time this pool will wait for response messages.
func (r *Resolvers) SetTimeout(d time.Duration) {
    r.Lock()
    defer r.Unlock()

    r.timeout = d
    r.updateResolverTimeouts()
}

func (r *Resolvers) updateResolverTimeouts() {
    all := r.pool.AllResolvers()
    if r.detector != nil {
        all = append(all, r.detector)
    }

    for _, res := range all {
        select {
        case <-res.done:
        default:
            res.xchgs.setTimeout(r.timeout)
        }
    }
}

// QPS returns the maximum queries per second provided by the resolver pool.
func (r *Resolvers) QPS() int {
    r.Lock()
    defer r.Unlock()

    return r.qps
}

// SetMaxQPS allows a preferred maximum number of queries per second to be specified for the pool.
func (r *Resolvers) SetMaxQPS(qps int) {
    r.qps = qps
    if qps > 0 {
        r.maxSet = true
        r.rate = ratelimit.New(qps)
        return
    }
    r.maxSet = false
    r.rate = nil
}

// AddResolvers initializes and adds new resolvers to the pool of resolvers.
func (r *Resolvers) AddResolvers(qps int, addrs ...string) error {
    r.Lock()
    defer r.Unlock()

    if qps == 0 {
        return errors.New("failed to provide a maximum number of queries per second greater than zero")
    }

    for _, addr := range addrs {
        if _, _, err := net.SplitHostPort(addr); err != nil {
            // add the default port number to the IP address
            addr = net.JoinHostPort(addr, "53")
        }
        // check that this address will not create a duplicate resolver
        if host, _, err := net.SplitHostPort(addr); err == nil {
            if _, found := r.rmap[host]; !found {
                if res := r.initializeResolver(qps, addr); res != nil {
                    r.rmap[res.address.IP.String()] = struct{}{}
                    r.pool.AddResolver(res)
                    if !r.maxSet {
                        r.qps += qps
                    }
                }
            }
        }
    }
    // create the new rate limiter for the updated QPS
    if !r.maxSet {
        r.rate = ratelimit.New(r.qps)
    }
    return nil
}

// Stop will release resources for the resolver pool and all add resolvers.
func (r *Resolvers) Stop() {
    select {
    case <-r.done:
        return
    default:
    }
    close(r.done)
    if r.servRates != nil {
        r.servRates.Stop()
    }
    r.conns.Close()

    all := r.pool.AllResolvers()
    if d := r.getDetectionResolver(); d != nil {
        all = append(all, d)
    }

    for _, res := range all {
        if !r.maxSet {
            r.qps -= res.qps
        }
        res.stop()
    }
    r.pool.Close()
}

// Query queues the provided DNS message and returns the response on the provided channel.
func (r *Resolvers) Query(ctx context.Context, msg *dns.Msg, ch chan *dns.Msg) {
    if msg == nil {
        ch <- msg
        return
    }

    select {
    case <-ctx.Done():
    case <-r.done:
    default:
        req := reqPool.Get().(*request)

        req.Msg = msg
        req.Result = ch
        if r.servRates != nil {
            r.servRates.Take(msg.Question[0].Name)
        }
        r.queue.Append(req)
        return
    }

    msg.Rcode = RcodeNoResponse
    ch <- msg
}

// Query queues the provided DNS message and sends the response on the returned channel.
func (r *Resolvers) QueryChan(ctx context.Context, msg *dns.Msg) chan *dns.Msg {
    ch := make(chan *dns.Msg, 1)
    r.Query(ctx, msg, ch)
    return ch
}

// Query queues the provided DNS message and returns the associated response message.
func (r *Resolvers) QueryBlocking(ctx context.Context, msg *dns.Msg) (*dns.Msg, error) {
    select {
    case <-ctx.Done():
        return msg, errors.New("the context expired")
    default:
    }

    var err error
    resp := <-r.QueryChan(ctx, msg)
    if resp == nil {
        err = errors.New("query failed")
    }
    return resp, err
}

func (r *Resolvers) enforceMaxQPS() {
loop:
    for {
        select {
        case <-r.done:
            break loop
        case <-r.queue.Signal():
            element, found := r.queue.Next()
            if !found {
                continue loop
            }

            if r.rate != nil {
                _ = r.rate.Take()
            }

            if req, ok := element.(*request); ok {
                if res := r.pool.GetResolver(); res != nil {
                    req.Res = res
                    res.queue.Append(req)
                } else {
                    req.errNoResponse()
                    req.release()
                }
            }
        }
    }
    // release the requests remaining on the queue
    r.queue.Process(func(element interface{}) {
        if req, ok := element.(request); ok {
            req.errNoResponse()
            req.release()
        }
    })
}

func (r *Resolvers) processResponses() {
    for {
        select {
        case <-r.done:
            return
        case <-r.resps.Signal():
        }

        r.resps.Process(func(element interface{}) {
            if response, ok := element.(*resp); ok && response != nil {
                go r.processSingleResp(response)
            }
        })
    }
}

func (r *Resolvers) processSingleResp(response *resp) {
    var res *resolver
    addr, _, _ := net.SplitHostPort(response.Addr.String())

    if res = r.pool.LookupResolver(addr); res == nil {
        if detector := r.getDetectionResolver(); detector != nil {
            if detector.address.IP.String() == addr {
                res = detector
            }
        }
    }
    if res == nil {
        return
    }

    msg := response.Msg
    name := msg.Question[0].Name
    if req := res.xchgs.remove(msg.Id, name); req != nil {
        req.Resp = msg
        if req.Resp.Truncated {
            go req.Res.tcpExchange(req)
        } else {
            req.Result <- req.Resp
            req.Res.collectStats(req.Resp)
            if r.servRates != nil {
                r.servRates.Success(name)
            }
            req.release()
        }
    }
}

func (r *Resolvers) timeouts() {
    r.Lock()
    d := r.timeout / 2
    r.Unlock()

    t := time.NewTicker(d)
    defer t.Stop()

    for range t.C {
        select {
        case <-r.done:
            return
        default:
        }

        all := r.pool.AllResolvers()
        if d := r.getDetectionResolver(); d != nil {
            all = append(all, d)
        }

        for _, res := range all {
            select {
            case <-r.done:
                return
            default:
                for _, req := range res.xchgs.removeExpired() {
                    req.errNoResponse()
                    res.collectStats(req.Msg)
                    if r.servRates != nil {
                        r.servRates.Timeout(req.Msg.Question[0].Name)
                    }
                    req.release()
                }
            }
        }
    }
}

func (r *resolver) processRequests() {
    for {
        select {
        case <-r.done:
            return
        case <-r.queue.Signal():
        }

        r.queue.Process(func(element interface{}) {
            if req, ok := element.(*request); ok && req != nil {
                _ = r.rate.Take()
                go r.writeReq(req)
            }
        })
    }
}

func (r *resolver) writeReq(req *request) {
    msg := req.Msg.Copy()
    req.Timestamp = time.Now()

    if r.xchgs.add(req) == nil {
        if err := r.pool.conns.WriteMsg(msg, r.address); err != nil {
            _ = r.xchgs.remove(msg.Id, msg.Question[0].Name)
            req.errNoResponse()
            req.release()
        }
    }
}

func (r *resolver) tcpExchange(req *request) {
    client := dns.Client{
        Net:     "tcp",
        Timeout: time.Minute,
    }
    if m, _, err := client.Exchange(req.Msg, r.address.String()); err == nil {
        req.Result <- m
        r.collectStats(m)
    } else {
        req.errNoResponse()
    }
    req.release()
}