-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added EMD histogram comparison support. #689
Changes from 2 commits
3561c59
7161cde
eaa0051
f491e52
44b3875
31534f3
fc9898b
23b1a3d
81e8a07
ebece53
1c69936
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1386,6 +1386,8 @@ Result Parser::ParseExpect() { | |
return Result("missing buffer name between EXPECT and EQ_BUFFER"); | ||
if (token->AsString() == "RMSE_BUFFER") | ||
return Result("missing buffer name between EXPECT and RMSE_BUFFER"); | ||
if (token->AsString() == "EMD_BUFFER") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we also change "emd" to "hist emd" (or "histogram emd" where appropriate, such as in function names) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does EMD stand for? My preference would be the full histogram over hist. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Earth mover's distance. Yeah histogram sounds good. |
||
return Result("missing buffer name between EXPECT and EMD_BUFFER"); | ||
|
||
size_t line = tokenizer_->GetCurrentLine(); | ||
auto* buffer = script_->GetBuffer(token->AsString()); | ||
|
@@ -1398,7 +1400,8 @@ Result Parser::ParseExpect() { | |
if (!token->IsString()) | ||
return Result("Invalid comparator in EXPECT command"); | ||
|
||
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER") { | ||
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER" || | ||
token->AsString() == "EMD_BUFFER") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be EQ_HISTOGRAM_EMD_BUFFER now? |
||
auto type = token->AsString(); | ||
|
||
token = tokenizer_->NextToken(); | ||
|
@@ -1446,6 +1449,22 @@ Result Parser::ParseExpect() { | |
if (!r.IsSuccess()) | ||
return r; | ||
|
||
cmd->SetTolerance(token->AsFloat()); | ||
} else if (type == "EMD_BUFFER") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add EMD_BUFFER to the amber_script.md file similar to RMSE_BUFFER? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be EQ_HISTOGRAM_EMD_BUFFER ? |
||
cmd->SetComparator(CompareBufferCommand::Comparator::kEmd); | ||
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsString() && token->AsString() == "TOLERANCE") | ||
return Result("Missing TOLERANCE for EXPECT EMD_BUFFER"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lower case 'm' the error strings start lower cased (and below) |
||
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsInteger() && !token->IsDouble()) | ||
return Result("Invalid TOLERANCE for EXPECT EMD_BUFFER"); | ||
|
||
Result r = token->ConvertToDouble(); | ||
if (!r.IsSuccess()) | ||
return r; | ||
|
||
cmd->SetTolerance(token->AsFloat()); | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,7 @@ double CalculateDiff(const Format::Segment* seg, | |
return Sub<uint32_t>(buf1, buf2); | ||
if (type::Type::IsUint64(mode, num_bits)) | ||
return Sub<uint64_t>(buf1, buf2); | ||
// TOOD(dsinclair): Handle float16 ... | ||
// TODO(dsinclair): Handle float16 ... | ||
if (type::Type::IsFloat16(mode, num_bits)) { | ||
assert(false && "Float16 suppport not implemented"); | ||
return 0.0; | ||
|
@@ -115,16 +115,9 @@ Result Buffer::CopyTo(Buffer* buffer) const { | |
} | ||
|
||
Result Buffer::IsEqual(Buffer* buffer) const { | ||
if (!buffer->format_->Equal(format_)) | ||
return Result{"Buffers have a different format"}; | ||
if (buffer->element_count_ != element_count_) | ||
return Result{"Buffers have a different size"}; | ||
if (buffer->width_ != width_) | ||
return Result{"Buffers have a different width"}; | ||
if (buffer->height_ != height_) | ||
return Result{"Buffers have a different height"}; | ||
if (buffer->bytes_.size() != bytes_.size()) | ||
return Result{"Buffers have a different number of values"}; | ||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
uint32_t num_different = 0; | ||
uint32_t first_different_index = 0; | ||
|
@@ -177,7 +170,7 @@ std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const { | |
return diffs; | ||
} | ||
|
||
Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | ||
Result Buffer::CheckCompability(Buffer* buffer) const { | ||
if (!buffer->format_->Equal(format_)) | ||
return Result{"Buffers have a different format"}; | ||
if (buffer->element_count_ != element_count_) | ||
|
@@ -189,6 +182,14 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | |
if (buffer->ValueCount() != ValueCount()) | ||
return Result{"Buffers have a different number of values"}; | ||
|
||
return Result{}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just |
||
} | ||
|
||
Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | ||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
auto diffs = CalculateDiffs(buffer); | ||
double sum = 0.0; | ||
for (const auto val : diffs) | ||
|
@@ -204,6 +205,82 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | |
return {}; | ||
} | ||
|
||
std::vector<double> Buffer::GetHistogramForChannel(uint32_t channel, | ||
uint32_t num_bins) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel that |
||
std::vector<double> bins(num_bins, 0.0); | ||
auto* buf_ptr = GetValues<uint8_t>(); | ||
const auto& segments = format_->GetSegments(); | ||
const auto comp = segments[channel].GetComponent(); | ||
auto num_channels = format_->ValuesPerElement(); | ||
buf_ptr += channel * comp->SizeInBytes(); | ||
for (size_t i = 0; i < ElementCount(); ++i) { | ||
const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr); | ||
bins[bin]++; | ||
// Assumes unpacked format where all channels have the same bit size | ||
buf_ptr += comp->SizeInBytes() * num_channels; | ||
} | ||
dj2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return bins; | ||
} | ||
|
||
Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add unit tests for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean under parser_expect_test.cc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in buffer_test.cc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I'll ask another team member to write me those. I will push the other changes before that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw. the RMSE tests are under parser_expect_test.cc. Should we have all in the same place for consistency? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I added unit tests for RMSE or just parser tests. I think I didn't add unit tests at the time I wrote it, so we should add some in the buffer class at some point in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do separate issues for both so this PR won't have to wait for the unit tests? |
||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
const int num_bins = 256; | ||
auto num_channels = format_->ValuesPerElement(); | ||
|
||
for (auto segment : format_->GetSegments()) { | ||
dj2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!segment.GetComponent()->IsUint8() || num_channels != 4) | ||
return Result( | ||
"EMD comparison only supports 8bit unorm format with four channels."); | ||
} | ||
|
||
std::vector<std::vector<double>> histograms[2]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd rather this was 2 distinct vectors:
seems a bit clearer. Or, looking below, should the histogram be a struct?
Then below we don't have to re-sum the totals to get the normalization value. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like we don't have to calculate the totals at all. It equals to the number of pixels. |
||
for (uint32_t c = 0; c < num_channels; ++c) { | ||
histograms[0].push_back(GetHistogramForChannel(c, num_bins)); | ||
histograms[1].push_back(buffer->GetHistogramForChannel(c, num_bins)); | ||
} | ||
|
||
// Normalize histograms | ||
for (uint32_t i = 0; i < 2; ++i) | ||
for (uint32_t c = 0; c < num_channels; ++c) { | ||
double total = 0; | ||
for (auto value : histograms[i][c]) | ||
total += value; | ||
for (auto& value : histograms[i][c]) | ||
value /= total; | ||
} | ||
|
||
double max_emd = 0; | ||
|
||
for (uint32_t c = 0; c < num_channels; ++c) { | ||
std::vector<double> diffs(histograms[0][c].size()); | ||
for (size_t i = 0; i < histograms[0][c].size(); ++i) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably work off the segments, and use CalculateDiff above for consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't that complicate things? The histogram format is different from the original pixel format, so we would need to create new format object for this to work. And add pointer iteration instead of one single for loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need a new format. CalculateDiff just takes a single segment which in this case is just your Uint, 8 bits. Or, loop over the segments from the buffer format and use those. |
||
diffs[i] = histograms[0][c][i] - histograms[1][c][i]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Add a blank line after blocks which don't have {}'s around them (and below) |
||
// Accumulate diffs | ||
for (size_t i = 1; i < histograms[0][c].size(); ++i) | ||
diffs[i] += diffs[i - 1]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment before this method on what the algorithm is here? Why do we add the diffs together this way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (Edited) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice comment, Paul. "the distance is the number of adjacent bins over which the dirt is carried" - change "dirt" to "earth", as the only other mention of "dirt" is in an "i.e.", and I think it's best to have one term be the main term used. (I would even drop the "i.e. dirt".) "The cumulative difference of the bins is one way to calculate this" - phrase this to make it clear that this is the approach taken here: "We calculate this as the cumulative difference of ..." "same amount of dirt" and "the dirt was moved": "dirt" -> "earth" |
||
// Take absolute value and calculate total | ||
double diff_total = 0; | ||
for (auto diff : diffs) | ||
diff_total += fabs(diff); | ||
// Normalize diff | ||
diff_total /= num_bins; | ||
|
||
if (diff_total > max_emd) | ||
max_emd = diff_total; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. max_emd = std::max(diff_total, max_emd); |
||
} | ||
|
||
if (max_emd > static_cast<double>(tolerance)) { | ||
return Result("Histogram EMD value of " + std::to_string(max_emd) + | ||
" is greater than tolerance of " + std::to_string(tolerance)); | ||
} | ||
|
||
return {}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: doesn't matter for this PR, but when adding to the standalone |
||
} | ||
|
||
Result Buffer::SetData(const std::vector<Value>& data) { | ||
return SetDataWithOffset(data, 0); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,7 +212,7 @@ class DrawArraysCommand : public PipelineCommand { | |
/// A command to compare two buffers. | ||
class CompareBufferCommand : public Command { | ||
public: | ||
enum class Comparator { kEq, kRmse }; | ||
enum class Comparator { kEq, kRmse, kEmd }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kHistogramEmd ? Is there any other kind of EMD? |
||
|
||
CompareBufferCommand(Buffer* buffer_1, Buffer* buffer_2); | ||
~CompareBufferCommand() override; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!amber | ||
# | ||
# Copyright 2019 The Amber Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
SHADER vertex vtex_shader GLSL | ||
#version 430 | ||
|
||
layout(location = 0) in vec4 position; | ||
layout(location = 0) out vec4 frag_color; | ||
|
||
layout(set = 0, binding = 0) readonly buffer block1 { | ||
vec4 in_color; | ||
}; | ||
|
||
void main() { | ||
gl_Position = position; | ||
frag_color = vec4(0.25, 0.5, 0.75, 1.0); | ||
} | ||
END | ||
|
||
SHADER fragment frag_shader GLSL | ||
#version 430 | ||
|
||
layout(location = 0) in vec4 frag_color; | ||
layout(location = 0) out vec4 final_color; | ||
|
||
void main() { | ||
final_color = frag_color; | ||
} | ||
END | ||
|
||
BUFFER frame1 FORMAT B8G8R8A8_UNORM | ||
BUFFER frame2 FORMAT B8G8R8A8_UNORM | ||
|
||
PIPELINE graphics pipeline1 | ||
ATTACH vtex_shader | ||
ATTACH frag_shader | ||
|
||
FRAMEBUFFER_SIZE 256 256 | ||
BIND BUFFER frame1 AS color LOCATION 0 | ||
END | ||
|
||
PIPELINE graphics pipeline2 | ||
ATTACH vtex_shader | ||
ATTACH frag_shader | ||
|
||
FRAMEBUFFER_SIZE 256 256 | ||
BIND BUFFER frame2 AS color LOCATION 0 | ||
END | ||
|
||
CLEAR pipeline1 | ||
CLEAR pipeline2 | ||
# Draw a rect to different parts of the frame buffers. | ||
# Histograms will still match. | ||
RUN pipeline1 DRAW_RECT POS 0 0 SIZE 128 128 | ||
RUN pipeline2 DRAW_RECT POS 128 128 SIZE 128 128 | ||
|
||
EXPECT frame1 EMD_BUFFER frame2 TOLERANCE 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we want to deprecate RMSE_BUFFER in favour of EQ_RMSE_BUFFER should this be
EQ_EMD_BUFFER
?