Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added EMD histogram comparison support. #689

Merged
merged 11 commits into from
Oct 16, 2019
8 changes: 7 additions & 1 deletion docs/amber_script.md
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ CLEAR {pipeline}
* `EQ_RGBA`
* `EQ_BUFFER`
* `RMSE_BUFFER`
* `EQ_HISTOGRAM_EMD_BUFFER`

```groovy
# Checks that |buffer_name| at |x| has the given |value|s when compared
Expand Down Expand Up @@ -451,9 +452,14 @@ EXPECT {buffer_name} IDX _x_in_pixels_ _y_in_pixels_ \
EXPECT {buffer_1} EQ_BUFFER {buffer_2}

# Checks that the Root Mean Square Error when comparing |buffer_1| to
# |buffer_2| is less than or equal too |tolerance|. Note, |tolerance| is a
# |buffer_2| is less than or equal to |tolerance|. Note, |tolerance| is a
# unit-less number.
EXPECT {buffer_1} RMSE_BUFFER {buffer_2} TOLERANCE _value_

# Checks that the Earth Mover's Distance when comparing histograms of
# |buffer_1| to |buffer_2| is less than or equal to |tolerance|.
# Note, |tolerance| is a unit-less number.
EXPECT {buffer_1} EQ_HISTOGRAM_EMD_BUFFER {buffer_2} TOLERANCE _value_
```

## Examples
Expand Down
34 changes: 27 additions & 7 deletions src/amberscript/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) {
return {};
}
if (!token->IsString())
return Result("Invalid token after ATTACH");
return Result("invalid token after ATTACH");

bool set_shader_type = false;
ShaderType shader_type = shader->GetType();
Expand All @@ -466,7 +466,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) {
type = token->AsString();
}
if (set_shader_type && type != "ENTRY_POINT")
return Result("Unknown ATTACH parameter: " + type);
return Result("unknown ATTACH parameter: " + type);

