Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Support speculate decoding #2541

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions llm/server/server/engine/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def read_from_env(self):
self.block_size = int(env.get("BLOCK_SIZE", 64))
self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))

# speculate decoding config
self.speculate_method = str(env.get("SPECULATE_METHOD", None))
self.speculate_max_draft_token_num = int(os.getenv("SPECULATE_MAX_DRAFT_TOKEN_NUM", 5))
self.speculate_max_ngram_size = int(os.getenv("SPECULATE_MAX_NGRAM_SIZE", 2))

# infer config
self.max_batch_size = int(env.get("BATCH_SIZE", 50))
Expand Down
107 changes: 92 additions & 15 deletions llm/server/server/engine/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddlenlp.utils.llm_utils import get_rotary_position_embedding
from paddlenlp_ops import step_paddle
from paddlenlp_ops import step_paddle, speculate_step_paddle, speculate_update_input_ids_cpu
from server.data.processor import DataProcessor
from server.engine.config import Config
from server.engine.proposers import InferenceWithReferenceProposer
from server.utils import get_logger
from task_queue_manager import TaskQueueManager

Expand Down Expand Up @@ -62,6 +63,15 @@ def __init__(self, args):
self.cache_kvs = {}
self.init_inputs()

# whether use speculate decoding
if self.config.speculate_method is not None and self.config.speculate_method == "inference_with_reference":
self.proposer = InferenceWithReferenceProposer(
self.config.speculate_max_draft_token_num,
self.config.speculate_max_ngram_size,
self.args.max_batch_size)
else:
self.proposer = None

self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)

model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
Expand Down Expand Up @@ -246,6 +256,21 @@ def init_inputs(self):
self.share_inputs['free_list_len'] = paddle.full(
shape=[1], fill_value=self.free_list_len, dtype="int32")

# speculate decoding input
if self.config.speculate_method is not None:
self.share_inputs["input_ids_cpu"] = paddle.full(
shape=[self.args.max_batch_size, self.args.max_seq_len], fill_value=1, dtype='int64').cpu()
self.share_inputs["accept_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["accept_num"] = paddle.full(shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
self.share_inputs["draft_tokens"] = paddle.full(
shape=[self.args.max_batch_size, self.config.speculate_max_draft_token_num + 1], fill_value=0, dtype="int64"
)
self.share_inputs["actual_draft_token_num"] = paddle.full(
shape=[self.args.max_batch_size], fill_value=self.config.speculate_max_draft_token_num, dtype="int32"
)

def dy_input_preprocess(self, tasks):
"""
dynamic insertion
Expand Down Expand Up @@ -288,23 +313,46 @@ def dy_input_preprocess(self, tasks):
self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
task['block_tables'], dtype="int32")

if self.proposer is not None:
if self.config.speculate_method == "inference_with_reference":
speculate_update_input_ids_cpu(self.share_inputs['input_ids_cpu'], task['input_ids'], idx, self.args.max_seq_len)
self.share_inputs["draft_tokens"][idx:idx + 1] = np.zeros([self.config.speculate_max_draft_token_num + 1])
self.share_inputs["actual_draft_token_num"][idx:idx + 1] = np.array([self.config.speculate_max_draft_token_num])
self.proposer.update(idx, length)

def step_cuda(self, seq_lens_this_time):
"""
step cuda
"""
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
self.share_inputs['encoder_block_lens'],
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
if self.config.speculate_method is None:
step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
self.share_inputs['encoder_block_lens'],
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
else:
speculate_step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
self.share_inputs['step_seq_lens_encoder'],
self.share_inputs['seq_lens_encoder'],
self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
self.share_inputs['encoder_block_lens'],
self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
self.share_inputs['free_list'], self.share_inputs['free_list_len'],
self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id,
self.config.speculate_max_draft_token_num)

def initialize_engine_ready_check_flag(self):
"""
Expand Down Expand Up @@ -404,6 +452,9 @@ def run(self):
self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time

tasks, read_finish = self.infer_queue.get()
logger.info(f'tasks: {tasks}')
logger.info(f'read_finish: {read_finish}')

if read_finish:
flag_broadcast_array[0] = 0

Expand All @@ -412,7 +463,7 @@ def run(self):
real_bsz = int(bsz)
req_dicts.extend(req_dict)
logger.info(
f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
f'req_dict: {req_dict} rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
)

self.dy_input_preprocess(req_dicts)
Expand All @@ -429,10 +480,36 @@ def run(self):
time.sleep(0.001)
continue

if self.proposer is not None:
logger.info("start run proposer")
logger.info(f'before draft_tokens: {self.share_inputs["draft_tokens"]}')
logger.info(f'before accept_tokens: {self.share_inputs["accept_tokens"]}')

self.proposer.run(
self.share_inputs,
real_batch_size=self.args.max_batch_size,
seq_lens_this_time=self.share_inputs["seq_lens_this_time"],
)
logger.info(f'after draft_tokens: {self.share_inputs["draft_tokens"]}')
logger.info("finish run proposer")
logger.info(f'input_ids: {self.share_inputs["input_ids"]}')
logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}')
logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}')
logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}')
logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}')
logger.info(f'step_idx: {self.share_inputs["step_idx"]}')
logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}')
logger.info(f'before block_tables: {self.share_inputs["block_tables"]}')

