Skip to content

Commit

Permalink
fix(StakeVault): make unstaking actually work
Browse files Browse the repository at this point in the history
Unstaking didn't actually work because it was using `transferFrom()` on the
`StakeVault` with the `from` address being the vault itself.
This would result in an approval error because the vault isn't creating
any approvals to spend its own funds.

The solution is to use `transfer` instead and ensuring the return value
is checked.
  • Loading branch information
0x-r4bbit committed Jan 19, 2024
1 parent edc44e0 commit 74ff357
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 2 deletions.
14 changes: 12 additions & 2 deletions contracts/StakeVault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ import { StakeManager } from "./StakeManager.sol";
contract StakeVault is Ownable {
error StakeVault__MigrationNotAvailable();

error StakeVault__StakingFailed();

error StakeVault__UnstakingFailed();

StakeManager private stakeManager;

ERC20 private immutable STAKED_TOKEN;
Expand All @@ -27,7 +31,10 @@ contract StakeVault is Ownable {
}

function stake(uint256 _amount, uint256 _time) external onlyOwner {
STAKED_TOKEN.transferFrom(msg.sender, address(this), _amount);
bool success = STAKED_TOKEN.transferFrom(msg.sender, address(this), _amount);
if (!success) {
revert StakeVault__StakingFailed();
}
stakeManager.stake(_amount, _time);

emit Staked(msg.sender, address(this), _amount, _time);
Expand All @@ -39,7 +46,10 @@ contract StakeVault is Ownable {

function unstake(uint256 _amount) external onlyOwner {
stakeManager.unstake(_amount);
STAKED_TOKEN.transferFrom(address(this), msg.sender, _amount);
bool success = STAKED_TOKEN.transfer(msg.sender, _amount);
if (!success) {
revert StakeVault__UnstakingFailed();
}
}

function leave() external onlyOwner {
Expand Down
18 changes: 18 additions & 0 deletions test/StakeManager.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@ contract UnstakeTest is StakeManagerTest {
vm.expectRevert(StakeManager.StakeManager__FundsLocked.selector);
userVault.unstake(100);
}

function test_UnstakeShouldReturnFunds() public {
// ensure user has funds
deal(stakeToken, testUser, 1000);
StakeVault userVault = _createTestVault(testUser);

vm.startPrank(testUser);
ERC20(stakeToken).approve(address(userVault), 100);

userVault.stake(100, 0);
assertEq(ERC20(stakeToken).balanceOf(testUser), 900);

userVault.unstake(100);

assertEq(stakeManager.stakeSupply(), 0);
assertEq(ERC20(stakeToken).balanceOf(address(userVault)), 0);
assertEq(ERC20(stakeToken).balanceOf(testUser), 1000);
}
}

contract LockTest is StakeManagerTest {
Expand Down
23 changes: 23 additions & 0 deletions test/StakeVault.t.sol
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol";

import { Test } from "forge-std/Test.sol";
import { Deploy } from "../script/Deploy.s.sol";
import { DeployBroken } from "./script/DeployBroken.s.sol";
import { DeploymentConfig } from "../script/DeploymentConfig.s.sol";
import { StakeManager } from "../contracts/StakeManager.sol";
import { StakeVault } from "../contracts/StakeVault.sol";
Expand Down Expand Up @@ -42,3 +45,23 @@ contract StakedTokenTest is StakeVaultTest {
assertEq(address(stakeVault.stakedToken()), stakeToken);
}
}

contract StakeTest is StakeVaultTest {
function setUp() public override {
DeployBroken deployment = new DeployBroken();
(vaultFactory, stakeManager, stakeToken) = deployment.run();

vm.prank(testUser);
stakeVault = vaultFactory.createVault();
}

function test_RevertWhen_StakeTokenTransferFails() public {
// ensure user has funds
deal(stakeToken, testUser, 1000);

vm.startPrank(address(testUser));
ERC20(stakeToken).approve(address(stakeVault), 100);
vm.expectRevert(StakeVault.StakeVault__StakingFailed.selector);
stakeVault.stake(100, 0);
}
}
20 changes: 20 additions & 0 deletions test/mocks/BrokenERC20.s.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol";

contract BrokenERC20 is ERC20 {
constructor() ERC20("Mock SNT", "SNT") {
_mint(msg.sender, 1_000_000_000_000_000_000);
}

// solhint-disable-next-line no-unused-vars
function transferFrom(address sender, address recipient, uint256 amount) public override returns (bool) {
return false;
}

// solhint-disable-next-line no-unused-vars
function transfer(address recipient, uint256 amount) public override returns (bool) {
return false;
}
}
20 changes: 20 additions & 0 deletions test/script/DeployBroken.s.sol
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// SPDX-License-Identifier: UNLICENSED
pragma solidity ^0.8.19;

import { BaseScript } from "../../script/Base.s.sol";
import { StakeManager } from "../../contracts/StakeManager.sol";
import { VaultFactory } from "../../contracts/VaultFactory.sol";
import { BrokenERC20 } from "../mocks/BrokenERC20.s.sol";

contract DeployBroken is BaseScript {
function run() public returns (VaultFactory, StakeManager, address) {
BrokenERC20 token = new BrokenERC20();

vm.startBroadcast(broadcaster);
StakeManager stakeManager = new StakeManager(address(token), address(0));
VaultFactory vaultFactory = new VaultFactory(address(stakeManager));
vm.stopBroadcast();

return (vaultFactory, stakeManager, address(token));
}
}

0 comments on commit 74ff357

Please sign in to comment.