Skip to content

Commit

Permalink
Fix data leakage in KV cache initialization (#1057)
Browse files Browse the repository at this point in the history
### Description

This PR fixes a data leakage issue that can occur between generators
when buffer sharing is enabled.

### Motivation and Context

Suppose you create generator A with one input id and run one iteration
of the generation loop with it. Now you destroy generator A and create
generator B with three input ids. The input KV caches for generator B
before running one iteration contain the values in the output KV caches
for generator A after generator A ran for one iteration.

The data leakage can be stopped when initializing the KV caches by
always setting the memory to zeros.
  • Loading branch information
kunal-vaishnavi authored Nov 12, 2024
1 parent f66e4f5 commit cc4577e
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,19 @@ KV_Cache::KV_Cache(State& state)
}
}

auto kv_cache_size_bytes = SizeOf(type_) * shape_[0] * shape_[1] * shape_[2] * shape_[3];
for (int i = 0; i < layer_count_ * 2; ++i) {
presents_.push_back(
sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_kvcache_, shape_, type_)
: sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_));
#if USE_CUDA
if (model_.device_type_ == DeviceType::CUDA) {
cudaMemsetAsync(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes, model_.cuda_stream_);
} else
#endif
{
memset(presents_.back()->GetTensorMutableRawData(), 0, kv_cache_size_bytes);
}
}
}

Expand Down

0 comments on commit cc4577e

Please sign in to comment.