Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mean IOU op. #10519

Merged
merged 10 commits into from
Jun 14, 2018
Merged

Add mean IOU op. #10519

merged 10 commits into from
Jun 14, 2018

Conversation

wanghaoshuang
Copy link
Contributor

@wanghaoshuang wanghaoshuang commented May 9, 2018

Performance on GPU P40:

input size class_num GPU CPU
1024 * 2048 100 0.168812ms 13.0831ms
1024 * 2048 50 0.172748ms 13.5145ms
1024 * 2048 20 0.174807ms 14.4619ms
1024 * 2048 10 0.188483ms 16.1516ms
1024 * 2048 1 0.230743ms 12.7893ms
1024 * 2048 *2 100 0.308306ms 26.4576
1024 * 2048 * 2 50 0.326073ms 26.9835ms
1024 * 2048 *2 20 0.28971ms 29.0224ms
1024 * 2048* 2 10 0.267694ms 34.2029ms
1024 * 2048 * 2 1 0.295844ms 25.4808ms

test code:

import paddle.fluid as fluid
import numpy as np
from paddle.fluid import core

num_classes=1
input_size=1024 * 2048 * 2

images = fluid.layers.data(name='image', shape=[1], dtype='int32')
label = fluid.layers.data(name='label', shape=[1], dtype='int32')
fluid.layers.mean_iou(images, label, num_classes)

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

image_t = core.LoDTensor()
label_t = core.LoDTensor()
image_t.set(np.random.randint(num_classes, size=[input_size]).astype("int32"), place)
label_t.set(np.random.randint(num_classes, size=[input_size]).astype("int32"), place)

with fluid.profiler.profiler('GPU', sorted_key='total'):
    for i in range(500):
        exe.run(feed={"image": image_t, "label":label_t})

@wanghaoshuang
Copy link
Contributor Author

请先不要review,我再优化下GPU kernel.

1. Merge computing in GPU to two kernel.
2. Use wrong array and correct array instead of confusion matrix.
"A Tensor representing the"
" mean intersection-over-union.");
AddOutput("out_wrong", "A Tensor with shape [num_classes]. ");
AddOutput("out_correct", "A Tensor with shape [num_classes]. ");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

AddComment(R"DOC(
mean-IOU Operator.
Mean Intersection-Over-Union is a common evaluation metric for semantic image segmentation, which first computes the IOU for each semantic class and then computes the average over classes. IOU is defined as follows: IOU = true_positive / (true_positive + false_positive + false_negative). The predictions are accumulated in a confusion matrix and mean-IOU is then calculated from it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have iou_similarity_op:

https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/operators/detection/iou_similarity_op.cc#L71

The doc here better to give more details for the difference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

.AsDispensable();
AddOutput("out_mean_iou",
"A Tensor representing the"
" mean intersection-over-union.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to give the shape.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

REGISTER_OPERATOR(mean_iou, ops::MeanIoUOp, ops::MeanIoUOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(mean_iou, ops::MeanIoUKernel<int>,
ops::MeanIoUKernel<int64_t>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面文档描述里是支持int32和int64,这里没有注册int32。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.


namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(mean_iou, ops::MeanIoUCUDAOpKernel<int>,
ops::MeanIoUKernel<int64_t>);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("predictions",
"A Tensor of prediction results for semantic labels"
" with type int32 or int64.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to give the shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

float* out_mean_iou_data =
out_mean_iou->mutable_data<float>(ctx.GetPlace());

// get eigen tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eigen -> Eigen

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

auto out_wrong_t = EigenTensor<int, 1>::From(*out_wrong);
auto out_correct_t = EigenTensor<int, 1>::From(*out_correct);

// Tmp tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dot not use Tmp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx. FIxed.

.AsDispensable();
AddInput("in_mean_iou",
"A list of Tensor that Output(mean_iou) should "
"be added to. Empty list is also valid here.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_wrongs, in_corrects, in_mean_iou是干啥的?和out_wrong/correct/mean_iou有啥区别?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in_wrongs, in_corrects, in_mean_iou之前当前batch之前累计的数据,加上当前batch的统计结果,就得到:out_wrong/correct/mean_iou

for (int i = threadIdx.x; i < num_classes; i += blockDim.x) {
atomicAdd(wrong + i, wrong_c[i]);
atomicAdd(correct + i, correct_c[i]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果num_classes较小, predictions的shape较大,会导致这个kernel的性能非常低效,其实感觉类似这样的kernel,先CPU即可,后续最好评估下时间。

Copy link
Contributor Author

@wanghaoshuang wanghaoshuang Jun 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input size class_num GPU CPU
1024 * 2048 100 0.168812ms 13.0831ms
1024 * 2048 50 0.172748ms 13.5145ms
1024 * 2048 20 0.174807ms 14.4619ms
1024 * 2048 10 0.188483ms 16.1516ms
1024 * 2048 1 0.230743ms 12.7893ms
1024 * 2048 *2 100 0.308306ms 26.4576
1024 * 2048 * 2 50 0.326073ms 26.9835ms
1024 * 2048 *2 20 0.28971ms 29.0224ms
1024 * 2048* 2 10 0.267694ms 34.2029ms
1024 * 2048 * 2 1 0.295844ms 25.4808ms

'softmax_with_cross_entropy', 'smooth_l1', 'one_hot',
'autoincreased_step_counter', 'reshape', 'lod_reset', 'lrn', 'pad',
'label_smooth', 'roi_pool', 'dice_loss', 'image_resize',
'image_resize_short', 'resize_bilinear', 'gather', 'random_crop', 'mean_iou'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to he yapf version?

Copy link
Contributor Author

@wanghaoshuang wanghaoshuang Jun 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The version of my yapf is 0.22. I'm not sure it is due to he yapf version. Which style is correct?

@wanghaoshuang wanghaoshuang merged commit 6fcdb24 into PaddlePaddle:develop Jun 14, 2018
@wanghaoshuang wanghaoshuang deleted the mean_iou branch May 20, 2022 03:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants