Skip to content

Commit

Permalink
Formatter improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tzielinski-habana committed Sep 4, 2024
1 parent 85c4b2a commit d6616c6
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 45 deletions.
10 changes: 6 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,12 +855,14 @@ def _verify_args(self) -> None:
if self.enable_delayed_sampling and self.num_lookahead_slots != 1:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be 1 for delayed sampling.")
f"({self.num_lookahead_slots}) must be 1 for delayed sampling."
)

if self.enable_delayed_sampling and not self.use_v2_block_manager:
raise ValueError(
"use_v2_block_manager "
f"({self.use_v2_block_manager}) must be True for delayed sampling.")
raise ValueError("use_v2_block_manager "
f"({self.use_v2_block_manager}) must be True "
"for delayed sampling.")


class DeviceConfig:

Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ def create_output_processor(
This returns a single-step output processor if num_lookahead_slots is
zero, else returns a multi-step output processor.
"""
if (scheduler_config.num_lookahead_slots == 0 or (scheduler_config.num_lookahead_slots == 1
and scheduler_config.enable_delayed_sampling)):
if (scheduler_config.num_lookahead_slots == 0
or (scheduler_config.num_lookahead_slots == 1
and scheduler_config.enable_delayed_sampling)):
# Importing here to avoid cycle.
from vllm.engine.output_processor.single_step import (
SingleStepOutputProcessor)
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# delayed sampling is enabled).
if last_child_sample.output_token != -1:
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs)
last_child_sample.logprobs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,9 @@ def _greedy_sample(
assert num_parent_seqs == 1, (
"Greedy sampling should have only one seq.")
parent_ids = list(range(num_parent_seqs))
next_token_ids = [sample_idx if token_positions_only else samples_lst[sample_idx]]
next_token_ids = [
sample_idx if token_positions_only else samples_lst[sample_idx]
]
results.append((next_token_ids, parent_ids))
sample_idx += num_parent_seqs
return results
Expand Down Expand Up @@ -586,15 +588,17 @@ def _sample_with_torch(
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")

# GPU<->CPU sync happens in the loop below, unless we're storing only token positions (token_positions_only=True)
# GPU<->CPU sync happens in the loop below,
# unless we're storing only token positions (token_positions_only=True)
# This also converts the sample output to Python objects.
if not sampling_metadata.skip_sampler_cpu_output:
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(seq_groups, greedy_samples, token_positions_only)
sample_results = _greedy_sample(seq_groups, greedy_samples,
token_positions_only)
elif sampling_type in (SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
sample_results = _random_sample(
Expand Down Expand Up @@ -699,7 +703,8 @@ def _sample_with_triton_kernel(
def _sample(
probs: torch.Tensor, logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors,
include_gpu_probs_tensor: bool, modify_greedy_probs: bool, token_positions_only: bool
include_gpu_probs_tensor: bool, modify_greedy_probs: bool,
token_positions_only: bool
) -> Tuple[SampleResultType, Optional[torch.Tensor]]:
"""
Args:
Expand Down
7 changes: 4 additions & 3 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def output_token_ids(self, new_output_token_ids) -> None:
def output_token_ids_array(self) -> array:
return self._output_token_ids

def append_token_id(self, token_id: int, logprob: float) -> None:
def append_token_id(self, token_id: int, logprob: Optional[float]) -> None:
self._output_token_ids.append(token_id)
self._cached_all_token_ids.append(token_id)
self.cumulative_logprob += logprob if logprob is not None else 0.0
Expand Down Expand Up @@ -336,9 +336,10 @@ def append_token_id(
token_id: int,
logprobs: Dict[int, Logprob],
) -> None:
# assert token_id in logprobs
self.output_logprobs.append(logprobs)
self.data.append_token_id(token_id, logprobs[token_id].logprob if token_id in logprobs else None)
self.data.append_token_id(
token_id,
logprobs[token_id].logprob if token_id in logprobs else None)

def get_len(self) -> int:
return self.data.get_len()
Expand Down
83 changes: 52 additions & 31 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@ class ModelInputForHPU(ModelRunnerInputBase):
real_batch_size: Optional[int] = None
batch_size_padded: Optional[int] = None
virtual_engine: int = 0
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -359,7 +360,6 @@ class ModelInputForHPUWithSamplingMetadata(ModelInputForHPU):
# Used for speculative decoding. We do not broadcast it because it is only
# used by the driver worker.
is_prompt: Optional[bool] = None
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None

def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
tensor_dict = {
Expand Down Expand Up @@ -855,8 +855,9 @@ def _prepare_decode(
generation_token = seq_data.get_last_token_id()
input_tokens.append([generation_token])

seq_len = ((seq_data.get_num_computed_tokens() + 1)
if self.scheduler_config.enable_delayed_sampling else seq_data.get_len())
seq_len = ((seq_data.get_num_computed_tokens() +
1) if self.scheduler_config.enable_delayed_sampling
else seq_data.get_len())
position = seq_len - 1
input_positions.append([position])

Expand Down Expand Up @@ -1635,29 +1636,41 @@ def execute_model(
logits_ids_list = []
logits_tensor = None
logits_tensor_list = []
for seq_group_metadata in model_input.seq_group_metadata_list:
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
if seq_data.prev_logits is not None:
if logits_tensor is None:
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is logits_tensor:
# accumulate row ids from the same tensor
logits_ids_list.append(seq_data.prev_logits_idx)
if model_input.seq_group_metadata_list is not None:
for seq_group_metadata in model_input.seq_group_metadata_list:
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
if seq_data.prev_logits is not None:
if logits_tensor is None:
logits_tensor = seq_data.prev_logits
if seq_data.prev_logits is logits_tensor:
# accumulate row ids from the same tensor
logits_ids_list.append(
seq_data.prev_logits_idx)
else:
# new logits tensor,
# gather all previously collected rows
logits_tensor_list.append(
logits_tensor[torch.tensor(
logits_ids_list,
device=seq_data.prev_logits.device)])
logits_ids_list = [seq_data.prev_logits_idx]
logits_tensor = seq_data.prev_logits
else:
# new logits tensor, gather all previously collected rows
logits_tensor_list.append(logits_tensor[torch.tensor(logits_ids_list, device=seq_data.prev_logits.device)])
logits_ids_list = [seq_data.prev_logits_idx]
logits_tensor = seq_data.prev_logits
else:
# warmup only, TODO add a check
logits_tensor_list.append(torch.zeros([1, 32000], dtype=torch.float, device="hpu"))
# warmup only, TODO add a check
logits_tensor_list.append(
torch.zeros([1, 32000],
dtype=torch.float,
device="hpu"))
if logits_tensor is not None:
logits_tensor_list.append(logits_tensor[torch.tensor(logits_ids_list, device=seq_data.prev_logits.device)])
logits_tensor_list.append(logits_tensor[torch.tensor(
logits_ids_list, device=seq_data.prev_logits.device)])

prev_logits = torch.cat(logits_tensor_list, dim=0)

with self.profiler.record_event('internal', f'sample_{"prompt" if is_prompt else "decode"}_bs{batch_size}_seq{seq_len}'):
with self.profiler.record_event(
'internal', f'sample_{"prompt" if is_prompt else "decode"}'
'_bs{batch_size}_seq{seq_len}'):
output = self.model.sample(
logits=prev_logits,
sampling_metadata=sampling_metadata,
Expand Down Expand Up @@ -1693,26 +1706,32 @@ def execute_model(
if self.scheduler_config.enable_delayed_sampling:
if not is_prompt:
htorch.core.mark_step()
# Only after dispatching next model.forward() read and update the previous token ids to return
# Only after dispatching next model.forward() read and update
# the previous token ids to return
sampled_token_ids = output.sampled_token_ids.tolist()
for seq_group_output in output.outputs[:real_batch_size]:
for sample in seq_group_output.samples:
sample.output_token = sampled_token_ids[sample.output_token][0]
sample.output_token = sampled_token_ids[
sample.output_token][0]
output = output
else:
# For prompts compose empty output
from vllm.sequence import (Logprob, SamplerOutput, CompletionSequenceGroupOutput, SequenceOutput)
from vllm.sequence import (CompletionSequenceGroupOutput,
Logprob, SamplerOutput,
SequenceOutput)
sampler_output = []
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
next_token_id, parent_id = -1, 0
seq_outputs = []
seq_outputs.append(
SequenceOutput(seq_ids[parent_id], next_token_id, {-1: Logprob(0.0)}))
SequenceOutput(seq_ids[parent_id], next_token_id,
{-1: Logprob(0.0)}))
sampler_output.append(
CompletionSequenceGroupOutput(seq_outputs, None))

sampled_token_probs, logprobs_tensor, sampled_token_ids = (None, None, None)
sampled_token_probs, logprobs_tensor, sampled_token_ids = (
None, None, None)
output = SamplerOutput(
outputs=sampler_output,
sampled_token_probs=sampled_token_probs,
Expand All @@ -1733,8 +1752,10 @@ def execute_model(
logits = self.model.compute_logits(hidden_states,
sampling_metadata)

if self.scheduler_config.enable_delayed_sampling:
for idx, seq_group_metadata in enumerate(model_input.seq_group_metadata_list):
if (self.scheduler_config.enable_delayed_sampling
and model_input.seq_group_metadata_list is not None):
for idx, seq_group_metadata in enumerate(
model_input.seq_group_metadata_list):
assert len(seq_group_metadata.seq_data) == 1
for seq_data in seq_group_metadata.seq_data.values():
seq_data.prev_logits = logits
Expand All @@ -1749,9 +1770,9 @@ def execute_model(
if not self.scheduler_config.enable_delayed_sampling:
with self.profiler.record_event(
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
Expand Down

0 comments on commit d6616c6

Please sign in to comment.