-
Notifications
You must be signed in to change notification settings - Fork 771
/
TNNBlazeFaceDetectorViewModel.mm
109 lines (93 loc) · 4.62 KB
/
TNNBlazeFaceDetectorViewModel.mm
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Tencent is pleased to support the open source community by making TNN available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#import "TNNBlazeFaceDetectorViewModel.h"
#import "blazeface_detector.h"
using namespace std;
@implementation TNNBlazeFaceDetectorViewModel
-(Status)loadNeuralNetworkModel:(TNNComputeUnits)units {
Status status = TNN_OK;
// check release mode at Product->Scheme when running
//运行时请在Product->Scheme中确认已经调整到release模式
// Get metallib path from app bundle
// PS:A script(Build Phases -> Run Script) is added to copy the metallib
// file from tnn framework project to TNNExamples app
//注意:此工程添加了脚本将tnn工程生成的tnn.metallib自动复制到app内
auto library_path = [[NSBundle mainBundle] pathForResource:@"tnn.metallib" ofType:nil];
auto model_path = [[NSBundle mainBundle] pathForResource:@"model/blazeface/blazeface.tnnmodel"
ofType:nil];
auto proto_path = [[NSBundle mainBundle] pathForResource:@"model/blazeface/blazeface.tnnproto"
ofType:nil];
auto anchor_path = [[NSBundle mainBundle] pathForResource:@"model/blazeface/blazeface_anchors.txt"
ofType:nil];
if (model_path.length <= 0 || proto_path.length <= 0 || anchor_path.length <= 0) {
status = Status(TNNERR_NET_ERR, "Error: proto or model or anchor path is invalid");
NSLog(@"Error: proto or model or anchor path is invalid");
return status;
}
NSString *protoFormat = [NSString stringWithContentsOfFile:proto_path
encoding:NSUTF8StringEncoding
error:nil];
string proto_content = protoFormat.UTF8String;
NSData *data = [NSData dataWithContentsOfFile:model_path];
string model_content = [data length] > 0 ? string((const char *)[data bytes], [data length]) : "";
if (proto_content.size() <= 0 || model_content.size() <=0) {
status = Status(TNNERR_NET_ERR, "Error: proto or model path is invalid");
NSLog(@"Error: proto or model path is invalid");
return status;
}
auto option = std::make_shared<BlazeFaceDetectorOption>();
{
option->proto_content = proto_content;
option->model_content = model_content;
option->library_path = library_path.UTF8String;
option->compute_units = units;
option->cache_path = NSTemporaryDirectory().UTF8String;
option->input_width = 128;
option->input_height = 128;
//min_score_thresh
option->min_score_threshold = 0.75;
//min_suppression_thresh
option->min_suppression_threshold = 0.3;
//predefined anchor file path
option->anchor_path = string(anchor_path.UTF8String);
}
auto predictor = std::make_shared<BlazeFaceDetector>();
status = predictor->Init(option);
BenchOption bench_option;
bench_option.forward_count = 1;
predictor->SetBenchOption(bench_option);
//考虑多线程安全,最好初始化完全没问题后再赋值给成员变量
//for muti-thread safety, copy to member variable after allocate
self.predictor = predictor;
return status;
}
-(std::vector<std::shared_ptr<ObjectInfo> >)getObjectList:(std::shared_ptr<TNNSDKOutput>)sdk_output {
std::vector<std::shared_ptr<ObjectInfo> > object_list;
if (sdk_output && dynamic_cast<BlazeFaceDetectorOutput *>(sdk_output.get())) {
auto face_output = dynamic_cast<BlazeFaceDetectorOutput *>(sdk_output.get());
for (auto item : face_output->face_list) {
auto face = std::make_shared<BlazeFaceInfo>();
*face = item;
object_list.push_back(face);
}
}
return object_list;
}
-(NSString*)labelForObject:(std::shared_ptr<ObjectInfo>)object {
if (object) {
return [NSString stringWithFormat:@"%.2f", object->score];
}
return nil;
}
@end