xjasonlyu/tun2socks

View on GitHub
proxy/relay.go

Summary

Maintainability
A
2 hrs
Test Coverage
package proxy

import (
    "bytes"
    "context"
    "encoding/binary"
    "errors"
    "fmt"
    "io"
    "math"
    "net"
    "sync"

    "github.com/go-gost/relay"

    "github.com/xjasonlyu/tun2socks/v2/common/pool"
    "github.com/xjasonlyu/tun2socks/v2/dialer"
    M "github.com/xjasonlyu/tun2socks/v2/metadata"
    "github.com/xjasonlyu/tun2socks/v2/proxy/proto"
)

var _ Proxy = (*Relay)(nil)

type Relay struct {
    *Base

    user string
    pass string

    noDelay bool
}

func NewRelay(addr, user, pass string, noDelay bool) (*Relay, error) {
    return &Relay{
        Base: &Base{
            addr:  addr,
            proto: proto.Relay,
        },
        user:    user,
        pass:    pass,
        noDelay: noDelay,
    }, nil
}

func (rl *Relay) DialContext(ctx context.Context, metadata *M.Metadata) (c net.Conn, err error) {
    return rl.dialContext(ctx, metadata)
}

func (rl *Relay) DialUDP(metadata *M.Metadata) (net.PacketConn, error) {
    ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout)
    defer cancel()

    return rl.dialContext(ctx, metadata)
}

func (rl *Relay) dialContext(ctx context.Context, metadata *M.Metadata) (rc *relayConn, err error) {
    var c net.Conn

    c, err = dialer.DialContext(ctx, "tcp", rl.Addr())
    if err != nil {
        return nil, fmt.Errorf("connect to %s: %w", rl.Addr(), err)
    }
    setKeepAlive(c)

    defer func(c net.Conn) {
        safeConnClose(c, err)
    }(c)

    req := relay.Request{
        Version: relay.Version1,
        Cmd:     relay.CmdConnect,
    }

    if metadata.Network == M.UDP {
        req.Cmd |= relay.FUDP
        req.Features = append(req.Features, &relay.NetworkFeature{
            Network: relay.NetworkUDP,
        })
    }

    if rl.user != "" {
        req.Features = append(req.Features, &relay.UserAuthFeature{
            Username: rl.user,
            Password: rl.pass,
        })
    }

    req.Features = append(req.Features, serializeRelayAddr(metadata))

    if rl.noDelay {
        if _, err = req.WriteTo(c); err != nil {
            return
        }
        if err = readRelayResponse(c); err != nil {
            return
        }
    }

    switch metadata.Network {
    case M.TCP:
        rc = newRelayConn(c, metadata.Addr(), rl.noDelay, false)
        if !rl.noDelay {
            if _, err = req.WriteTo(rc.wbuf); err != nil {
                return
            }
        }
    case M.UDP:
        rc = newRelayConn(c, metadata.Addr(), rl.noDelay, true)
        if !rl.noDelay {
            if _, err = req.WriteTo(rc.wbuf); err != nil {
                return
            }
        }
    default:
        err = fmt.Errorf("network %s is unsupported", metadata.Network)
        return
    }

    return
}

type relayConn struct {
    net.Conn
    udp  bool
    addr net.Addr
    once sync.Once
    wbuf *bytes.Buffer
}

func newRelayConn(c net.Conn, addr net.Addr, noDelay, udp bool) *relayConn {
    rc := &relayConn{
        Conn: c,
        addr: addr,
        udp:  udp,
    }
    if !noDelay {
        rc.wbuf = &bytes.Buffer{}
    }
    return rc
}

func (rc *relayConn) ReadFrom(b []byte) (int, net.Addr, error) {
    n, err := rc.Read(b)
    return n, rc.addr, err
}

func (rc *relayConn) Read(b []byte) (n int, err error) {
    rc.once.Do(func() {
        if rc.wbuf != nil {
            err = readRelayResponse(rc.Conn)
        }
    })
    if err != nil {
        return
    }

    if !rc.udp {
        return rc.Conn.Read(b)
    }

    var bb [2]byte
    _, err = io.ReadFull(rc.Conn, bb[:])
    if err != nil {
        return
    }

    dLen := int(binary.BigEndian.Uint16(bb[:]))
    if len(b) >= dLen {
        return io.ReadFull(rc.Conn, b[:dLen])
    }

    buf := pool.Get(dLen)
    defer pool.Put(buf)
    _, err = io.ReadFull(rc.Conn, buf)
    n = copy(b, buf)

    return
}

func (rc *relayConn) WriteTo(b []byte, _ net.Addr) (int, error) {
    return rc.Write(b)
}

func (rc *relayConn) Write(b []byte) (int, error) {
    if rc.udp {
        return rc.udpWrite(b)
    }
    return rc.tcpWrite(b)
}

func (rc *relayConn) tcpWrite(b []byte) (n int, err error) {
    if rc.wbuf != nil && rc.wbuf.Len() > 0 {
        n = len(b)
        rc.wbuf.Write(b)
        _, err = rc.Conn.Write(rc.wbuf.Bytes())
        rc.wbuf.Reset()
        return
    }
    return rc.Conn.Write(b)
}

func (rc *relayConn) udpWrite(b []byte) (n int, err error) {
    if len(b) > math.MaxUint16 {
        err = errors.New("write: data maximum exceeded")
        return
    }

    n = len(b)
    if rc.wbuf != nil && rc.wbuf.Len() > 0 {
        var bb [2]byte
        binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
        rc.wbuf.Write(bb[:])
        rc.wbuf.Write(b)
        _, err = rc.wbuf.WriteTo(rc.Conn)
        return
    }

    var bb [2]byte
    binary.BigEndian.PutUint16(bb[:], uint16(len(b)))
    _, err = rc.Conn.Write(bb[:])
    if err != nil {
        return
    }
    return rc.Conn.Write(b)
}

func readRelayResponse(r io.Reader) error {
    resp := relay.Response{}
    if _, err := resp.ReadFrom(r); err != nil {
        return err
    }
    if resp.Version != relay.Version1 {
        return relay.ErrBadVersion
    }
    if resp.Status != relay.StatusOK {
        return fmt.Errorf("status %d", resp.Status)
    }
    return nil
}

func serializeRelayAddr(m *M.Metadata) *relay.AddrFeature {
    af := &relay.AddrFeature{
        Host: m.DstIP.String(),
        Port: m.DstPort,
    }
    if m.DstIP.To4() != nil {
        af.AType = relay.AddrIPv4
    } else {
        af.AType = relay.AddrIPv6
    }
    return af
}