Skip to content

Commit

Permalink
[Inference] Fix get_save_output op and refactor specu_decoding (#9576)
Browse files Browse the repository at this point in the history
* fix get_save_output op and refactor specu_decoding
* add token_penalty for speculate decoding
  • Loading branch information
Wanglongzhi2001 authored Dec 12, 2024
1 parent 70e80c7 commit 4f57179
Show file tree
Hide file tree
Showing 11 changed files with 755 additions and 251 deletions.
70 changes: 26 additions & 44 deletions csrc/gpu/get_output.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -17,71 +17,53 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>

#include "paddle/extension.h"

#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
#define MAX_BSZ 512

template <int SIZE>
struct MsgData {
long mtype;
std::array<int, SIZE> mtext;
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

template <int SIZE>
void GetOutputFunc(MsgData<SIZE>& msg_rcv, // NOLINT
const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
if (rank_id > 0) return;

static struct msgdata msg_rcv;

static key_t key = ftok("./", 1);

static int msgid = msgget(key, IPC_CREAT | 0666);

int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
ret = msgrcv(
msgid, &msg_rcv, SIZE * sizeof(int), 0, wait_flag ? 0 : IPC_NOWAIT);

int64_t* out_data = const_cast<int64_t*>(x.data<int64_t>());

if (ret == -1) {
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1)
{
// read none
out_data[0] = -2;
out_data[1] = 0;
return;
}
return;
}

for (int64_t i = 0; i < SIZE; i++) {
int bsz = msg_rcv.mtext[1];

for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}

return;
}

void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag,
bool speculative_decoding) {
if (!speculative_decoding) {
constexpr int SIZE = MAX_BSZ + 2; // stop_flag, bsz, tokens...
static struct MsgData<SIZE> msg_rcv;
GetOutputFunc<SIZE>(msg_rcv, x, rank_id, wait_flag);
} else {
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_rcv;
GetOutputFunc<SIZE>(specu_msg_rcv, x, rank_id, wait_flag);
}
}

PD_BUILD_OP(get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool",
"speculative_decoding: bool"})
"wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutput));
102 changes: 25 additions & 77 deletions csrc/gpu/save_with_output_msg.cc
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//
// http://www.apache.org/licenses/LICENSE-2.0
//
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand All @@ -17,96 +17,44 @@
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>

#include "paddle/extension.h"

#define MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6
#define MAX_BSZ 512

template <int SIZE>
struct MsgData {
long mtype;
std::array<int, SIZE> mtext;
struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

template <int SIZE>
void SaveOutMsgFunc(MsgData<SIZE>& msg_sed, // NOLINT
const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::optional<paddle::Tensor>& accept_num,
int64_t rank_id) {
if (rank_id > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t* x_data = x_cpu.data<int64_t>();
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false);
bool* not_need_stop_data = not_need_stop_cpu.data<bool>();
void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id) {
if (rank_id > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
auto not_need_stop_cpu = not_need_stop.copy_to(paddle::CPUPlace(), false);
bool* not_need_stop_data = not_need_stop_cpu.data<bool>();

static key_t key = ftok("./", 1);
static int msgid = msgget(key, IPC_CREAT | 0666);
int bsz = x.shape()[0];
static struct msgdata msg_sed;
static key_t key = ftok("./", 1);
static int msgid = msgget(key, IPC_CREAT | 0666);

if (!accept_num) {
msg_sed.mtype = 1;
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
if ((msgsnd(msgid, &msg_sed, SIZE * sizeof(int), 0)) == -1) {
// printf("full msg buffer\n");
msg_sed.mtext[i] = (int)x_data[i - 2];
}
} else {
auto accept_num_cpu = accept_num.get().copy_to(paddle::CPUPlace(), false);
int* accept_num_data = accept_num_cpu.data<int>();

msg_sed.mtype = 1;
msg_sed.mtext[0] = not_need_stop_data[0] ? 1 : -1;
msg_sed.mtext[1] = bsz;
for (int i = 2; i < MAX_BSZ + 2; i++) {
if (i - 2 >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = (int)accept_num_data[i - 2];
}
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
// printf("full msg buffer\n");
}
for (int i = MAX_BSZ + 2; i < SIZE; i++) {
int token_id = i - MAX_BSZ - 2;
int bid = token_id / MAX_DRAFT_TOKENS;
int local_token_id = token_id % MAX_DRAFT_TOKENS;
if (token_id / MAX_DRAFT_TOKENS >= bsz) {
msg_sed.mtext[i] = 0;
} else {
msg_sed.mtext[i] = x_data[bid * MAX_DRAFT_TOKENS + local_token_id];
}
}
if ((msgsnd(msgid, &msg_sed, SIZE * sizeof(int), 0)) == -1) {
printf("full msg buffer\n");
}
}

return;
}

void SaveOutMsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
const paddle::optional<paddle::Tensor>& accept_num,
int64_t rank_id) {
if (!accept_num) {
constexpr int SIZE = MAX_BSZ + 2; // stop_flag, bsz, tokens...
static struct MsgData<SIZE> msg_sed;
SaveOutMsgFunc<SIZE>(msg_sed, x, not_need_stop, accept_num, rank_id);
} else {
constexpr int SIZE = MAX_BSZ * MAX_DRAFT_TOKENS +
MAX_BSZ +
2; // stop_flag, bsz, accept_num*bsz, tokens...
static struct MsgData<SIZE> specu_msg_sed;
SaveOutMsgFunc<SIZE>(specu_msg_sed, x, not_need_stop, accept_num, rank_id);
}
return;
}

PD_BUILD_OP(save_output)
.Inputs({"x", "not_need_stop", paddle::Optional("accept_num")})
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMsg));
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
68 changes: 68 additions & 0 deletions csrc/gpu/speculate_decoding_kernels/speculate_get_output.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"

#define SPECULATE_MAX_BSZ 256
#define MAX_DRAFT_TOKENS 6

struct msgdata {
long mtype;
int mtext[SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2]; // stop_flag, bsz, accept_num*bsz, tokens...
};

void SpeculateGetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
if (rank_id > 0) {
return;
}
static struct msgdata msg_rcv;

static key_t key = ftok("./", 1);

static int msgid = msgget(key, IPC_CREAT | 0666);

int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1) {
out_data[0] = -2;
out_data[1] = 0;
return;
}
int bsz = msg_rcv.mtext[1];

for (int64_t i = 0; i < SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
return;
}

PD_BUILD_OP(speculate_get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetOutput));
Loading

0 comments on commit 4f57179

Please sign in to comment.