Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added EMD histogram comparison support. #689

Merged
merged 11 commits into from
Oct 16, 2019
21 changes: 20 additions & 1 deletion src/amberscript/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,8 @@ Result Parser::ParseExpect() {
return Result("missing buffer name between EXPECT and EQ_BUFFER");
if (token->AsString() == "RMSE_BUFFER")
return Result("missing buffer name between EXPECT and RMSE_BUFFER");
if (token->AsString() == "EMD_BUFFER")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we want to deprecate RMSE_BUFFER in favour of EQ_RMSE_BUFFER should this be EQ_EMD_BUFFER?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we also change "emd" to "hist emd" (or "histogram emd" where appropriate, such as in function names)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does EMD stand for? My preference would be the full histogram over hist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earth mover's distance. Yeah histogram sounds good.

return Result("missing buffer name between EXPECT and EMD_BUFFER");

size_t line = tokenizer_->GetCurrentLine();
auto* buffer = script_->GetBuffer(token->AsString());
Expand All @@ -1398,7 +1400,8 @@ Result Parser::ParseExpect() {
if (!token->IsString())
return Result("Invalid comparator in EXPECT command");

if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER") {
if (token->AsString() == "EQ_BUFFER" || token->AsString() == "RMSE_BUFFER" ||
token->AsString() == "EMD_BUFFER") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be EQ_HISTOGRAM_EMD_BUFFER now?

auto type = token->AsString();

token = tokenizer_->NextToken();
Expand Down Expand Up @@ -1446,6 +1449,22 @@ Result Parser::ParseExpect() {
if (!r.IsSuccess())
return r;

cmd->SetTolerance(token->AsFloat());
} else if (type == "EMD_BUFFER") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add EMD_BUFFER to the amber_script.md file similar to RMSE_BUFFER?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be EQ_HISTOGRAM_EMD_BUFFER ?

cmd->SetComparator(CompareBufferCommand::Comparator::kEmd);

token = tokenizer_->NextToken();
if (!token->IsString() && token->AsString() == "TOLERANCE")
return Result("Missing TOLERANCE for EXPECT EMD_BUFFER");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lower case 'm' the error strings start lower cased (and below)


token = tokenizer_->NextToken();
if (!token->IsInteger() && !token->IsDouble())
return Result("Invalid TOLERANCE for EXPECT EMD_BUFFER");

Result r = token->ConvertToDouble();
if (!r.IsSuccess())
return r;

cmd->SetTolerance(token->AsFloat());
}

Expand Down
101 changes: 89 additions & 12 deletions src/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ double CalculateDiff(const Format::Segment* seg,
return Sub<uint32_t>(buf1, buf2);
if (type::Type::IsUint64(mode, num_bits))
return Sub<uint64_t>(buf1, buf2);
// TOOD(dsinclair): Handle float16 ...
// TODO(dsinclair): Handle float16 ...
if (type::Type::IsFloat16(mode, num_bits)) {
assert(false && "Float16 suppport not implemented");
return 0.0;
Expand Down Expand Up @@ -115,16 +115,9 @@ Result Buffer::CopyTo(Buffer* buffer) const {
}

Result Buffer::IsEqual(Buffer* buffer) const {
if (!buffer->format_->Equal(format_))
return Result{"Buffers have a different format"};
if (buffer->element_count_ != element_count_)
return Result{"Buffers have a different size"};
if (buffer->width_ != width_)
return Result{"Buffers have a different width"};
if (buffer->height_ != height_)
return Result{"Buffers have a different height"};
if (buffer->bytes_.size() != bytes_.size())
return Result{"Buffers have a different number of values"};
auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

uint32_t num_different = 0;
uint32_t first_different_index = 0;
Expand Down Expand Up @@ -177,7 +170,7 @@ std::vector<double> Buffer::CalculateDiffs(const Buffer* buffer) const {
return diffs;
}

Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
Result Buffer::CheckCompability(Buffer* buffer) const {
if (!buffer->format_->Equal(format_))
return Result{"Buffers have a different format"};
if (buffer->element_count_ != element_count_)
Expand All @@ -189,6 +182,14 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
if (buffer->ValueCount() != ValueCount())
return Result{"Buffers have a different number of values"};

return Result{};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just return {}; should work

}

Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

auto diffs = CalculateDiffs(buffer);
double sum = 0.0;
for (const auto val : diffs)
Expand All @@ -204,6 +205,82 @@ Result Buffer::CompareRMSE(Buffer* buffer, float tolerance) const {
return {};
}

std::vector<double> Buffer::GetHistogramForChannel(uint32_t channel,
uint32_t num_bins) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel that num_bins has to be 256 for this function to make sense. I suggest adding an assert. You may also want to assert something about the format_ (again) like format_->IsUint8()

std::vector<double> bins(num_bins, 0.0);
auto* buf_ptr = GetValues<uint8_t>();
const auto& segments = format_->GetSegments();
const auto comp = segments[channel].GetComponent();
auto num_channels = format_->ValuesPerElement();
buf_ptr += channel * comp->SizeInBytes();
for (size_t i = 0; i < ElementCount(); ++i) {
const auto bin = *reinterpret_cast<const uint8_t*>(buf_ptr);
bins[bin]++;
// Assumes unpacked format where all channels have the same bit size
buf_ptr += comp->SizeInBytes() * num_channels;
}
dj2 marked this conversation as resolved.
Show resolved Hide resolved

return bins;
}

Result Buffer::CompareHistogramEMD(Buffer* buffer, float tolerance) const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add unit tests for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean under parser_expect_test.cc?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in buffer_test.cc

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I'll ask another team member to write me those. I will push the other changes before that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw. the RMSE tests are under parser_expect_test.cc. Should we have all in the same place for consistency?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if I added unit tests for RMSE or just parser tests. I think I didn't add unit tests at the time I wrote it, so we should add some in the buffer class at some point in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do separate issues for both so this PR won't have to wait for the unit tests?

auto result = CheckCompability(buffer);
if (!result.IsSuccess())
return result;

const int num_bins = 256;
auto num_channels = format_->ValuesPerElement();

for (auto segment : format_->GetSegments()) {
dj2 marked this conversation as resolved.
Show resolved Hide resolved
if (!segment.GetComponent()->IsUint8() || num_channels != 4)
return Result(
"EMD comparison only supports 8bit unorm format with four channels.");
}

std::vector<std::vector<double>> histograms[2];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather this was 2 distinct vectors:

std::vector<double> buffer1_histograms;
std::vector<double> buffer2_historgrams;

seems a bit clearer.

Or, looking below, should the histogram be a struct?

struct {
  std::vector<double> values;
  std::vector<uint32_t> totals;
}

Then below we don't have to re-sum the totals to get the normalization value.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we don't have to calculate the totals at all. It equals to the number of pixels.

for (uint32_t c = 0; c < num_channels; ++c) {
histograms[0].push_back(GetHistogramForChannel(c, num_bins));
histograms[1].push_back(buffer->GetHistogramForChannel(c, num_bins));
}

// Normalize histograms
for (uint32_t i = 0; i < 2; ++i)
for (uint32_t c = 0; c < num_channels; ++c) {
double total = 0;
for (auto value : histograms[i][c])
total += value;
for (auto& value : histograms[i][c])
value /= total;
}

double max_emd = 0;

for (uint32_t c = 0; c < num_channels; ++c) {
std::vector<double> diffs(histograms[0][c].size());
for (size_t i = 0; i < histograms[0][c].size(); ++i)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably work off the segments, and use CalculateDiff above for consistency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that complicate things? The histogram format is different from the original pixel format, so we would need to create new format object for this to work. And add pointer iteration instead of one single for loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need a new format. CalculateDiff just takes a single segment which in this case is just your Uint, 8 bits. Or, loop over the segments from the buffer format and use those.

diffs[i] = histograms[0][c][i] - histograms[1][c][i];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Add a blank line after blocks which don't have {}'s around them (and below)

// Accumulate diffs
for (size_t i = 1; i < histograms[0][c].size(); ++i)
diffs[i] += diffs[i - 1];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment before this method on what the algorithm is here? Why do we add the diffs together this way?

Copy link
Collaborator

@paulthomson paulthomson Oct 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Edited)
We calculate the minimal cost of moving "earth" to transform the first histogram into the second, where each bin of the histogram can be thought of as a column of units of earth. The cost is the amount of earth moved times the distance carried (the distance is the number of adjacent bins over which the earth is carried). We calculate this using the cumulative difference of the bins, which works as long as both histograms have the same amount of earth. We have to sum the absolute values of the cumulative difference to get the final cost of how much (and how far) the earth was moved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment, Paul.

"the distance is the number of adjacent bins over which the dirt is carried" - change "dirt" to "earth", as the only other mention of "dirt" is in an "i.e.", and I think it's best to have one term be the main term used. (I would even drop the "i.e. dirt".)

"The cumulative difference of the bins is one way to calculate this" - phrase this to make it clear that this is the approach taken here: "We calculate this as the cumulative difference of ..."

"same amount of dirt" and "the dirt was moved": "dirt" -> "earth"

// Take absolute value and calculate total
double diff_total = 0;
for (auto diff : diffs)
diff_total += fabs(diff);
// Normalize diff
diff_total /= num_bins;

if (diff_total > max_emd)
max_emd = diff_total;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_emd = std::max(diff_total, max_emd);

}

if (max_emd > static_cast<double>(tolerance)) {
return Result("Histogram EMD value of " + std::to_string(max_emd) +
" is greater than tolerance of " + std::to_string(tolerance));
}

return {};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: doesn't matter for this PR, but when adding to the standalone image_diff tool, I guess you might want to move the code that yields max_emd into another (public?) function so that image_diff can get and print the value, even if the tolerance was not exceeded.

}

Result Buffer::SetData(const std::vector<Value>& data) {
return SetDataWithOffset(data, 0);
}
Expand Down
11 changes: 11 additions & 0 deletions src/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,21 @@ class Buffer {
/// Succeeds only if both buffer contents are equal
Result IsEqual(Buffer* buffer) const;

/// Returns a histogram
std::vector<double> GetHistogramForChannel(uint32_t channel,
uint32_t num_bins) const;

/// Checks if buffers are compatible for comparison
Result CheckCompability(Buffer* buffer) const;

/// Compare the RMSE of this buffer against |buffer|. The RMSE must be
/// less than |tolerance|.
Result CompareRMSE(Buffer* buffer, float tolerance) const;

/// Compare the histogram EMD of this buffer against |buffer|. The EMD must be
/// less than |tolerance|.
Result CompareHistogramEMD(Buffer* buffer, float tolerance) const;

private:
uint32_t WriteValueFromComponent(const Value& value,
FormatMode mode,
Expand Down
2 changes: 1 addition & 1 deletion src/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class DrawArraysCommand : public PipelineCommand {
/// A command to compare two buffers.
class CompareBufferCommand : public Command {
public:
enum class Comparator { kEq, kRmse };
enum class Comparator { kEq, kRmse, kEmd };
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kHistogramEmd ? Is there any other kind of EMD?


CompareBufferCommand(Buffer* buffer_1, Buffer* buffer_2);
~CompareBufferCommand() override;
Expand Down
2 changes: 2 additions & 0 deletions src/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ Result Executor::ExecuteCommand(Engine* engine, Command* cmd) {
switch (compare->GetComparator()) {
case CompareBufferCommand::Comparator::kRmse:
return buffer_1->CompareRMSE(buffer_2, compare->GetTolerance());
case CompareBufferCommand::Comparator::kEmd:
return buffer_1->CompareHistogramEMD(buffer_2, compare->GetTolerance());
case CompareBufferCommand::Comparator::kEq:
return buffer_1->IsEqual(buffer_2);
}
Expand Down
70 changes: 70 additions & 0 deletions tests/cases/buffer_emd.amber
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!amber
#
# Copyright 2019 The Amber Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

SHADER vertex vtex_shader GLSL
#version 430

layout(location = 0) in vec4 position;
layout(location = 0) out vec4 frag_color;

layout(set = 0, binding = 0) readonly buffer block1 {
vec4 in_color;
};

void main() {
gl_Position = position;
frag_color = vec4(0.25, 0.5, 0.75, 1.0);
}
END

SHADER fragment frag_shader GLSL
#version 430

layout(location = 0) in vec4 frag_color;
layout(location = 0) out vec4 final_color;

void main() {
final_color = frag_color;
}
END

BUFFER frame1 FORMAT B8G8R8A8_UNORM
BUFFER frame2 FORMAT B8G8R8A8_UNORM

PIPELINE graphics pipeline1
ATTACH vtex_shader
ATTACH frag_shader

FRAMEBUFFER_SIZE 256 256
BIND BUFFER frame1 AS color LOCATION 0
END

PIPELINE graphics pipeline2
ATTACH vtex_shader
ATTACH frag_shader

FRAMEBUFFER_SIZE 256 256
BIND BUFFER frame2 AS color LOCATION 0
END

CLEAR pipeline1
CLEAR pipeline2
# Draw a rect to different parts of the frame buffers.
# Histograms will still match.
RUN pipeline1 DRAW_RECT POS 0 0 SIZE 128 128
RUN pipeline2 DRAW_RECT POS 128 128 SIZE 128 128

EXPECT frame1 EMD_BUFFER frame2 TOLERANCE 0.1