status-im/status-go

View on GitHub
services/typeddata/types.go

Summary

Maintainability
A
0 mins
Test Coverage
A
95%
package typeddata

import (
    "encoding/json"
    "errors"
    "fmt"
    "math/big"
    "strconv"
)

const (
    eip712Domain = "EIP712Domain"
    ChainIDKey   = "chainId"
)

// Types define fields for each composite type.
type Types map[string][]Field

// Field stores name and solidity type of the field.
type Field struct {
    Name string `json:"name"`
    Type string `json:"type"`
}

// Validate checks that both name and type are not empty.
func (f Field) Validate() error {
    if len(f.Name) == 0 {
        return errors.New("`name` is required")
    }
    if len(f.Type) == 0 {
        return errors.New("`type` is required")
    }
    return nil
}

// TypedData defines typed data according to eip-712.
type TypedData struct {
    Types       Types                      `json:"types"`
    PrimaryType string                     `json:"primaryType"`
    Domain      map[string]json.RawMessage `json:"domain"`
    Message     map[string]json.RawMessage `json:"message"`
}

// Validate that required fields are defined.
// This method doesn't check if dependencies of the main type are defined, it will be validated
// when type string is computed.
func (t TypedData) Validate() error {
    if _, exist := t.Types[eip712Domain]; !exist {
        return fmt.Errorf("`%s` must be in `types`", eip712Domain)
    }
    if t.PrimaryType == "" {
        return errors.New("`primaryType` is required")
    }
    if _, exist := t.Types[t.PrimaryType]; !exist {
        return fmt.Errorf("primary type `%s` not defined in types", t.PrimaryType)
    }
    if t.Domain == nil {
        return errors.New("`domain` is required")
    }
    if t.Message == nil {
        return errors.New("`message` is required")
    }
    for typ := range t.Types {
        fields := t.Types[typ]
        for i := range fields {
            if err := fields[i].Validate(); err != nil {
                return fmt.Errorf("field %d from type `%s` is invalid: %v", i, typ, err)
            }
        }
    }
    return nil
}

// ValidateChainID accept chain as big integer and verifies if typed data belongs to the same chain.
func (t TypedData) ValidateChainID(chain *big.Int) error {
    if _, exist := t.Domain[ChainIDKey]; !exist {
        return fmt.Errorf("domain misses chain key %s", ChainIDKey)
    }
    var chainID int64
    if err := json.Unmarshal(t.Domain[ChainIDKey], &chainID); err != nil {
        var chainIDString string
        if err = json.Unmarshal(t.Domain[ChainIDKey], &chainIDString); err != nil {
            return err
        }
        if chainID, err = strconv.ParseInt(chainIDString, 0, 64); err != nil {
            return err
        }
    }
    if chainID != chain.Int64() {
        return fmt.Errorf("chainId %d doesn't match selected chain %s", chainID, chain)
    }
    return nil
}