status-im/status-go

View on GitHub
services/wallet/reader.go

Summary

Maintainability
A
3 hrs
Test Coverage
D
66%
package wallet

import (
    "context"
    "math"
    "math/big"
    "sync"
    "time"

    "golang.org/x/exp/maps"

    "github.com/ethereum/go-ethereum/common"
    "github.com/ethereum/go-ethereum/common/hexutil"
    "github.com/ethereum/go-ethereum/event"
    "github.com/ethereum/go-ethereum/log"
    gocommon "github.com/status-im/status-go/common"
    "github.com/status-im/status-go/rpc/chain"
    "github.com/status-im/status-go/services/wallet/async"
    "github.com/status-im/status-go/services/wallet/market"
    "github.com/status-im/status-go/services/wallet/thirdparty"
    "github.com/status-im/status-go/services/wallet/token"
    "github.com/status-im/status-go/services/wallet/transfer"
    "github.com/status-im/status-go/services/wallet/walletevent"
)

// WalletTickReload emitted every 15mn to reload the wallet balance and history
const EventWalletTickReload walletevent.EventType = "wallet-tick-reload"
const EventWalletTickCheckConnected walletevent.EventType = "wallet-tick-check-connected"

const (
    walletTickReloadPeriod      = 10 * time.Minute
    activityReloadDelay         = 30 // Wait this many seconds after activity is detected before triggering a wallet reload
    activityReloadMarginSeconds = 30 // Trigger a wallet reload if activity is detected this many seconds before the last reload
)

func getFixedCurrencies() []string {
    return []string{"USD"}
}

func belongsToMandatoryTokens(symbol string) bool {
    var mandatoryTokens = []string{"ETH", "DAI", "SNT", "STT"}
    for _, t := range mandatoryTokens {
        if t == symbol {
            return true
        }
    }
    return false
}

func NewReader(tokenManager token.ManagerInterface, marketManager *market.Manager, persistence token.TokenBalancesStorage, walletFeed *event.Feed) *Reader {
    return &Reader{
        tokenManager:        tokenManager,
        marketManager:       marketManager,
        persistence:         persistence,
        walletFeed:          walletFeed,
        refreshBalanceCache: true,
    }
}

type Reader struct {
    tokenManager                   token.ManagerInterface
    marketManager                  *market.Manager
    persistence                    token.TokenBalancesStorage
    walletFeed                     *event.Feed
    cancel                         context.CancelFunc
    walletEventsWatcher            *walletevent.Watcher
    lastWalletTokenUpdateTimestamp sync.Map
    reloadDelayTimer               *time.Timer
    refreshBalanceCache            bool
    rw                             sync.RWMutex
}

func splitVerifiedTokens(tokens []*token.Token) ([]*token.Token, []*token.Token) {
    verified := make([]*token.Token, 0)
    unverified := make([]*token.Token, 0)

    for _, t := range tokens {
        if t.Verified {
            verified = append(verified, t)
        } else {
            unverified = append(unverified, t)
        }
    }

    return verified, unverified
}

func getTokenBySymbols(tokens []*token.Token) map[string][]*token.Token {
    res := make(map[string][]*token.Token)

    for _, t := range tokens {
        if _, ok := res[t.Symbol]; !ok {
            res[t.Symbol] = make([]*token.Token, 0)
        }

        res[t.Symbol] = append(res[t.Symbol], t)
    }

    return res
}

func getTokenAddresses(tokens []*token.Token) []common.Address {
    set := make(map[common.Address]bool)
    for _, token := range tokens {
        set[token.Address] = true
    }
    res := make([]common.Address, 0)
    for address := range set {
        res = append(res, address)
    }
    return res
}

func (r *Reader) Start() error {
    ctx, cancel := context.WithCancel(context.Background())
    r.cancel = cancel

    r.startWalletEventsWatcher()

    go func() {
        defer gocommon.LogOnPanic()
        ticker := time.NewTicker(walletTickReloadPeriod)
        defer ticker.Stop()
        for {
            select {
            case <-ctx.Done():
                return
            case <-ticker.C:
                r.triggerWalletReload()
            }
        }
    }()
    return nil
}

func (r *Reader) Stop() {
    if r.cancel != nil {
        r.cancel()
    }

    r.stopWalletEventsWatcher()

    r.cancelDelayedWalletReload()

    r.lastWalletTokenUpdateTimestamp = sync.Map{}
}

