diff --git a/bittensor/extrinsics/unstaking.py b/bittensor/extrinsics/unstaking.py index 6046124f40..f88f6a93eb 100644 --- a/bittensor/extrinsics/unstaking.py +++ b/bittensor/extrinsics/unstaking.py @@ -72,13 +72,13 @@ def __do_remove_stake_single( def check_threshold_amount( - subtensor: "bittensor.subtensor", unstaking_balance: Balance + subtensor: "bittensor.subtensor", stake_balance: Balance ) -> bool: """ - Checks if the unstaking amount is above the threshold or 0 + Checks if the remaining stake balance is above the minimum required stake threshold. Args: - unstaking_balance (Balance): + stake_balance (Balance): the balance to check for threshold limits. Returns: @@ -88,9 +88,9 @@ def check_threshold_amount( """ min_req_stake: Balance = subtensor.get_minimum_required_stake() - if min_req_stake > unstaking_balance > 0: + if min_req_stake > stake_balance > 0: bittensor.__console__.print( - f":cross_mark: [red]Unstaking balance of {unstaking_balance} less than minimum of {min_req_stake} TAO[/red]" + f":cross_mark: [yellow]Remaining stake balance of {stake_balance} less than minimum of {min_req_stake} TAO[/yellow]" ) return False else: @@ -161,9 +161,12 @@ def unstake_extrinsic( return False if not check_threshold_amount( - subtensor=subtensor, unstaking_balance=unstaking_balance + subtensor=subtensor, stake_balance=(stake_on_uid - unstaking_balance) ): - return False + bittensor.__console__.print( + f":warning: [yellow]This action will unstake the entire staked balance![/yellow]" + ) + unstaking_balance = stake_on_uid # Ask before moving on. if prompt: @@ -337,9 +340,12 @@ def unstake_multiple_extrinsic( continue if not check_threshold_amount( - subtensor=subtensor, unstaking_balance=unstaking_balance + subtensor=subtensor, stake_balance=(stake_on_uid - unstaking_balance) ): - return False + bittensor.__console__.print( + f":warning: [yellow]This action will unstake the entire staked balance![/yellow]" + ) + unstaking_balance = stake_on_uid # Ask before moving on. if prompt: diff --git a/tests/integration_tests/test_cli.py b/tests/integration_tests/test_cli.py index c20c905549..675d621879 100644 --- a/tests/integration_tests/test_cli.py +++ b/tests/integration_tests/test_cli.py @@ -782,16 +782,18 @@ def test_unstake_with_thresholds(self, _): config.no_prompt = True # as the minimum required stake may change, this method allows us to dynamically # update the amount in the mock without updating the tests - config.amount = Balance.from_rao(_subtensor_mock.min_required_stake() - 1) + min_stake = Balance.from_rao(_subtensor_mock.min_required_stake()) + # Must be a float + config.amount = min_stake.tao # Unstake below the minimum required stake config.wallet.name = "fake_wallet" config.hotkeys = ["hk0", "hk1", "hk2"] config.all_hotkeys = False # Notice no max_stake specified mock_stakes: Dict[str, Balance] = { - "hk0": Balance.from_float(10.0), - "hk1": Balance.from_float(11.1), - "hk2": Balance.from_float(12.2), + "hk0": 2 * min_stake - 1, # remaining stake will be below the threshold + "hk1": 2 * min_stake - 2, + "hk2": 2 * min_stake - 5, } mock_coldkey_kp = _get_mock_keypair(0, self.id()) @@ -827,27 +829,47 @@ def mock_get_wallet(*args, **kwargs): else: return mock_wallets[0] - with patch("bittensor.wallet") as mock_create_wallet: - mock_create_wallet.side_effect = mock_get_wallet + with patch("bittensor.__console__.print") as mock_print: # Catch console print + with patch("bittensor.wallet") as mock_create_wallet: + mock_create_wallet.side_effect = mock_get_wallet - # Check stakes before unstaking - for wallet in mock_wallets: - stake = _subtensor_mock.get_stake_for_coldkey_and_hotkey( - hotkey_ss58=wallet.hotkey.ss58_address, - coldkey_ss58=wallet.coldkey.ss58_address, - ) - self.assertEqual(stake.rao, mock_stakes[wallet.hotkey_str].rao) + # Check stakes before unstaking + for wallet in mock_wallets: + stake = _subtensor_mock.get_stake_for_coldkey_and_hotkey( + hotkey_ss58=wallet.hotkey.ss58_address, + coldkey_ss58=wallet.coldkey.ss58_address, + ) + self.assertEqual(stake.rao, mock_stakes[wallet.hotkey_str].rao) - cli.run() + with patch.object(_subtensor_mock, "_do_unstake") as mock_unstake: + cli.run() - # Check stakes after unstaking - for wallet in mock_wallets: - stake = _subtensor_mock.get_stake_for_coldkey_and_hotkey( - hotkey_ss58=wallet.hotkey.ss58_address, - coldkey_ss58=wallet.coldkey.ss58_address, - ) - # because the amount is less than the threshold, none of these should unstake - self.assertEqual(stake.tao, mock_stakes[wallet.hotkey_str].tao) + # Filter for console print calls + console_prints = [call[0][0] for call in mock_print.call_args_list] + minimum_print = filter( + lambda x: "less than minimum of" + in x[1].lower(), # Check for warning + enumerate(console_prints), + ) + # Check for each hotkey + unstake_calls = mock_unstake.call_args_list + for wallet, unstake_call in zip(mock_wallets, unstake_calls): + _, kwargs = unstake_call + # Verify hotkey was unstaked + self.assertEqual( + kwargs["hotkey_ss58"], wallet.hotkey.ss58_address + ) + # Should unstake *all* the stake + staked = mock_stakes[wallet.hotkey_str] + self.assertEqual(kwargs["amount"], staked) + + # Check warning was printed + console_print = next( + minimum_print + ) # advance so there is one per hotkey + self.assertIsNotNone( + console_print + ) # Check that the warning was printed def test_unstake_all(self, _): config = self.config