diff --git a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs index 4a9cfd2df9edcf..5ceda39c03858b 100644 --- a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs +++ b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs @@ -28,6 +28,17 @@ struct AccountReadLocks { lock_counts: [LockCount; MAX_THREADS], } +/// Account locks. +/// Write Locks - only one thread can hold a write lock at a time. +/// Contains how many write locks are held by the thread. +/// Read Locks - multiple threads can hold a read lock at a time. +/// Contains thread-set for easily checking which threads are scheduled. +#[derive(Default)] +struct AccountLocks { + pub write_locks: Option, + pub read_locks: Option, +} + /// Thread-aware account locks which allows for scheduling on threads /// that already hold locks on the account. This is useful for allowing /// queued transactions to be scheduled on a thread while the transaction @@ -35,13 +46,9 @@ struct AccountReadLocks { pub(crate) struct ThreadAwareAccountLocks { /// Number of threads. num_threads: usize, // 0..MAX_THREADS - /// Write locks - only one thread can hold a write lock at a time. - /// Contains how many write locks are held by the thread. - write_locks: HashMap, - /// Read locks - multiple threads can hold a read lock at a time. - /// Contains thread-set for easily checking which threads are scheduled. - /// Contains how many read locks are held by each thread. - read_locks: HashMap, + /// Locks for each account. An account should only have an entry if there + /// is at least one lock. + locks: HashMap, } impl ThreadAwareAccountLocks { @@ -55,8 +62,7 @@ impl ThreadAwareAccountLocks { Self { num_threads, - write_locks: HashMap::new(), - read_locks: HashMap::new(), + locks: HashMap::new(), } } @@ -144,9 +150,12 @@ impl ThreadAwareAccountLocks { /// holds all read locks. Otherwise, no threads are write-schedulable. /// If only read-locked, all threads are read-schedulable. fn schedulable_threads(&self, account: &Pubkey) -> ThreadSet { - match (self.write_locks.get(account), self.read_locks.get(account)) { - (None, None) => ThreadSet::any(self.num_threads), - (None, Some(read_locks)) => { + match self.locks.get(account) { + None => ThreadSet::any(self.num_threads), + Some(AccountLocks { + write_locks: None, + read_locks: Some(read_locks), + }) => { if WRITE { read_locks .thread_set @@ -157,14 +166,24 @@ impl ThreadAwareAccountLocks { ThreadSet::any(self.num_threads) } } - (Some(write_locks), None) => ThreadSet::only(write_locks.thread_id), - (Some(write_locks), Some(read_locks)) => { + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: None, + }) => ThreadSet::only(write_locks.thread_id), + Some(AccountLocks { + write_locks: Some(write_locks), + read_locks: Some(read_locks), + }) => { assert_eq!( read_locks.thread_set.only_one_contained(), Some(write_locks.thread_id) ); read_locks.thread_set } + Some(AccountLocks { + write_locks: None, + read_locks: None, + }) => unreachable!(), } } @@ -191,57 +210,61 @@ impl ThreadAwareAccountLocks { /// Locks the given `account` for writing on `thread_id`. /// Panics if the account is already locked for writing on another thread. fn write_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - match self.write_locks.entry(*account) { - Entry::Occupied(mut entry) => { - let AccountWriteLocks { - thread_id: lock_thread_id, - lock_count, - } = entry.get_mut(); - assert_eq!( - *lock_thread_id, thread_id, - "outstanding write lock must be on same thread" - ); + let entry = self.locks.entry(*account).or_default(); - *lock_count += 1; - } - Entry::Vacant(entry) => { - entry.insert(AccountWriteLocks { - thread_id, - lock_count: 1, - }); - } - } + let AccountLocks { + write_locks, + read_locks, + } = entry; - // Check for outstanding read-locks - if let Some(read_locks) = self.read_locks.get(account) { + if let Some(read_locks) = read_locks { assert_eq!( - read_locks.thread_set, - ThreadSet::only(thread_id), + read_locks.thread_set.only_one_contained(), + Some(thread_id), "outstanding read lock must be on same thread" ); } + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + write_locks.lock_count += 1; + } else { + *write_locks = Some(AccountWriteLocks { + thread_id, + lock_count: 1, + }); + } } /// Unlocks the given `account` for writing on `thread_id`. /// Panics if the account is not locked for writing on `thread_id`. fn write_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - match self.write_locks.entry(*account) { - Entry::Occupied(mut entry) => { - let AccountWriteLocks { - thread_id: lock_thread_id, - lock_count, - } = entry.get_mut(); - assert_eq!( - *lock_thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - *lock_count -= 1; - if *lock_count == 0 { - entry.remove(); - } - } - Entry::Vacant(_) => { - panic!("write lock must exist for account: {account}"); + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("write lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks: maybe_write_locks, + read_locks, + } = entry.get_mut(); + + let Some(write_locks) = maybe_write_locks else { + panic!("write lock must exist for account: {account}"); + }; + + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + + write_locks.lock_count -= 1; + if write_locks.lock_count == 0 { + *maybe_write_locks = None; + if read_locks.is_none() { + entry.remove(); } } } @@ -249,58 +272,64 @@ impl ThreadAwareAccountLocks { /// Locks the given `account` for reading on `thread_id`. /// Panics if the account is already locked for writing on another thread. fn read_lock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - match self.read_locks.entry(*account) { - Entry::Occupied(mut entry) => { - let AccountReadLocks { - thread_set, - lock_counts, - } = entry.get_mut(); - thread_set.insert(thread_id); - lock_counts[thread_id] += 1; + let AccountLocks { + write_locks, + read_locks, + } = self.locks.entry(*account).or_default(); + + if let Some(write_locks) = write_locks { + assert_eq!( + write_locks.thread_id, thread_id, + "outstanding write lock must be on same thread" + ); + } + + match read_locks { + Some(read_locks) => { + read_locks.thread_set.insert(thread_id); + read_locks.lock_counts[thread_id] += 1; } - Entry::Vacant(entry) => { + None => { let mut lock_counts = [0; MAX_THREADS]; lock_counts[thread_id] = 1; - entry.insert(AccountReadLocks { + *read_locks = Some(AccountReadLocks { thread_set: ThreadSet::only(thread_id), lock_counts, }); } } - - // Check for outstanding write-locks - if let Some(write_locks) = self.write_locks.get(account) { - assert_eq!( - write_locks.thread_id, thread_id, - "outstanding write lock must be on same thread" - ); - } } /// Unlocks the given `account` for reading on `thread_id`. /// Panics if the account is not locked for reading on `thread_id`. fn read_unlock_account(&mut self, account: &Pubkey, thread_id: ThreadId) { - match self.read_locks.entry(*account) { - Entry::Occupied(mut entry) => { - let AccountReadLocks { - thread_set, - lock_counts, - } = entry.get_mut(); - assert!( - thread_set.contains(thread_id), - "outstanding read lock must be on same thread" - ); - lock_counts[thread_id] -= 1; - if lock_counts[thread_id] == 0 { - thread_set.remove(thread_id); - if thread_set.is_empty() { - entry.remove(); - } + let Entry::Occupied(mut entry) = self.locks.entry(*account) else { + panic!("read lock must exist for account: {account}"); + }; + + let AccountLocks { + write_locks, + read_locks: maybe_read_locks, + } = entry.get_mut(); + + let Some(read_locks) = maybe_read_locks else { + panic!("read lock must exist for account: {account}"); + }; + + assert!( + read_locks.thread_set.contains(thread_id), + "outstanding read lock must be on same thread" + ); + + read_locks.lock_counts[thread_id] -= 1; + if read_locks.lock_counts[thread_id] == 0 { + read_locks.thread_set.remove(thread_id); + if read_locks.thread_set.is_empty() { + *maybe_read_locks = None; + if write_locks.is_none() { + entry.remove(); } } - Entry::Vacant(_) => { - panic!("read lock must exist for account: {account}"); - } } } } @@ -641,7 +670,7 @@ mod tests { locks.write_lock_account(&pk1, 1); locks.write_unlock_account(&pk1, 1); locks.write_unlock_account(&pk1, 1); - assert!(locks.write_locks.is_empty()); + assert!(locks.locks.is_empty()); } #[test] @@ -652,7 +681,7 @@ mod tests { locks.read_lock_account(&pk1, 1); locks.read_unlock_account(&pk1, 1); locks.read_unlock_account(&pk1, 1); - assert!(locks.read_locks.is_empty()); + assert!(locks.locks.is_empty()); } #[test]