diff --git a/src/lib.rs b/src/lib.rs index d0099fe..b88bba6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,9 @@ mod utils; use events::Event; use near_sdk::json_types::{U128, U64}; -use near_sdk::{env, near, require, AccountId, EpochHeight, PanicOnDefault}; +use near_sdk::{ + env, ext_contract, near, require, AccountId, EpochHeight, Gas, PanicOnDefault, PromiseError, +}; use std::collections::HashMap; use utils::{validator_stake, validator_total_stake}; @@ -12,6 +14,8 @@ type Balance = u128; /// Timestamp in milliseconds type Timestamp = u64; +const GET_OWNER_ID_GAS: Gas = Gas::from_tgas(5); + /// Voting contract for any specific proposal. Once the majority of the stake holders agree to /// the proposal, the time will be recorded and the voting ends. #[near(contract_state)] @@ -25,6 +29,11 @@ pub struct Contract { last_epoch_height: EpochHeight, } +#[ext_contract(ext_staking_pool)] +pub trait StakingPoolContract { + fn get_owner_id(&self) -> AccountId; +} + // Implement the contract structure #[near] impl Contract { @@ -46,11 +55,23 @@ impl Contract { } } - /// Method for validators to vote or withdraw the vote. - /// Votes for if `is_vote` is true, or withdraws the vote if `is_vote` is false. - pub fn vote(&mut self, is_vote: bool) { + pub fn check_owner_id_and_vote( + &mut self, + owner_account_id: AccountId, + staking_pool_id: AccountId, + is_vote: bool, + #[callback_result] owner_account_id_result: Result, + ) { + require!( + owner_account_id == owner_account_id_result.unwrap(), + "Voting is only allowed for the staking pool owner" + ); + self.vote_internal(is_vote, staking_pool_id); + } + + fn vote_internal(&mut self, is_vote: bool, account_id: AccountId) { self.ping(); - let account_id = env::predecessor_account_id(); + let account_stake = if is_vote { let stake = validator_stake(&account_id); require!(stake > 0, format!("{} is not a validator", account_id)); @@ -85,6 +106,31 @@ impl Contract { } } + /// Method for validators to vote or withdraw the vote. + /// Votes for if `is_vote` is true, or withdraws the vote if `is_vote` is false. + pub fn vote(&mut self, is_vote: bool, staking_pool_id: Option) { + if let Some(pool_id) = staking_pool_id { + let strs = pool_id.as_str().split(".").collect::>(); + require!( + strs.len() == 3 && strs[1] == "pool" && strs[2] == "near", + "New staking_pool_id must be in the format .pool.near" + ); + ext_staking_pool::ext(pool_id.clone()) + .with_static_gas(GET_OWNER_ID_GAS) + .get_owner_id() + .then( + Self::ext(env::current_account_id()).check_owner_id_and_vote( + env::predecessor_account_id(), + pool_id, + is_vote, + ), + ); + } else { + let staking_pool_id = env::predecessor_account_id(); + self.vote_internal(is_vote, staking_pool_id); + } + } + /// Ping to update the votes according to current stake of validators. pub fn ping(&mut self) { require!( @@ -254,7 +300,7 @@ mod tests { ]); set_context_and_validators(&context, &validators); let mut contract = get_contract(); - contract.vote(true); + contract.vote(true, None); } #[test] @@ -269,10 +315,10 @@ mod tests { set_context_and_validators(&context, &validators); let mut contract = get_contract(); // vote - contract.vote(true); + contract.vote(true, None); assert!(contract.get_result().is_some()); // vote again. should panic because voting has ended - contract.vote(true); + contract.vote(true, None); } #[test] @@ -286,7 +332,7 @@ mod tests { let voter = validator(i); let mut context = get_context(&voter); set_context(&context); - contract.vote(true); + contract.vote(true, None); // check total voted stake context.is_view(true); @@ -319,7 +365,7 @@ mod tests { // vote by each validator let context = get_context_with_epoch_height(&validator(i), i); set_context(&context); - contract.vote(true); + contract.vote(true, None); // check votes assert_eq!(contract.get_votes().len() as u64, i + 1); // check voting result @@ -342,7 +388,7 @@ mod tests { let context = get_context_with_epoch_height(&validator(1), 1); set_context_and_validators(&context, &validators); let mut contract = get_contract(); - contract.vote(true); + contract.vote(true, None); // ping at epoch 2 validators.insert(validator(1).to_string(), NearToken::from_yoctonear(50)); let context = get_context_with_epoch_height(&validator(2), 2); @@ -361,12 +407,12 @@ mod tests { set_context_and_validators(&context, &validators); let mut contract = get_contract(); // vote at epoch 1 - contract.vote(true); + contract.vote(true, None); assert_eq!(contract.get_votes().len(), 1); // withdraw vote at epoch 2 let context = get_context_with_epoch_height(&validator(1), 2); set_context_and_validators(&context, &validators); - contract.vote(false); + contract.vote(false, None); assert!(contract.get_votes().is_empty()); } @@ -381,7 +427,7 @@ mod tests { set_context_and_validators(&context, &validators); let mut contract = get_contract(); // vote at epoch 1 - contract.vote(true); + contract.vote(true, None); assert_eq!((contract.get_total_voted_stake().0).0, 40); assert_eq!(contract.get_votes().len(), 1); // remove validator at epoch 2 @@ -428,7 +474,7 @@ mod tests { // vote after deadline set_context(context.block_timestamp(env::block_timestamp_ms() + 2000 * 1_000_000)); - contract.vote(true); + contract.vote(true, None); } #[test] @@ -439,7 +485,7 @@ mod tests { // vote at epoch 1 set_context(&context); - contract.vote(true); + contract.vote(true, None); // ping at epoch 2 after deadline set_context(