ocean_provider/utils/datatoken.py
import json
import logging
from datetime import datetime, timezone
from typing import Optional
from eth_keys import KeyAPI
from eth_keys.backends import NativeECCBackend
from eth_typing.encoding import HexStr
from eth_typing.evm import HexAddress
from hexbytes import HexBytes
from ocean_provider.utils.address import get_contract_definition
from ocean_provider.utils.basics import get_provider_wallet, get_network_name
from ocean_provider.utils.currency import to_wei
from ocean_provider.utils.data_nft import get_data_nft_contract
from ocean_provider.utils.services import Service
from web3.contract import Contract
from web3.logs import DISCARD
from web3.main import Web3
from websockets import ConnectionClosed
logger = logging.getLogger(__name__)
keys = KeyAPI(NativeECCBackend)
def get_datatoken_contract(web3: Web3, address: Optional[str] = None) -> Contract:
"""
Build a web3 Contract instance using the Ocean Protocol ERC20Template ABI.
This function assumes that the `ERC20Template` stored at index 1 of the
`ERC721Factory` provides all the functionality needed by Provider,
especially the `getMetaData` contract method.
"""
abi = get_contract_definition("ERC20Template")["abi"]
return web3.eth.contract(address=web3.toChecksumAddress(address), abi=abi)
def _get_tx_receipt(web3, tx_hash):
return web3.eth.wait_for_transaction_receipt(HexBytes(tx_hash), timeout=120)
def verify_order_tx(
web3: Web3,
datatoken_address: HexAddress,
tx_id: HexStr,
service: Service,
amount: int,
sender: HexAddress,
extra_data: None,
allow_expired_provider_fees=False,
):
"""Check order tx and provider fees validity on-chain for the given parameters."""
provider_wallet = get_provider_wallet(web3.chain_id)
try:
tx_receipt = _get_tx_receipt(web3, tx_id)
except ConnectionClosed:
# try again in this case
tx_receipt = _get_tx_receipt(web3, tx_id)
network_name = get_network_name(web3.eth.chain_id)
if tx_receipt is None:
raise AssertionError(
f"Provider {network_name}: Failed to get tx receipt for the `startOrder` transaction.."
)
if tx_receipt.status == 0:
raise AssertionError(f"Provider {network_name}: order transaction failed.")
# check provider fees
datatoken_contract = get_datatoken_contract(web3, datatoken_address)
provider_fee_order_log = None
provider_fee_event_logs = datatoken_contract.events.ProviderFee().processReceipt(
tx_receipt, errors=DISCARD
)
# search in all provider_fee events until we have a match. if not, we don't have a valid event
# also, make sure that somebody is not spoofing provider fee event from another datatoken
for provider_fees_logs in provider_fee_event_logs:
try:
provider_data = json.loads(provider_fees_logs.args.providerData)
if (
provider_data["dt"].lower() == datatoken_address.lower()
and provider_data["id"].lower() == service.id.lower()
):
provider_fee_order_log = provider_fees_logs
except:
# silent pass, means json formatting errors
pass
if not provider_fee_order_log:
raise AssertionError(
f"Provider {network_name}: Cannot find the event for the provider fee in tx id {tx_id}."
)
provider_initialize_timestamp = 0
if extra_data:
provider_data = json.loads(provider_fee_order_log.args.providerData)
if extra_data["environment"] != provider_data["environment"]:
raise AssertionError(
f"Provider {network_name}: Mismatch between ordered c2d environment and selected one."
)
provider_initialize_timestamp = provider_data["timestamp"]
if Web3.toChecksumAddress(
provider_fee_order_log.args.providerFeeAddress
) != Web3.toChecksumAddress(provider_wallet.address):
raise AssertionError(
f"Provider {network_name}: The providerFeeAddress {provider_fee_order_log.args.providerFeeAddress} in the event does "
f"not match the provider address {provider_wallet.address}\n"
)
bts = b"".join(
[
provider_fee_order_log.args.r,
provider_fee_order_log.args.s,
Web3.toBytes(provider_fee_order_log.args.v - 27),
]
)
signature = keys.Signature(signature_bytes=bts)
message_hash = Web3.solidityKeccak(
["bytes", "address", "address", "uint256", "uint256"],
[
provider_fee_order_log.args.providerData,
provider_fee_order_log.args.providerFeeAddress,
provider_fee_order_log.args.providerFeeToken,
provider_fee_order_log.args.providerFeeAmount,
provider_fee_order_log.args.validUntil,
],
)
prefix = "\x19Ethereum Signed Message:\n32"
signable_hash = Web3.solidityKeccak(
["bytes", "bytes"], [Web3.toBytes(text=prefix), Web3.toBytes(message_hash)]
)
pk = keys.PrivateKey(provider_wallet.key)
if not keys.ecdsa_verify(signable_hash, signature, pk.public_key):
raise AssertionError(
f"Provider {network_name}: Provider was not able to check the signed message in ProviderFees event\n"
)
timestamp_now = datetime.now(timezone.utc).timestamp()
# check validUntil
if provider_fee_order_log.args.validUntil > 0 and not allow_expired_provider_fees:
if timestamp_now >= provider_fee_order_log.args.validUntil:
# expired validUntil. let's add the difference between provider_data["timestamp"](time of initializeCompute) and transaction timestamp.
# Also add 90 secs as buffer, since block.timestamp can be manipulated
block = web3.eth.get_block(tx_receipt.blockHash, False)
if provider_initialize_timestamp > 0 and (
timestamp_now
>= provider_fee_order_log.args.validUntil
+ (block.timestamp - provider_initialize_timestamp + 90)
):
raise AssertionError(
f"Provider {network_name}: Ordered c2d time was exceeded, check validUntil."
)
# end check provider fees
# check if we have an OrderReused event. If so, get orderTxId and switch next checks to use that
start_order_tx_id = tx_receipt.transactionHash
try:
event_logs = datatoken_contract.events.OrderReused().processReceipt(
tx_receipt, errors=DISCARD
)
except Exception as e:
logger.error(f"Provider {network_name}: {e}")
logger.debug(
f"Provider {network_name}: Got events log when searching for ReuseOrder : {event_logs}"
)
log_timestamp = None
order_log = event_logs[0] if event_logs else None
if order_log and order_log.args.orderTxId:
log_timestamp = order_log.args.timestamp
try:
tx_receipt = _get_tx_receipt(web3, order_log.args.orderTxId)
except ConnectionClosed:
# try again in this case
tx_receipt = _get_tx_receipt(web3, order_log.args.orderTxId)
if tx_receipt is None:
raise AssertionError(
f"Provider {network_name}: Failed to get tx receipt referenced in OrderReused.."
)
if tx_receipt.status == 0:
raise AssertionError(
f"Provider {network_name}: order referenced in OrderReused failed."
)
logger.debug(
f"Provider {network_name}: Search for orderStarted in tx_receipt : {tx_receipt}"
)
# this has changed now if the original original_tx was a reuseOrder
start_order_tx_id = tx_receipt.transactionHash
try:
event_logs = datatoken_contract.events.OrderStarted().processReceipt(
tx_receipt, errors=DISCARD
)
except Exception as e:
logger.error(f"Provider {network_name}: {e}")
logger.debug(
f"Provider {network_name}: Got events log when searching for OrderStarted : {event_logs}"
)
order_log = None
# search in all startOrder events until we have a match. if not, we don't have a valid event
for log in event_logs:
if log.args.serviceIndex == service.index:
order_log = log
if not order_log:
raise AssertionError(
f"Provider {network_name}: Cannot find the event for the order transaction with tx id {tx_id}."
)
if order_log.args.serviceIndex != service.index:
raise AssertionError(
f"Provider {network_name}: The service id in the event does "
f"not match the requested asset. \n"
f"requested: serviceIndex={service.index}\n"
f"event: serviceIndex={order_log.args.serviceIndex}"
)
if order_log.args.amount < amount:
raise ValueError(
f"Provider {network_name}: The amount in the event is less than the amount requested. \n"
f"requested: amount={amount}\n"
f"event: amount={order_log.args.amount}"
)
# Check if order expired. timeout == 0 means order is valid forever
# use orderReused timestamp if it exists
log_timestamp = (
log_timestamp if log_timestamp is not None else order_log.args.timestamp
)
timestamp_delta = timestamp_now - log_timestamp
logger.debug(
f"Provider {network_name}: verify_order_tx: service timeout = {service.timeout}, timestamp delta = {timestamp_delta}"
)
if service.timeout != 0:
if timestamp_delta > service.timeout:
raise ValueError(
f"Provider {network_name}: The order has expired. \n"
f"current timestamp={timestamp_now}\n"
f"order timestamp={log_timestamp}\n"
f"timestamp delta={timestamp_delta}\n"
f"service timeout={service.timeout}"
)
if web3.toChecksumAddress(sender) not in [
web3.toChecksumAddress(order_log.args.consumer),
web3.toChecksumAddress(order_log.args.payer),
]:
raise ValueError(
f"Provider {network_name}: sender of order transaction is not the consumer/payer."
)
tx = web3.eth.get_transaction(HexBytes(tx_id))
return tx, order_log, provider_fee_order_log, start_order_tx_id
def validate_order(
web3,
sender,
tx_id,
asset,
service,
extra_data=None,
allow_expired_provider_fees=False,
):
did = asset.did
token_address = web3.toChecksumAddress(service.datatoken_address)
num_tokens = 1
network_name = get_network_name(asset.chain_id)
logger.debug(
f"Provider {network_name}: validate_order: did={did}, service_id={service.id}, tx_id={tx_id}, "
f"sender={sender}, num_tokens={num_tokens}, token_address={token_address}"
)
nft_contract = get_data_nft_contract(web3, asset.nft["address"])
assert nft_contract.caller.isDeployed(token_address)
amount = to_wei(num_tokens)
num_tries = 3
i = 0
while i < num_tries:
logger.debug(
f"Provider {network_name}: validate_order is on trial {i + 1} in {num_tries}."
)
i += 1
try:
tx, order_event, provider_fees_event, start_order_tx_id = verify_order_tx(
web3,
token_address,
tx_id,
service,
amount,
sender,
extra_data,
allow_expired_provider_fees,
)
logger.debug(
f"Provider {network_name}: validate_order succeeded for: did={did}, service_id={service.id}, tx_id={tx_id}, "
f"sender={sender}, num_tokens={num_tokens}, token_address={token_address}. "
f"result is: tx={tx}, order_event={order_event}."
)
return tx, order_event, provider_fees_event, start_order_tx_id
except ConnectionClosed:
logger.debug(
f"Provider {network_name}: got ConnectionClosed error on validate_order."
)
if i == num_tries:
logger.debug(
f"Provider {network_name}: reached max no. of tries, raise ConnectionClosed in validate_order."
)
raise
except Exception:
raise
def validate_transfer_not_used_for_other_service(
did, service_id, transfer_tx_id, consumer_address, token_address
):
logger.debug(
f"validate_transfer_not_used_for_other_service: "
f"did={did}, service_id={service_id}, transfer_tx_id={transfer_tx_id},"
f" consumer_address={consumer_address}, token_address={token_address}"
)
return
def record_consume_request(
did, service_id, order_tx_id, consumer_address, token_address, amount
):
logger.debug(
f"record_consume_request: "
f"did={did}, service_id={service_id}, transfer_tx_id={order_tx_id}, "
f"consumer_address={consumer_address}, token_address={token_address}, "
f"amount={amount}"
)
return