packages/contracts-core/contracts/libs/merkle/MerkleMath.sol
// SPDX-License-Identifier: MIT
pragma solidity 0.8.17;
import {TreeHeightTooLow} from "../Errors.sol";
library MerkleMath {
// ═════════════════════════════════════════ BASIC MERKLE CALCULATIONS ═════════════════════════════════════════════
/**
* @notice Calculates the merkle root for the given leaf and merkle proof.
* @dev Will revert if proof length exceeds the tree height.
* @param index Index of `leaf` in tree
* @param leaf Leaf of the merkle tree
* @param proof Proof of inclusion of `leaf` in the tree
* @param height Height of the merkle tree
* @return root_ Calculated Merkle Root
*/
function proofRoot(uint256 index, bytes32 leaf, bytes32[] memory proof, uint256 height)
internal
pure
returns (bytes32 root_)
{
// Proof length could not exceed the tree height
uint256 proofLen = proof.length;
if (proofLen > height) revert TreeHeightTooLow();
root_ = leaf;
/// @dev Apply unchecked to all ++h operations
unchecked {
// Go up the tree levels from the leaf following the proof
for (uint256 h = 0; h < proofLen; ++h) {
// Get a sibling node on current level: this is proof[h]
root_ = getParent(root_, proof[h], index, h);
}
// Go up to the root: the remaining siblings are EMPTY
for (uint256 h = proofLen; h < height; ++h) {
root_ = getParent(root_, bytes32(0), index, h);
}
}
}
/**
* @notice Calculates the parent of a node on the path from one of the leafs to root.
* @param node Node on a path from tree leaf to root
* @param sibling Sibling for a given node
* @param leafIndex Index of the tree leaf
* @param nodeHeight "Level height" for `node` (ZERO for leafs, ORIGIN_TREE_HEIGHT for root)
*/
function getParent(bytes32 node, bytes32 sibling, uint256 leafIndex, uint256 nodeHeight)
internal
pure
returns (bytes32 parent)
{
// Index for `node` on its "tree level" is (leafIndex / 2**height)
// "Left child" has even index, "right child" has odd index
if ((leafIndex >> nodeHeight) & 1 == 0) {
// Left child
return getParent(node, sibling);
} else {
// Right child
return getParent(sibling, node);
}
}
/// @notice Calculates the parent of tow nodes in the merkle tree.
/// @dev We use implementation with H(0,0) = 0
/// This makes EVERY empty node in the tree equal to ZERO,
/// saving us from storing H(0,0), H(H(0,0), H(0, 0)), and so on
/// @param leftChild Left child of the calculated node
/// @param rightChild Right child of the calculated node
/// @return parent Value for the node having above mentioned children
function getParent(bytes32 leftChild, bytes32 rightChild) internal pure returns (bytes32 parent) {
if (leftChild == bytes32(0) && rightChild == bytes32(0)) {
return 0;
} else {
return keccak256(bytes.concat(leftChild, rightChild));
}
}
// ════════════════════════════════ ROOT/PROOF CALCULATION FOR A LIST OF LEAFS ═════════════════════════════════════
/**
* @notice Calculates merkle root for a list of given leafs.
* Merkle Tree is constructed by padding the list with ZERO values for leafs until list length is `2**height`.
* Merkle Root is calculated for the constructed tree, and then saved in `leafs[0]`.
* > Note:
* > - `leafs` values are overwritten in the process to avoid excessive memory allocations.
* > - Caller is expected not to reuse `hashes` list after the call, and only use `leafs[0]` value,
* which is guaranteed to contain the calculated merkle root.
* > - root is calculated using the `H(0,0) = 0` Merkle Tree implementation. See MerkleTree.sol for details.
* @dev Amount of leaves should be at most `2**height`
* @param hashes List of leafs for the merkle tree (to be overwritten)
* @param height Height of the Merkle Tree to construct
*/
function calculateRoot(bytes32[] memory hashes, uint256 height) internal pure {
uint256 levelLength = hashes.length;
// Amount of hashes could not exceed amount of leafs in tree with the given height
if (levelLength > (1 << height)) revert TreeHeightTooLow();
/// @dev h, leftIndex, rightIndex and levelLength never overflow
unchecked {
// Iterate `height` levels up from the leaf level
// For every level we will only record "significant values", i.e. not equal to ZERO
for (uint256 h = 0; h < height; ++h) {
// Let H be the height of the "current level". H = 0 for the "leafs level".
// Invariant: a total of 2**(HEIGHT-H) nodes are on the current level
// Invariant: hashes[0 .. length) are "significant values" for the "current level" nodes
// Invariant: bytes32(0) is the value for nodes with indexes [length .. 2**(HEIGHT-H))
// Iterate over every pair of (leftChild, rightChild) on the current level
for (uint256 leftIndex = 0; leftIndex < levelLength; leftIndex += 2) {
uint256 rightIndex = leftIndex + 1;
bytes32 leftChild = hashes[leftIndex];
// Note: rightChild might be ZERO
bytes32 rightChild = rightIndex < levelLength ? hashes[rightIndex] : bytes32(0);
// Record the parent hash in the same array. This will not affect
// further calculations for the same level: (leftIndex >> 1) <= leftIndex.
hashes[leftIndex >> 1] = getParent(leftChild, rightChild);
}
// Set length for the "parent level": the amount of iterations for the for loop above.
levelLength = (levelLength + 1) >> 1;
}
}
}
/**
* @notice Generates a proof of inclusion of a leaf in the list. If the requested index is outside
* of the list range, generates a proof of inclusion for an empty leaf (proof of non-inclusion).
* The Merkle Tree is constructed by padding the list with ZERO values until list length is a power of two
* __AND__ index is in the extended list range. For example:
* - `hashes.length == 6` and `0 <= index <= 7` will "extend" the list to 8 entries.
* - `hashes.length == 6` and `7 < index <= 15` will "extend" the list to 16 entries.
* > Note: `leafs` values are overwritten in the process to avoid excessive memory allocations.
* Caller is expected not to reuse `hashes` list after the call.
* @param hashes List of leafs for the merkle tree (to be overwritten)
* @param index Leaf index to generate the proof for
* @return proof Generated merkle proof
*/
function calculateProof(bytes32[] memory hashes, uint256 index) internal pure returns (bytes32[] memory proof) {
// Use only meaningful values for the shortened proof
// Check if index is within the list range (we want to generates proofs for outside leafs as well)
uint256 height = getHeight(index < hashes.length ? hashes.length : (index + 1));
proof = new bytes32[](height);
uint256 levelLength = hashes.length;
/// @dev h, leftIndex, rightIndex and levelLength never overflow
unchecked {
// Iterate `height` levels up from the leaf level
// For every level we will only record "significant values", i.e. not equal to ZERO
for (uint256 h = 0; h < height; ++h) {
// Use sibling for the merkle proof; `index^1` is index of our sibling
proof[h] = (index ^ 1 < levelLength) ? hashes[index ^ 1] : bytes32(0);
// Let H be the height of the "current level". H = 0 for the "leafs level".
// Invariant: a total of 2**(HEIGHT-H) nodes are on the current level
// Invariant: hashes[0 .. length) are "significant values" for the "current level" nodes
// Invariant: bytes32(0) is the value for nodes with indexes [length .. 2**(HEIGHT-H))
// Iterate over every pair of (leftChild, rightChild) on the current level
for (uint256 leftIndex = 0; leftIndex < levelLength; leftIndex += 2) {
uint256 rightIndex = leftIndex + 1;
bytes32 leftChild = hashes[leftIndex];
// Note: rightChild might be ZERO
bytes32 rightChild = rightIndex < levelLength ? hashes[rightIndex] : bytes32(0);
// Record the parent hash in the same array. This will not affect
// further calculations for the same level: (leftIndex >> 1) <= leftIndex.
hashes[leftIndex >> 1] = getParent(leftChild, rightChild);
}
// Set length for the "parent level"
levelLength = (levelLength + 1) >> 1;
// Traverse to parent node
index >>= 1;
}
}
}
/// @notice Returns the height of the tree having a given amount of leafs.
function getHeight(uint256 leafs) internal pure returns (uint256 height) {
uint256 amount = 1;
while (amount < leafs) {
unchecked {
++height;
}
amount <<= 1;
}
}
}