-
Notifications
You must be signed in to change notification settings - Fork 65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added EMD histogram comparison support. #689
Changes from 4 commits
3561c59
7161cde
eaa0051
f491e52
44b3875
31534f3
fc9898b
23b1a3d
81e8a07
ebece53
1c69936
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -443,7 +443,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) { | |
return {}; | ||
} | ||
if (!token->IsString()) | ||
return Result("Invalid token after ATTACH"); | ||
return Result("invalid token after ATTACH"); | ||
|
||
bool set_shader_type = false; | ||
ShaderType shader_type = shader->GetType(); | ||
|
@@ -466,7 +466,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) { | |
type = token->AsString(); | ||
} | ||
if (set_shader_type && type != "ENTRY_POINT") | ||
return Result("Unknown ATTACH parameter: " + type); | ||
return Result("unknown ATTACH parameter: " + type); | ||
|
||
if (shader->GetType() == ShaderType::kShaderTypeMulti && !set_shader_type) | ||
return Result("ATTACH missing TYPE for multi shader"); | ||
|
@@ -498,7 +498,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) { | |
if (token->IsEOL() || token->IsEOS()) | ||
return {}; | ||
if (token->IsString()) | ||
return Result("Unknown ATTACH parameter: " + token->AsString()); | ||
return Result("unknown ATTACH parameter: " + token->AsString()); | ||
return Result("extra parameters after ATTACH command"); | ||
} | ||
} | ||
|
@@ -1386,6 +1386,9 @@ Result Parser::ParseExpect() { | |
return Result("missing buffer name between EXPECT and EQ_BUFFER"); | ||
if (token->AsString() == "RMSE_BUFFER") | ||
return Result("missing buffer name between EXPECT and RMSE_BUFFER"); | ||
if (token->AsString() == "EQ_HISTOGRAM_EMD_BUFFER") | ||
return Result( | ||
"missing buffer name between EXPECT and EQ_HISTOGRAM_EMD_BUFFER"); | ||
|
||
size_t line = tokenizer_->GetCurrentLine(); | ||
auto* buffer = script_->GetBuffer(token->AsString()); | ||
|
@@ -1396,9 +1399,10 @@ Result Parser::ParseExpect() { | |
token = tokenizer_->NextToken(); | ||
|
||
if (!token->IsString()) | ||
return Result("Invalid comparator in EXPECT command"); | ||
return Result("invalid comparator in EXPECT command"); | ||
|
||
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER") { | ||
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER" || | ||
token->AsString() == "EMD_BUFFER") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be EQ_HISTOGRAM_EMD_BUFFER now? |
||
auto type = token->AsString(); | ||
|
||
token = tokenizer_->NextToken(); | ||
|
@@ -1436,11 +1440,27 @@ Result Parser::ParseExpect() { | |
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsString() && token->AsString() == "TOLERANCE") | ||
return Result("Missing TOLERANCE for EXPECT RMSE_BUFFER"); | ||
return Result("missing TOLERANCE for EXPECT RMSE_BUFFER"); | ||
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsInteger() && !token->IsDouble()) | ||
return Result("invalid TOLERANCE for EXPECT RMSE_BUFFER"); | ||
|
||
Result r = token->ConvertToDouble(); | ||
if (!r.IsSuccess()) | ||
return r; | ||
|
||
cmd->SetTolerance(token->AsFloat()); | ||
} else if (type == "EMD_BUFFER") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add EMD_BUFFER to the amber_script.md file similar to RMSE_BUFFER? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be EQ_HISTOGRAM_EMD_BUFFER ? |
||
cmd->SetComparator(CompareBufferCommand::Comparator::kEmd); | ||
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsString() && token->AsString() == "TOLERANCE") | ||
return Result("missing TOLERANCE for EXPECT EMD_BUFFER"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. EQ_HISTOGRAM_EMD_BUFFER |
||
|
||
token = tokenizer_->NextToken(); | ||
if (!token->IsInteger() && !token->IsDouble()) | ||
return Result("Invalid TOLERANCE for EXPECT RMSE_BUFFER"); | ||
return Result("invalid TOLERANCE for EXPECT EMD_BUFFER"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Uh, I was testing against the wrong amber file that still had the old syntax. |
||
|
||
Result r = token->ConvertToDouble(); | ||
if (!r.IsSuccess()) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -81,7 +81,7 @@ double CalculateDiff(const Format::Segment* seg, | |
return Sub<uint32_t>(buf1, buf2); | ||
if (type::Type::IsUint64(mode, num_bits)) | ||
return Sub<uint64_t>(buf1, buf2); | ||
// TOOD(dsinclair): Handle float16 ... | ||
// TODO(dsinclair): Handle float16 ... | ||
if (type::Type::IsFloat16(mode, num_bits)) { | ||
assert(false && "Float16 suppport not implemented"); | ||
return 0.0; | ||
|
@@ -115,16 +115,9 @@ Result Buffer::CopyTo(Buffer* buffer) const { | |
} | ||
|
||
Result Buffer::IsEqual(Buffer* buffer) const { | ||
if (!buffer->format_->Equal(format_)) | ||
return Result{"Buffers have a different format"}; | ||
if (buffer->element_count_ != element_count_) | ||
return Result{"Buffers have a different size"}; | ||
if (buffer->width_ != width_) | ||
return Result{"Buffers have a different width"}; | ||
if (buffer->height_ != height_) | ||
return Result{"Buffers have a different height"}; | ||
if (buffer->bytes_.size() != bytes_.size()) | ||
return Result{"Buffers have a different number of values"}; | ||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
uint32_t num_different = 0; | ||
uint32_t first_different_index = 0; | ||
|
@@ -177,7 +170,7 @@ std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const { | |
return diffs; | ||
} | ||
|
||
Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | ||
Result Buffer::CheckCompability(Buffer* buffer) const { | ||
if (!buffer->format_->Equal(format_)) | ||
return Result{"Buffers have a different format"}; | ||
if (buffer->element_count_ != element_count_) | ||
|
@@ -189,6 +182,14 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | |
if (buffer->ValueCount() != ValueCount()) | ||
return Result{"Buffers have a different number of values"}; | ||
|
||
return {}; | ||
} | ||
|
||
Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | ||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
auto diffs = CalculateDiffs(buffer); | ||
double sum = 0.0; | ||
for (const auto val : diffs) | ||
|
@@ -204,6 +205,85 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const { | |
return {}; | ||
} | ||
|
||
std::vector<int64_t> Buffer::GetHistogramForChannel(uint32_t channel, | ||
uint32_t num_bins) const { | ||
assert(num_bins == 256); | ||
std::vector<int64_t> bins(num_bins, 0.0); | ||
auto* buf_ptr = GetValues<uint8_t>(); | ||
auto num_channels = format_->InputNeededPerElement(); | ||
uint32_t channel_id = 0; | ||
|
||
for (size_t i = 0; i < ElementCount(); ++i) { | ||
for (const auto& seg : format_->GetSegments()) { | ||
if (seg.IsPadding()) { | ||
buf_ptr += seg.PaddingBytes(); | ||
continue; | ||
} | ||
if (channel_id == channel) { | ||
assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits())); | ||
const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr); | ||
bins[bin]++; | ||
} | ||
buf_ptr += seg.SizeInBytes(); | ||
channel_id = (channel_id + 1) % num_channels; | ||
} | ||
} | ||
dj2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return bins; | ||
} | ||
|
||
Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to add unit tests for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You mean under parser_expect_test.cc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in buffer_test.cc There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I'll ask another team member to write me those. I will push the other changes before that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw. the RMSE tests are under parser_expect_test.cc. Should we have all in the same place for consistency? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if I added unit tests for RMSE or just parser tests. I think I didn't add unit tests at the time I wrote it, so we should add some in the buffer class at some point in the future. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we do separate issues for both so this PR won't have to wait for the unit tests? |
||
auto result = CheckCompability(buffer); | ||
if (!result.IsSuccess()) | ||
return result; | ||
|
||
const int num_bins = 256; | ||
auto num_channels = format_->InputNeededPerElement(); | ||
for (auto segment : format_->GetSegments()) { | ||
dj2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) || | ||
num_channels != 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add {}'s |
||
return Result( | ||
"EMD comparison only supports 8bit unorm format with four channels."); | ||
} | ||
|
||
std::vector<std::vector<int64_t>> histogram1, histogram2; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you make this two lines? Each variable gets its own declaration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have simplified the code nicely below. However, I would rather normalize the histograms first according to the number of elements in each buffer, before you do the EMD part. In theory, the images could have different sizes. Not sure if this can happen (nor whether we even want it to happen), but the algorithm is more future-proof if we normalize the histograms first. Thus, use |
||
for (uint32_t c = 0; c < num_channels; ++c) { | ||
histogram1.push_back(GetHistogramForChannel(c, num_bins)); | ||
histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins)); | ||
} | ||
|
||
// Earth movers's distance: Calculate the minimal cost of moving "earth" to | ||
// transform the first histogram into the second, where each bin of the | ||
// histogram can be thought of as a column of units of earth. The cost is the | ||
// amount of earth moved times the distance carried (the distance is the | ||
// number of adjacent bins over which the earth is carried). Calculate this | ||
// using the cumulative difference of the bins, which works as long as both | ||
// histograms have the same amount of earth. Sum the absolute values of the | ||
// cumulative difference to get the final cost of how much (and how far) the | ||
// earth was moved. | ||
double max_emd = 0; | ||
|
||
for (uint32_t c = 0; c < num_channels; ++c) { | ||
uint64_t diff_total = 0; | ||
int64_t diff_accum = 0; | ||
|
||
for (size_t i = 0; i < histogram1[c].size(); ++i) { | ||
diff_accum += histogram1[c][i] - histogram2[c][i]; | ||
diff_total += std::abs(diff_accum); | ||
} | ||
// Normalize to range 0..1 | ||
double emd = static_cast<double>(diff_total) / (num_bins * element_count_); | ||
max_emd = std::max(max_emd, emd); | ||
} | ||
|
||
if (max_emd > static_cast<double>(tolerance)) { | ||
return Result("Histogram EMD value of " + std::to_string(max_emd) + | ||
" is greater than tolerance of " + std::to_string(tolerance)); | ||
} | ||
|
||
return {}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: doesn't matter for this PR, but when adding to the standalone |
||
} | ||
|
||
Result Buffer::SetData(const std::vector<Value>& data) { | ||
return SetDataWithOffset(data, 0); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -212,7 +212,7 @@ class DrawArraysCommand : public PipelineCommand { | |
/// A command to compare two buffers. | ||
class CompareBufferCommand : public Command { | ||
public: | ||
enum class Comparator { kEq, kRmse }; | ||
enum class Comparator { kEq, kRmse, kEmd }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kHistogramEmd ? Is there any other kind of EMD? |
||
|
||
CompareBufferCommand(Buffer* buffer_1, Buffer* buffer_2); | ||
~CompareBufferCommand() override; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
#!amber | ||
# | ||
# Copyright 2019 The Amber Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
SHADER vertex vtex_shader GLSL | ||
#version 430 | ||
|
||
layout(location = 0) in vec4 position; | ||
layout(location = 0) out vec4 frag_color; | ||
|
||
layout(set = 0, binding = 0) readonly buffer block1 { | ||
vec4 in_color; | ||
}; | ||
|
||
void main() { | ||
gl_Position = position; | ||
frag_color = vec4(0.25, 0.5, 0.75, 1.0); | ||
} | ||
END | ||
|
||
SHADER fragment frag_shader GLSL | ||
#version 430 | ||
|
||
layout(location = 0) in vec4 frag_color; | ||
layout(location = 0) out vec4 final_color; | ||
|
||
void main() { | ||
final_color = frag_color; | ||
} | ||
END | ||
|
||
BUFFER frame1 FORMAT B8G8R8A8_UNORM | ||
BUFFER frame2 FORMAT B8G8R8A8_UNORM | ||
|
||
PIPELINE graphics pipeline1 | ||
ATTACH vtex_shader | ||
ATTACH frag_shader | ||
|
||
FRAMEBUFFER_SIZE 256 256 | ||
BIND BUFFER frame1 AS color LOCATION 0 | ||
END | ||
|
||
PIPELINE graphics pipeline2 | ||
ATTACH vtex_shader | ||
ATTACH frag_shader | ||
|
||
FRAMEBUFFER_SIZE 256 256 | ||
BIND BUFFER frame2 AS color LOCATION 0 | ||
END | ||
|
||
CLEAR pipeline1 | ||
CLEAR pipeline2 | ||
# Draw a rect to different parts of the frame buffers. | ||
# Histograms will still match. | ||
RUN pipeline1 DRAW_RECT POS 0 0 SIZE 128 128 | ||
RUN pipeline2 DRAW_RECT POS 128 128 SIZE 128 128 | ||
|
||
EXPECT frame1 EQ_HISTOGRAM_EMD_BUFFER frame2 TOLERANCE 0.1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add {}s since the body is multiline