self.infer_engine.predictor.run()
logger.info(f'after accept_tokens: {self.share_inputs["accept_tokens"]}')
logger.info(f'after accept_num: {self.share_inputs["accept_num"]}')
logger.info(f'after block_tables: {self.share_inputs["block_tables"]}')

self.share_inputs['infer_seed'].add_(infer_seed_increment)
self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
if self.free_list_len > 0:
logger.info(f'free_list_len > 0')
self.step_cuda(seq_lens_this_time)


Expand Down
95 changes: 95 additions & 0 deletions llm/server/server/engine/proposers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC, abstractmethod

import paddle
from paddlenlp_ops import ngram_match


class Proposer(ABC):
"""
Abstract base class for all proposers that can be used in the speculative decoding framework.
The subclasses of this class must implement the run method to get the draft tokens that are
generated by the proposer.
"""

def __init__(self, **kwargs):
pass

@abstractmethod
def run(self, model_inputs: dict[str, paddle.Tensor], **kargs):
"""
Get the draft tokens that are generated by the proposer.
"""
raise NotImplementedError()


class InferenceWithReferenceProposer(Proposer):
"""
InferenceWithReference(https://arxiv.org/pdf/2304.04487) is one of the speculative decoding method.
It match tokens in the input and output as draft tokens.
"""

def __init__(self, max_draft_token_num: int, max_ngram_size: int, max_batch_size: int):
"""
Args:
max_draft_token_num (int):
Maximum number of tokens a proposer can generate at one time.
The hyperparameter of k in the paper.
max_ngram_size (int):
The maximum size of the window used to match inputs and outputs.
The hyperparameter of n in the paper.
max_batch_size (int):
The maximum batch size.
"""
super().__init__()
self.max_ngram_size = max_ngram_size
self.input_ids_len = paddle.zeros(shape=[max_batch_size, 1], dtype="int64").cpu()
self.max_batch_size = max_batch_size
self.max_draft_token_num = max_draft_token_num
# self.input_ids_cpu = paddle.full(shape=[max_batch_size, max_seq_len], fill_value=1, dtype="int64").cpu()

def update(self, bid: int, seq_len: int):
"""
Used when inserting a new query to update the length of the input_ids.
"""
self.input_ids_len[bid] = seq_len

def run(self, share_inputs: dict[str, paddle.Tensor], **kargs):
"""
Use ngram_match to get draft tokens from the input and output.
"""
draft_tokens = share_inputs["draft_tokens"].cpu()
seq_lens_this_time = kargs["seq_lens_this_time"].cpu()
seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu()
seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu()
ngram_match(
share_inputs["input_ids_cpu"],
self.input_ids_len.cpu(),
share_inputs["pre_ids"].cpu(),
share_inputs["step_idx"].cpu(),
share_inputs["actual_draft_token_num"].cpu(),
draft_tokens,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
kargs["real_batch_size"],
self.max_ngram_size,
self.max_draft_token_num,
)
share_inputs["draft_tokens"][:] = draft_tokens.cuda()
share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda()
kargs["seq_lens_this_time"][:] = seq_lens_this_time.cuda()
Loading