Skip to content

Commit

Permalink
stake-pool: Add tolerance for stake accounts at minimum (solana-labs#…
Browse files Browse the repository at this point in the history
…3839)

* stake-pool: Add tolerance for stake accounts at minimum

* Use test-case to check more cases

* Add more tolerance on withdrawal

* Potentially fix test per solana-labs#3854

* Keep throwing solutions until CI passes

* Fix repeated transaction issue

* Fix preferred withdrawal tolerance too

* Remove doubled tolerance
  • Loading branch information
joncinque authored and HaoranYi committed Jul 19, 2023
1 parent 4411a49 commit 60f2458
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 97 deletions.
39 changes: 13 additions & 26 deletions stake-pool/program/src/big_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find<T: Pack>(&self, data: &[u8], predicate: fn(&[u8], &[u8]) -> bool) -> Option<&T> {
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
}
current_index = end_index;
Expand All @@ -165,18 +165,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find_mut<T: Pack>(
&mut self,
data: &[u8],
predicate: fn(&[u8], &[u8]) -> bool,
) -> Option<&mut T> {
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
}
current_index = end_index;
Expand Down Expand Up @@ -242,10 +238,7 @@ impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {

#[cfg(test)]
mod tests {
use {
super::*,
solana_program::{program_memory::sol_memcmp, program_pack::Sealed},
};
use {super::*, solana_program::program_pack::Sealed};

#[derive(Debug, PartialEq)]
struct TestStruct {
Expand Down Expand Up @@ -317,11 +310,11 @@ mod tests {
check_big_vec_eq(&v, &[2, 4]);
}

fn find_predicate(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
fn find_predicate(a: &[u8], b: u64) -> bool {
if a.len() != 8 {
false
} else {
sol_memcmp(a, b, a.len()) == 0
u64::try_from_slice(&a[0..8]).unwrap() == b
}
}

Expand All @@ -330,32 +323,26 @@ mod tests {
let mut data = [0u8; 4 + 8 * 4];
let v = from_slice(&mut data, &[1, 2, 3, 4]);
assert_eq!(
v.find::<TestStruct>(&1u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
Some(&TestStruct::new(1))
);
assert_eq!(
v.find::<TestStruct>(&4u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
Some(&TestStruct::new(4))
);
assert_eq!(
v.find::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
fn find_mut() {
let mut data = [0u8; 4 + 8 * 4];
let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
let mut test_struct = v
.find_mut::<TestStruct>(&1u64.to_le_bytes(), find_predicate)
.find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
.unwrap();
test_struct.value = 0;
check_big_vec_eq(&v, &[0, 2, 3, 4]);
assert_eq!(
v.find_mut::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
Expand Down
102 changes: 57 additions & 45 deletions stake-pool/program/src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,9 @@ impl Processor {
if header.max_validators == validator_list.len() {
return Err(ProgramError::AccountDataTooSmall);
}
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
validator_vote_info.key.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, validator_vote_info.key)
});
if maybe_validator_stake_info.is_some() {
return Err(StakePoolError::ValidatorAlreadyAdded.into());
}
Expand Down Expand Up @@ -994,10 +993,9 @@ impl Processor {

let (meta, stake) = get_stake_state(stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1154,10 +1152,9 @@ impl Processor {
let (meta, stake) = get_stake_state(validator_stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1316,10 +1313,9 @@ impl Processor {

let vote_account_address = validator_vote_account_info.key;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1481,10 +1477,9 @@ impl Processor {
}

if let Some(vote_account_address) = vote_account_address {
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
match maybe_validator_stake_info {
Some(vsi) => {
if vsi.status != StakeStatus::Active {
Expand Down Expand Up @@ -2031,10 +2026,9 @@ impl Processor {
}

let mut validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
check_validator_stake_address(
program_id,
Expand Down Expand Up @@ -2428,7 +2422,7 @@ impl Processor {
.checked_sub(pool_tokens_fee)
.ok_or(StakePoolError::CalculationFailure)?;

let withdraw_lamports = stake_pool
let mut withdraw_lamports = stake_pool
.calc_lamports_withdraw_amount(pool_tokens_burnt)
.ok_or(StakePoolError::CalculationFailure)?;

Expand All @@ -2442,17 +2436,27 @@ impl Processor {
let meta = stake_state.meta().ok_or(StakePoolError::WrongStakeState)?;
let required_lamports = minimum_stake_lamports(&meta, stake_minimum_delegation);

let lamports_per_pool_token = stake_pool
.get_lamports_per_pool_token()
.ok_or(StakePoolError::CalculationFailure)?;
let minimum_lamports_with_tolerance =
required_lamports.saturating_add(lamports_per_pool_token);

let has_active_stake = validator_list
.find::<ValidatorStakeInfo>(
&required_lamports.to_le_bytes(),
ValidatorStakeInfo::active_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::active_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();
let has_transient_stake = validator_list
.find::<ValidatorStakeInfo>(
&0u64.to_le_bytes(),
ValidatorStakeInfo::transient_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::transient_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();

let validator_list_item_info = if *stake_split_from.key == stake_pool.reserve_stake {
Expand All @@ -2478,25 +2482,23 @@ impl Processor {
stake_pool.preferred_withdraw_validator_vote_address
{
let preferred_validator_info = validator_list
.find::<ValidatorStakeInfo>(
preferred_withdraw_validator.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &preferred_withdraw_validator)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
let available_lamports = preferred_validator_info
.active_stake_lamports
.saturating_sub(required_lamports);
.saturating_sub(minimum_lamports_with_tolerance);
if preferred_withdraw_validator != vote_account_address && available_lamports > 0 {
msg!("Validator vote address {} is preferred for withdrawals, it currently has {} lamports available. Please withdraw those before using other validator stake accounts.", preferred_withdraw_validator, preferred_validator_info.active_stake_lamports);
return Err(StakePoolError::IncorrectWithdrawVoteAddress.into());
}
}

let validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;

let withdraw_source = if has_active_stake {
Expand Down Expand Up @@ -2548,11 +2550,21 @@ impl Processor {
}
}
StakeWithdrawSource::ValidatorRemoval => {
if withdraw_lamports != stake_split_from.lamports() {
msg!("Cannot withdraw a whole account worth {} lamports, must withdraw exactly {} lamports worth of pool tokens",
withdraw_lamports, stake_split_from.lamports());
let split_from_lamports = stake_split_from.lamports();
let upper_bound = split_from_lamports.saturating_add(lamports_per_pool_token);
if withdraw_lamports < split_from_lamports || withdraw_lamports > upper_bound {
msg!(
"Cannot withdraw a whole account worth {} lamports, \
must withdraw at least {} lamports worth of pool tokens \
with a margin of {} lamports",
withdraw_lamports,
split_from_lamports,
lamports_per_pool_token
);
return Err(StakePoolError::StakeLamportsNotEqualToMinimum.into());
}
// truncate the lamports down to the amount in the account
withdraw_lamports = split_from_lamports;
}
}
Some((validator_stake_info, withdraw_source))
Expand Down
29 changes: 19 additions & 10 deletions stake-pool/program/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ impl StakePool {
}
}

/// Get the current value of pool tokens, rounded up
#[inline]
pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
self.total_lamports
.checked_add(self.pool_token_supply)?
.checked_sub(1)?
.checked_div(self.pool_token_supply)
}

/// Checks that the withdraw or deposit authority is valid
fn check_program_derived_authority(
authority_address: &Pubkey,
Expand Down Expand Up @@ -660,24 +669,24 @@ impl ValidatorStakeInfo {

/// Performs a very cheap comparison, for checking if this validator stake
/// info matches the vote account address
pub fn memcmp_pubkey(data: &[u8], vote_address_bytes: &[u8]) -> bool {
pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
sol_memcmp(
&data[41..41 + PUBKEY_BYTES],
vote_address_bytes,
vote_address.as_ref(),
PUBKEY_BYTES,
) == 0
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have active lamports equal to the given bytes
pub fn active_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[0..8], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more active lamports than some limit
pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[0..8]).unwrap() > *lamports
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have lamports equal to the given bytes
pub fn transient_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[8..16], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more transient lamports than some limit
pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[8..16]).unwrap() > *lamports
}

/// Check that the validator stake info is valid
Expand Down
2 changes: 1 addition & 1 deletion stake-pool/program/tests/huge_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use {
},
};

const HUGE_POOL_SIZE: u32 = 2_000;
const HUGE_POOL_SIZE: u32 = 3_300;
const STAKE_AMOUNT: u64 = 200_000_000_000;

async fn setup(
Expand Down
21 changes: 15 additions & 6 deletions stake-pool/program/tests/update_validator_list_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ async fn setup(
// Warp forward so the stakes properly activate, and deposit
slot += slots_per_epoch;
context.warp_to_slot(slot).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

stake_pool_accounts
.update_all(
Expand All @@ -111,12 +116,6 @@ async fn setup(
)
.await;

let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

for deposit_account in &mut deposit_accounts {
deposit_account
.deposit_stake(
Expand All @@ -130,6 +129,11 @@ async fn setup(

slot += slots_per_epoch;
context.warp_to_slot(slot).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

stake_pool_accounts
.update_all(
Expand Down Expand Up @@ -422,6 +426,11 @@ async fn merge_into_validator_stake() {

// Warp just a little bit to get a new blockhash and update again
context.warp_to_slot(slot + 10).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&last_blockhash)
.await
.unwrap();

// Update, should not change, no merges yet
let error = stake_pool_accounts
Expand Down
Loading

0 comments on commit 60f2458

Please sign in to comment.