diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 09c84cf176c7e..06dfc406dd79e 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -684,16 +684,26 @@ if(ON_INFER OR WITH_GPU) endif() # IPU -if (WITH_IPU) - set(INFERENCE_EXTRA_DEPS paddle_inference_api paddle_inference_io ir_pass_manager analysis_predictor benchmark) - # ERNIE from test_analyzer_ernie +if (WITH_IPU) + #word2vec sample + set(WORD2VEC_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/word2vec/word2vec.inference.model") + inference_analysis_test(ipu_word2vec_sample SRCS ipu_word2vec_sample.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${WORD2VEC_INSTALL_DIR}) + + # ERNIE set(ERNIE_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/Ernie") inference_analysis_api_test(ipu_ernie_test ${ERNIE_INSTALL_DIR} ipu_ernie_test.cc ARGS --warmup=true --repeat=1000) + inference_analysis_api_test(ipu_ernie_fp16_test ${ERNIE_INSTALL_DIR} ipu_ernie_fp16_test.cc + ARGS --warmup=true --repeat=1000) - #resnet50 + # Resnet50 set(RESNET50_MODEL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/resnet50") inference_analysis_test(ipu_resnet50_test SRCS ipu_resnet50_test.cc EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=1000) + inference_analysis_test(ipu_resnet50_fp16_test SRCS ipu_resnet50_fp16_test.cc + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} + ARGS --infer_model=${RESNET50_MODEL_DIR} --warmup=true --repeat=1000) endif() diff --git a/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc b/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc new file mode 100644 index 0000000000000..1f2e38cc342b1 --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_ernie_fp16_test.cc @@ -0,0 +1,184 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +using paddle::PaddleTensor; + +template +void GetValueFromStream(std::stringstream *ss, T *t) { + (*ss) >> (*t); +} + +template <> +void GetValueFromStream(std::stringstream *ss, std::string *t) { + *t = ss->str(); +} + +// Split string to vector +template +void Split(const std::string &line, char sep, std::vector *v) { + std::stringstream ss; + T t; + for (auto c : line) { + if (c != sep) { + ss << c; + } else { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } + } + + if (!ss.str().empty()) { + GetValueFromStream(&ss, &t); + v->push_back(std::move(t)); + ss.str({}); + ss.clear(); + } +} + +// Parse tensor from string +template +bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { + std::vector data; + Split(field, ':', &data); + if (data.size() < 2) return false; + + std::string shape_str = data[0]; + + std::vector shape; + Split(shape_str, ' ', &shape); + + std::string mat_str = data[1]; + + std::vector mat; + Split(mat_str, ' ', &mat); + + tensor->shape = shape; + auto size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) * + sizeof(T); + tensor->data.Resize(size); + std::copy(mat.begin(), mat.end(), static_cast(tensor->data.data())); + tensor->dtype = GetPaddleDType(); + + return true; +} + +// Parse input tensors from string +bool ParseLine(const std::string &line, + std::vector *tensors) { + std::vector fields; + Split(line, ';', &fields); + + tensors->clear(); + tensors->reserve(4); + + int i = 0; + auto input_name = FLAGS_ernie_large ? "eval_placeholder_" : "placeholder_"; + for (; i < 3; i++) { + paddle::PaddleTensor temp; + ParseTensor(fields[i], &temp); + temp.name = input_name + std::to_string(i); + tensors->push_back(temp); + } + + // input_mask + paddle::PaddleTensor input_mask; + ParseTensor(fields[i], &input_mask); + // fp32 to fp16 + ConvertFP32toFP16(input_mask); + input_mask.name = input_name + std::to_string(i); + tensors->push_back(input_mask); + + return true; +} + +bool LoadInputData(std::vector> *inputs, + int batch_size = 1) { + if (FLAGS_infer_data.empty()) { + LOG(ERROR) << "please set input data path"; + return false; + } + + std::ifstream fin(FLAGS_infer_data); + std::string line; + int sample = 0; + + // The unit-test dataset only have 10 samples, each sample have 5 feeds. + while (std::getline(fin, line)) { + std::vector feed_data; + ParseLine(line, &feed_data); + inputs->push_back(std::move(feed_data)); + sample++; + if (!FLAGS_test_all_data && sample == batch_size) break; + } + LOG(INFO) << "number of samples: " << sample; + return true; +} + +void SetConfig(AnalysisConfig *cfg, int batch_size = 1) { + cfg->SetModel(FLAGS_infer_model); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + cfg->EnableIpu(1, batch_size, false); + // ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, + // ipu_enable_half_partial + cfg->SetIpuConfig(true, 1, 1.0, true); +} + +// Compare results +TEST(Analyzer_Ernie_ipu, compare_results) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + LoadInputData(&input_slots_all); + + std::ifstream fin(FLAGS_refer_result); + std::string line; + std::vector ref; + + while (std::getline(fin, line)) { + Split(line, ' ', &ref); + } + + auto predictor = CreateTestPredictor( + reinterpret_cast(&cfg), + FLAGS_use_analysis); + + std::vector outputs; + for (size_t i = 0; i < input_slots_all.size(); i++) { + outputs.clear(); + predictor->Run(input_slots_all[i], &outputs); + + auto output = outputs.front(); + ConvertFP16toFP32(output); + auto outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float *fp32_data = reinterpret_cast(output.data.data()); + for (size_t j = 0; j < outputs_size; ++j) { + EXPECT_NEAR(ref[i * outputs_size + j], fp32_data[j], 5e-3); + } + } +} + +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_ernie_test.cc b/paddle/fluid/inference/tests/api/ipu_ernie_test.cc index 1438f4705728a..28b964a1c955f 100644 --- a/paddle/fluid/inference/tests/api/ipu_ernie_test.cc +++ b/paddle/fluid/inference/tests/api/ipu_ernie_test.cc @@ -131,27 +131,15 @@ bool LoadInputData(std::vector> *inputs, return true; } -void SetConfig(AnalysisConfig *cfg, bool use_ipu = false, int batch_size = 1) { +void SetConfig(AnalysisConfig *cfg, int batch_size = 1) { cfg->SetModel(FLAGS_infer_model); - if (use_ipu) { - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - cfg->EnableIpu(4, false, 1, batch_size, true); - } -} - -void SetPipelineConfig(AnalysisConfig *cfg, bool use_ipu = false) { - cfg->SetModel(FLAGS_infer_model); - if (use_ipu) { - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - cfg->EnableIpu(4, true, 4, 1, true); - } + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + cfg->EnableIpu(1, batch_size, false); } -void profile(bool use_ipu = false) { +void profile() { AnalysisConfig config; - SetConfig(&config, use_ipu); + SetConfig(&config); std::vector> outputs; std::vector> inputs; @@ -161,12 +149,12 @@ void profile(bool use_ipu = false) { } // performance profile -TEST(Analyzer_Ernie_ipu, performance_profile) { profile(true); } +TEST(Analyzer_Ernie_ipu, performance_profile) { profile(); } // Compare Deterministic result TEST(Analyzer_Ernie_ipu, compare_determine) { AnalysisConfig cfg; - SetConfig(&cfg, true); + SetConfig(&cfg); std::vector> input_slots_all; LoadInputData(&input_slots_all); @@ -177,7 +165,7 @@ TEST(Analyzer_Ernie_ipu, compare_determine) { // Compare results TEST(Analyzer_Ernie_ipu, compare_results) { AnalysisConfig cfg; - SetConfig(&cfg, true); + SetConfig(&cfg); std::vector> input_slots_all; LoadInputData(&input_slots_all); @@ -207,38 +195,5 @@ TEST(Analyzer_Ernie_ipu, compare_results) { } } -// Compare pipeline result -// TEST(Analyzer_Ernie_pipeline, compare_results) { -// AnalysisConfig cfg; -// SetPipelineConfig(&cfg, true); - -// std::vector> input_slots_all; -// LoadInputData(&input_slots_all, 4); - -// std::ifstream fin(FLAGS_refer_result); -// std::string line; -// std::vector ref; - -// while (std::getline(fin, line)) { -// Split(line, ' ', &ref); -// } - -// auto predictor = CreateTestPredictor( -// reinterpret_cast(&cfg), -// FLAGS_use_analysis); - -// std::vector outputs; -// for (size_t i = 0; i < input_slots_all.size(); i++) { -// outputs.clear(); -// predictor->Run(input_slots_all[i], &outputs); -// auto outputs_size = outputs.front().data.length() / (sizeof(float)); -// for (size_t j = 0; j < outputs_size; ++j) { -// EXPECT_NEAR(ref[i * outputs_size + j], -// static_cast(outputs[0].data.data())[j], -// FLAGS_accuracy); -// } -// } -// } - } // namespace inference } // namespace paddle diff --git a/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc b/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc new file mode 100644 index 0000000000000..1c31d51a873a8 --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_resnet50_fp16_test.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "gflags/gflags.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { + +// Compare results with 1 batch +TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { + std::string model_dir = FLAGS_infer_model + "/" + "model"; + AnalysisConfig config; + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 1, false); + // ipu_enable_fp16, ipu_replica_num, ipu_available_memory_proportion, + // ipu_enable_half_partial + config.SetIpuConfig(true, 1, 1.0, true); + config.SetModel(model_dir + "/model", model_dir + "/params"); + + std::vector inputs; + auto predictor = CreatePaddlePredictor(config); + const int batch = 1; + const int channel = 3; + const int height = 318; + const int width = 318; + const int input_num = batch * channel * height * width; + std::vector input(input_num, 1); + + PaddleTensor in; + in.shape = {batch, channel, height, width}; + in.data = + PaddleBuf(static_cast(input.data()), input_num * sizeof(float)); + in.dtype = PaddleDType::FLOAT32; + ConvertFP32toFP16(in); + inputs.emplace_back(in); + + std::vector outputs; + + ASSERT_TRUE(predictor->Run(inputs, &outputs)); + + const std::vector truth_values = { + 127.779f, 738.165f, 1013.22f, -438.17f, 366.401f, 927.659f, + 736.222f, -633.684f, -329.927f, -430.155f, -633.062f, -146.548f, + -1324.28f, -1349.36f, -242.675f, 117.448f, -801.723f, -391.514f, + -404.818f, 454.16f, 515.48f, -133.031f, 69.293f, 590.096f, + -1434.69f, -1070.89f, 307.074f, 400.525f, -316.12f, -587.125f, + -161.056f, 800.363f, -96.4708f, 748.706f, 868.174f, -447.938f, + 112.737f, 1127.2f, 47.4355f, 677.72f, 593.186f, -336.4f, + 551.362f, 397.823f, 78.3979f, -715.398f, 405.969f, 404.256f, + 246.019f, -8.42969f, 131.365f, -648.051f}; + + const size_t expected_size = 1; + EXPECT_EQ(outputs.size(), expected_size); + + auto output = outputs.front(); + ConvertFP16toFP32(output); + auto outputs_size = 1; + for (auto dim : output.shape) { + outputs_size *= dim; + } + float* fp32_data = reinterpret_cast(output.data.data()); + + for (size_t j = 0; j < outputs_size; j += 10) { + EXPECT_NEAR((fp32_data[j] - truth_values[j / 10]) / truth_values[j / 10], + 0., 9e-2); + } +} + +} // namespace inference +} // namespace paddle \ No newline at end of file diff --git a/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc b/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc index f270784c570f6..a8afb449371e0 100644 --- a/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc +++ b/paddle/fluid/inference/tests/api/ipu_resnet50_test.cc @@ -26,9 +26,8 @@ namespace inference { TEST(Analyzer_Resnet50_ipu, performance_profile) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - config.EnableIpu(1, false); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 1, false); config.SetModel(model_dir + "/model", model_dir + "/params"); std::vector inputs; @@ -55,9 +54,8 @@ TEST(Analyzer_Resnet50_ipu, performance_profile) { TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - config.EnableIpu(1, false); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 1, false); config.SetModel(model_dir + "/model", model_dir + "/params"); std::vector inputs; @@ -105,9 +103,8 @@ TEST(Analyzer_Resnet50_ipu, compare_results_1_batch) { TEST(Analyzer_Resnet50_ipu, compare_results_2_batch) { std::string model_dir = FLAGS_infer_model + "/" + "model"; AnalysisConfig config; - // num_ipu, enable_pipelining, batches_per_step, batch_size, - // need_avg_shard - config.EnableIpu(2, false, 1, 2, 1); + // ipu_device_num, ipu_micro_batch_size, ipu_enable_pipelining + config.EnableIpu(1, 2, false); config.SetModel(model_dir + "/model", model_dir + "/params"); std::vector inputs; diff --git a/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc b/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc new file mode 100644 index 0000000000000..14efd1e0746cc --- /dev/null +++ b/paddle/fluid/inference/tests/api/ipu_word2vec_sample.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +/* + * This file contains a simple demo for how to take a model for inference with + * IPUs. + * Model: wget -q + * http://paddle-inference-dist.bj.bcebos.com/word2vec.inference.model.tar.gz + */ + +#include +#include +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "paddle/fluid/inference/api/paddle_inference_api.h" + +DEFINE_string(infer_model, "", "Directory of the inference model."); + +using paddle_infer::Config; +using paddle_infer::Predictor; +using paddle_infer::CreatePredictor; + +void inference(std::string model_path, bool use_ipu, + std::vector *out_data) { + //# 1. Create Predictor with a config. + Config config; + config.SetModel(FLAGS_infer_model); + if (use_ipu) { + // ipu_device_num, ipu_micro_batch_size + config.EnableIpu(1, 4); + } + auto predictor = CreatePredictor(config); + + //# 2. Prepare input/output tensor. + auto input_names = predictor->GetInputNames(); + std::vector data{1, 2, 3, 4}; + // For simplicity, we set all the slots with the same data. + for (auto input_name : input_names) { + auto input_tensor = predictor->GetInputHandle(input_name); + input_tensor->Reshape({4, 1}); + input_tensor->CopyFromCpu(data.data()); + } + + //# 3. Run + predictor->Run(); + + //# 4. Get output. + auto output_names = predictor->GetOutputNames(); + auto output_tensor = predictor->GetOutputHandle(output_names[0]); + std::vector output_shape = output_tensor->shape(); + int out_num = std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + out_data->resize(out_num); + output_tensor->CopyToCpu(out_data->data()); +} + +int main(int argc, char *argv[]) { + ::GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, true); + std::vector ipu_result; + std::vector cpu_result; + inference(FLAGS_infer_model, true, &ipu_result); + inference(FLAGS_infer_model, false, &cpu_result); + for (size_t i = 0; i < ipu_result.size(); i++) { + CHECK_NEAR(ipu_result[i], cpu_result[i], 1e-6); + } + LOG(INFO) << "Finished"; +} \ No newline at end of file diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 170b915ec7436..d1fd8234d218b 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -77,6 +77,7 @@ namespace paddle { namespace inference { using paddle::framework::proto::VarType; +using float16 = paddle::platform::float16; template constexpr paddle::PaddleDType GetPaddleDType(); @@ -1057,5 +1058,42 @@ static bool CompareTensor(const framework::LoDTensor &a, return true; } +void ConvertFP32toFP16(paddle::PaddleTensor &tensor) { + int num = 1; + for (auto dim : tensor.shape) { + num *= dim; + } + PADDLE_ENFORCE_EQ( + tensor.dtype, PaddleDType::FLOAT32, + platform::errors::InvalidArgument( + "The tensor dtype is not float32, only support float32 as input")); + float *fp32_data = reinterpret_cast(tensor.data.data()); + float16 *fp16_data = (float16 *)std::malloc(num * sizeof(float16)); + for (int i = 0; i < num; i++) { + fp16_data[i] = float16(fp32_data[i]); + } + tensor.data = + PaddleBuf(static_cast(fp16_data), num * sizeof(float16)); + tensor.dtype = PaddleDType::FLOAT16; +} + +void ConvertFP16toFP32(paddle::PaddleTensor &tensor) { + int num = 1; + for (auto dim : tensor.shape) { + num *= dim; + } + PADDLE_ENFORCE_EQ( + tensor.dtype, PaddleDType::FLOAT16, + platform::errors::InvalidArgument( + "The tensor dtype is not float16, only support float16 as input")); + float16 *fp16_data = reinterpret_cast(tensor.data.data()); + float *fp32_data = (float *)std::malloc(num * sizeof(float)); + for (int i = 0; i < num; i++) { + fp32_data[i] = float(fp16_data[i]); + } + tensor.data = PaddleBuf(static_cast(fp32_data), num * sizeof(float)); + tensor.dtype = PaddleDType::FLOAT32; +} + } // namespace inference } // namespace paddle