Skip to content

Commit

Permalink
fix: Memory counter leak (#10358)
Browse files Browse the repository at this point in the history
  • Loading branch information
liurenjie1024 authored Jun 16, 2023
1 parent 1c1354c commit 2c2a2b7
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 101 deletions.
197 changes: 134 additions & 63 deletions src/batch/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
self.mem_context.add(memory_usage_diff);
}

// generate output data chunks
let mut result = groups.into_iter();
// Don't use `into_iter` here, it may cause memory leak.
let mut result = groups.iter_mut();
let cardinality = self.chunk_size;
loop {
let mut group_builders: Vec<_> = self
Expand All @@ -259,9 +259,9 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {
array_len += 1;
key.deserialize_to_builders(&mut group_builders[..], &self.group_key_types)?;
states
.into_iter()
.iter_mut()
.zip_eq_fast(&mut agg_builders)
.try_for_each(|(mut aggregator, builder)| aggregator.output(builder))?;
.try_for_each(|(aggregator, builder)| aggregator.output(builder))?;
}
if !has_next {
break; // exit loop
Expand All @@ -281,6 +281,11 @@ impl<K: HashKey + Send + Sync> HashAggExecutor<K> {

#[cfg(test)]
mod tests {
use std::alloc::{AllocError, Allocator, Global, Layout};
use std::ptr::NonNull;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

use prometheus::IntGauge;
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::test_prelude::DataChunkTestExt;
Expand All @@ -296,9 +301,11 @@ mod tests {

#[tokio::test]
async fn execute_int32_grouped() {
let src_exec = Box::new(MockExecutor::with_chunk(
DataChunk::from_pretty(
"i i i
let parent_mem = MemoryContext::root(IntGauge::new("root_memory_usage", " ").unwrap());
{
let src_exec = Box::new(MockExecutor::with_chunk(
DataChunk::from_pretty(
"i i i
0 1 1
1 1 1
0 0 1
Expand All @@ -307,68 +314,75 @@ mod tests {
0 0 2
1 1 3
0 1 2",
),
Schema::new(vec![
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int64),
]),
));

let agg_call = AggCall {
r#type: Type::Sum as i32,
args: vec![InputRef {
index: 2,
r#type: Some(PbDataType {
type_name: TypeName::Int32 as i32,
),
Schema::new(vec![
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int64),
]),
));

let agg_call = AggCall {
r#type: Type::Sum as i32,
args: vec![InputRef {
index: 2,
r#type: Some(PbDataType {
type_name: TypeName::Int32 as i32,
..Default::default()
}),
}],
return_type: Some(PbDataType {
type_name: TypeName::Int64 as i32,
..Default::default()
}),
}],
return_type: Some(PbDataType {
type_name: TypeName::Int64 as i32,
..Default::default()
}),
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let agg_prost = HashAggNode {
group_key: vec![0, 1],
agg_calls: vec![agg_call],
};

let mem_context = MemoryContext::root(IntGauge::new("memory_usage", " ").unwrap());
let actual_exec = HashAggExecutorBuilder::deserialize(
&agg_prost,
src_exec,
TaskId::default(),
"HashAggExecutor".to_string(),
CHUNK_SIZE,
mem_context.clone(),
)
.unwrap();

// TODO: currently the order is fixed unless the hasher is changed
let expect_exec = Box::new(MockExecutor::with_chunk(
DataChunk::from_pretty(
"i i I
distinct: false,
order_by: vec![],
filter: None,
direct_args: vec![],
};

let agg_prost = HashAggNode {
group_key: vec![0, 1],
agg_calls: vec![agg_call],
};

let mem_context = MemoryContext::new(
Some(parent_mem.clone()),
IntGauge::new("memory_usage", " ").unwrap(),
);
let actual_exec = HashAggExecutorBuilder::deserialize(
&agg_prost,
src_exec,
TaskId::default(),
"HashAggExecutor".to_string(),
CHUNK_SIZE,
mem_context.clone(),
)
.unwrap();

// TODO: currently the order is fixed unless the hasher is changed
let expect_exec = Box::new(MockExecutor::with_chunk(
DataChunk::from_pretty(
"i i I
1 0 1
0 0 3
0 1 3
1 1 6",
),
Schema::new(vec![
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int64),
]),
));
diff_executor_output(actual_exec, expect_exec).await;

// check estimated memory usage = 4 groups x state size
assert_eq!(mem_context.get_bytes_used() as usize, 4 * 72);
),
Schema::new(vec![
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int32),
Field::unnamed(DataType::Int64),
]),
));
diff_executor_output(actual_exec, expect_exec).await;

// check estimated memory usage = 4 groups x state size
assert_eq!(mem_context.get_bytes_used() as usize, 4 * 72);
}

// Ensure that agg memory counter has been dropped.
assert_eq!(0, parent_mem.get_bytes_used());
}

#[tokio::test]
Expand Down Expand Up @@ -425,4 +439,61 @@ mod tests {
);
diff_executor_output(actual_exec, Box::new(expect_exec)).await;
}

/// A test to verify that `HashMap` may leak memory counter when using `into_iter`.
#[test]
fn test_hashmap_into_iter_bug() {
let dropped: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));

