package inventory
import (
// Manager is the interface for the inventory manager.
//go:generate go run --name Manager --output ./mocks --case=underscore
type Manager interface {
// Start starts the inventory manager.
Start(ctx context.Context) (err error)
// GetCommittableBalance gets the total balance available for quotes
// this does not include on-chain balances committed in previous quotes that may be
// refunded in the event of a revert.
GetCommittableBalance(ctx context.Context, chainID int, token common.Address, options ...BalanceFetchArgOption) (*big.Int, error)
// GetCommittableBalances gets the total balances committable for all tracked tokens.
GetCommittableBalances(ctx context.Context, options ...BalanceFetchArgOption) (map[int]map[common.Address]*big.Int, error)
// ApproveAllTokens approves all tokens for the relayer address.
ApproveAllTokens(ctx context.Context) error
// HasSufficientGas checks if there is sufficient gas for a given route.
HasSufficientGas(ctx context.Context, chainID int, gasValue *big.Int) (bool, error)
// Rebalance attempts any rebalances that could be executed across all supported tokens and chains.
Rebalance(ctx context.Context) error
// GetTokenMetadata gets the metadata for a token.
GetTokenMetadata(chainID int, token common.Address) (*TokenMetadata, error)
type inventoryManagerImpl struct {
// map chainID->address->TokenMetadata
tokens map[int]map[common.Address]*TokenMetadata
// map chainID->balance
gasBalances map[int]*big.Int
// mux contains the mutex
mux sync.RWMutex
// handler is the metrics handler
handler metrics.Handler
// cfg is the config
cfg relconfig.Config
// relayerAddress contains the relayer address
relayerAddress common.Address
// chainClient is an omnirpc client
chainClient submitter.ClientFetcher
// txSubmitter is the transaction submitter
txSubmitter submitter.TransactionSubmitter
// rebalanceManagers is the map of rebalance managers
rebalanceManagers map[relconfig.RebalanceMethod]RebalanceManager
// db is the database
db reldb.Service
// meter is the metrics meter for this package
meter metric.Meter
// balanceGauge is the histogram for balance
balanceGauge metric.Float64ObservableGauge
// inFlightQuoteManager is the cache for in flight quotes
inFlightQuoteManager *inFlightManager
// ErrUnsupportedChain is the error for an unsupported chain.
var ErrUnsupportedChain = errors.New("could not get gas balance for unsupported chain")
// GetCommittableBalance gets the committable balances.
func (i *inventoryManagerImpl) GetCommittableBalance(ctx context.Context, chainID int, token common.Address, options ...BalanceFetchArgOption) (*big.Int, error) {
committableBalances, err := i.GetCommittableBalances(ctx, options...)
if err != nil {
return nil, fmt.Errorf("could not get balances: %w", err)
balance := committableBalances[chainID][token]
// the gas token may not be registered in the inventory tokens map,
// but it is always tracked in gasBalances.
if balance == nil && token == util.EthAddress {
gasBalance, ok := i.gasBalances[chainID]
if !ok || gasBalance == nil {
return nil, ErrUnsupportedChain
balance = i.gasBalances[chainID]
return balance, nil
func (i *inventoryManagerImpl) GetCommittableBalances(ctx context.Context, options ...BalanceFetchArgOption) (res map[int]map[common.Address]*big.Int, err error) {
reqOptions := makeOptions(options)
// TODO: hard fail if cache skip breaks
if reqOptions.shouldRefreshBalances {
// TODO; no need for this if refresh already in flight
_ = i.refreshBalances(ctx)
// get db first
// Add other committed, but incomplete statuses here
// TODO: clean me up: you can do this by having a IsLiquidityCommitted() method on the type.
inFlightQuotes, err := i.inFlightQuoteManager.GetInFlightQuotes(ctx, reqOptions.skipDBCache)
if err != nil {
return nil, fmt.Errorf("could not get in flight quotes: %w", err)
// TODO: lock should be context aware
defer i.mux.RUnlock()
res = make(map[int]map[common.Address]*big.Int)
for chainID, tokenMap := range i.tokens {
res[chainID] = map[common.Address]*big.Int{}
for address, tokenData := range tokenMap {
res[chainID][address] = core.CopyBigInt(tokenData.Balance)
// now subtract by in flight quotes.
// Yeah, this is an algorithmically atrocious for
// TODO: fix, but we're really talking about 4 tokens
for _, quote := range inFlightQuotes {
if quote.Transaction.DestToken == address && quote.Transaction.DestChainId == uint32(chainID) {
res[chainID][address] = new(big.Int).Sub(res[chainID][address], quote.Transaction.DestAmount)
// TODO: db subtraction
return res, nil
// TokenMetadata contains metadata for a token.
type TokenMetadata struct {
// Name is the name of the token in the config
Name string
// ChainName is the name of the token onchain
ChainName string
Balance *big.Int
Decimals uint8
IsGasToken bool
ChainID int
Addr common.Address
var (
funcBalanceOf = w3.MustNewFunc("balanceOf(address)", "uint256")
funcName = w3.MustNewFunc("name()", "string")
funcDecimals = w3.MustNewFunc("decimals()", "uint8")
// TODO: replace w/ config.
const defaultPollPeriod = 5
const meterName = ""
// NewInventoryManager creates a new inventory manager.
// TODO: too many args here.
func NewInventoryManager(ctx context.Context, clientFetcher submitter.ClientFetcher, handler metrics.Handler, cfg relconfig.Config, relayer common.Address, txSubmitter submitter.TransactionSubmitter, db reldb.Service) (Manager, error) {
rebalanceMethods, err := cfg.GetAllRebalanceMethods()
if err != nil {
return nil, fmt.Errorf("could not get rebalance methods: %w", err)
rebalanceManagers := make(map[relconfig.RebalanceMethod]RebalanceManager)
for method := range rebalanceMethods {
switch method {
case relconfig.RebalanceMethodSynapseCCTP:
rebalanceManagers[method] = newRebalanceManagerSynapseCCTP(cfg, handler, clientFetcher, txSubmitter, relayer, db)
case relconfig.RebalanceMethodCircleCCTP:
rebalanceManagers[method] = newRebalanceManagerCircleCCTP(cfg, handler, clientFetcher, txSubmitter, relayer, db)
case relconfig.RebalanceMethodScroll:
rebalanceManagers[method] = newRebalanceManagerScroll(cfg, handler, clientFetcher, txSubmitter, relayer, db)
return nil, fmt.Errorf("unsupported rebalance method: %s", method)
i := inventoryManagerImpl{
relayerAddress: relayer,
handler: handler,
cfg: cfg,
chainClient: clientFetcher,
txSubmitter: txSubmitter,
rebalanceManagers: rebalanceManagers,
db: db,
meter: handler.Meter(meterName),
inFlightQuoteManager: newInflightManager(db),
i.balanceGauge, err = i.meter.Float64ObservableGauge("inventory_balance")
if err != nil {
return nil, fmt.Errorf("could not create balance gauge: %w", err)
_, err = i.meter.RegisterCallback(i.recordBalances, i.balanceGauge)
if err != nil {
return nil, fmt.Errorf("could not register callback: %w", err)
err = i.initializeTokens(ctx, cfg)
if err != nil {
return nil, fmt.Errorf("could not initialize tokens: %w", err)
return &i, nil
func (i *inventoryManagerImpl) Start(ctx context.Context) error {
g, ctx := errgroup.WithContext(ctx)
for _, rebalanceManager := range i.rebalanceManagers {
rebalanceManager := rebalanceManager
g.Go(func() error {
err := rebalanceManager.Start(ctx)
if err != nil {
return fmt.Errorf("could not start rebalance manager: %w", err)
return nil
// continuously refresh balances
g.Go(func() error {
for {
select {
case <-ctx.Done():
return fmt.Errorf("context canceled: %w", ctx.Err())
case <-time.After(250 * time.Millisecond):
// this returning an error isn't really possible unless a config error happens
// TODO: need better error handling.
err := i.refreshBalances(ctx)
if err != nil {
logger.Errorf("could not refresh balances")
return nil
// continuously check for rebalances
rebalanceInterval := i.cfg.GetRebalanceInterval()
if rebalanceInterval > 0 {
g.Go(func() error {
for {
select {
case <-ctx.Done():
return fmt.Errorf("context canceled: %w", ctx.Err())
case <-time.After(rebalanceInterval):
rebalanceCtx, span := i.handler.Tracer().Start(ctx, "newRebalanceInterval")
err := i.refreshBalances(rebalanceCtx)
if err != nil {
metrics.EndSpanWithErr(span, err)
return fmt.Errorf("could not refresh balances: %w", err)
err = i.Rebalance(rebalanceCtx)
if err != nil {
logger.Errorf("could not rebalance: %v", err)
metrics.EndSpanWithErr(span, err)
err := g.Wait()
if err != nil {
return fmt.Errorf("error starting inventory manager: %w", err)
return nil
const maxBatchSize = 10
// ApproveAllTokens approves all checks if allowance is set and if not approves.
// nolint:gocognit,nestif,cyclop
func (i *inventoryManagerImpl) ApproveAllTokens(ctx context.Context) (err error) {
ctx, span := i.handler.Tracer().Start(ctx, "approveAllTokens")
defer func() {
metrics.EndSpanWithErr(span, err)
defer i.mux.RUnlock()
for chainID, tokenMap := range i.tokens {
backendClient, err := i.chainClient.GetClient(ctx, big.NewInt(int64(chainID)))
if err != nil {
return fmt.Errorf("could not get chain client: %w", err)
for tokenAddr := range tokenMap {
// Note: in the case where submitter hasn't finished from last boot,
// this will double submit approvals unfortunately.
contractAddr, addrErr := i.cfg.GetRFQAddress(chainID)
if addrErr == nil {
err = i.approve(ctx, tokenAddr, contractAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve RFQ contract: %w", err)
contractAddr, addrErr = i.cfg.GetSynapseCCTPAddress(chainID)
if addrErr == nil {
err = i.approve(ctx, tokenAddr, contractAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve SynapseCCTP contract: %w", err)
contractAddr, addrErr = i.cfg.GetTokenMessengerAddress(chainID)
if addrErr == nil {
err = i.approve(ctx, tokenAddr, contractAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve TokenMessenger contract: %w", err)
parentAddr, addrErr := i.cfg.GetL1GatewayAddress(chainID)
if addrErr == nil {
span.AddEvent(fmt.Sprintf("got l1 gateway address: %s", parentAddr.Hex()))
err = i.approve(ctx, tokenAddr, parentAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve L1GatewayRouter contract: %w", err)
contract, err := l1gateway.NewL1GatewayRouter(parentAddr, backendClient)
if err != nil {
return fmt.Errorf("could not get L1Gateway contract: %w", err)
contractAddr, err = contract.ERC20Gateway(&bind.CallOpts{Context: ctx}, tokenAddr)
if err != nil {
return fmt.Errorf("could not get L1ERC20Gateway address: %w", err)
span.AddEvent(fmt.Sprintf("got l1 erc20 gateway address: %s", contractAddr.Hex()))
err = i.approve(ctx, tokenAddr, contractAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve L1ERC20Gateway contract: %w", err)
parentAddr, addrErr = i.cfg.GetL2GatewayAddress(chainID)
if addrErr == nil {
span.AddEvent(fmt.Sprintf("got l2 gateway address: %s", parentAddr.Hex()))
err = i.approve(ctx, tokenAddr, parentAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve L2GatewayRouter contract: %w", err)
contract, err := l2gateway.NewL2GatewayRouter(parentAddr, backendClient)
if err != nil {
return fmt.Errorf("could not get L2Gateway contract: %w", err)
contractAddr, err = contract.ERC20Gateway(&bind.CallOpts{Context: ctx}, tokenAddr)
if err != nil {
return fmt.Errorf("could not get L2ERC20Gateway address: %w", err)
span.AddEvent(fmt.Sprintf("got l2 erc20 gateway address: %s", contractAddr.Hex()))
err = i.approve(ctx, tokenAddr, contractAddr, backendClient)
if err != nil {
return fmt.Errorf("could not approve L2ERC20Gateway contract: %w", err)
return nil
// approve submits an ERC20 approval for a given token and contract address.
func (i *inventoryManagerImpl) approve(parentCtx context.Context, tokenAddr, contractAddr common.Address, backendClient client.EVM) (err error) {
ctx, span := i.handler.Tracer().Start(parentCtx, "approve", trace.WithAttributes(
attribute.String("token_address", tokenAddr.Hex()),
attribute.String("contract_address", contractAddr.Hex()),
defer func() {
metrics.EndSpanWithErr(span, err)
if tokenAddr == util.EthAddress {
span.AddEvent("not approving with eth address")
return nil
if contractAddr == (common.Address{}) {
span.AddEvent("not approving to zero address")
return nil
erc20, err := ierc20.NewIERC20(tokenAddr, backendClient)
if err != nil {
return fmt.Errorf("could not get erc20: %w", err)
allowance, err := erc20.Allowance(&bind.CallOpts{Context: ctx}, i.relayerAddress, contractAddr)
if err != nil {
return fmt.Errorf("could not get allowance: %w", err)
if allowance.Cmp(big.NewInt(0)) > 0 {
span.AddEvent("already has positive allowance")
return nil
chainID, err := backendClient.ChainID(ctx)
if err != nil {
return fmt.Errorf("could not get chain id: %w", err)
_, err = i.txSubmitter.SubmitTransaction(ctx, chainID, func(transactor *bind.TransactOpts) (tx *types.Transaction, err error) {
tx, err = erc20.Approve(transactor, contractAddr, abi.MaxUint256)
if err != nil {
return nil, fmt.Errorf("could not approve: %w", err)
return tx, nil
if err != nil {
return fmt.Errorf("could not submit approval: %w", err)
return nil
// HasSufficientGas checks if there is sufficient gas for a given route.
func (i *inventoryManagerImpl) HasSufficientGas(parentCtx context.Context, chainID int, gasValue *big.Int) (sufficient bool, err error) {
ctx, span := i.handler.Tracer().Start(parentCtx, "HasSufficientGas", trace.WithAttributes(
attribute.Int(metrics.ChainID, chainID),
defer func(err error) {
metrics.EndSpanWithErr(span, err)
gasThreshRaw, err := i.cfg.GetMinGasToken(chainID)
if err != nil {
return false, fmt.Errorf("error getting min gas token on origin: %w", err)
gasThresh := core.CopyBigInt(gasThreshRaw)
if gasValue != nil {
gasThresh = new(big.Int).Add(gasThresh, gasValue)
span.SetAttributes(attribute.String("gas_value", gasValue.String()))
gasBalance, err := i.GetCommittableBalance(ctx, chainID, util.EthAddress, SkipDBCache())
if err != nil {
return false, fmt.Errorf("error getting committable gas on origin: %w", err)
sufficient = gasBalance.Cmp(gasThresh) >= 0
attribute.String("gas_threshold_raw", gasThreshRaw.String()),
attribute.String("gas_threshold", gasThresh.String()),
attribute.String("gas_balance", gasBalance.String()),
attribute.Bool("sufficient", sufficient),
return sufficient, nil
func (i *inventoryManagerImpl) Rebalance(ctx context.Context) (err error) {
ctx, span := i.handler.Tracer().Start(ctx, "Rebalance")
defer func(err error) {
metrics.EndSpanWithErr(span, err)
rebalances, err := getRebalances(ctx, i.cfg, i.tokens)
if err != nil {
return fmt.Errorf("could not get rebalances: %w", err)
for tokenName, rebalance := range rebalances {
if rebalance == nil || rebalance.Amount == nil {
err = i.tryExecuteRebalance(ctx, rebalance)
if err != nil {
return fmt.Errorf("could not execute rebalance for token %s: %w", tokenName, err)
return nil
func (i *inventoryManagerImpl) tryExecuteRebalance(ctx context.Context, rebalance *RebalanceData) (err error) {
ctx, span := i.handler.Tracer().Start(ctx, "tryExecuteRebalance", trace.WithAttributes(
attribute.Int("origin", rebalance.OriginMetadata.ChainID),
attribute.Int("dest", rebalance.DestMetadata.ChainID),
attribute.String("origin_token", rebalance.OriginMetadata.Addr.Hex()),
attribute.String("dest_token", rebalance.DestMetadata.Addr.Hex()),
attribute.String("origin_balance", rebalance.OriginMetadata.Balance.String()),
attribute.String("dest_balance", rebalance.DestMetadata.Balance.String()),
attribute.String("rebalance_amount", rebalance.Amount.String()),
attribute.String("token_name", rebalance.OriginMetadata.Name),
defer func(err error) {
metrics.EndSpanWithErr(span, err)
// make sure there are no pending rebalances that touch the given path
pendingRebalances, err := i.db.GetPendingRebalances(ctx, uint64(rebalance.OriginMetadata.ChainID), uint64(rebalance.DestMetadata.ChainID))
if err != nil {
return fmt.Errorf("could not check pending rebalance: %w", err)
var pending bool
for _, pendingRebalance := range pendingRebalances {
if pendingRebalance.TokenName == rebalance.OriginMetadata.Name {
pending = true
span.SetAttributes(attribute.Bool("rebalance_pending", pending))
if pending {
return nil
// execute the rebalance
manager, ok := i.rebalanceManagers[rebalance.Method]
if !ok {
return fmt.Errorf("no rebalance manager for method: %s", rebalance.Method)
err = manager.Execute(ctx, rebalance)
if err != nil {
return fmt.Errorf("could not execute rebalance: %w", err)
return nil
func (i *inventoryManagerImpl) GetTokenMetadata(chainID int, token common.Address) (*TokenMetadata, error) {
defer i.mux.RUnlock()
tokenData, ok := i.tokens[chainID][token]
if !ok {
return nil, fmt.Errorf("token not found")
return tokenData, nil
// initializeTokens converts the configuration into a data structure we can use to determine inventory
// it gets metadata like name, decimals, etc once and exports these to prometheus for ease of debugging.
func (i *inventoryManagerImpl) initializeTokens(parentCtx context.Context, cfg relconfig.Config) (err error) {
defer i.mux.Unlock()
ctx, span := i.handler.Tracer().Start(parentCtx, "initializeTokens", trace.WithAttributes(
attribute.String("relayer_address", i.relayerAddress.String()),
defer func(err error) {
metrics.EndSpanWithErr(span, err)
// TODO: this needs to be a struct bound variable otherwise will be stuck.
i.tokens = make(map[int]map[common.Address]*TokenMetadata)
i.gasBalances = make(map[int]*big.Int)
// TODO: this can be pre-capped w/ len(cfg.Tokens) for each chain id.
// here we register metrics for exporting through otel. We wait to call these functions until are tokens have been initialized to avoid nil issues.
deferredCalls := make(map[int][]w3types.Caller)
// iterate through all tokens to get the metadata
for cid, chainCfg := range cfg.GetChains() {
chainID := cid // capture func literal
i.tokens[chainID] = map[common.Address]*TokenMetadata{}
// set up balance fetching for this chain's gas token
i.gasBalances[chainID] = new(big.Int)
deferredCalls[chainID] = append(deferredCalls[chainID],
eth.Balance(i.relayerAddress, nil).Returns(i.gasBalances[chainID]),
// assign metadata for each configured token
for tokenName, tokenCfg := range chainCfg.Tokens {
nativeToken, err := cfg.GetNativeToken(chainID)
if err != nil {
return fmt.Errorf("could not get native token: %w", err)
rtoken := &TokenMetadata{
Name: tokenName,
IsGasToken: tokenName == nativeToken,
ChainID: chainID,
var token common.Address
if rtoken.IsGasToken {
token = util.EthAddress
} else {
token = common.HexToAddress(tokenCfg.Address)
i.tokens[chainID][token] = rtoken
rtoken.Addr = token
// requires non-nil pointer
rtoken.Balance = new(big.Int)
if rtoken.IsGasToken {
rtoken.Decimals = 18
rtoken.Balance = i.gasBalances[chainID]
rtoken.ChainName = tokenName
} else {
deferredCalls[chainID] = append(deferredCalls[chainID],
eth.CallFunc(funcBalanceOf, token, i.relayerAddress).Returns(rtoken.Balance),
eth.CallFunc(funcDecimals, token).Returns(&rtoken.Decimals),
eth.CallFunc(funcName, token).Returns(&rtoken.ChainName),
// run through the deferred cals
g, gctx := errgroup.WithContext(ctx)
for chainID := range deferredCalls {
chainID := chainID // capture func literal
chainClient, err := i.chainClient.GetClient(ctx, big.NewInt(int64(chainID)))
if err != nil {
return fmt.Errorf("can't initialize tokens, no chain client available for chain %d: %w", chainID, err)
g.Go(func() error {
// TODO: add retries
// TODO: we should see if we can move this to ethergo and deduplicate. We do this a lot, especially in
// the prom exporter
batches := core.ChunkSlice(deferredCalls[chainID], maxBatchSize)
for _, batch := range batches {
err = chainClient.BatchWithContext(gctx, batch...)
if err != nil {
return fmt.Errorf("could not batch: %w", err)
return nil
err = g.Wait()
if err != nil {
return fmt.Errorf("could not get tx: %w", err)
return nil
var logger = log.Logger("inventory")
// refreshBalances refreshes all the token balances.
func (i *inventoryManagerImpl) refreshBalances(parentCtx context.Context) (err error) {
ctx, span := i.handler.Tracer().Start(parentCtx, "refreshBalances")
defer func() {
metrics.EndSpanWithErr(span, err)
var wg sync.WaitGroup
// TODO: this can be pre-capped w/ len(cfg.Tokens) for each chain id.
// here we register metrics for exporting through otel. We wait to call these functions until are tokens have been initialized to avoid nil issues.
for cid, tokenMap := range i.tokens {
chainID := cid // capture func literal
chainClient, clientErr := i.chainClient.GetClient(ctx, big.NewInt(int64(chainID)))
if clientErr != nil {
logger.Warnf("could not get chain client: %w", clientErr)
span.SetAttributes(attribute.String(fmt.Sprintf("client_error_%d", chainID), clientErr.Error()))
// queue gas token balance fetch
deferredCalls := []w3types.Caller{
eth.Balance(i.relayerAddress, nil).Returns(i.gasBalances[chainID]),
// queue token balance fetches
for ta, token := range tokenMap {
tokenAddress := ta // capture func literal
// TODO: make sure Returns does nothing on error
if !token.IsGasToken {
deferredCalls = append(deferredCalls, eth.CallFunc(funcBalanceOf, tokenAddress, i.relayerAddress).Returns(token.Balance))
go func() {
defer wg.Done()
err = chainClient.BatchWithContext(ctx, deferredCalls...)
if err != nil {
logger.Warnf("could not refresh balances on %d: %v", chainID, err)
return nil
func (i *inventoryManagerImpl) recordBalances(ctx context.Context, observer metric.Observer) (err error) {
if i.meter == nil || i.balanceGauge == nil {
return nil
defer i.mux.RUnlock()
for chainID, tokens := range i.tokens {
for token, tokenData := range tokens {
opts := metric.WithAttributes(
attribute.Int(metrics.ChainID, chainID),
attribute.String("relayer_address", i.relayerAddress.String()),
attribute.String("token_name", tokenData.Name),
attribute.Int("decimals", int(tokenData.Decimals)),
attribute.String("token_address", token.String()),
attribute.String("raw_balance", tokenData.Balance.String()),
attribute.String("relayer", i.relayerAddress.Hex()),
// Convert the balance and record it
decimalBalance := core.BigToDecimals(tokenData.Balance, tokenData.Decimals)
observer.ObserveFloat64(i.balanceGauge, decimalBalance, opts)
return nil