Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stake-pool: Add tolerance for stake accounts at minimum #3839

Merged
merged 8 commits into from
Dec 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 13 additions & 26 deletions stake-pool/program/src/big_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find<T: Pack>(&self, data: &[u8], predicate: fn(&[u8], &[u8]) -> bool) -> Option<&T> {
pub fn find<T: Pack, F: Fn(&[u8]) -> bool>(&self, predicate: F) -> Option<&T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &*(current_slice.as_ptr() as *const T) });
}
current_index = end_index;
Expand All @@ -165,18 +165,14 @@ impl<'data> BigVec<'data> {
}

/// Find matching data in the array
pub fn find_mut<T: Pack>(
&mut self,
data: &[u8],
predicate: fn(&[u8], &[u8]) -> bool,
) -> Option<&mut T> {
pub fn find_mut<T: Pack, F: Fn(&[u8]) -> bool>(&mut self, predicate: F) -> Option<&mut T> {
let len = self.len() as usize;
let mut current = 0;
let mut current_index = VEC_SIZE_BYTES;
while current != len {
let end_index = current_index + T::LEN;
let current_slice = &self.data[current_index..end_index];
if predicate(current_slice, data) {
if predicate(current_slice) {
return Some(unsafe { &mut *(current_slice.as_ptr() as *mut T) });
}
current_index = end_index;
Expand Down Expand Up @@ -242,10 +238,7 @@ impl<'data, 'vec, T: Pack + 'data> Iterator for IterMut<'data, 'vec, T> {

#[cfg(test)]
mod tests {
use {
super::*,
solana_program::{program_memory::sol_memcmp, program_pack::Sealed},
};
use {super::*, solana_program::program_pack::Sealed};

#[derive(Debug, PartialEq)]
struct TestStruct {
Expand Down Expand Up @@ -317,11 +310,11 @@ mod tests {
check_big_vec_eq(&v, &[2, 4]);
}

fn find_predicate(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
fn find_predicate(a: &[u8], b: u64) -> bool {
if a.len() != 8 {
false
} else {
sol_memcmp(a, b, a.len()) == 0
u64::try_from_slice(&a[0..8]).unwrap() == b
}
}

Expand All @@ -330,32 +323,26 @@ mod tests {
let mut data = [0u8; 4 + 8 * 4];
let v = from_slice(&mut data, &[1, 2, 3, 4]);
assert_eq!(
v.find::<TestStruct>(&1u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 1)),
Some(&TestStruct::new(1))
);
assert_eq!(
v.find::<TestStruct>(&4u64.to_le_bytes(), find_predicate),
v.find::<TestStruct, _>(|x| find_predicate(x, 4)),
Some(&TestStruct::new(4))
);
assert_eq!(
v.find::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
fn find_mut() {
let mut data = [0u8; 4 + 8 * 4];
let mut v = from_slice(&mut data, &[1, 2, 3, 4]);
let mut test_struct = v
.find_mut::<TestStruct>(&1u64.to_le_bytes(), find_predicate)
.find_mut::<TestStruct, _>(|x| find_predicate(x, 1))
.unwrap();
test_struct.value = 0;
check_big_vec_eq(&v, &[0, 2, 3, 4]);
assert_eq!(
v.find_mut::<TestStruct>(&5u64.to_le_bytes(), find_predicate),
None
);
assert_eq!(v.find_mut::<TestStruct, _>(|x| find_predicate(x, 5)), None);
}

#[test]
Expand Down
102 changes: 57 additions & 45 deletions stake-pool/program/src/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,9 @@ impl Processor {
if header.max_validators == validator_list.len() {
return Err(ProgramError::AccountDataTooSmall);
}
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
validator_vote_info.key.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, validator_vote_info.key)
});
if maybe_validator_stake_info.is_some() {
return Err(StakePoolError::ValidatorAlreadyAdded.into());
}
Expand Down Expand Up @@ -994,10 +993,9 @@ impl Processor {

let (meta, stake) = get_stake_state(stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1154,10 +1152,9 @@ impl Processor {
let (meta, stake) = get_stake_state(validator_stake_account_info)?;
let vote_account_address = stake.delegation.voter_pubkey;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1316,10 +1313,9 @@ impl Processor {

let vote_account_address = validator_vote_account_info.key;

let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, vote_account_address)
});
if maybe_validator_stake_info.is_none() {
msg!(
"Vote account {} not found in stake pool",
Expand Down Expand Up @@ -1481,10 +1477,9 @@ impl Processor {
}

if let Some(vote_account_address) = vote_account_address {
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
);
let maybe_validator_stake_info = validator_list.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
});
match maybe_validator_stake_info {
Some(vsi) => {
if vsi.status != StakeStatus::Active {
Expand Down Expand Up @@ -2031,10 +2026,9 @@ impl Processor {
}

let mut validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
check_validator_stake_address(
program_id,
Expand Down Expand Up @@ -2428,7 +2422,7 @@ impl Processor {
.checked_sub(pool_tokens_fee)
.ok_or(StakePoolError::CalculationFailure)?;

let withdraw_lamports = stake_pool
let mut withdraw_lamports = stake_pool
.calc_lamports_withdraw_amount(pool_tokens_burnt)
.ok_or(StakePoolError::CalculationFailure)?;

Expand All @@ -2442,17 +2436,27 @@ impl Processor {
let meta = stake_state.meta().ok_or(StakePoolError::WrongStakeState)?;
let required_lamports = minimum_stake_lamports(&meta, stake_minimum_delegation);

let lamports_per_pool_token = stake_pool
.get_lamports_per_pool_token()
.ok_or(StakePoolError::CalculationFailure)?;
let minimum_lamports_with_tolerance =
required_lamports.saturating_add(lamports_per_pool_token);

let has_active_stake = validator_list
.find::<ValidatorStakeInfo>(
&required_lamports.to_le_bytes(),
ValidatorStakeInfo::active_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::active_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();
let has_transient_stake = validator_list
.find::<ValidatorStakeInfo>(
&0u64.to_le_bytes(),
ValidatorStakeInfo::transient_lamports_not_equal,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::transient_lamports_greater_than(
x,
&minimum_lamports_with_tolerance,
)
})
.is_some();

let validator_list_item_info = if *stake_split_from.key == stake_pool.reserve_stake {
Expand All @@ -2478,25 +2482,23 @@ impl Processor {
stake_pool.preferred_withdraw_validator_vote_address
{
let preferred_validator_info = validator_list
.find::<ValidatorStakeInfo>(
preferred_withdraw_validator.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &preferred_withdraw_validator)
})
.ok_or(StakePoolError::ValidatorNotFound)?;
let available_lamports = preferred_validator_info
.active_stake_lamports
.saturating_sub(required_lamports);
.saturating_sub(minimum_lamports_with_tolerance);
if preferred_withdraw_validator != vote_account_address && available_lamports > 0 {
msg!("Validator vote address {} is preferred for withdrawals, it currently has {} lamports available. Please withdraw those before using other validator stake accounts.", preferred_withdraw_validator, preferred_validator_info.active_stake_lamports);
return Err(StakePoolError::IncorrectWithdrawVoteAddress.into());
}
}

let validator_stake_info = validator_list
.find_mut::<ValidatorStakeInfo>(
vote_account_address.as_ref(),
ValidatorStakeInfo::memcmp_pubkey,
)
.find_mut::<ValidatorStakeInfo, _>(|x| {
ValidatorStakeInfo::memcmp_pubkey(x, &vote_account_address)
})
.ok_or(StakePoolError::ValidatorNotFound)?;

let withdraw_source = if has_active_stake {
Expand Down Expand Up @@ -2548,11 +2550,21 @@ impl Processor {
}
}
StakeWithdrawSource::ValidatorRemoval => {
if withdraw_lamports != stake_split_from.lamports() {
msg!("Cannot withdraw a whole account worth {} lamports, must withdraw exactly {} lamports worth of pool tokens",
withdraw_lamports, stake_split_from.lamports());
let split_from_lamports = stake_split_from.lamports();
let upper_bound = split_from_lamports.saturating_add(lamports_per_pool_token);
if withdraw_lamports < split_from_lamports || withdraw_lamports > upper_bound {
msg!(
"Cannot withdraw a whole account worth {} lamports, \
must withdraw at least {} lamports worth of pool tokens \
with a margin of {} lamports",
withdraw_lamports,
split_from_lamports,
lamports_per_pool_token
);
return Err(StakePoolError::StakeLamportsNotEqualToMinimum.into());
}
// truncate the lamports down to the amount in the account
withdraw_lamports = split_from_lamports;
}
}
Some((validator_stake_info, withdraw_source))
Expand Down
29 changes: 19 additions & 10 deletions stake-pool/program/src/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,15 @@ impl StakePool {
}
}

/// Get the current value of pool tokens, rounded up
#[inline]
pub fn get_lamports_per_pool_token(&self) -> Option<u64> {
self.total_lamports
.checked_add(self.pool_token_supply)?
.checked_sub(1)?
.checked_div(self.pool_token_supply)
}

/// Checks that the withdraw or deposit authority is valid
fn check_program_derived_authority(
authority_address: &Pubkey,
Expand Down Expand Up @@ -660,24 +669,24 @@ impl ValidatorStakeInfo {

/// Performs a very cheap comparison, for checking if this validator stake
/// info matches the vote account address
pub fn memcmp_pubkey(data: &[u8], vote_address_bytes: &[u8]) -> bool {
pub fn memcmp_pubkey(data: &[u8], vote_address: &Pubkey) -> bool {
sol_memcmp(
&data[41..41 + PUBKEY_BYTES],
vote_address_bytes,
vote_address.as_ref(),
PUBKEY_BYTES,
) == 0
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have active lamports equal to the given bytes
pub fn active_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[0..8], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more active lamports than some limit
pub fn active_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[0..8]).unwrap() > *lamports
}

/// Performs a very cheap comparison, for checking if this validator stake
/// info does not have lamports equal to the given bytes
pub fn transient_lamports_not_equal(data: &[u8], lamports_le_bytes: &[u8]) -> bool {
sol_memcmp(&data[8..16], lamports_le_bytes, 8) != 0
/// Performs a comparison, used to check if this validator stake
/// info has more transient lamports than some limit
pub fn transient_lamports_greater_than(data: &[u8], lamports: &u64) -> bool {
u64::try_from_slice(&data[8..16]).unwrap() > *lamports
}

/// Check that the validator stake info is valid
Expand Down
2 changes: 1 addition & 1 deletion stake-pool/program/tests/huge_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use {
},
};

const HUGE_POOL_SIZE: u32 = 2_000;
const HUGE_POOL_SIZE: u32 = 3_300;
const STAKE_AMOUNT: u64 = 200_000_000_000;

async fn setup(
Expand Down
21 changes: 15 additions & 6 deletions stake-pool/program/tests/update_validator_list_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ async fn setup(
// Warp forward so the stakes properly activate, and deposit
slot += slots_per_epoch;
context.warp_to_slot(slot).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

stake_pool_accounts
.update_all(
Expand All @@ -111,12 +116,6 @@ async fn setup(
)
.await;

let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

for deposit_account in &mut deposit_accounts {
deposit_account
.deposit_stake(
Expand All @@ -130,6 +129,11 @@ async fn setup(

slot += slots_per_epoch;
context.warp_to_slot(slot).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&context.last_blockhash)
.await
.unwrap();

stake_pool_accounts
.update_all(
Expand Down Expand Up @@ -418,6 +422,11 @@ async fn merge_into_validator_stake() {

// Warp just a little bit to get a new blockhash and update again
context.warp_to_slot(slot + 10).unwrap();
let last_blockhash = context
.banks_client
.get_new_latest_blockhash(&last_blockhash)
.await
.unwrap();

// Update, should not change, no merges yet
let error = stake_pool_accounts
Expand Down
Loading