-
Notifications
You must be signed in to change notification settings - Fork 0
/
SegmentationAccuracy.m
60 lines (50 loc) · 1.72 KB
/
SegmentationAccuracy.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
classdef SegmentationAccuracy < dagnn.Loss
properties (Transient)
pixelAccuracy = 0
meanAccuracy = 0
meanIntersectionUnion = 0
confusion = 0
end
methods
function outputs = forward(obj, inputs, params)
[~,predictions] = max(inputs{1}, [], 3) ;
predictions = gather(predictions) ;
labels = gather(inputs{2}) ;
% compute statistics only on accumulated pixels
ok = labels > 0 ;
numPixels = sum(ok(:)) ;
obj.confusion = obj.confusion + ...
accumarray([labels(ok),predictions(ok)],1,[21 21]) ;
% compute various statistics of the confusion matrix
pos = sum(obj.confusion,2) ;
res = sum(obj.confusion,1)' ;
tp = diag(obj.confusion) ;
obj.pixelAccuracy = sum(tp) / max(1,sum(obj.confusion(:))) ;
obj.meanAccuracy = mean(tp ./ max(1, pos)) ;
obj.meanIntersectionUnion = mean(tp ./ max(1, pos + res - tp)) ;
obj.average = [obj.pixelAccuracy ; obj.meanAccuracy ; obj.meanIntersectionUnion] ;
obj.numAveraged = obj.numAveraged + numPixels ;
outputs{1} = obj.average ;
end
function [derInputs, derParams] = backward(obj, inputs, params, derOutputs)
derInputs{1} = [] ;
derInputs{2} = [] ;
derParams = {} ;
end
function reset(obj)
obj.confusion = 0 ;
obj.pixelAccuracy = 0 ;
obj.meanAccuracy = 0 ;
obj.meanIntersectionUnion = 0 ;
obj.average = [0;0;0] ;
obj.numAveraged = 0 ;
end
function str = toString(obj)
str = sprintf('acc:%.2f, mAcc:%.2f, mIU:%.2f', ...
obj.pixelAccuracy, obj.meanAccuracy, obj.meanIntersectionUnion) ;
end
function obj = SegmentationAccuracy(varargin)
obj.load(varargin) ;
end
end
end