Skip to content

Commit

Permalink
【Complex op】add complex support for assign_value (#59536)
Browse files Browse the repository at this point in the history
* support_complex_for_assign_value

* add test complex test for test_program_converter

* add complex test for assign_value xpu

* solve conflict

* fix timeout

* fix CE infer bug

* fix program convert bug

* fix program convert bug for assign_value

---------

Co-authored-by: zyt1024 <1522064645@qq.com>
  • Loading branch information
zyt1024 and zyt1024 authored Dec 28, 2023
1 parent 76ce9bb commit c1d7860
Show file tree
Hide file tree
Showing 31 changed files with 614 additions and 123 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_version_proto.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace pb {
const std::unordered_map<std::string, uint32_t>& GetLegacyOpVersions() {
static std::unordered_map<std::string, uint32_t> op_versions = {
{"not_equal", 1},
{"assign_value", 0},
{"fake_channel_wise_dequantize_max_abs", 2},
{"yolo_box", 1},
{"data_norm", 1},
Expand Down
85 changes: 84 additions & 1 deletion paddle/fluid/framework/program_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,41 @@ void ConvertSetValueOp(OpDesc* op) {
}
}

void ConvertAssignValueOp(OpDesc* op) {
std::vector<paddle::experimental::Scalar> values = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, op->GetAttr("values", false));
op->RemoveAttr("values");
op->SetAttr("bool_values", std::vector<int>());
op->SetAttr("fp32_values", std::vector<float>());
op->SetAttr("int32_values", std::vector<int>());
op->SetAttr("int64_values", std::vector<int64_t>());

phi::DataType dtype = phi::DataType::FLOAT32;
if (values.size()) {
dtype = values.at(0).dtype();
}

switch (dtype) {
case phi::DataType::BOOL:
op->SetAttr("bool_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::FLOAT32:
op->SetAttr("fp32_values", ExtractPlainVector<float>(values));
break;
case phi::DataType::FLOAT64:
op->SetAttr("fp32_values", ExtractPlainVector<float>(values));
break;
case phi::DataType::INT32:
op->SetAttr("int32_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::INT64:
op->SetAttr("int64_values", ExtractPlainVector<int64_t>(values));
break;
default:
PD_THROW("Invalid data type `", dtype, "`.");
}
}

void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
Expand Down Expand Up @@ -144,6 +179,9 @@ void ConvertProgram(ProgramDesc* program) {
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
if (op_type == "assign_value") {
ConvertAssignValueOp(op);
}
}
}
}
Expand Down Expand Up @@ -204,6 +242,45 @@ void ConvertSetValueOp(OpDesc* op) {
op->SetAttr("values", values);
}

void ConvertAssignValueOp(OpDesc* op) {
VLOG(3) << "convert old assign value op to new";
std::vector<paddle::experimental::Scalar> values;

if (op->HasAttr("bool_values")) {
std::vector<int> bool_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("bool_values", false));
if (bool_values.size()) {
values = WrapAsScalars(bool_values);
}
op->RemoveAttr("bool_values");
}
if (op->HasAttr("fp32_values")) {
std::vector<float> fp32_values =
PADDLE_GET_CONST(std::vector<float>, op->GetAttr("fp32_values", false));
if (fp32_values.size()) {
values = WrapAsScalars(fp32_values);
}
op->RemoveAttr("fp32_values");
}
if (op->HasAttr("int32_values")) {
std::vector<int> int32_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("int32_values", false));
if (int32_values.size()) {
values = WrapAsScalars(int32_values);
}
op->RemoveAttr("int32_values");
}
if (op->HasAttr("int64_values")) {
std::vector<int64_t> int64_values = PADDLE_GET_CONST(
std::vector<int64_t>, op->GetAttr("int64_values", false));
if (int64_values.size()) {
values = WrapAsScalars(int64_values);
}
op->RemoveAttr("int64_values");
}
op->SetAttr("values", values);
}