func (r *Reader) Restart() error {
    r.Stop()
    return r.Start()
}

func (r *Reader) triggerWalletReload() {
    r.cancelDelayedWalletReload()

    r.walletFeed.Send(walletevent.Event{
        Type: EventWalletTickReload,
    })
}

func (r *Reader) triggerDelayedWalletReload() {
    r.cancelDelayedWalletReload()

    r.reloadDelayTimer = time.AfterFunc(time.Duration(activityReloadDelay)*time.Second, r.triggerWalletReload)
}

func (r *Reader) cancelDelayedWalletReload() {

    if r.reloadDelayTimer != nil {
        r.reloadDelayTimer.Stop()
        r.reloadDelayTimer = nil
    }
}

func (r *Reader) startWalletEventsWatcher() {
    if r.walletEventsWatcher != nil {
        return
    }

    // Respond to ETH/Token transfers
    walletEventCb := func(event walletevent.Event) {
        if event.Type != transfer.EventInternalETHTransferDetected &&
            event.Type != transfer.EventInternalERC20TransferDetected {
            return
        }

        for _, address := range event.Accounts {
            timestamp, ok := r.lastWalletTokenUpdateTimestamp.Load(address)
            timecheck := int64(0)
            if ok {
                timecheck = timestamp.(int64) - activityReloadMarginSeconds
            }

            if !ok || event.At > timecheck {
                r.triggerDelayedWalletReload()
                r.invalidateBalanceCache()
                break
            }
        }
    }

    r.walletEventsWatcher = walletevent.NewWatcher(r.walletFeed, walletEventCb)

    r.walletEventsWatcher.Start()
}

func (r *Reader) stopWalletEventsWatcher() {
    if r.walletEventsWatcher != nil {
        r.walletEventsWatcher.Stop()
        r.walletEventsWatcher = nil
    }
}

func (r *Reader) tokensCachedForAddresses(addresses []common.Address) bool {
    cachedTokens, err := r.getCachedWalletTokensWithoutMarketData()
    if err != nil {
        return false
    }

    for _, address := range addresses {
        _, ok := cachedTokens[address]
        if !ok {
            return false
        }
    }

    return true
}

func (r *Reader) isCacheTimestampValidForAddress(address common.Address) bool {
    _, ok := r.lastWalletTokenUpdateTimestamp.Load(address)
    return ok
}

func (r *Reader) areCacheTimestampsValid(addresses []common.Address) bool {
    for _, address := range addresses {
        if !r.isCacheTimestampValidForAddress(address) {
            return false
        }
    }

    return true
}

func (r *Reader) isBalanceCacheValid(addresses []common.Address) bool {
    r.rw.RLock()
    defer r.rw.RUnlock()

    return !r.refreshBalanceCache && r.tokensCachedForAddresses(addresses) && r.areCacheTimestampsValid(addresses)
}

func (r *Reader) balanceRefreshed() {
    r.rw.Lock()
    defer r.rw.Unlock()

    r.refreshBalanceCache = false
}

func (r *Reader) invalidateBalanceCache() {
    r.rw.Lock()
    defer r.rw.Unlock()

    r.refreshBalanceCache = true
}

func (r *Reader) FetchOrGetCachedWalletBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address, forceRefresh bool) (map[common.Address][]token.StorageToken, error) {
    needFetch := forceRefresh || !r.isBalanceCacheValid(addresses) || r.isBalanceUpdateNeededAnyway(clients, addresses)

    if needFetch {
        _, err := r.FetchBalances(ctx, clients, addresses)
        if err != nil {
            log.Error("FetchOrGetCachedWalletBalances error", "err", err)
        }
    }

    return r.GetCachedBalances(clients, addresses)
}

func (r *Reader) isBalanceUpdateNeededAnyway(clients map[uint64]chain.ClientInterface, addresses []common.Address) bool {
    cachedTokens, err := r.getCachedWalletTokensWithoutMarketData()
    if err != nil {
        return true
    }

    chainIDs := maps.Keys(clients)
    updateAnyway := false
    for _, address := range addresses {
        if res, ok := cachedTokens[address]; !ok || len(res) == 0 {
            updateAnyway = true
            break
        }

        networkFound := map[uint64]bool{}
        for _, token := range cachedTokens[address] {
            for _, chain := range chainIDs {
                if _, ok := token.BalancesPerChain[chain]; ok {
                    networkFound[chain] = true
                }
            }
        }

        for _, chain := range chainIDs {
            if !networkFound[chain] {
                updateAnyway = true
                return updateAnyway
            }
        }
    }

    return updateAnyway
}

func tokensToBalancesPerChain(cachedTokens map[common.Address][]token.StorageToken) map[uint64]map[common.Address]map[common.Address]*hexutil.Big {
    cachedBalancesPerChain := map[uint64]map[common.Address]map[common.Address]*hexutil.Big{}
    for address, tokens := range cachedTokens {
        for _, token := range tokens {
            for _, balance := range token.BalancesPerChain {
                if _, ok := cachedBalancesPerChain[balance.ChainID]; !ok {
                    cachedBalancesPerChain[balance.ChainID] = map[common.Address]map[common.Address]*hexutil.Big{}
                }
                if _, ok := cachedBalancesPerChain[balance.ChainID][address]; !ok {
                    cachedBalancesPerChain[balance.ChainID][address] = map[common.Address]*hexutil.Big{}
                }

                bigBalance, _ := new(big.Int).SetString(balance.RawBalance, 10)
                cachedBalancesPerChain[balance.ChainID][address][balance.Address] = (*hexutil.Big)(bigBalance)
            }
        }
    }

    return cachedBalancesPerChain
}

func (r *Reader) fetchBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address, tokenAddresses []common.Address) (map[uint64]map[common.Address]map[common.Address]*hexutil.Big, error) {
    latestBalances, err := r.tokenManager.GetBalancesByChain(ctx, clients, addresses, tokenAddresses)
    if err != nil {
        log.Error("tokenManager.GetBalancesByChain error", "err", err)
        return nil, err
    }

    return latestBalances, nil
}

func toChainBalance(
    balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big,
    tok *token.Token,
    address common.Address,
    decimals uint,
    cachedTokens map[common.Address][]token.StorageToken,
    hasError bool,
    isMandatoryToken bool,
) *token.ChainBalance {
    hexBalance := &big.Int{}
    if balances != nil {
        hexBalance = balances[tok.ChainID][address][tok.Address].ToInt()
    }

    balance := big.NewFloat(0.0)
    if hexBalance != nil {
        balance = new(big.Float).Quo(
            new(big.Float).SetInt(hexBalance),
            big.NewFloat(math.Pow(10, float64(decimals))),
        )
    }

    isVisible := balance.Cmp(big.NewFloat(0.0)) > 0 || isCachedToken(cachedTokens, address, tok.Symbol, tok.ChainID)
    if !isVisible && !isMandatoryToken {
        return nil
    }

    return &token.ChainBalance{
        RawBalance:     hexBalance.String(),
        Balance:        balance,
        Balance1DayAgo: "0",
        Address:        tok.Address,
        ChainID:        tok.ChainID,
        HasError:       hasError,
    }
}

func (r *Reader) getBalance1DayAgo(balance *token.ChainBalance, dayAgoTimestamp int64, symbol string, address common.Address) (*big.Int, error) {
    balance1DayAgo, err := r.tokenManager.GetTokenHistoricalBalance(address, balance.ChainID, symbol, dayAgoTimestamp)
    if err != nil {
        log.Error("tokenManager.GetTokenHistoricalBalance error", "err", err)
        return nil, err
    }

    return balance1DayAgo, nil
}

func (r *Reader) balancesToTokensByAddress(connectedPerChain map[uint64]bool, addresses []common.Address, allTokens []*token.Token, balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big, cachedTokens map[common.Address][]token.StorageToken) map[common.Address][]token.StorageToken {
    verifiedTokens, unverifiedTokens := splitVerifiedTokens(allTokens)

    result := make(map[common.Address][]token.StorageToken)
    dayAgoTimestamp := time.Now().Add(-24 * time.Hour).Unix()

    for _, address := range addresses {
        for _, tokenList := range [][]*token.Token{verifiedTokens, unverifiedTokens} {
            for symbol, tokens := range getTokenBySymbols(tokenList) {
                balancesPerChain := r.createBalancePerChainPerSymbol(address, balances, tokens, cachedTokens, connectedPerChain, dayAgoTimestamp)
                if balancesPerChain == nil {
                    continue
                }

                walletToken := token.StorageToken{
                    Token: token.Token{
                        Name:          tokens[0].Name,
                        Symbol:        symbol,
                        Decimals:      tokens[0].Decimals,
                        PegSymbol:     token.GetTokenPegSymbol(symbol),
                        Verified:      tokens[0].Verified,
                        CommunityData: tokens[0].CommunityData,
                        Image:         tokens[0].Image,
                    },
                    BalancesPerChain: balancesPerChain,
                }

                result[address] = append(result[address], walletToken)
            }
        }
    }

    return result
}

