Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge branch 'develop' into sort
Browse files Browse the repository at this point in the history
# Conflicts:
#	cinn/frontend/net_builder_test.cc
#	cinn/hlir/op/contrib/CMakeLists.txt
  • Loading branch information
zrr1999 committed Sep 7, 2022
2 parents bd79ec5 + 68dfadf commit 2ca7fb1
Show file tree
Hide file tree
Showing 68 changed files with 2,897 additions and 384 deletions.
5 changes: 3 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ message(STATUS "PYTHON_LIBRARIES: ${PYTHON_LIBRARIES}")
message(STATUS "PYTHON_INCLUDE_DIR: ${PYTHON_INCLUDE_DIR}")

INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR})
cc_library(cinnapi SHARED SRCS ${cinnapi_src} DEPS glog ${llvm_libs} framework_proto param_proto framework_proto absl isl ginac pybind)
cc_library(cinnapi SHARED SRCS ${cinnapi_src} DEPS glog ${llvm_libs} framework_proto param_proto auto_schedule_proto framework_proto absl isl ginac pybind)
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})

Expand All @@ -133,7 +133,7 @@ function(gen_cinncore LINKTYPE)
if (${LINKTYPE} STREQUAL "STATIC")
set(CINNCORE_TARGET cinncore_static)
endif()
cc_library(${CINNCORE_TARGET} ${LINKTYPE} SRCS ${core_src} DEPS glog ${llvm_libs} framework_proto param_proto framework_proto absl isl ginac)
cc_library(${CINNCORE_TARGET} ${LINKTYPE} SRCS ${core_src} DEPS glog ${llvm_libs} framework_proto param_proto auto_schedule_proto framework_proto absl isl ginac)
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})

Expand Down Expand Up @@ -205,6 +205,7 @@ if (PUBLISH_LIBS)
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/libcinncore_static.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libcinncore_static.a
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/frontend/paddle/libframework_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libframework_proto.a
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/hlir/pe/libparam_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libparam_proto.a
COMMAND cmake -E copy ${CMAKE_BINARY_DIR}/cinn/auto_schedule/libauto_schedule_proto.a ${CMAKE_BINARY_DIR}/dist/cinn/lib/libauto_schedule_proto.a
COMMENT "distribute libcinncore_static.a and related header files."
DEPENDS cinncore_static
)
Expand Down
6 changes: 6 additions & 0 deletions cinn/auto_schedule/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@ add_subdirectory(search_strategy)
add_subdirectory(task)
add_subdirectory(task_scheduler)

proto_library(auto_schedule_proto SRCS auto_schedule.proto)

core_gather_headers()

gather_srcs(cinnapi_src SRCS auto_tuner.cc)

cc_test(test_auto_tuner SRCS auto_tuner_test.cc DEPS cinncore)

foreach(header ${auto_schedule_proto_HDRS})
set(core_proto_includes "${core_proto_includes};${header}" CACHE INTERNAL "")
endforeach()
22 changes: 22 additions & 0 deletions cinn/auto_schedule/auto_schedule.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

syntax ="proto3";

package cinn.auto_schedule.proto;

message TuningRecord {
string task_key = 1;
double execution_cost = 2;
}
1 change: 1 addition & 0 deletions cinn/auto_schedule/auto_tuner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "cinn/auto_schedule/auto_tuner.h"

#include <glog/logging.h>
#include <pybind11/embed.h>

#include <algorithm>
#include <memory>
Expand Down
63 changes: 43 additions & 20 deletions cinn/auto_schedule/auto_tuner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/ir/ir_base.h"
#include "cinn/runtime/flags.h"

DECLARE_bool(auto_schedule_use_cost_model);

