diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py index cffb502c4..f31345db2 100644 --- a/lmdeploy/turbomind/turbomind.py +++ b/lmdeploy/turbomind/turbomind.py @@ -565,6 +565,8 @@ def _broadcast_np(data, dtype, shape=(batch_size, )): bad_words.append(self.eos_id) else: stop_words = gen_config.stop_token_ids + if self.eos_id not in stop_words: + stop_words.append(self.eos_id) stop_words = _construct_stop_or_bad_words(stop_words) bad_words = _construct_stop_or_bad_words(bad_words) diff --git a/src/turbomind/kernels/sampling_kernels.cu b/src/turbomind/kernels/sampling_kernels.cu index 427b9a4aa..7263cfcf8 100644 --- a/src/turbomind/kernels/sampling_kernels.cu +++ b/src/turbomind/kernels/sampling_kernels.cu @@ -17,8 +17,6 @@ __global__ void sampling(T* logits, int* indices, int* kept, curandState_t* curandstate, - bool* finished, - const int* end_ids, int* output_ids, int* sequence_length, float* sampled_logprobs, @@ -55,10 +53,8 @@ __global__ void sampling(T* logits, selected = min(i, n - 1); output_ids[batch_id] = indices[selected]; - if (sequence_length != nullptr && finished != nullptr) { - sequence_length[batch_id] = - finished[batch_id] ? sequence_length[batch_id] : sequence_length[batch_id] + 1; - finished[batch_id] = output_ids[batch_id] == end_ids[batch_id] ? 1 : 0; + if (sequence_length != nullptr) { + sequence_length[batch_id] += 1; } } break; @@ -94,8 +90,6 @@ void invokeSampling(SamplingParams& params, cudaStream_t stream) params.indices, params.kept, params.curandstate, - params.finished, - params.end_ids, params.output_ids, params.sequence_length, params.sampled_logprobs, diff --git a/src/turbomind/kernels/sampling_kernels.h b/src/turbomind/kernels/sampling_kernels.h index a3cd8fa02..954817e14 100644 --- a/src/turbomind/kernels/sampling_kernels.h +++ b/src/turbomind/kernels/sampling_kernels.h @@ -29,8 +29,6 @@ struct SamplingParams { int* kept; curandState_t* curandstate; size_t batch_size; - bool* finished; - int* end_ids; int* output_ids; int* sequence_length; float* sampled_logprobs; diff --git a/src/turbomind/layers/sampling_layers/SamplingLayer.cc b/src/turbomind/layers/sampling_layers/SamplingLayer.cc index 1d1b18cdc..6040402af 100644 --- a/src/turbomind/layers/sampling_layers/SamplingLayer.cc +++ b/src/turbomind/layers/sampling_layers/SamplingLayer.cc @@ -220,9 +220,7 @@ void SamplingLayer::forward(TensorMap* output_tensors, TensorMap* input_tenso params.kept = kept_; params.curandstate = output_tensors->at("curand_state").getPtr(); params.batch_size = batch_size; - params.finished = output_tensors->at("finished", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); - params.end_ids = input_tensors->at("end_id").getPtr(); - params.output_ids = output_tensors->at("output_ids").getPtrWithOffset(step * batch_size); + params.output_ids = output_tensors->at("output_ids").getPtrWithOffset(step * batch_size); params.sequence_length = output_tensors->at("sequence_length", Tensor{MEMORY_GPU, TYPE_INVALID, {}, nullptr}).getPtr(); params.sampled_logprobs =