status-im/status-go

View on GitHub
server/pairing/connection.go

Summary

Maintainability
A
0 mins
Test Coverage
C
72%
package pairing

import (
    "crypto/ecdsa"
    "crypto/elliptic"
    "fmt"
    "log"
    "math/big"
    "net"
    "net/url"
    "strings"

    "github.com/btcsuite/btcutil/base58"
    "github.com/google/uuid"

    "github.com/status-im/status-go/server/pairing/versioning"
)

const (
    connectionStringID = "cs"
)

type ConnectionParams struct {
    version        versioning.ConnectionParamVersion
    netIPs         []net.IP
    port           int
    publicKey      *ecdsa.PublicKey
    aesKey         []byte
    installationID string
    keyUID         string
}

func NewConnectionParams(netIPs []net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte, installationID, keyUID string) *ConnectionParams {
    cp := new(ConnectionParams)
    cp.version = versioning.LatestConnectionParamVer
    cp.netIPs = netIPs
    cp.port = port
    cp.publicKey = publicKey
    cp.aesKey = aesKey
    cp.installationID = installationID
    cp.keyUID = keyUID
    return cp
}

// ToString generates a string required for generating a secure connection to another Status device.
//
// The returned string will look like below:
//   - "cs2:4FHRnp:H6G:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
//
// Format bytes encoded into a base58 string, delimited by ":"
//   - string type identifier
//   - version
//   - net.IP
//   - array of IPs in next form:
//     | 1 byte | 4*N bytes | 1 byte | 16*N bytes |
//     |   N    | N * IPv4  |    M   |  M * IPv6  |
//   - port
//   - ecdsa CompressedPublicKey
//   - AES encryption key
//   - string InstallationID of the sending device
//   - string KeyUID of the sending device
//
// NOTE:
// - append(accrete) parameters instead of changing(breaking) existing parameters. Appending should **never** break, modifying existing parameters will break. Watch this before making changes: https://www.youtube.com/watch?v=oyLBGkS5ICk
// - never strictly check version, unless you really want to break
func (cp *ConnectionParams) ToString() string {
    v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes())
    ips := base58.Encode(SerializeNetIps(cp.netIPs))
    p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes())
    k := base58.Encode(elliptic.MarshalCompressed(cp.publicKey.Curve, cp.publicKey.X, cp.publicKey.Y))
    ek := base58.Encode(cp.aesKey)

    var i string
    if cp.installationID != "" {
        u, err := uuid.Parse(cp.installationID)
        if err != nil {
            log.Fatalf("Failed to parse UUID: %v", err)
        } else {
            // Convert UUID to byte slice
            byteSlice := u[:]
            i = base58.Encode(byteSlice)
        }
    }

    var kuid string
    if cp.keyUID != "" {
        kuid = base58.Encode([]byte(cp.keyUID))
    }

    return fmt.Sprintf("%s%s:%s:%s:%s:%s:%s:%s", connectionStringID, v, ips, p, k, ek, i, kuid)
}

func (cp *ConnectionParams) InstallationID() string {
    return cp.installationID
}

func (cp *ConnectionParams) KeyUID() string {
    return cp.keyUID
}

func SerializeNetIps(ips []net.IP) []byte {
    var out []byte
    var ipv4 []net.IP
    var ipv6 []net.IP

    for _, ip := range ips {
        if v := ip.To4(); v != nil {
            ipv4 = append(ipv4, v)
        } else {
            ipv6 = append(ipv6, ip)
        }
    }

    for _, arr := range [][]net.IP{ipv4, ipv6} {
        out = append(out, uint8(len(arr)))
        for _, ip := range arr {
            out = append(out, ip...)
        }
    }

    return out
}

func ParseNetIps(in []byte) ([]net.IP, error) {
    var out []net.IP

    if len(in) < 1 {
        return nil, fmt.Errorf("net.ip field is too short: '%d', at least 1 byte required", len(in))
    }

    for _, ipLen := range []int{net.IPv4len, net.IPv6len} {

        count := int(in[0])
        in = in[1:]

        if expectedLen := ipLen * count; len(in) < expectedLen {
            return nil, fmt.Errorf("net.ip.ip%d field is too short, expected at least '%d' bytes, '%d' bytes found", ipLen, expectedLen, len(in))
        }

        for i := 0; i < count; i++ {
            offset := i * ipLen
            ip := in[offset : ipLen+offset]
            out = append(out, ip)
        }

        in = in[ipLen*count:]
    }

    return out, nil
}

