Skip to content

Commit

Permalink
async infershape (PaddlePaddle#26)
Browse files Browse the repository at this point in the history
Co-authored-by: root <root@yq01-inf-hic-k8s-a100-ab2-0175.yq01.baidu.com>
  • Loading branch information
xcpher and root authored Jun 27, 2022
1 parent a5e3184 commit bb9733e
Show file tree
Hide file tree
Showing 15 changed files with 1,225 additions and 556 deletions.
305 changes: 225 additions & 80 deletions paddle/fluid/framework/data_feed.cc

Large diffs are not rendered by default.

15 changes: 4 additions & 11 deletions paddle/fluid/framework/data_feed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,8 @@ void SlotRecordInMemoryDataFeed::FillSlotValueOffset(
const int ins_num, const int used_slot_num, size_t *slot_value_offsets,
const int *uint64_offsets, const int uint64_slot_size,
const int *float_offsets, const int float_slot_size,
const UsedSlotGpuType *used_slots) {
auto stream =
dynamic_cast<platform::CUDADeviceContext *>(
paddle::platform::DeviceContextPool::Instance().Get(this->place_))
->stream();
const UsedSlotGpuType *used_slots,
cudaStream_t stream) {
FillSlotValueOffsetKernel<<<GET_BLOCKS(used_slot_num), CUDA_NUM_THREADS, 0,
stream>>>(
ins_num, used_slot_num, slot_value_offsets, uint64_offsets,
Expand Down Expand Up @@ -130,12 +127,8 @@ void SlotRecordInMemoryDataFeed::CopyForTensor(
const int *uint64_offsets, const int *uint64_ins_lens,
const int uint64_slot_size, const float *float_feas,
const int *float_offsets, const int *float_ins_lens,
const int float_slot_size, const UsedSlotGpuType *used_slots) {
auto stream =
dynamic_cast<platform::CUDADeviceContext *>(
paddle::platform::DeviceContextPool::Instance().Get(this->place_))
->stream();

const int float_slot_size, const UsedSlotGpuType *used_slots,
cudaStream_t stream) {
CopyForTensorKernel<<<GET_BLOCKS(used_slot_num * ins_num), CUDA_NUM_THREADS,
0, stream>>>(
used_slot_num, ins_num, dest, slot_value_offsets, uint64_feas,
Expand Down
106 changes: 83 additions & 23 deletions paddle/fluid/framework/data_feed.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,9 @@ class MiniBatchGpuPack {
MiniBatchGpuPack(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos);
~MiniBatchGpuPack();

bool is_use();
void set_use_flag(bool is_use);
void reset(const paddle::platform::Place& place);
void pack_instance(const SlotRecord* ins_vec, int num);
int ins_num() { return ins_num_; }
Expand Down Expand Up @@ -559,6 +562,8 @@ class MiniBatchGpuPack {
}
LoDTensor& float_tensor(void) { return float_tensor_; }
LoDTensor& uint64_tensor(void) { return uint64_tensor_; }
std::vector<LoDTensor>& float_tensor_vec(void) { return float_tensor_vec_; }
std::vector<LoDTensor>& uint64_tensor_vec(void) { return uint64_tensor_vec_; }

HostBuffer<size_t>& offsets(void) { return offsets_; }
HostBuffer<void*>& h_tensor_ptrs(void) { return h_tensor_ptrs_; }
Expand All @@ -583,6 +588,10 @@ class MiniBatchGpuPack {
return batch_ins_[idx]->ins_id_;
}

cudaStream_t get_stream() {
return stream_;
}

private:
void transfer_to_gpu(void);
void pack_all_data(const SlotRecord* ins_vec, int num);
Expand All @@ -605,7 +614,9 @@ class MiniBatchGpuPack {
}

private:
bool is_using_ = false;
paddle::platform::Place place_;
std::unique_ptr<platform::stream::CUDAStream> stream_holder_;
cudaStream_t stream_;
BatchGPUValue value_;
BatchCPUValue buf_;
Expand All @@ -624,8 +635,10 @@ class MiniBatchGpuPack {

// uint64 tensor
LoDTensor uint64_tensor_;
std::vector<LoDTensor> uint64_tensor_vec_;
// float tensor
LoDTensor float_tensor_;
std::vector<LoDTensor> float_tensor_vec_;
// batch
HostBuffer<size_t> offsets_;
HostBuffer<void*> h_tensor_ptrs_;
Expand All @@ -638,33 +651,42 @@ class MiniBatchGpuPackMgr {

public:
MiniBatchGpuPackMgr() {
pack_list_.resize(MAX_DEIVCE_NUM);
for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
pack_list_[i] = nullptr;
pack_list_[i].clear();
}
}
~MiniBatchGpuPackMgr() {
for (int i = 0; i < MAX_DEIVCE_NUM; ++i) {
if (pack_list_[i] == nullptr) {
continue;
for (size_t j = 0; j < pack_list_[i].size(); j++) {
if (pack_list_[i][j] == nullptr) {
continue;
}
delete pack_list_[i][j];
pack_list_[i][j] = nullptr;
}
delete pack_list_[i];
pack_list_[i] = nullptr;
}
}
// one device one thread

// thread unsafe
MiniBatchGpuPack* get(const paddle::platform::Place& place,
const std::vector<UsedSlotInfo>& infos) {
int device_id = place.GetDeviceId();
if (pack_list_[device_id] == nullptr) {
pack_list_[device_id] = new MiniBatchGpuPack(place, infos);
} else {
pack_list_[device_id]->reset(place);
for (size_t i = 0; i < pack_list_[device_id].size(); i++) {
if (!pack_list_[device_id][i]->is_use()) {
pack_list_[device_id][i]->set_use_flag(true);
pack_list_[device_id][i]->reset(place);
return pack_list_[device_id][i];
}
}
return pack_list_[device_id];
auto* pack = new MiniBatchGpuPack(place, infos);
pack->set_use_flag(true);
pack_list_[device_id].push_back(pack);
return pack;
}

private:
MiniBatchGpuPack* pack_list_[MAX_DEIVCE_NUM];
std::vector<std::vector<MiniBatchGpuPack*>> pack_list_;
};
// global mgr
inline MiniBatchGpuPackMgr& BatchGpuPackMgr() {
Expand Down Expand Up @@ -744,6 +766,7 @@ class DLManager {
if (it != handle_map_.end()) {
return it->second.parser;
}

// load so symbol
// 导出libps、core_avx符号给parser共享
const std::vector<std::string> packages {"libps.so", "core_avx.so"};
Expand Down Expand Up @@ -825,6 +848,10 @@ class DataFeed {
// This function is used for binding feed_vec memory in a given scope
virtual void AssignFeedVar(const Scope& scope);

virtual std::vector<std::string> GetInputVarNames() {
return std::vector<std::string>();
}

// This function will do nothing at default
virtual void SetInputPvChannel(void* channel) {}
// This function will do nothing at default
Expand Down Expand Up @@ -871,6 +898,14 @@ class DataFeed {
}
virtual const paddle::platform::Place& GetPlace() const { return place_; }

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack) { return nullptr; }
virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope) {
PADDLE_THROW(platform::errors::Unimplemented(
"This function(PackToScope) is not implemented."));
}
#endif

protected:
// The following three functions are used to check if it is executed in this
// order:
Expand Down Expand Up @@ -1403,13 +1438,8 @@ class MultiSlotInMemoryDataFeed : public InMemoryDataFeed<Record> {
class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
public:
SlotRecordInMemoryDataFeed() {}
virtual ~SlotRecordInMemoryDataFeed() {
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
if (pack_ != nullptr) {
pack_ = nullptr;
}
#endif
}
virtual ~SlotRecordInMemoryDataFeed();

virtual void Init(const DataFeedDesc& data_feed_desc);
virtual void LoadIntoMemory();
void ExpandSlotRecord(SlotRecord* ins);
Expand All @@ -1433,21 +1463,37 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
bool ParseOneInstance(const std::string& line, SlotRecord* rec);
virtual void PutToFeedVec(const SlotRecord* ins_vec, int num);
virtual void AssignFeedVar(const Scope& scope);
virtual std::vector<std::string> GetInputVarNames() {
std::vector<std::string> var_names;
for (int i = 0; i < use_slot_size_; ++i) {
var_names.push_back(used_slots_info_[i].slot);
}
return var_names;
}

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
void BuildSlotBatchGPU(const int ins_num);
void BuildSlotBatchGPU(const int ins_num, MiniBatchGpuPack* pack);

// async infershape
virtual MiniBatchGpuPack* get_pack(MiniBatchGpuPack* last_pack);
virtual void PackToScope(MiniBatchGpuPack* pack, const Scope* scope = nullptr);

void FillSlotValueOffset(const int ins_num, const int used_slot_num,
size_t* slot_value_offsets,
const int* uint64_offsets,
const int uint64_slot_size, const int* float_offsets,
const int float_slot_size,
const UsedSlotGpuType* used_slots);
const UsedSlotGpuType* used_slots,
cudaStream_t stream
);
void CopyForTensor(const int ins_num, const int used_slot_num, void** dest,
const size_t* slot_value_offsets,
const uint64_t* uint64_feas, const int* uint64_offsets,
const int* uint64_ins_lens, const int uint64_slot_size,
const float* float_feas, const int* float_offsets,
const int* float_ins_lens, const int float_slot_size,
const UsedSlotGpuType* used_slots);
const UsedSlotGpuType* used_slots,
cudaStream_t stream);
#endif
float sample_rate_ = 1.0f;
int use_slot_size_ = 0;
Expand All @@ -1459,7 +1505,21 @@ class SlotRecordInMemoryDataFeed : public InMemoryDataFeed<SlotRecord> {
std::vector<int> float_total_dims_without_inductives_;

#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_HETERPS)
MiniBatchGpuPack* pack_ = nullptr;
int pack_thread_num_ {5};
std::vector<std::thread> pack_threads_;
std::vector<MiniBatchGpuPack*> pack_vec_;
BlockingQueue<MiniBatchGpuPack*> free_pack_queue_;
BlockingQueue<MiniBatchGpuPack*> using_pack_queue_;
std::atomic<bool> pack_is_end_ {false};
std::atomic<uint64_t> pack_offset_index_ {0};
MiniBatchGpuPack* last_pack_ {nullptr};
std::atomic<bool> stop_token_ {false};
std::atomic<int> thread_count_ {0};
std::mutex pack_mutex_;

// async infershape
std::map<const Scope*, std::vector<LoDTensor*> > scpoe_feed_vec_;

#endif
};

Expand Down
38 changes: 37 additions & 1 deletion paddle/fluid/framework/device_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ class HeterCpuWorker : public HogwildWorker {
class PSGPUWorker : public HogwildWorker {
public:
PSGPUWorker() {}
virtual ~PSGPUWorker() {}
virtual ~PSGPUWorker();
virtual void Initialize(const TrainerDesc& desc);
virtual void TrainFiles();
virtual void TrainFilesWithProfiler();
Expand All @@ -542,13 +542,27 @@ class PSGPUWorker : public HogwildWorker {
virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
void ResetStat();

// async infershape
virtual void CreateDeviceResource(const ProgramDesc& main_prog);
virtual void BindingDataFeedMemory();

protected:
void PushGradients();
void CopySparseTable();
void CopyDenseTable();
void CopyDenseVars();
void PrepareCudaGraph();

struct InferShapeCheckData {
std::vector<std::vector<DDim>> pre_dims;
std::vector<std::vector<LoD>> pre_lods;
std::vector<std::vector<DDim>> after_dims;
std::vector<std::vector<LoD>> after_lods;
};

int OpRunAndShapeCheck(OperatorBase& op,
const Scope& scope,
const platform::Place& place);
private:
int mpi_rank_;
std::mutex mutex_;
Expand Down Expand Up @@ -618,6 +632,28 @@ class PSGPUWorker : public HogwildWorker {
double gpu_2_cpu_time_;
double cpu_2_gpu_time_;
uint64_t total_inst_;

// async infershape
int task_threads_num_ {6};
int scope_num_ {task_threads_num_ + 1};
std::atomic<int> thread_count_ {0};
std::atomic<bool> stop_token_ {false};
std::atomic<bool> pack_is_end_ {false};
std::vector<std::thread> task_threads_;
std::vector<Scope*> thread_scope_vec_;
std::map<Scope*, std::vector<Variable*>> need_reuse_var_vec_;
std::vector<Variable*> need_reuse_var_;

struct TaskData {
int ins_num;
Scope* scope;
MiniBatchGpuPack* pack;
};
paddle::framework::BlockingQueue<TaskData> free_task_queue_;
paddle::framework::BlockingQueue<TaskData> using_task_queue_;

static std::atomic<int> shape_check_count_;
static std::atomic<bool> shape_check_flag_;
};
#endif

Expand Down
Loading

0 comments on commit bb9733e

Please sign in to comment.