func (r *Reader) createBalancePerChainPerSymbol(
    address common.Address,
    balances map[uint64]map[common.Address]map[common.Address]*hexutil.Big,
    tokens []*token.Token,
    cachedTokens map[common.Address][]token.StorageToken,
    clientConnectionPerChain map[uint64]bool,
    dayAgoTimestamp int64,
) map[uint64]token.ChainBalance {
    var balancesPerChain map[uint64]token.ChainBalance
    decimals := tokens[0].Decimals
    isMandatoryToken := belongsToMandatoryTokens(tokens[0].Symbol) // we expect all tokens in the list to have the same symbol
    for _, tok := range tokens {
        hasError := false
        if connected, ok := clientConnectionPerChain[tok.ChainID]; ok {
            hasError = !connected
        }

        if _, ok := balances[tok.ChainID][address][tok.Address]; !ok {
            hasError = true
        }

        // TODO: Avoid passing the entire balances map to toChainBalance. Iterate over the balances map once and pass the balance per address per token to toChainBalance
        balance := toChainBalance(balances, tok, address, decimals, cachedTokens, hasError, isMandatoryToken)
        if balance != nil {
            balance1DayAgo, _ := r.getBalance1DayAgo(balance, dayAgoTimestamp, tok.Symbol, address) // Ignore error
            if balance1DayAgo != nil {
                balance.Balance1DayAgo = balance1DayAgo.String()
            }

            if balancesPerChain == nil {
                balancesPerChain = make(map[uint64]token.ChainBalance)
            }
            balancesPerChain[tok.ChainID] = *balance
        }
    }

    return balancesPerChain
}

func (r *Reader) GetWalletToken(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address, currency string) (map[common.Address][]token.StorageToken, error) {
    currencies := make([]string, 0)
    currencies = append(currencies, currency)
    currencies = append(currencies, getFixedCurrencies()...)

    result, err := r.FetchOrGetCachedWalletBalances(ctx, clients, addresses, true)
    if err != nil {
        return nil, err
    }

    tokenSymbols := make([]string, 0)
    for _, storageTokens := range result {
        for _, t := range storageTokens {
            tokenSymbols = append(tokenSymbols, t.Token.Symbol)
        }
    }

    var (
        group             = async.NewAtomicGroup(ctx)
        prices            = map[string]map[string]market.DataPoint{}
        tokenDetails      = map[string]thirdparty.TokenDetails{}
        tokenMarketValues = map[string]thirdparty.TokenMarketValues{}
    )

    group.Add(func(parent context.Context) error {
        prices, err = r.marketManager.GetOrFetchPrices(tokenSymbols, currencies, market.MaxAgeInSecondsForBalances)
        if err != nil {
            log.Info("marketManager.GetOrFetchPrices err", err)
        }
        return nil
    })

    group.Add(func(parent context.Context) error {
        tokenDetails, err = r.marketManager.FetchTokenDetails(tokenSymbols)
        if err != nil {
            log.Info("marketManager.FetchTokenDetails err", err)
        }
        return nil
    })

    group.Add(func(parent context.Context) error {
        tokenMarketValues, err = r.marketManager.GetOrFetchTokenMarketValues(tokenSymbols, currency, market.MaxAgeInSecondsForBalances)
        if err != nil {
            log.Info("marketManager.GetOrFetchTokenMarketValues err", err)
        }
        return nil
    })

    select {
    case <-group.WaitAsync():
    case <-ctx.Done():
        return nil, ctx.Err()
    }
    err = group.Error()
    if err != nil {
        return nil, err
    }

    for address, tokens := range result {
        for index, tok := range tokens {
            marketValuesPerCurrency := make(map[string]token.TokenMarketValues)
            for _, currency := range currencies {
                if _, ok := tokenMarketValues[tok.Symbol]; !ok {
                    continue
                }
                marketValuesPerCurrency[currency] = token.TokenMarketValues{
                    MarketCap:       tokenMarketValues[tok.Symbol].MKTCAP,
                    HighDay:         tokenMarketValues[tok.Symbol].HIGHDAY,
                    LowDay:          tokenMarketValues[tok.Symbol].LOWDAY,
                    ChangePctHour:   tokenMarketValues[tok.Symbol].CHANGEPCTHOUR,
                    ChangePctDay:    tokenMarketValues[tok.Symbol].CHANGEPCTDAY,
                    ChangePct24hour: tokenMarketValues[tok.Symbol].CHANGEPCT24HOUR,
                    Change24hour:    tokenMarketValues[tok.Symbol].CHANGE24HOUR,
                    Price:           prices[tok.Symbol][currency].Price,
                    HasError:        !r.marketManager.IsConnected,
                }
            }

            if _, ok := tokenDetails[tok.Symbol]; !ok {
                continue
            }

            result[address][index].Description = tokenDetails[tok.Symbol].Description
            result[address][index].AssetWebsiteURL = tokenDetails[tok.Symbol].AssetWebsiteURL
            result[address][index].BuiltOn = tokenDetails[tok.Symbol].BuiltOn
            result[address][index].MarketValuesPerCurrency = marketValuesPerCurrency
        }
    }

    r.updateTokenUpdateTimestamp(addresses)

    return result, r.persistence.SaveTokens(result)
}

