diff --git a/contracts/StakeVault.sol b/contracts/StakeVault.sol index dea4af3..fde47b7 100644 --- a/contracts/StakeVault.sol +++ b/contracts/StakeVault.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.18; import { Ownable } from "@openzeppelin/contracts/access/Ownable.sol"; import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol"; +import { SafeERC20 } from "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol"; import { StakeManager } from "./StakeManager.sol"; /** @@ -11,8 +12,9 @@ import { StakeManager } from "./StakeManager.sol"; * @author Ricardo Guilherme Schmidt * @notice Secures user stake */ - contract StakeVault is Ownable { + using SafeERC20 for ERC20; + error StakeVault__MigrationNotAvailable(); StakeManager private stakeManager; @@ -28,7 +30,7 @@ contract StakeVault is Ownable { } function stake(uint256 _amount, uint256 _time) external onlyOwner { - STAKED_TOKEN.transferFrom(msg.sender, address(this), _amount); + STAKED_TOKEN.safeTransferFrom(msg.sender, address(this), _amount); stakeManager.stake(_amount, _time); emit Staked(msg.sender, address(this), _amount, _time); @@ -40,7 +42,7 @@ contract StakeVault is Ownable { function unstake(uint256 _amount) external onlyOwner { stakeManager.unstake(_amount); - STAKED_TOKEN.transferFrom(address(this), msg.sender, _amount); + STAKED_TOKEN.safeTransfer(msg.sender, _amount); } function leave() external onlyOwner { diff --git a/test/StakeManager.t.sol b/test/StakeManager.t.sol index d5745b9..97322b4 100644 --- a/test/StakeManager.t.sol +++ b/test/StakeManager.t.sol @@ -97,6 +97,24 @@ contract UnstakeTest is StakeManagerTest { vm.expectRevert(StakeManager.StakeManager__FundsLocked.selector); userVault.unstake(100); } + + function test_UnstakeShouldReturnFunds() public { + // ensure user has funds + deal(stakeToken, testUser, 1000); + StakeVault userVault = _createTestVault(testUser); + + vm.startPrank(testUser); + ERC20(stakeToken).approve(address(userVault), 100); + + userVault.stake(100, 0); + assertEq(ERC20(stakeToken).balanceOf(testUser), 900); + + userVault.unstake(100); + + assertEq(stakeManager.stakeSupply(), 0); + assertEq(ERC20(stakeToken).balanceOf(address(userVault)), 0); + assertEq(ERC20(stakeToken).balanceOf(testUser), 1000); + } } contract LockTest is StakeManagerTest {