// FromString parses a connection params string required for to securely connect to another Status device.
// This function parses a connection string generated by ToString
func (cp *ConnectionParams) FromString(s string) error {

    if len(s) < 2 {
        return fmt.Errorf("connection string is too short: '%s'", s)
    }

    if s[:2] != connectionStringID {
        return fmt.Errorf("connection string doesn't begin with identifier '%s'", connectionStringID)
    }

    requiredParams := 5

    sData := strings.Split(s[2:], ":")
    // NOTE: always allow extra parameters for forward compatibility, error on not enough required parameters or failing to parse
    if len(sData) < requiredParams {
        return fmt.Errorf("expected data '%s' to have length of '%d', received '%d'", s, requiredParams, len(sData))
    }

    netIpsBytes := base58.Decode(sData[1])
    netIps, err := ParseNetIps(netIpsBytes)
    if err != nil {
        return err
    }
    cp.netIPs = netIps

    cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64())
    cp.publicKey = new(ecdsa.PublicKey)
    cp.publicKey.X, cp.publicKey.Y = elliptic.UnmarshalCompressed(elliptic.P256(), base58.Decode(sData[3]))
    cp.publicKey.Curve = elliptic.P256()
    cp.aesKey = base58.Decode(sData[4])

    if len(sData) > 5 && len(sData[5]) != 0 {
        installationIDBytes := base58.Decode(sData[5])
        installationID, err := uuid.FromBytes(installationIDBytes)
        if err != nil {
            return err
        }
        cp.installationID = installationID.String()
    }

    if len(sData) > 6 && len(sData[6]) != 0 {
        decodedBytes := base58.Decode(sData[6])
        cp.keyUID = string(decodedBytes)
    }

    return cp.validate()
}

func (cp *ConnectionParams) validate() error {
    err := cp.validateNetIP()
    if err != nil {
        return err
    }

    err = cp.validatePort()
    if err != nil {
        return err
    }

    err = cp.validatePublicKey()
    if err != nil {
        return err
    }

    return cp.validateAESKey()
}

func (cp *ConnectionParams) validateNetIP() error {
    for _, ip := range cp.netIPs {
        if ok := net.ParseIP(ip.String()); ok == nil {
            return fmt.Errorf("invalid net ip '%s'", cp.netIPs)
        }
    }
    return nil
}

func (cp *ConnectionParams) validatePort() error {
    if cp.port > 0 && cp.port < 0x10000 {
        return nil
    }

    return fmt.Errorf("port '%d' outside of bounds of 1 - 65535", cp.port)
}

func (cp *ConnectionParams) validatePublicKey() error {
    switch {
    case cp.publicKey.Curve == nil, cp.publicKey.Curve != elliptic.P256():
        return fmt.Errorf("public key Curve not `elliptic.P256`")
    case cp.publicKey.X == nil, cp.publicKey.X.Cmp(big.NewInt(0)) == 0:
        return fmt.Errorf("public key X not set")
    case cp.publicKey.Y == nil, cp.publicKey.Y.Cmp(big.NewInt(0)) == 0:
        return fmt.Errorf("public key Y not set")
    default:
        return nil
    }
}

func (cp *ConnectionParams) validateAESKey() error {
    if len(cp.aesKey) != 32 {
        return fmt.Errorf("AES key invalid length, expect length 32, received length '%d'", len(cp.aesKey))
    }
    return nil
}

func (cp *ConnectionParams) URL(IPIndex int) (*url.URL, error) {
    if IPIndex < 0 || IPIndex >= len(cp.netIPs) {
        return nil, fmt.Errorf("invalid IP index '%d'", IPIndex)
    }

    err := cp.validate()
    if err != nil {
        return nil, err
    }

    return cp.BuildURL(cp.netIPs[IPIndex]), nil
}

func (cp *ConnectionParams) BuildURL(ip net.IP) *url.URL {
    return &url.URL{
        Scheme: "https",
        Host:   fmt.Sprintf("%s:%d", ip, cp.port),
    }
}

func ValidateConnectionString(cs string) error {
    ccp := ConnectionParams{}
    err := ccp.FromString(cs)
    return err
}