func isCachedToken(cachedTokens map[common.Address][]token.StorageToken, address common.Address, symbol string, chainID uint64) bool {
    if tokens, ok := cachedTokens[address]; ok {
        for _, t := range tokens {
            if t.Symbol != symbol {
                continue
            }
            _, ok := t.BalancesPerChain[chainID]
            if ok {
                return true
            }
        }
    }
    return false
}

// getCachedWalletTokensWithoutMarketData returns the latest fetched balances, minus
// price information
func (r *Reader) getCachedWalletTokensWithoutMarketData() (map[common.Address][]token.StorageToken, error) {
    return r.persistence.GetTokens()
}

func (r *Reader) updateTokenUpdateTimestamp(addresses []common.Address) {
    for _, address := range addresses {
        r.lastWalletTokenUpdateTimestamp.Store(address, time.Now().Unix())
    }
}

func (r *Reader) FetchBalances(ctx context.Context, clients map[uint64]chain.ClientInterface, addresses []common.Address) (map[common.Address][]token.StorageToken, error) {
    cachedTokens, err := r.getCachedWalletTokensWithoutMarketData()
    if err != nil {
        return nil, err
    }

    chainIDs := maps.Keys(clients)
    allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs)
    if err != nil {
        return nil, err
    }

    tokenAddresses := getTokenAddresses(allTokens)
    balances, err := r.fetchBalances(ctx, clients, addresses, tokenAddresses)
    if err != nil {
        log.Error("failed to update balances", "err", err)
        return nil, err
    }

    connectedPerChain := map[uint64]bool{}
    for chainID, client := range clients {
        connectedPerChain[chainID] = client.IsConnected()
    }

    tokens := r.balancesToTokensByAddress(connectedPerChain, addresses, allTokens, balances, cachedTokens)

    err = r.persistence.SaveTokens(tokens)
    if err != nil {
        log.Error("failed to save tokens", "err", err) // Do not return error, as it is not critical
    }

    r.updateTokenUpdateTimestamp(addresses)
    r.balanceRefreshed()

    return tokens, err
}

func (r *Reader) GetCachedBalances(clients map[uint64]chain.ClientInterface, addresses []common.Address) (map[common.Address][]token.StorageToken, error) {
    cachedTokens, err := r.getCachedWalletTokensWithoutMarketData()
    if err != nil {
        return nil, err
    }

    chainIDs := maps.Keys(clients)
    allTokens, err := r.tokenManager.GetTokensByChainIDs(chainIDs)
    if err != nil {
        return nil, err
    }

    connectedPerChain := map[uint64]bool{}
    for chainID, client := range clients {
        connectedPerChain[chainID] = client.IsConnected()
    }

    balances := tokensToBalancesPerChain(cachedTokens)
    return r.balancesToTokensByAddress(connectedPerChain, addresses, allTokens, balances, cachedTokens), nil
}