namespace cinn {
namespace auto_schedule {
Expand Down Expand Up @@ -88,6 +91,32 @@ class TestAutoTuner : public ::testing::Test {
ASSERT_EQ(2, runtime_program->size());
runtime_program->Execute();
}

void ZeroMeasure() {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";

TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
}

void NonZeroMeasure() {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";

TuningOptions tuning_options;
tuning_options.num_measure_trials = 4;
tuning_options.num_samples_per_iteration = 2;

auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
}
};

frontend::Program TestAutoTuner::CreateAddReluProgram() {
Expand All @@ -101,30 +130,24 @@ frontend::Program TestAutoTuner::CreateAddReluProgram() {
return builder.Build();
}

TEST_F(TestAutoTuner, ZeroMeasure) {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";

TuningOptions tuning_options;
tuning_options.num_measure_trials = 0;
auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
TEST_F(TestAutoTuner, ZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
ZeroMeasure();
}

TEST_F(TestAutoTuner, NonZeroMeasure) {
// set config and options
AutoTuner::Config tuning_config;
tuning_config.task_schedule_strategy = "round_robin";
TEST_F(TestAutoTuner, ZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
ZeroMeasure();
}

TuningOptions tuning_options;
tuning_options.num_measure_trials = 4;
tuning_options.num_samples_per_iteration = 2;
TEST_F(TestAutoTuner, NonZeroMeasure_DisableCostModel) {
FLAGS_auto_schedule_use_cost_model = false;
NonZeroMeasure();
}

auto result = InitializeAndTune(tuning_config, tuning_options);
BasicCheckResult(result);
ApplyTunedAndRun(result);
TEST_F(TestAutoTuner, NonZeroMeasure_EnableCostModel) {
FLAGS_auto_schedule_use_cost_model = true;
NonZeroMeasure();
}

} // namespace auto_schedule
Expand Down
14 changes: 4 additions & 10 deletions cinn/auto_schedule/cost_model/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
core_gather_headers()

gather_srcs(cinnapi_src SRCS cost_model.cc)
gather_srcs(cinnapi_src SRCS xgb_cost_model.cc expr_cost_model.cc feature.cc feature_extractor.cc)

set(Python_VIRTUALENV FIRST)
find_package(PythonInterp ${PY_VERSION} REQUIRED)
find_package(PythonLibs ${PY_VERSION} REQUIRED)

if (WITH_TESTING)
cc_test(test_cost_model SRCS cost_model_test.cc cost_model.cc DEPS pybind gtest_main)

target_link_libraries(test_cost_model ${PYTHON_LIBRARIES})
endif()
cc_test(test_xgb_cost_model SRCS xgb_cost_model_test.cc DEPS cinncore)
cc_test(test_feature_extractor SRCS feature_extractor_test.cc DEPS cinncore)
cc_test(test_feature SRCS feature_test.cc DEPS cinncore)
77 changes: 77 additions & 0 deletions cinn/auto_schedule/cost_model/expr_cost_model.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/auto_schedule/cost_model/expr_cost_model.h"

#include <glog/logging.h>

#include <atomic>
#include <vector>

#include "cinn/auto_schedule/cost_model/feature.h"
#include "cinn/auto_schedule/cost_model/feature_extractor.h"
#include "cinn/auto_schedule/search_space/search_state.h"
#include "cinn/common/target.h"
#include "cinn/ir/ir_schedule.h"

namespace cinn {
namespace auto_schedule {

float ExprCostModel::Predict(const ir::ModuleExpr& sample, const common::Target& target) const {
if (trained_times_.load() == 0) {
return SearchState::NOT_INIT_COST;
}
FeatureExtractor extractor;
Feature feature = extractor.Extract(sample, target);
std::vector<float> feature_numbers = feature.ToFixedSizeVector();
std::vector<float> pred = XgbCostModel::Predict({feature_numbers});
return pred[0];
}

void ExprCostModel::Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target) {
trained_times_.store(1);
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector();
}

XgbCostModel::Train(train_feature_numbers, labels);
}

void ExprCostModel::Update(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target) {
++trained_times_;
size_t total_size = samples.size();
CHECK_EQ(total_size, labels.size()) << "Samples must have same size as labels";
std::vector<std::vector<float>> train_feature_numbers(total_size);
FeatureExtractor extractor;
for (size_t i = 0; i < total_size; ++i) {
CHECK(samples[i] != nullptr) << "Train samples cannot be nullptr";
Feature feature = extractor.Extract(*samples[i], target);
train_feature_numbers[i] = feature.ToFixedSizeVector();
}

XgbCostModel::Update(train_feature_numbers, labels);
}

} // namespace auto_schedule
} // namespace cinn
45 changes: 45 additions & 0 deletions cinn/auto_schedule/cost_model/expr_cost_model.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <atomic>
#include <vector>

#include "cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include "cinn/ir/ir_schedule.h"

namespace cinn {
namespace auto_schedule {

/**
* A C++ cost model which trains and predicts on ir::Expr
*
*/
class ExprCostModel : public XgbCostModel {
public:
float Predict(const ir::ModuleExpr& sample, const common::Target& target) const;
void Train(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target);
void Update(const std::vector<const ir::ModuleExpr*>& samples,
const std::vector<float>& labels,
const common::Target& target);

private:
std::atomic<int> trained_times_{0};
};

} // namespace auto_schedule
} // namespace cinn
Loading

0 comments on commit 2ca7fb1

Please sign in to comment.