if (shader->GetType() == ShaderType::kShaderTypeMulti && !set_shader_type)
return Result("ATTACH missing TYPE for multi shader");
Expand Down Expand Up @@ -498,7 +498,7 @@ Result Parser::ParsePipelineAttach(Pipeline* pipeline) {
if (token->IsEOL() || token->IsEOS())
return {};
if (token->IsString())
return Result("Unknown ATTACH parameter: " + token->AsString());
return Result("unknown ATTACH parameter: " + token->AsString());
return Result("extra parameters after ATTACH command");
}
}
Expand Down Expand Up @@ -1386,6 +1386,9 @@ Result Parser::ParseExpect() {
return Result("missing buffer name between EXPECT and EQ_BUFFER");
if (token->AsString() == "RMSE_BUFFER")
return Result("missing buffer name between EXPECT and RMSE_BUFFER");
if (token->AsString() == "EQ_HISTOGRAM_EMD_BUFFER")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add {}s since the body is multiline

return Result(
"missing buffer name between EXPECT and EQ_HISTOGRAM_EMD_BUFFER");

size_t line = tokenizer_->GetCurrentLine();
auto* buffer = script_->GetBuffer(token->AsString());
Expand All @@ -1396,9 +1399,10 @@ Result Parser::ParseExpect() {
token = tokenizer_->NextToken();

if (!token->IsString())
return Result("Invalid comparator in EXPECT command");
return Result("invalid comparator in EXPECT command");

if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER") {
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER" ||
token->AsString() == "EMD_BUFFER") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be EQ_HISTOGRAM_EMD_BUFFER now?

auto type = token->AsString();

token = tokenizer_->NextToken();
Expand Down Expand Up @@ -1436,11 +1440,27 @@ Result Parser::ParseExpect() {

token = tokenizer_->NextToken();
if (!token->IsString() && token->AsString() == "TOLERANCE")
return Result("Missing TOLERANCE for EXPECT RMSE_BUFFER");
return Result("missing TOLERANCE for EXPECT RMSE_BUFFER");

token = tokenizer_->NextToken();
if (!token->IsInteger() && !token->IsDouble())
return Result("invalid TOLERANCE for EXPECT RMSE_BUFFER");

Result r = token->ConvertToDouble();
if (!r.IsSuccess())
return r;

cmd->SetTolerance(token->AsFloat());
} else if (type == "EMD_BUFFER") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add EMD_BUFFER to the amber_script.md file similar to RMSE_BUFFER?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be EQ_HISTOGRAM_EMD_BUFFER ?

cmd->SetComparator(CompareBufferCommand::Comparator::kEmd);

token = tokenizer_->NextToken();
if (!token->IsString() && token->AsString() == "TOLERANCE")
return Result("missing TOLERANCE for EXPECT EMD_BUFFER");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EQ_HISTOGRAM_EMD_BUFFER


token = tokenizer_->NextToken();
if (!token->IsInteger() && !token->IsDouble())
return Result("Invalid TOLERANCE for EXPECT RMSE_BUFFER");
return Result("invalid TOLERANCE for EXPECT EMD_BUFFER");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, I was testing against the wrong amber file that still had the old syntax.


Result r = token->ConvertToDouble();
if (!r.IsSuccess())
Expand Down
104 changes: 92 additions & 12 deletions src/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ double CalculateDiff(const Format::Segment* seg,
return Sub<uint32_t>(buf1, buf2);
if (type::Type::IsUint64(mode, num_bits))
return Sub<uint64_t>(buf1, buf2);
// TOOD(dsinclair): Handle float16 ...
// TODO(dsinclair): Handle float16 ...
if (type::Type::IsFloat16(mode, num_bits)) {
assert(false && "Float16 suppport not implemented");
return 0.0;
Expand Down Expand Up @@ -115,16 +115,9 @@ Result Buffer::CopyTo(Buffer* buffer) const {
}

Result Buffer::IsEqual(Buffer* buffer) const {
if (!buffer->format_->Equal(format_))
return Result{"Buffers have a different format"};
if (buffer->element_count_ != element_count_)
return Result{"Buffers have a different size"};
if (buffer->width_ != width_)
return Result{"Buffers have a different width"};
if (buffer->height_ != height_)
return Result{"Buffers have a different height"};
if (buffer->bytes_.size() != bytes_.size())
return Result{"Buffers have a different number of values"};
auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

uint32_t num_different = 0;
uint32_t first_different_index = 0;
Expand Down Expand Up @@ -177,7 +170,7 @@ std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
return diffs;
}

Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
Result Buffer::CheckCompability(Buffer* buffer) const {
if (!buffer->format_->Equal(format_))
return Result{"Buffers have a different format"};
if (buffer->element_count_ != element_count_)
Expand All @@ -189,6 +182,14 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
if (buffer->ValueCount() != ValueCount())
return Result{"Buffers have a different number of values"};

return {};
}

Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

auto diffs = CalculateDiffs(buffer);
double sum = 0.0;
for (const auto val : diffs)
Expand All @@ -204,6 +205,85 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
return {};
}

std::vector<int64_t> Buffer::GetHistogramForChannel(uint32_t channel,
uint32_t num_bins) const {
assert(num_bins == 256);
std::vector<int64_t> bins(num_bins, 0.0);
auto* buf_ptr = GetValues<uint8_t>();
auto num_channels = format_->InputNeededPerElement();
uint32_t channel_id = 0;

for (size_t i = 0; i < ElementCount(); ++i) {
for (const auto& seg : format_->GetSegments()) {
if (seg.IsPadding()) {
buf_ptr += seg.PaddingBytes();
continue;
}
if (channel_id == channel) {
assert(type::Type::IsUint8(seg.GetFormatMode(), seg.GetNumBits()));
const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
bins[bin]++;
}
buf_ptr += seg.SizeInBytes();
channel_id = (channel_id + 1) % num_channels;
}
}
dj2 marked this conversation as resolved.
Show resolved Hide resolved

return bins;
}

Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add unit tests for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean under parser_expect_test.cc?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in buffer_test.cc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll ask another team member to write me those. I will push the other changes before that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw. the RMSE tests are under parser_expect_test.cc. Should we have all in the same place for consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I added unit tests for RMSE or just parser tests. I think I didn't add unit tests at the time I wrote it, so we should add some in the buffer class at some point in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do separate issues for both so this PR won't have to wait for the unit tests?

auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

const int num_bins = 256;
auto num_channels = format_->InputNeededPerElement();
for (auto segment : format_->GetSegments()) {
dj2 marked this conversation as resolved.
Show resolved Hide resolved
if (!type::Type::IsUint8(segment.GetFormatMode(), segment.GetNumBits()) ||
num_channels != 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add {}'s

return Result(
"EMD comparison only supports 8bit unorm format with four channels.");
}

std::vector<std::vector<int64_t>> histogram1, histogram2;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this two lines? Each variable gets its own declaration.

Copy link
Collaborator

@paulthomson paulthomson Oct 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have simplified the code nicely below. However, I would rather normalize the histograms first according to the number of elements in each buffer, before you do the EMD part. In theory, the images could have different sizes. Not sure if this can happen (nor whether we even want it to happen), but the algorithm is more future-proof if we normalize the histograms first. Thus, use vector<double> here.

for (uint32_t c = 0; c < num_channels; ++c) {
histogram1.push_back(GetHistogramForChannel(c, num_bins));
histogram2.push_back(buffer->GetHistogramForChannel(c, num_bins));
}

// Earth movers's distance: Calculate the minimal cost of moving "earth" to
// transform the first histogram into the second, where each bin of the
// histogram can be thought of as a column of units of earth. The cost is the
// amount of earth moved times the distance carried (the distance is the
// number of adjacent bins over which the earth is carried). Calculate this
// using the cumulative difference of the bins, which works as long as both
// histograms have the same amount of earth. Sum the absolute values of the
// cumulative difference to get the final cost of how much (and how far) the
// earth was moved.
double max_emd = 0;

for (uint32_t c = 0; c < num_channels; ++c) {
uint64_t diff_total = 0;
int64_t diff_accum = 0;

for (size_t i = 0; i < histogram1[c].size(); ++i) {
diff_accum += histogram1[c][i] - histogram2[c][i];
diff_total += std::abs(diff_accum);
}
// Normalize to range 0..1
double emd = static_cast<double>(diff_total) / (num_bins * element_count_);
max_emd = std::max(max_emd, emd);
}

if (max_emd > static_cast<double>(tolerance)) {
return Result("Histogram EMD value of " + std::to_string(max_emd) +
" is greater than tolerance of " + std::to_string(tolerance));
}

return {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: doesn't matter for this PR, but when adding to the standalone image_diff tool, I guess you might want to move the code that yields max_emd into another (public?) function so that image_diff can get and print the value, even if the tolerance was not exceeded.

}

Result Buffer::SetData(const std::vector<Value>& data) {
return SetDataWithOffset(data, 0);
}
Expand Down
11 changes: 11 additions & 0 deletions src/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,21 @@ class Buffer {
/// Succeeds only if both buffer contents are equal
Result IsEqual(Buffer* buffer) const;

/// Returns a histogram
std::vector<int64_t> GetHistogramForChannel(uint32_t channel,
uint32_t num_bins) const;

/// Checks if buffers are compatible for comparison
Result CheckCompability(Buffer* buffer) const;

/// Compare the RMSE of this buffer against |buffer|. The RMSE must be
/// less than |tolerance|.
Result CompareRMSE(Buffer* buffer, float tolerance) const;

/// Compare the histogram EMD of this buffer against |buffer|. The EMD must be
/// less than |tolerance|.
Result CompareHistogramEMD(Buffer* buffer, float tolerance) const;

private:
uint32_t WriteValueFromComponent(const Value& value,
FormatMode mode,
Expand Down
2 changes: 1 addition & 1 deletion src/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class DrawArraysCommand : public PipelineCommand {
/// A command to compare two buffers.
class CompareBufferCommand : public Command {
public:
enum class Comparator { kEq, kRmse };
enum class Comparator { kEq, kRmse, kEmd };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kHistogramEmd ? Is there any other kind of EMD?


CompareBufferCommand(Buffer* buffer_1, Buffer* buffer_2);
~CompareBufferCommand() override;
Expand Down
2 changes: 2 additions & 0 deletions src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Result Executor::ExecuteCommand(Engine* engine, Command* cmd) {
switch (compare->GetComparator()) {
case CompareBufferCommand::Comparator::kRmse:
return buffer_1->CompareRMSE(buffer_2, compare->GetTolerance());
case CompareBufferCommand::Comparator::kEmd:
return buffer_1->CompareHistogramEMD(buffer_2, compare->GetTolerance());
case CompareBufferCommand::Comparator::kEq:
return buffer_1->IsEqual(buffer_2);
}
Expand Down
70 changes: 70 additions & 0 deletions tests/cases/buffer_emd.amber
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!amber
#
# Copyright 2019 The Amber Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SHADER vertex vtex_shader GLSL
#version 430

layout(location = 0) in vec4 position;
layout(location = 0) out vec4 frag_color;

layout(set = 0, binding = 0) readonly buffer block1 {
vec4 in_color;
};

void main() {
gl_Position = position;
frag_color = vec4(0.25, 0.5, 0.75, 1.0);
}
END

SHADER fragment frag_shader GLSL
#version 430

layout(location = 0) in vec4 frag_color;
layout(location = 0) out vec4 final_color;

void main() {
final_color = frag_color;
}
END

BUFFER frame1 FORMAT B8G8R8A8_UNORM
BUFFER frame2 FORMAT B8G8R8A8_UNORM

PIPELINE graphics pipeline1
ATTACH vtex_shader
ATTACH frag_shader

FRAMEBUFFER_SIZE 256 256
BIND BUFFER frame1 AS color LOCATION 0
END

PIPELINE graphics pipeline2
ATTACH vtex_shader
ATTACH frag_shader

FRAMEBUFFER_SIZE 256 256
BIND BUFFER frame2 AS color LOCATION 0
END

CLEAR pipeline1
CLEAR pipeline2
# Draw a rect to different parts of the frame buffers.
# Histograms will still match.
RUN pipeline1 DRAW_RECT POS 0 0 SIZE 128 128
RUN pipeline2 DRAW_RECT POS 128 128 SIZE 128 128

EXPECT frame1 EQ_HISTOGRAM_EMD_BUFFER frame2 TOLERANCE 0.1