Skip to content

Commit

Permalink
remove end_ids from sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
irexyc committed Sep 6, 2024
1 parent 622c52a commit 0758305
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 13 deletions.
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,8 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
bad_words.append(self.eos_id)
else:
stop_words = gen_config.stop_token_ids
if self.eos_id not in stop_words:
stop_words.append(self.eos_id)
stop_words = _construct_stop_or_bad_words(stop_words)
bad_words = _construct_stop_or_bad_words(bad_words)

Expand Down
10 changes: 2 additions & 8 deletions src/turbomind/kernels/sampling_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ __global__ void sampling(T* logits,
int* indices,
int* kept,
curandState_t* curandstate,
bool* finished,
const int* end_ids,
int* output_ids,
int* sequence_length,
float* sampled_logprobs,
Expand Down Expand Up @@ -55,10 +53,8 @@ __global__ void sampling(T* logits,
selected = min(i, n - 1);
output_ids[batch_id] = indices[selected];

if (sequence_length != nullptr && finished != nullptr) {
sequence_length[batch_id] =
finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1;
finished[batch_id] = output_ids[batch_id] == end_ids[batch_id] ? 1 : 0;
if (sequence_length != nullptr) {
sequence_length[batch_id] += 1;
}
}
break;
Expand Down Expand Up @@ -94,8 +90,6 @@ void invokeSampling(SamplingParams& params, cudaStream_t stream)
params.indices,
params.kept,
params.curandstate,
params.finished,
params.end_ids,
params.output_ids,
params.sequence_length,
params.sampled_logprobs,
Expand Down
2 changes: 0 additions & 2 deletions src/turbomind/kernels/sampling_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ struct SamplingParams {
int* kept;
curandState_t* curandstate;
size_t batch_size;
bool* finished;
int* end_ids;
int* output_ids;
int* sequence_length;
float* sampled_logprobs;
Expand Down
4 changes: 1 addition & 3 deletions src/turbomind/layers/sampling_layers/SamplingLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,7 @@ void SamplingLayer<T>::forward(TensorMap* output_tensors, TensorMap* input_tenso
params.kept = kept_;
params.curandstate = output_tensors->at("curand_state").getPtr<curandState_t>();
params.batch_size = batch_size;
params.finished = output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<bool>();
params.end_ids = input_tensors->at("end_id").getPtr<int>();
params.output_ids = output_tensors->at("output_ids").getPtrWithOffset<int>(step * batch_size);
params.output_ids = output_tensors->at("output_ids").getPtrWithOffset<int>(step * batch_size);
params.sequence_length =
output_tensors->at("sequence_length", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr<int>();
params.sampled_logprobs =
Expand Down

0 comments on commit 0758305

Please sign in to comment.