{
struct MyAllocInner {
drop_flag: Arc<AtomicBool>,
}

#[derive(Clone)]
struct MyAlloc {
inner: Arc<MyAllocInner>,
}

impl Drop for MyAllocInner {
fn drop(&mut self) {
println!("MyAlloc freed.");
self.drop_flag.store(true, Ordering::SeqCst);
}
}

unsafe impl Allocator for MyAlloc {
fn allocate(
&self,
layout: Layout,
) -> std::result::Result<NonNull<[u8]>, AllocError> {
let g = Global;
g.allocate(layout)
}

unsafe fn deallocate(&self, ptr: NonNull<u8>, layout: Layout) {
let g = Global;
g.deallocate(ptr, layout)
}
}

let mut map = hashbrown::HashMap::with_capacity_in(
10,
MyAlloc {
inner: Arc::new(MyAllocInner {
drop_flag: dropped.clone(),
}),
},
);
for i in 0..10 {
map.entry(i).or_insert_with(|| "i".to_string());
}

for (k, v) in map {
println!("{}, {}", k, v);
}
}

assert!(!dropped.load(Ordering::SeqCst));
}
}
82 changes: 48 additions & 34 deletions src/batch/src/executor/join/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1845,6 +1845,7 @@ impl<K> HashJoinExecutor<K> {
mod tests {
use futures::StreamExt;
use futures_async_stream::for_await;
use prometheus::IntGauge;
use risingwave_common::array::{ArrayBuilderImpl, DataChunk};
use risingwave_common::catalog::{Field, Schema};
use risingwave_common::error::Result;
Expand Down Expand Up @@ -2051,6 +2052,7 @@ mod tests {
left_child: BoxedExecutor,
right_child: BoxedExecutor,
shutdown_rx: Option<Receiver<ShutdownMsg>>,
parent_mem_ctx: Option<MemoryContext>,
) -> BoxedExecutor {
let join_type = self.join_type;

Expand All @@ -2067,6 +2069,8 @@ mod tests {
None
};

let mem_ctx =
MemoryContext::new(parent_mem_ctx, IntGauge::new("memory_usage", " ").unwrap());
Box::new(HashJoinExecutor::<Key32>::new(
join_type,
output_indices,
Expand All @@ -2079,7 +2083,7 @@ mod tests {
"HashJoinExecutor".to_string(),
chunk_size,
shutdown_rx,
MemoryContext::none(),
mem_ctx,
))
}

Expand All @@ -2106,45 +2110,53 @@ mod tests {
left_executor: BoxedExecutor,
right_executor: BoxedExecutor,
) {
let join_executor = self.create_join_executor_with_chunk_size_and_executors(
has_non_equi_cond,
null_safe,
chunk_size,
left_executor,
right_executor,
None,
);

let mut data_chunk_merger = DataChunkMerger::new(self.output_data_types()).unwrap();
let parent_mem_context =
MemoryContext::root(IntGauge::new("total_memory_usage", " ").unwrap());

{
let join_executor = self.create_join_executor_with_chunk_size_and_executors(
has_non_equi_cond,
null_safe,
chunk_size,
left_executor,
right_executor,
None,
Some(parent_mem_context.clone()),
);

let mut data_chunk_merger = DataChunkMerger::new(self.output_data_types()).unwrap();

let fields = &join_executor.schema().fields;

if self.join_type.keep_all() {
assert_eq!(fields[1].data_type, DataType::Float32);
assert_eq!(fields[3].data_type, DataType::Float64);
} else if self.join_type.keep_left() {
assert_eq!(fields[1].data_type, DataType::Float32);
} else if self.join_type.keep_right() {
assert_eq!(fields[1].data_type, DataType::Float64)
} else {
unreachable!()
}

let fields = &join_executor.schema().fields;
let mut stream = join_executor.execute();

if self.join_type.keep_all() {
assert_eq!(fields[1].data_type, DataType::Float32);
assert_eq!(fields[3].data_type, DataType::Float64);
} else if self.join_type.keep_left() {
assert_eq!(fields[1].data_type, DataType::Float32);
} else if self.join_type.keep_right() {
assert_eq!(fields[1].data_type, DataType::Float64)
} else {
unreachable!()
}
while let Some(data_chunk) = stream.next().await {
let data_chunk = data_chunk.unwrap();
let data_chunk = data_chunk.compact();
data_chunk_merger.append(&data_chunk).unwrap();
}

let mut stream = join_executor.execute();
let result_chunk = data_chunk_merger.finish().unwrap();
println!("expected: {:?}", expected);
println!("result: {:?}", result_chunk);

while let Some(data_chunk) = stream.next().await {
let data_chunk = data_chunk.unwrap();
let data_chunk = data_chunk.compact();
data_chunk_merger.append(&data_chunk).unwrap();
// TODO: Replace this with unsorted comparison
// assert_eq!(expected, result_chunk);
assert!(is_data_chunk_eq(&expected, &result_chunk));
}

let result_chunk = data_chunk_merger.finish().unwrap();
println!("expected: {:?}", expected);
println!("result: {:?}", result_chunk);

// TODO: Replace this with unsorted comparison
// assert_eq!(expected, result_chunk);
assert!(is_data_chunk_eq(&expected, &result_chunk));
assert_eq!(0, parent_mem_context.get_bytes_used());
}

async fn do_test_shutdown(&self, has_non_equi_cond: bool) {
Expand All @@ -2159,6 +2171,7 @@ mod tests {
left_executor,
right_executor,
Some(shutdown_rx),
None,
);
shutdown_tx.send(ShutdownMsg::Cancel).unwrap();
#[for_await]
Expand All @@ -2178,6 +2191,7 @@ mod tests {
left_executor,
right_executor,
Some(shutdown_rx),
None,
);
shutdown_tx
.send(ShutdownMsg::Abort("Test".to_string()))
Expand Down
Loading

0 comments on commit 2c2a2b7

Please sign in to comment.