Skip to content

Commit

Permalink
[cherry-pick]Sparse static graph (#46838)
Browse files Browse the repository at this point in the history
cherry-pick : #46322, #46245
Sparse API 支持静态图
  • Loading branch information
zhangkaihuo authored Oct 17, 2022
1 parent 976af0d commit 10225d2
Show file tree
Hide file tree
Showing 45 changed files with 937 additions and 74 deletions.
5 changes: 3 additions & 2 deletions paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ cc_test(
cc_library(
var_type_traits
SRCS var_type_traits.cc
DEPS framework_proto scope tensor_array)
DEPS framework_proto scope tensor_array sparse_coo_tensor)
if(WITH_GPU)
target_link_libraries(var_type_traits dynload_cuda)
endif()
Expand Down Expand Up @@ -1138,7 +1138,8 @@ cc_library(
phi
phi_api_utils
op_info
shape_inference)
shape_inference
sparse_coo_tensor)
cc_test(
infershape_utils_test
SRCS infershape_utils_test.cc
Expand Down
12 changes: 10 additions & 2 deletions paddle/fluid/framework/feed_fetch_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ limitations under the License. */

namespace paddle {
namespace framework {
using FeedType = paddle::variant<LoDTensor, Strings>;
using FeedType = paddle::variant<LoDTensor, Strings, phi::SparseCooTensor>;
using FeedList = std::vector<FeedType>;

using FetchType = paddle::variant<LoDTensor, LoDTensorArray, framework::Vocab>;
using FetchType = paddle::
variant<LoDTensor, LoDTensorArray, framework::Vocab, phi::SparseCooTensor>;
using FetchList = std::vector<FetchType>;

using FetchUnmergedList = std::vector<std::vector<FetchType>>;
Expand All @@ -52,6 +53,13 @@ inline bool data_is_string_tensor(const FeedType &data) {
return false;
}

inline bool data_is_sparse_coo_tensor(const FetchType &data) {
if (data.type() == typeid(phi::SparseCooTensor)) {
return true;
}
return false;
}

static const char kFeedOpType[] = "feed";
static const char kFetchOpType[] = "fetch";

Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ message VarType {
FEED_LIST = 28;
// The data type of phi::StringTensor
PSTRING = 29;
// the data type of phi::SparseCooTensor
SPARSE_COO = 30;
}

required Type type = 1;
Expand Down Expand Up @@ -186,6 +188,7 @@ message VarType {
optional TensorDesc string = 8;
optional TensorDesc strings = 9;
optional TensorDesc vocab = 10;
optional TensorDesc sparse_coo = 11;
}

message VarDesc {
Expand Down
46 changes: 45 additions & 1 deletion paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class InferShapeArgumentMappingContext : public phi::ArgumentMappingContext {
});
}

bool IsSparseCooTensorInput(const std::string& name) const override {
auto var_type = ctx_.GetInputVarType(name);
return var_type == proto::VarType::SPARSE_COO;
}

bool IsDenseTensorOutput(const std::string& name) const override {
auto var_types = ctx_.GetOutputsVarType(name);
return std::all_of(var_types.begin(),
Expand Down Expand Up @@ -145,6 +150,26 @@ int64_t CompatMetaTensor::numel() const {
}
}

bool CompatMetaTensor::is_dense() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<phi::DenseTensor>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::LOD_TENSOR;
}
}

bool CompatMetaTensor::is_tensor_array() const {
if (is_runtime_) {
auto* var = PADDLE_GET_CONST(Variable*, var_);
return var->IsType<framework::LoDTensorArray>();
} else {
auto* var = PADDLE_GET_CONST(VarDesc*, var_);
return var->GetType() == proto::VarType::LOD_TENSOR_ARRAY;
}
}

DDim CompatMetaTensor::dims() const {
ValidCheck(*this);
if (is_runtime_) {
Expand All @@ -153,6 +178,8 @@ DDim CompatMetaTensor::dims() const {
return var->Get<phi::DenseTensor>().dims();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dims();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().dims();
} else if (var->IsType<framework::LoDTensorArray>()) {
// use tensor array size as dims
auto& tensor_array = var->Get<framework::LoDTensorArray>();
Expand All @@ -178,6 +205,8 @@ phi::DataType CompatMetaTensor::dtype() const {
return var->Get<phi::DenseTensor>().dtype();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().dtype();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().dtype();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get dtype from LoDTensorArray now
Expand All @@ -200,6 +229,8 @@ DataLayout CompatMetaTensor::layout() const {
return var->Get<phi::DenseTensor>().layout();
} else if (var->IsType<phi::SelectedRows>()) {
return var->Get<phi::SelectedRows>().layout();
} else if (var->IsType<phi::SparseCooTensor>()) {
return var->Get<phi::SparseCooTensor>().layout();
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported get layout from LoDTensorArray now
Expand All @@ -226,6 +257,9 @@ void CompatMetaTensor::set_dims(const DDim& dims) {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dims = dims;
} else if (var->IsType<framework::LoDTensorArray>()) {
auto* tensor_array = var->GetMutable<framework::LoDTensorArray>();
// Note: Here I want enforce `tensor_array->size() == 0UL`, because
Expand Down Expand Up @@ -257,6 +291,9 @@ void CompatMetaTensor::set_dtype(phi::DataType dtype) {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->dtype = dtype;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
Expand All @@ -280,6 +317,9 @@ void CompatMetaTensor::set_layout(DataLayout layout) {
} else if (var->IsType<phi::SelectedRows>()) {
auto* tensor = var->GetMutable<phi::SelectedRows>()->mutable_value();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<phi::SparseCooTensor>()) {
auto* tensor = var->GetMutable<phi::SparseCooTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->layout = layout;
} else if (var->IsType<framework::LoDTensorArray>()) {
// NOTE(chenweihang): do nothing
// Unsupported set dtype for LoDTensorArray now
Expand All @@ -299,7 +339,7 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
ValidCheck(meta_tensor);
if (is_runtime_) {
auto* var = PADDLE_GET(Variable*, var_);
if (var->IsType<phi::DenseTensor>()) {
if (var->IsType<phi::DenseTensor>() && meta_tensor.is_dense()) {
auto* tensor = var->GetMutable<phi::DenseTensor>();
phi::DenseTensorUtils::GetMutableMeta(tensor)->lod =
static_cast<const CompatMetaTensor&>(meta_tensor).GetRuntimeLoD();
Expand All @@ -309,6 +349,10 @@ void CompatMetaTensor::share_lod(const MetaTensor& meta_tensor) {
}
} else {
auto* var = PADDLE_GET(VarDesc*, var_);
if (!meta_tensor.is_dense() && !meta_tensor.is_tensor_array()) {
VLOG(3) << "input metatensor is not LoDTensor or LoDTensorArray.";
return;
}
var->SetLoDLevel(
static_cast<const CompatMetaTensor&>(meta_tensor).GetCompileTimeLoD());
}
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/infershape_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class CompatMetaTensor : public phi::MetaTensor {

bool initialized() const override { return initialized_; };

bool is_tensor_array() const;
bool is_dense() const;

operator unspecified_bool_type() const override {
return initialized_ ? unspecified_bool_true : 0;
}
Expand Down
40 changes: 40 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2382,6 +2382,17 @@ void OperatorWithKernel::ParseInputDataType(
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
PADDLE_ENFORCE_EQ(
sp_t->initialized(),
true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(),
name));
*data_type = paddle::framework::TransToProtoVarType(sp_t->dtype());
return;
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
Expand Down Expand Up @@ -2419,6 +2430,29 @@ void OperatorWithKernel::ParseMultiInputDataType(
t = &var->Get<LoDTensor>();
} else if (var->IsType<phi::SelectedRows>()) {
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
PADDLE_ENFORCE_EQ(
sp_t->initialized(),
true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(),
name));
proto::VarType::Type tmp =
paddle::framework::TransToProtoVarType(sp_t->dtype());
PADDLE_ENFORCE(tmp == *data_type || *data_type == default_data_type,
platform::errors::InvalidArgument(
"The DataType of %s Op's duplicable or different "
"slot Variable %s must be "
"consistent or reigster GetExpectedKernelType. The "
"current variable type is (%s), but the "
"previous variable type is (%s).",
Type(),
name,
DataTypeToString(tmp),
DataTypeToString(*data_type)));
*data_type = tmp;
} else if (var->IsType<LoDTensorArray>()) {
auto t_arr = &var->Get<LoDTensorArray>();
for (size_t j = 0; j < t_arr->size(); j++) {
Expand Down Expand Up @@ -2663,6 +2697,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (var->IsType<phi::SelectedRows>()) {
tensor_in = &(var->Get<phi::SelectedRows>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<phi::SparseCooTensor>()) {
tensor_in = &(var->Get<phi::SparseCooTensor>());
phi_kernel_context->EmplaceBackInputWithoutSetRange(tensor_in);
} else if (var->IsType<framework::LoDTensorArray>()) {
need_prepare_phi_data_ = true;
tensor_in = &(var->Get<framework::LoDTensorArray>());
Expand Down Expand Up @@ -2708,6 +2745,9 @@ void OperatorWithKernel::BuildPhiKernelContext(
} else if (var->template IsType<phi::SelectedRows>()) {
tensor_out = var->template GetMutable<phi::SelectedRows>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<phi::SparseCooTensor>()) {
tensor_out = var->template GetMutable<phi::SparseCooTensor>();
phi_kernel_context->EmplaceBackOutputWithoutSetRange(tensor_out);
} else if (var->template IsType<framework::LoDTensorArray>()) {
tensor_out = var->template GetMutable<framework::LoDTensorArray>();
// Note: If the input LoDTensorArray size is 0, the output
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ class ExecutionArgumentMappingContext : public phi::ArgumentMappingContext {
});
}

bool IsSparseCooTensorInput(const std::string& name) const override {
const auto* var = ctx_.InputVar(name);
return var->IsType<phi::SparseCooTensor>();
}

bool IsDenseTensorOutput(const std::string& name) const override {
auto vars = ctx_.MultiOutputVar(name);
return std::all_of(vars.begin(), vars.end(), [](const Variable* var) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"

namespace paddle {
namespace framework {
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/var_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
return desc_.type().strings();
case proto::VarType::VOCAB:
return desc_.type().vocab();
case proto::VarType::SPARSE_COO:
return desc_.type().sparse_coo();
default:
PADDLE_THROW(platform::errors::Unavailable(
"Getting 'tensor_desc' is not supported by the %s type variable.",
Expand Down Expand Up @@ -284,6 +286,8 @@ proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
return desc_.mutable_type()->mutable_strings();
case proto::VarType::VOCAB:
return desc_.mutable_type()->mutable_vocab();
case proto::VarType::SPARSE_COO:
return desc_.mutable_type()->mutable_sparse_coo();
default:
PADDLE_THROW(
platform::errors::Unavailable("Getting 'mutable_tensor_desc' is not "
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/var_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ inline proto::VarType::Type ToVarType(int type) {
switch (type) {
case proto::VarType::LOD_TENSOR:
case proto::VarType::SELECTED_ROWS:
case proto::VarType::SPARSE_COO:
case proto::VarType::LOD_RANK_TABLE:
case proto::VarType::LOD_TENSOR_ARRAY:
case proto::VarType::FETCH_LIST:
Expand All @@ -59,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType::SELECTED_ROWS:
visitor(var.Get<phi::SelectedRows>());
return;
case proto::VarType::SPARSE_COO:
visitor(var.Get<phi::SparseCooTensor>());
return;
case proto::VarType::READER:
visitor(var.Get<ReaderHolder>());
return;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/var_type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
namespace phi {
class DenseTensor;
class SelectedRows;
class SparseCooTensor;
} // namespace phi

// Users should add forward declarations here
Expand Down Expand Up @@ -180,6 +181,7 @@ struct VarTypeRegistryImpl {
using VarTypeRegistry = detail::VarTypeRegistryImpl<
Tensor,
phi::SelectedRows,
phi::SparseCooTensor,
std::vector<Scope *>,
LoDRankTable,
Strings,
Expand Down Expand Up @@ -252,6 +254,7 @@ REG_PROTO_VAR_TYPE_TRAIT(float, proto::VarType::FP32);
REG_PROTO_VAR_TYPE_TRAIT(Vocab, proto::VarType::VOCAB);
REG_PROTO_VAR_TYPE_TRAIT(String, proto::VarType::STRING);
REG_PROTO_VAR_TYPE_TRAIT(Strings, proto::VarType::STRINGS);
REG_PROTO_VAR_TYPE_TRAIT(phi::SparseCooTensor, proto::VarType::SPARSE_COO);

/** End of variable type registration */

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/variable_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::RAW) {
// GetMutable will be called in operator
} else if (var_type == proto::VarType::SPARSE_COO) {
var->GetMutable<phi::SparseCooTensor>();
} else {
PADDLE_THROW(platform::errors::Unavailable(
"Variable type %d is not in "
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ bool PluginArgumentMappingContext::IsSelectedRowsInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsSparseCooTensorInput(
const std::string& name) const {
return false;
}
bool PluginArgumentMappingContext::IsDenseTensorVectorInput(
const std::string& name) const {
return false;
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/inference/tensorrt/plugin_arg_mapping_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class PluginArgumentMappingContext : public ::phi::ArgumentMappingContext {

bool IsSelectedRowsInput(const std::string& name) const override;

bool IsSparseCooTensorInput(const std::string& name) const override;

bool IsDenseTensorVectorInput(const std::string& name) const override;

bool IsDenseTensorOutput(const std::string& name) const override;
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/controlflow/feed_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/core/tensor_utils.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -61,6 +62,22 @@ class FeedVariableVisitor {
*out_str = in_str;
}

void operator()(const phi::SparseCooTensor &in_tensor) const {
phi::SparseCooTensor *out_tensor =
out_var_->GetMutable<phi::SparseCooTensor>();
if (platform::is_same_place(in_tensor.place(), place_)) {
*out_tensor = in_tensor;
} else {
platform::DeviceContext *context =
platform::DeviceContextPool::Instance().Get(place_);

phi::DenseTensor indices, values;
framework::TensorCopy(in_tensor.indices(), place_, *context, &indices);
framework::TensorCopy(in_tensor.values(), place_, *context, &values);
out_tensor->SetMember(indices, values, in_tensor.meta());
}
}

private:
framework::Variable *out_var_;
const platform::Place &place_;
Expand Down
Loading

0 comments on commit 10225d2

Please sign in to comment.