aergoio/aergo

View on GitHub
contract/ethstorageproof.go

Summary

Maintainability
A
3 hrs
Test Coverage
D
68%
package contract

import (
    "bytes"
    "encoding/binary"
    "errors"
    "math"

    "github.com/aergoio/aergo/v2/internal/enc/hex"
    "golang.org/x/crypto/sha3"
)

const (
    shortNode  = 2
    branchNode = 17
    hexChar    = "0123456789abcdef"
)

type (
    rlpNode   [][]byte
    keyStream struct {
        *bytes.Buffer
    }
)

var (
    errDecode = errors.New("storage proof decode error")
    lenBuf    = make([]byte, 8)
    nilBuf    = make([]byte, 8)
)

func verifyEthStorageProof(key []byte, value rlpObject, expectedHash []byte, proof [][]byte) bool {
    if len(key) == 0 || value == nil || len(proof) == 0 {
        return false
    }
    key = []byte(hex.Encode(keccak256(key)))
    valueRlpEncoded := rlpEncode(value)
    ks := keyStream{bytes.NewBuffer(key)}
    for i, p := range proof {
        if ((i != 0 && len(p) < 32) || !bytes.Equal(expectedHash, keccak256(p))) && !bytes.Equal(expectedHash, p) {
            return false
        }
        n := decodeRlpTrieNode(p)
        switch len(n) {
        case shortNode:
            if len(n[0]) == 0 {
                return false
            }
            leaf, sharedNibbles, err := decodeHpHeader(n[0][0])
            if err != nil {
                return false
            }
            sharedNibbles = append(sharedNibbles, []byte(hex.Encode(n[0][1:]))...)
            if len(sharedNibbles) == 0 {
                return false
            }
            if leaf {
                return bytes.Equal(sharedNibbles, ks.key(-1)) && bytes.Equal(n[1], valueRlpEncoded)
            }
            if !bytes.Equal(sharedNibbles, ks.key(len(sharedNibbles))) {
                return false
            }
            expectedHash = n[1]
        case branchNode:
            if ks.Len() == 0 {
                return bytes.Equal(n[16], valueRlpEncoded)
            }
            k := ks.index()
            if k > 0x0f {
                return false
            }
            expectedHash = n[k]
        default:
            return false
        }
    }
    return false
}

func decodeRlpTrieNode(data []byte) rlpNode {
    var (
        dataLen = uint64(len(data))
        node    rlpNode
    )
    if dataLen == uint64(0) {
        return nil
    }
    switch {
    case data[0] >= 0xf8:
        lenLen := int(data[0]) - 0xf7
        l, err := decodeLen(data[1:], lenLen)
        if err != nil {
            return nil
        }
        if dataLen != uint64(1)+uint64(lenLen)+l {
            return nil
        }
        node = toList(data[1+lenLen:], l)
    case data[0] >= 0xc0:
        l := uint64(data[0]) - 0xc0
        if dataLen != uint64(1+l) {
            return nil
        }
        node = toList(data[1:], l)
    }
    return node
}

func decodeLen(data []byte, lenLen int) (uint64, error) {
    if len(data) <= lenLen || lenLen > 8 {
        return 0, errDecode
    }
    switch lenLen {
    case 1:
        return uint64(data[0]), nil
    default:
        start := int(8 - lenLen)
        copy(lenBuf[:], nilBuf[:start])
        copy(lenBuf[start:], data[:lenLen])
        return binary.BigEndian.Uint64(lenBuf), nil
    }
}

func toList(data []byte, dataLen uint64) rlpNode {
    var (
        node   rlpNode
        offset = uint64(0)
    )
    for {
        e, l, err := toString(data[offset:])
        if err != nil {
            return nil
        }
        node = append(node, e)
        offset += l
        if dataLen == offset {
            break
        }
        if dataLen < offset {
            return nil
        }
    }
    nodeLen := uint64(len(node))
    if nodeLen != uint64(2) && nodeLen != uint64(17) {
        return nil
    }
    return node
}

func toString(data []byte) ([]byte, uint64, error) {
    if len(data) == 0 {
        return nil, 0, errDecode
    }
    switch {
    case data[0] <= 0x7f: // a single byte
        return data[0:1], 1, nil
    case data[0] <= 0xb7: // string <= 55
        end := 1 + data[0] - 0x80
        return data[1:end], uint64(end), nil
    case data[0] <= 0xbf: // string > 55
        lenLen := data[0] - 0xb7
        l, err := decodeLen(data[1:], int(lenLen))
        if err != nil {
            return nil, 0, err
        }
        start := 1 + lenLen
        end := uint64(start) + l
        return data[start:end], end, nil
    default:
        return nil, 0, errDecode
    }
}

func keccak256(data ...[]byte) []byte {
    h := sha3.NewLegacyKeccak256()
    for _, d := range data {
        h.Write(d)
    }
    return h.Sum(nil)
}

func keccak256Hex(data ...[]byte) string {
    return hex.Encode(keccak256(data...))
}

func decodeHpHeader(b byte) (bool, []byte, error) {
    switch b >> 4 {
    case 0:
        return false, []byte{}, nil
    case 1:
        return false, []byte{hexChar[b&0x0f]}, nil
    case 2:
        return true, []byte{}, nil
    case 3:
        return true, []byte{hexChar[b&0x0f]}, nil
    default:
        return false, []byte{}, errDecode
    }
}

func hexToIndex(c byte) (byte, error) {
    switch {
    case '0' <= c && c <= '9':
        return c - '0', nil
    case 'a' <= c && c <= 'f':
        return c - 'a' + 10, nil
    case 'A' <= c && c <= 'F':
        return c - 'A' + 10, nil
    }
    return 0, errDecode
}

func (ks keyStream) index() byte {
    b, err := ks.ReadByte()
    if err != nil {
        return 0x10
    }
    i, err := hexToIndex(b)
    if err != nil {
        return 0x10
    }
    return i
}

func (ks keyStream) key(l int) []byte {
    if l == -1 {
        return ks.Buffer.Bytes()
    }
    return ks.Buffer.Next(l)
}

func rlpEncode(o rlpObject) []byte {
    return o.rlpEncode()
}

type rlpObject interface {
    rlpEncode() []byte
}

type rlpString []byte

func (s rlpString) rlpEncode() []byte {
    var rlpBytes []byte
    l := len(s)
    if l == 1 && s[0] < 0x80 {
        rlpBytes = append(rlpBytes, s[0])
    } else {
        rlpBytes = append(rlpBytes, rlpLength(l, 0x80)...)
        rlpBytes = append(rlpBytes, s...)
    }
    return rlpBytes
}

type rlpList []rlpObject

func (l rlpList) rlpEncode() []byte {
    var rlpBytes []byte
    for _, item := range l {
        rlpBytes = append(rlpBytes, item.rlpEncode()...)
    }
    length := rlpLength(len(rlpBytes), 0xc0)
    return append(length, rlpBytes...)
}

func rlpLength(dataLen int, offset byte) []byte {
    if dataLen < 56 {
        return []byte{byte(dataLen) + offset}
    } else if dataLen < math.MaxInt32 {
        var output []byte
        b := toBinary(dataLen)
        output = append(output, byte(len(b)+int(offset)+55))
        return append(output, b...)
    } else {
        return []byte{}
    }
}

func toBinary(d int) []byte {
    var b []byte
    for d > 0 {
        b = append([]byte{byte(d % 256)}, b...)
        d /= 256
    }
    return b
}