void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
Expand All @@ -214,6 +291,7 @@ void ConvertProgram(ProgramDesc* program) {
const std::unordered_map<std::string, uint32_t>& legacy_op_versions =
legacy_op_results.second;

VLOG(3) << "is_legacy_program : " << is_legacy_program;
if (!is_legacy_program) return;

VLOG(3) << "Updating Program Version and OpVersionMap";
Expand All @@ -232,10 +310,15 @@ void ConvertProgram(ProgramDesc* program) {
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(static_cast<int>(j));
const std::string op_type = op->Type();

if (op_type == "assign_value") {
VLOG(3) << "Converting program from old to new, op_type=" << op_type;
ConvertAssignValueOp(op);
}
if (!legacy_op_versions.count(op_type)) {
continue;
}

VLOG(3) << "Converting program from old to new, op_type=" << op_type;
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -972,19 +972,20 @@ struct AssignValueOpTranscriber : public OpTranscriber {
ctx, phi::Place(phi::AllocationType::UNDEFINED));
attribute_map["place"] = attr_place;

int dtype = paddle::get<int>(op_desc.GetAttr("dtype"));

if (dtype == /*BOOL*/ 0) {
if (op_desc.HasAttr("bool_values")) {
legacy_attr = op_desc.GetAttr("bool_values");
} else if (dtype == /*INT32*/ 2) {
legacy_attr = op_desc.GetAttr("int32_values");
} else if (dtype == /*FP32*/ 5) {
} else if (op_desc.HasAttr("fp32_values")) {
legacy_attr = op_desc.GetAttr("fp32_values");
} else if (dtype == /*INT64*/ 3) {
} else if (op_desc.HasAttr("int32_values")) {
legacy_attr = op_desc.GetAttr("int32_values");
} else if (op_desc.HasAttr("int64_values")) {
legacy_attr = op_desc.GetAttr("int64_values");
} else if (op_desc.HasAttr("values")) {
legacy_attr = op_desc.GetAttr("values");
} else {
IR_THROW(
"Op assign_value should have attribute `**_values` but not find");
"Op assign_value should have attribute `**_values` or `values` but "
"not find");
}

pir::Attribute attr_values = attribute_translator(
Expand Down
26 changes: 2 additions & 24 deletions paddle/fluid/operators/ops_signature/assign_value_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,8 @@ namespace phi {

KernelSignature AssignValueOpArgumentMapping(
const ArgumentMappingContext& ctx) {
// Here we must use `dtype` attr to determine which attr to use, we can't
// judge by whether the attr is empty, some unittests will failed
int dtype = paddle::any_cast<int>(ctx.Attr("dtype"));
// heer we can't depend on the fluid proto::VarType, so we use the dtype enum
// value directly, If the enum value is updated, the code also needs to be
// updated here, but the probability of updating the enum value is very low
if (dtype == /*BOOL*/ 0) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "bool_values"}, {"Out"});
} else if (dtype == /*INT32*/ 2) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "int32_values"}, {"Out"});
} else if (dtype == /*FP32*/ 5) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "fp32_values"}, {"Out"});
} else if (dtype == /*FP64*/ 6) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "fp64_values"}, {"Out"});
} else if (dtype == /*INT64*/ 3) {
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "int64_values"}, {"Out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
return KernelSignature(
"assign_value", {}, {"shape", "dtype", "values"}, {"Out"});
}

} // namespace phi
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/op_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ phi::Scalar ScalarAttribute::data() {
return phi::Scalar(dyn_cast<pir::BoolAttribute>().data());
} else if (isa<pir::StrAttribute>()) {
return phi::Scalar(dyn_cast<pir::StrAttribute>().AsString());
} else if (isa<pir::Complex64Attribute>()) {
return phi::Scalar(dyn_cast<pir::Complex64Attribute>().data());
} else if (isa<pir::Complex128Attribute>()) {
return phi::Scalar(dyn_cast<pir::Complex128Attribute>().data());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir attribute when casting it into "
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ class ScalarAttribute : public pir::Attribute {
(val.type_id() == pir::Int32Attribute::type_id()) ||
(val.type_id() == pir::IndexAttribute::type_id()) ||
(val.type_id() == pir::Int64Attribute::type_id()) ||
(val.type_id() == pir::StrAttribute::type_id());
(val.type_id() == pir::StrAttribute::type_id()) ||
(val.type_id() == pir::Complex64Attribute::type_id()) ||
(val.type_id() == pir::Complex128Attribute::type_id());
}

static pir::Attribute get(pir::IrContext *ctx, phi::Scalar scalar) {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/dialect/operator/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ static inline pir::Attribute TransToIrAttribute(phi::Scalar scalar,
return pir::Int64Attribute::get(ctx, scalar.to<int64_t>());
case phi::DataType::BOOL:
return pir::BoolAttribute::get(ctx, scalar.to<bool>());
case phi::DataType::COMPLEX64:
return pir::Complex64Attribute::get(
ctx, scalar.to<phi::dtype::complex<float>>());
case phi::DataType::COMPLEX128:
return pir::Complex128Attribute::get(
ctx, scalar.to<phi::dtype::complex<double>>());
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data type `%s` when casting it into "
Expand Down
21 changes: 17 additions & 4 deletions paddle/fluid/pybind/op_function_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) {
}

if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT
.find("numpy") != std::string::npos) {
.find("numpy.int") != std::string::npos) {
auto to = PyNumber_Long(*obj);
if (to) {
*obj = to;
Expand All @@ -95,8 +95,12 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
(((TensorObject*)(*obj))->tensor.numel() == 1))) { // NOLINT
return true;
}
if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT
.find("numpy") != std::string::npos) {
auto type_name =
std::string(reinterpret_cast<PyTypeObject*>((*obj)->ob_type)->tp_name);
VLOG(4) << "type_name: " << type_name;

if (type_name.find("numpy") != std::string::npos &&
type_name.find("numpy.complex") == std::string::npos) {
auto to = PyNumber_Float(*obj);
if (to) {
*obj = to;
Expand All @@ -107,11 +111,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
}

bool PyObject_CheckComplexOrToComplex(PyObject** obj) {
if (PyComplex_Check(*obj) || PyLong_Check(*obj) || PyFloat_Check(*obj) ||
if (PyComplex_Check(*obj) ||
PyObject_TypeCheck(*obj, g_vartype_pytype) || // NOLINT
PyObject_TypeCheck(*obj, p_tensor_type)) { // NOLINT
return true;
}
if (std::string(((PyTypeObject*)(*obj)->ob_type)->tp_name) // NOLINT
.find("numpy.complex") != std::string::npos) {
return true;
}
// consider numpy cfloat & numpy cdouble?
return false;
}
Expand Down Expand Up @@ -242,10 +250,15 @@ double CastPyArg2Double(PyObject* obj,
phi::dtype::complex<float> CastPyArg2Complex(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
PyTypeObject* type = obj->ob_type;
auto type_name = std::string(type->tp_name);
if (PyComplex_Check(obj)) {
double real = PyComplex_RealAsDouble(obj);
double imag = PyComplex_ImagAsDouble(obj);
return phi::dtype::complex<float>(real, imag); // NOLINT
} else if (type_name == "numpy.complex64") {
Py_complex v = PyComplex_AsCComplex(obj);
return phi::dtype::complex<float>(v.real, v.imag);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/api/yaml/op_version.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@
- delete_attr : atol
comment : The attribute 'atol' is deleted. The reason why it is deleted is that
attributes do not support a float64 value and it is changed to a tensor.
- op : assign_value
version :
- checkpoint : Upgrade assign_value, remove plain attributes in favor of generic attribute.
action :
- add_attr : values
comment : replace generic types with scalar.
default : std::vector<paddle::experimental::Scalar>()
- delete_attr : bool_values
comment : remove plain attributes.
- delete_attr : fp32_values
comment : remove plain attributes.
- delete_attr : int32_values
comment : remove plain attributes.
- delete_attr : int64_values
comment : remove plain attributes.

- op : auc
version :
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/static_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
backward : assign_grad

- op : assign_value
args : (int[] shape, DataType dtype, int[] bool_values = {}, float[] fp32_values = {}, double[] fp64_values = {}, int[] int32_values = {}, int64_t[] int64_values = {})
args : (int[] shape, DataType dtype, Scalar[] values = {})
output : Tensor(out)
infer_meta :
func : AssignValueInferMeta
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/assign_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ PD_REGISTER_KERNEL(assign_value,
float,
double,
int8_t,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL_FOR_ALL_DTYPE(assign,
Expand Down Expand Up @@ -165,7 +167,9 @@ PD_REGISTER_KERNEL(assign_value,
float,
double,
int8_t,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif

#ifdef PADDLE_WITH_XPU
Expand Down Expand Up @@ -193,5 +197,7 @@ PD_REGISTER_KERNEL(assign_value,
int,
float,
double,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif
5 changes: 5 additions & 0 deletions paddle/pir/core/builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <list>

#include "paddle/phi/common/complex.h"
#include "paddle/pir/core/block.h"
#include "paddle/pir/core/ir_context.h"
#include "paddle/pir/core/operation.h"
Expand Down Expand Up @@ -44,6 +45,8 @@ class Int64Attribute;
class ArrayAttribute;
class PointerAttribute;
class TensorNameAttribute;
class Complex64Attribute;
class Complex128Attribute;

using InsertionPoint = std::pair<Block *, Block::Iterator>;
///
Expand Down Expand Up @@ -150,6 +153,8 @@ class Builder {
IR_API ArrayAttribute array_attr(const std::vector<Attribute> &value);
IR_API PointerAttribute pointer_attr(void *value);
IR_API TensorNameAttribute tensor_name_attr(const std::string &value);
IR_API Complex64Attribute complex64_attr(phi::dtype::complex<float> value);
IR_API Complex128Attribute complex128_attr(phi::dtype::complex<double> value);

private:
Operation *Insert(Operation *op);
Expand Down
10 changes: 10 additions & 0 deletions paddle/pir/core/builtin_attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ void* PointerAttribute::data() const { return storage()->data(); }

Type TypeAttribute::data() const { return storage()->data(); }

phi::dtype::complex<float> Complex64Attribute::data() const {
return storage()->data();
}

phi::dtype::complex<double> Complex128Attribute::data() const {
return storage()->data();
}

bool StrAttribute::operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
Expand Down Expand Up @@ -109,3 +117,5 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::ArrayAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::PointerAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::TypeAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::TensorNameAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex64Attribute)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::Complex128Attribute)
Loading

0 comments on commit c1d7860

Please sign in to comment.