Skip to content

Commit

Permalink
refine Vectorize and x86 builtin codebase (PaddlePaddle#95)
Browse files Browse the repository at this point in the history
* add Split test to test02

* add vectorize to test01 and add align to buffer
  • Loading branch information
Superjomn authored Mar 20, 2020
1 parent c6c14e7 commit ff82b53
Show file tree
Hide file tree
Showing 22 changed files with 289 additions and 66 deletions.
24 changes: 20 additions & 4 deletions cinn/backends/_x86_builtin_source.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,13 +129,13 @@ struct ExternalVec {

// AVX256 load
//@{
inline __m256 cinn_avx256_load(float* dst) { return _mm256_load_ps(dst); }
inline __m256d cinn_avx256_load(double* dst) { return _mm256_load_pd(dst); }
inline __m256 cinn_avx256_load(const float* dst) { return _mm256_load_ps(dst); }
inline __m256d cinn_avx256_load(const double* dst) { return _mm256_load_pd(dst); }
//@}
// AVX512 load
//@{
inline __m512 cinn_avx512_load(float* dst) { return _mm512_load_ps(dst); }
inline __m512d cinn_avx512_load(double* dst) { return _mm512_load_pd(dst); }
inline __m512 cinn_avx512_load(const float* dst) { return _mm512_load_ps(dst); }
inline __m512d cinn_avx512_load(const double* dst) { return _mm512_load_pd(dst); }
//@}

// FP32x8 * FP32x8
Expand Down Expand Up @@ -313,6 +313,22 @@ inline __m512 cinn_avx512_set1(float value) { return _mm512_set1_ps(value); }
inline __m512d cinn_avx512_set1(double value) { return _mm512_set1_pd(value); }
// @}

//! store
// @{
inline void cinn_avx512_store(float* dst, const __m512& x) { _mm512_store_ps(dst, x); }
inline void cinn_avx512_store(double* dst, const __m512d& x) { _mm512_store_pd(dst, x); }
inline void cinn_avx256_store(float* dst, const __m256& x) { _mm256_store_ps(dst, x); }
inline void cinn_avx256_store(double* dst, const __m256d& x) { _mm256_store_pd(dst, x); }
// @}

//! add
// @{
inline __m256 cinn_avx256_add(const __m256& a, const __m256& b) { return _mm256_add_ps(a, b); }
inline __m256d cinn_avx256_add(const __m256d& a, const __m256d& b) { return _mm256_add_pd(a, b); }
inline __m512 cinn_avx512_add(const __m512& a, const __m512& b) { return _mm512_add_ps(a, b); }
inline __m512d cinn_avx512_add(const __m512d& a, const __m512d& b) { return _mm512_add_pd(a, b); }
// @}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
/// )END Predefined utilities in CINN
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
4 changes: 2 additions & 2 deletions cinn/backends/codegen_c.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ using namespace utils;

void CodeGenC::Compile(const lang::Module &module, const Outputs &outputs) {
if (!outputs.c_header_name.empty()) {
LOG(WARNING) << "Output C source to file " << outputs.c_header_name;
auto source = Compile(module, OutputKind::CHeader);
std::ofstream file(outputs.c_header_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_header_name;
file << source;
file.close();
LOG(WARNING) << "Output C header to file " << outputs.c_header_name;
}

if (!outputs.c_source_name.empty()) {
LOG(WARNING) << "Output C source to file " << outputs.c_source_name;
auto source = Compile(module, OutputKind::CImpl);
std::ofstream file(outputs.c_source_name);
CHECK(file.is_open()) << "failed to open file " << outputs.c_source_name;
file << source;
file.close();
LOG(WARNING) << "Output C source to file " << outputs.c_source_name;
}
}

Expand Down
26 changes: 12 additions & 14 deletions cinn/backends/codegen_c_x86.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ void CodeGenCX86::Visit(const ir::Mul *op) { VisitBinaryOp(op, op->a, op->b, "mu
void CodeGenCX86::Visit(const ir::Div *op) { VisitBinaryOp(op, op->a, op->b, "div"); }

void CodeGenCX86::Visit(const ir::Load *op) {
LOG(INFO) << "visit load arguemnt";

Expr dense_strided_ramp = detail::StridedRampBase(op->index, 1);
if (dense_strided_ramp.defined()) { // Loading a continuous Ramp address.
CHECK(op->type().is_vector());
Expand Down Expand Up @@ -42,27 +44,23 @@ void CodeGenCX86::Visit(const ir::Store *op) {
int bits = op->type().bits() * op->type().lanes();
if (SupportsAVX512()) {
CHECK_EQ(bits, 512);
os() << "cinn_avx512_store(" << op->tensor.As<ir::_Tensor_>()->name << ", " << op->value << ")";
os() << "cinn_avx512_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
} else if (SupportsAVX256()) {
CHECK_EQ(bits, 256);
os() << "cinn_avx256_store(" << op->tensor.As<ir::_Tensor_>()->name << ", " << op->value << ")";
os() << "cinn_avx256_store(";
PrintAbsAddr(op);
os() << ", ";
Print(op->value);
os() << ")";
} else {
CodeGenC::Visit(op);
}
}

void CodeGenCX86::PrintAbsAddr(const ir::Load *op) {
os() << op->tensor.As<ir::_Tensor_>()->name << " + ";

auto *ramp_n = op->index.As<ir::Ramp>();
if (ramp_n) {
CHECK(!ramp_n->base.As<ir::Ramp>()) << "base of a Ramp node should not be Ramp type";
Print(ramp_n->base);
} else {
Print(op->index);
}
}

void CodeGenCX86::PrintVecInputArgument(const Expr *op) {
int bits = op->type().bits() * op->type().lanes();
auto *broadcast_n = op->As<ir::Broadcast>();
Expand Down
15 changes: 14 additions & 1 deletion cinn/backends/codegen_c_x86.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,20 @@ class CodeGenCX86 : public CodeGenC {
void PrintVecInputArgument(const Expr *op);
//! The output argument, such as the destination for Load.
void PrintVecOutputArgument(const Expr *op);
void PrintAbsAddr(const ir::Load *op);

template <typename Op>
void PrintAbsAddr(const Op *op) {
os() << op->tensor.template As<ir::_Tensor_>()->name << " + ";

auto *ramp_n = op->index.template As<ir::Ramp>();
if (ramp_n) {
CHECK(!ramp_n->base.template As<ir::Ramp>()) << "base of a Ramp node should not be Ramp type";
Print(ramp_n->base);
} else {
Print(op->index);
}
}

template <typename Op>
void VisitBinaryOp(const Op *op, Expr a, Expr b, const std::string &op_repr);
};
Expand Down
2 changes: 2 additions & 0 deletions cinn/cinn.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "cinn/backends/codegen_c.h"
#include "cinn/backends/codegen_c_x86.h"
#include "cinn/common/common.h"
#include "cinn/lang/builtin.h"
#include "cinn/lang/compute.h"
Expand All @@ -10,6 +11,7 @@
namespace cinn {

using backends::CodeGenC;
using backends::CodeGenCX86;
using backends::Outputs;
using ir::Var;
using lang::Buffer;
Expand Down
51 changes: 43 additions & 8 deletions cinn/common/graph_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <unordered_map>
#include <vector>

#include <algorithm>
#include "cinn/common/object.h"
#include "cinn/common/shared.h"

Expand Down Expand Up @@ -43,6 +44,23 @@ class GraphEdge : public Object {
GraphNode* sink_{};
};

} // namespace common
} // namespace cinn

namespace std {

template <>
struct hash<cinn::common::Shared<cinn::common::GraphEdge>> {
size_t operator()(const cinn::common::Shared<cinn::common::GraphEdge>& key) {
return reinterpret_cast<size_t>(key->source()) ^ reinterpret_cast<size_t>(key->sink());
}
};

} // namespace std

namespace cinn {
namespace common {

/**
* @brief The base class of all node of graph.
* This is used to normalize and share the graph operations.
Expand All @@ -55,17 +73,34 @@ class GraphNode : public Object {
//! Links from this to other.
template <typename EdgeT = GraphEdge>
std::tuple<EdgeT*, EdgeT*> LinkTo(GraphNode* other) {
EdgeT *a, *b;
CHECK(other);
CHECK_NE(other, this) << "cannot link to itself";
other->inlinks_.push_back(make_shared<GraphEdge>(other, this));
outlinks_.push_back(make_shared<GraphEdge>(this, other));
return std::make_tuple(static_cast<EdgeT*>(outlinks_.back().get()),
static_cast<EdgeT*>(other->inlinks().back().get()));
auto source_edge = make_shared<GraphEdge>(this, other);
auto sink_edge = make_shared<GraphEdge>(this, other);

outlinks_.insert(source_edge);
other->inlinks_.insert(sink_edge);

for (auto& item : outlinks_) {
if (item->sink()->id() == other->id()) {
a = static_cast<EdgeT*>(item.get());
break;
}
}
for (auto& item : other->inlinks_) {
if (item->sink()->id() == other->id()) {
b = static_cast<EdgeT*>(item.get());
break;
}
}
return std::make_tuple(a, b);
}

//! Get the input links of the node.
virtual std::list<Shared<GraphEdge>> inlinks() const { return inlinks_; }
virtual std::set<Shared<GraphEdge>> inlinks() const { return inlinks_; }
//! Get the output links of the node.
virtual std::list<Shared<GraphEdge>> outlinks() const { return outlinks_; }
virtual std::set<Shared<GraphEdge>> outlinks() const { return outlinks_; }
//! Get a derived pointer.
template <typename Derived>
Derived* As() {
Expand All @@ -90,10 +125,10 @@ class GraphNode : public Object {
protected:
//! The input links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid cycle reference.
std::list<common::Shared<GraphEdge>> inlinks_;
std::set<common::Shared<GraphEdge>> inlinks_;
//! The output links of the node.
//! \note We record the raw pointer rather than the shared pointer to avoid cycle reference.
std::list<common::Shared<GraphEdge>> outlinks_;
std::set<common::Shared<GraphEdge>> outlinks_;

mutable int visited_time_{};
};
Expand Down
5 changes: 3 additions & 2 deletions cinn/lang/lower.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,9 @@ std::vector<ir::Argument> PrepareArguments(const std::vector<Tensor>& tensors, c
std::vector<ir::LoweredFunc> Lower(const std::string& name, const std::vector<Tensor>& args) {
// make sure the graph's start-points in the args.

auto stages = poly::GatherStagesInTensors(args);
auto graph = poly::CreateGraph(stages);
auto stages = poly::GatherStagesInTensors(args);
auto extra_dependencies = poly::ExtractExtraDependencyFromStages(stages);
auto graph = poly::CreateGraph(stages, extra_dependencies);
LOG(INFO) << "Graph:\n" << graph->Visualize();

// Create a dic for stages and tensors.
Expand Down
6 changes: 3 additions & 3 deletions cinn/lang/tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,15 +235,15 @@ Expr _Tensor_::tensor_store_expanded_body() {
}

void _Tensor_::Bind(lang::Buffer &buffer) {
// Extract the tensors thouse has binded to this buffer.
buffer_depended_tensor_names_ = buffer.buffer()->binded_tensor_names();

buffer.buffer()->BindTo(this);
CHECK(!buffer->binded_tensor_names().empty());
this->buffer = buffer.buffer();
CHECK(this->buffer.defined());
CHECK(!inlined());

// Extract the tensors thouse has binded to this buffer.
buffer_depended_tensor_names_ = this->buffer->binded_tensor_names();

// Reset stage to nullptr to tell others this tensor should be inlined.
InitStage();
}
Expand Down
2 changes: 2 additions & 0 deletions cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

#include "cinn/optim/ir_copy.h"
#include "cinn/optim/ir_simplify.h"
#include "cinn/optim/vectorize_loops.h"

namespace cinn {
namespace optim {

Expr Optimize(Expr e) {
auto copied = IRCopy(e);
Simplify(&copied);
VectorizeLoops(&copied, Target());

return copied;
}
Expand Down
3 changes: 3 additions & 0 deletions cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,10 @@ struct VectorizeLoops_ : public IRMutator<Expr *> {
CHECK_GT(extent, 0) << "Loop over " << Expr(forloop->loop_var) << " has extent " << forloop->extent
<< ". Can only vectorize loops over a constant extent > 1";

VLOG(2) << "Vectorizing " << forloop->loop_var << " extent " << extent;
VLOG(2) << "body:\n" << node->body;
Vectorizer(forloop->loop_var, extent).Visit(&node->body);
VLOG(2) << "after vectorize body:\n" << node->body;

// Remove the forloop.
*expr = node->body;
Expand Down
3 changes: 3 additions & 0 deletions cinn/poly/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,9 @@ void IslAstExprToCinnExpr(const isl::ast_expr& node, ir::Expr* expr) {
case isl_ast_op_eq:
*expr = ir::EQ::Make(ops[0], ops[1]);
break;
case isl_ast_op_pdiv_q:
*expr = ir::Div::Make(ops[0], ops[1]);
break;
case isl_ast_op_call: {
ir::Expr caller_expr = ops.front();
// TODO(Superjomn) make it an string
Expand Down
5 changes: 4 additions & 1 deletion cinn/poly/poly_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,10 @@ std::unique_ptr<Schedule> PolyScheduler::BuildSchedule() {
PolyScheduler::PolyScheduler(const std::vector<Stage*>& stages) {
CHECK_GT(stages.size(), 0) << "No stage is provided";

dfg_ = CreateGraph(stages);
// collect extra links
auto extra_links = ExtractExtraDependencyFromStages(stages);

dfg_ = CreateGraph(stages, extra_links);

for (auto* stage : stages) {
AddStage(*stage);
Expand Down
20 changes: 20 additions & 0 deletions cinn/poly/stage.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ Stage::Stage(const isl::set &domain, Expr expr) : domain_(domain), expr_(expr) {
InitTransform();
}

std::tuple<Iterator, Iterator> Stage::Split(int level, int factor, SplitRestStrategy strategy) {
auto dim_names = GetDimNames(transform_, isl_dim_out);
auto axis_name = dim_names.at(level);
return Split(axis_name, factor, strategy);
}

std::tuple<Iterator, Iterator> Stage::Split(const Iterator &level, int factor, SplitRestStrategy strategy) {
int offset = isl_set_find_dim_by_name(transformed_domain().get(), isl_dim_set, level.id.c_str());
CHECK_GE(offset, 0) << "iterator " << level << " not in " << domain_;
Expand Down Expand Up @@ -215,5 +221,19 @@ std::string Stage::ith_dim_name(int level) {
return dims[level];
}

Iterator Stage::ith_iterator(int level) { return Iterator(ith_dim_name(level)); }

std::vector<std::pair<std::string, std::string>> ExtractExtraDependencyFromStages(const std::vector<Stage *> &stages) {
std::vector<std::pair<std::string, std::string>> extra_links;
for (auto &stage : stages) {
for (auto &tensor_name : stage->extra_depend_stages()) {
LOG(INFO) << "extra link " << tensor_name << " -> " << stage->id();
extra_links.emplace_back(tensor_name, stage->id());
}
}

return extra_links;
}

} // namespace poly
} // namespace cinn
6 changes: 6 additions & 0 deletions cinn/poly/stage.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class Stage : public Object {
Split(const Iterator& level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto);
std::tuple<Iterator, Iterator> //
Split(const std::string& level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto);
std::tuple<Iterator, Iterator> //
Split(int level, int factor, SplitRestStrategy strategy = SplitRestStrategy::kAuto);

/**
* Reorder the iterators.
Expand Down Expand Up @@ -108,6 +110,8 @@ class Stage : public Object {

//! Get the level-th dimensional name.
std::string ith_dim_name(int level);
//! Get the i-th iterator.
Iterator ith_iterator(int level);

//! Get the statements.
std::vector<std::string> input_statements() const;
Expand Down Expand Up @@ -143,6 +147,8 @@ class Stage : public Object {
std::set<std::string> extra_depend_stages_;
};

std::vector<std::pair<std::string, std::string>> ExtractExtraDependencyFromStages(const std::vector<Stage*>& stages);

struct ComputeAtRelation {
Shared<Stage> stage;
int level{-1};
Expand Down
6 changes: 5 additions & 1 deletion cinn/runtime/cinn_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,10 @@ cinn_type_t cinn_float64_t() { return cinn_type_t(cinn_type_float, 64); }

} // extern "C"

struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, cinn_type_t type, const std::vector<int>& shape) {
struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device,
cinn_type_t type,
const std::vector<int>& shape,
int align) {
int32_t dimensions = shape.size();
cinn_dimension_t* dims = new cinn_dimension_t[dimensions];
memcpy(dims, shape.data(), shape.size() * sizeof(int));
Expand All @@ -93,5 +96,6 @@ struct cinn_buffer_t* cinn_buffer_t::new_(cinn_device_kind_t device, cinn_type_t

x->dims = dims;
x->dimensions = dimensions;
x->align = align;
return x;
}
Loading

0 comments on commit ff82b53

Please sign in to comment.