Skip to content

Commit

Permalink
[CINN] Strong constraint branch (PaddlePaddle#58719)
Browse files Browse the repository at this point in the history
* Strong Constraint Branch

* Change UpdateOpLoweredFuncKey location (PaddlePaddle#86)

* Remove useless parameter (PaddlePaddle#87)

* Change codes according to comments (PaddlePaddle#89)

* Delete useless code (PaddlePaddle#91)
  • Loading branch information
jiahy0825 authored Nov 13, 2023
1 parent 3a7fd89 commit 5511ae0
Show file tree
Hide file tree
Showing 110 changed files with 11,251 additions and 55 deletions.
1 change: 1 addition & 0 deletions paddle/cinn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ if(WITH_TESTING)
cinn_cc_library(cinn_gtest_main SRCS gtest_main.cc DEPS gtest ${flags_dep})
endif()

add_subdirectory(adt)
add_subdirectory(api)
add_subdirectory(ast_gen_ius)
add_subdirectory(auto_schedule)
Expand Down
35 changes: 35 additions & 0 deletions paddle/cinn/adt/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
add_subdirectory(print_utils)

core_gather_headers()

gather_srcs(
cinnapi_src
SRCS
anchor_sd_equation_context.cc
equation_function.cc
equation_solver.cc
equation_value.cc
generate_map_expr.cc
get_sub_reshape_dim_ranges.cc
igroup.cc
index_expr_infer_context.cc
kgroup.cc
m_ir.cc
naive_bidirection_equation_generator.cc
naive_op_equation_context.cc
partition_op_stmts.cc
schedule_descriptor.cc
schedule_dim.cc
schedule_mesh.cc
simplify_value.cc
write_broadcast_disabled_bidirection_equation_generator.cc)

cinn_cc_test(equation_value_match_trait_test SRCS
equation_value_match_trait_test.cc DEPS gtest glog)

cinn_cc_test(tree_test SRCS tree_test.cc DEPS gtest glog)

cinn_cc_test(inline_translator_test SRCS inline_translator_test.cc DEPS
cinncore)

message(STATUS "ADT srcs: ${cinnapi_src}")
71 changes: 71 additions & 0 deletions paddle/cinn/adt/adapter_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "glog/logging.h"

#include "paddle/cinn/adt/adt.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/node.h"

namespace cinn::adt::adapter {

struct Tensor final {
const hlir::framework::NodeData* node_data;
const hlir::framework::Graph* graph;

bool operator==(const Tensor& other) const {
return this->node_data == other.node_data && this->graph == other.graph;
}

std::size_t GetRank() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
return shape_dict.at(node_data->id()).size();
}

const std::vector<int32_t>& GetShape() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
return shape_dict.at(node_data->id());
}

std::size_t GetNumel() const {
const auto& shape_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, utils::ShapeType>>(
"infershape");
CHECK(shape_dict.count(node_data->id()))
<< "Can't find " << node_data->id() << " 's shape!";
std::vector<int32_t> shape = shape_dict.at(node_data->id());
std::size_t ret = 1;
for (int32_t dim_size : shape) {
ret = ret * dim_size;
}
return ret;
}
};

inline std::size_t GetHashValueImpl(const Tensor& tensor) {
return hash_combine(
std::hash<const hlir::framework::NodeData*>()(tensor.node_data),
std::hash<const hlir::framework::Graph*>()(tensor.graph));
}

} // namespace cinn::adt::adapter
Loading

0 comments on commit 5511ae0

Please sign in to comment.