diff --git a/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol b/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol index 5146ca2c02..3e63a6386c 100644 --- a/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol +++ b/solidity/contracts/hooks/aggregation/StaticAggregationHook.sol @@ -13,11 +13,19 @@ pragma solidity >=0.8.0; @@@@@@@@@ @@@@@@@@@ @@@@@@@@@ @@@@@@@@*/ +// ============ Internal Imports ============ +import {StandardHookMetadata} from "../libs/StandardHookMetadata.sol"; import {AbstractPostDispatchHook} from "../libs/AbstractPostDispatchHook.sol"; import {IPostDispatchHook} from "../../interfaces/hooks/IPostDispatchHook.sol"; import {MetaProxy} from "../../libs/MetaProxy.sol"; +// ============ External Imports ============ +import {Address} from "@openzeppelin/contracts/utils/Address.sol"; + contract StaticAggregationHook is AbstractPostDispatchHook { + using StandardHookMetadata for bytes; + using Address for address payable; + // ============ External functions ============ /// @inheritdoc IPostDispatchHook @@ -32,17 +40,26 @@ contract StaticAggregationHook is AbstractPostDispatchHook { ) internal override { address[] memory _hooks = hooks(message); uint256 count = _hooks.length; + uint256 gasRemaining = msg.value; for (uint256 i = 0; i < count; i++) { uint256 quote = IPostDispatchHook(_hooks[i]).quoteDispatch( metadata, message ); + gasRemaining -= quote; IPostDispatchHook(_hooks[i]).postDispatch{value: quote}( metadata, message ); } + + if (gasRemaining > 0) { + address payable refundAddress = payable( + metadata.refundAddress(msg.sender) + ); + refundAddress.sendValue(gasRemaining); + } } /// @inheritdoc AbstractPostDispatchHook diff --git a/solidity/test/hooks/AggregationHook.t.sol b/solidity/test/hooks/AggregationHook.t.sol index a37b20f1a1..d344b3da04 100644 --- a/solidity/test/hooks/AggregationHook.t.sol +++ b/solidity/test/hooks/AggregationHook.t.sol @@ -72,6 +72,35 @@ contract AggregationHookTest is Test { hook.postDispatch{value: _msgValue}("", message); } + function test_postDispatch_refundsExcess(uint8 _hooks) public { + uint256 fee = PER_HOOK_GAS_AMOUNT; + address[] memory hooksDeployed = deployHooks(_hooks, fee); + uint256 requiredValue = hooksDeployed.length * fee; + uint256 overpaidValue = requiredValue + 1000; + + vm.prank(address(this)); + + uint256 initialBalance = address(this).balance; + + bytes memory message = abi.encodePacked("hello world"); + hook.postDispatch{value: overpaidValue}("", message); + + assertEq(address(hook).balance, 0); + assertEq(address(this).balance, initialBalance - requiredValue); + } + + function testPostDispatch_preventsUsingContractFunds(uint8 _hooks) public { + vm.assume(_hooks > 0); + + // aggregation hook has left over funds + uint256 additionalFunds = 1 ether; + vm.deal(address(hook), additionalFunds); + + bytes memory message = abi.encodePacked("hello world"); + vm.expectRevert(); // underflow + hook.postDispatch{value: 0}("", message); + } + function testQuoteDispatch(uint8 _hooks) public { uint256 fee = PER_HOOK_GAS_AMOUNT; address[] memory hooksDeployed = deployHooks(_hooks, fee); @@ -94,4 +123,6 @@ contract AggregationHookTest is Test { deployHooks(1, 0); assertEq(hook.hookType(), uint8(IPostDispatchHook.Types.AGGREGATION)); } + + receive() external payable {} }