Skip to content

Commit

Permalink
feat: Origin Tags part 3 (Memory) (#9758)
Browse files Browse the repository at this point in the history
This PR:
1. Adds Origin Tags for tracking dangerous interactions to all stdlib
memory primitives
2. Expands  the tests from TwinRomTable
3. Fixes a bug with the use of nonnormalized value.
  • Loading branch information
Rumata888 authored Nov 7, 2024
1 parent ae7cfe7 commit d77e473
Show file tree
Hide file tree
Showing 10 changed files with 389 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "../bool/bool.hpp"
#include "../circuit_builders/circuit_builders.hpp"
#include "barretenberg/circuit_checker/circuit_checker.hpp"
#include "barretenberg/transcript/origin_tag.hpp"

using namespace bb;

Expand All @@ -21,6 +22,60 @@ using field_ct = stdlib::field_t<Builder>;
using witness_ct = stdlib::witness_t<Builder>;
using DynamicArray_ct = stdlib::DynamicArray<Builder>;

STANDARD_TESTING_TAGS

/**
* @brief Check that tags in Dynamic array are propagated correctly
*
*/
TEST(DynamicArray, TagCorrectness)
{

Builder builder;
const size_t max_size = 4;

DynamicArray_ct array(&builder, max_size);

// Create random entries
field_ct entry_1 = witness_ct(&builder, bb::fr::random_element());
field_ct entry_2 = witness_ct(&builder, bb::fr::random_element());
field_ct entry_3 = witness_ct(&builder, bb::fr::random_element());
field_ct entry_4 = witness_ct(&builder, bb::fr::random_element());

// Assign a different tag to each entry
entry_1.set_origin_tag(submitted_value_origin_tag);
entry_2.set_origin_tag(challenge_origin_tag);
entry_3.set_origin_tag(next_challenge_tag);
// Entry 4 has an "instant death" tag, that triggers an exception when merged with another tag
entry_4.set_origin_tag(instant_death_tag);

// Fill out the dynamic array with the first 3 entries
array.push(entry_1);
array.push(entry_2);
array.push(entry_3);

// Check that the tags are preserved
EXPECT_EQ(array.read(1).get_origin_tag(), challenge_origin_tag);
EXPECT_EQ(array.read(2).get_origin_tag(), next_challenge_tag);
EXPECT_EQ(array.read(0).get_origin_tag(), submitted_value_origin_tag);
// Update an element of the array
array.write(0, entry_2);
// Check that the tag changed
EXPECT_EQ(array.read(0).get_origin_tag(), challenge_origin_tag);

#ifndef NDEBUG
// Check that "instant death" happens when an "instant death"-tagged element is taken from the array and added to
// another one
array.pop();
array.pop();
array.push(entry_4);
array.push(entry_2);
array.push(entry_3);

EXPECT_THROW(array.read(witness_ct(&builder, 1)) + array.read(witness_ct(&builder, 2)), std::runtime_error);
#endif
}

TEST(DynamicArray, DynamicArrayReadWriteConsistency)
{

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include "ram_table.hpp"

#include "../circuit_builders/circuit_builders.hpp"
#include "barretenberg/numeric/uint256/uint256.hpp"
#include "barretenberg/transcript/origin_tag.hpp"
#include <vector>

namespace bb::stdlib {

Expand Down Expand Up @@ -52,6 +55,12 @@ template <typename Builder> ram_table<Builder>::ram_table(const std::vector<fiel
// if this is the case we might not have a valid pointer to a Builder
// We get around this, by initializing the table when `read` or `write` operator is called
// with a non-const field element.

// Store tags
_tags.resize(_length);
for (size_t i = 0; i < _length; i++) {
_tags[i] = table_entries[i].get_origin_tag();
}
}

/**
Expand All @@ -69,7 +78,6 @@ template <typename Builder> void ram_table<Builder>::initialize_table() const
return;
}
ASSERT(_context != nullptr);

_ram_id = _context->create_RAM_array(_length);

if (_raw_entries.size() > 0) {
Expand All @@ -88,6 +96,13 @@ template <typename Builder> void ram_table<Builder>::initialize_table() const
}
}

// Store the tags of the original entries
_tags.resize(_length);
if (_raw_entries.size() > 0) {
for (size_t i = 0; i < _length; i++) {
_tags[i] = _raw_entries[i].get_origin_tag();
}
}
_ram_table_generated_in_builder = true;
}

Expand All @@ -100,6 +115,7 @@ template <typename Builder> void ram_table<Builder>::initialize_table() const
template <typename Builder>
ram_table<Builder>::ram_table(const ram_table& other)
: _raw_entries(other._raw_entries)
, _tags(other._tags)
, _index_initialized(other._index_initialized)
, _length(other._length)
, _ram_id(other._ram_id)
Expand All @@ -117,6 +133,7 @@ ram_table<Builder>::ram_table(const ram_table& other)
template <typename Builder>
ram_table<Builder>::ram_table(ram_table&& other)
: _raw_entries(other._raw_entries)
, _tags(other._tags)
, _index_initialized(other._index_initialized)
, _length(other._length)
, _ram_id(other._ram_id)
Expand All @@ -135,6 +152,7 @@ ram_table<Builder>::ram_table(ram_table&& other)
template <typename Builder> ram_table<Builder>& ram_table<Builder>::operator=(const ram_table& other)
{
_raw_entries = other._raw_entries;
_tags = other._tags;
_length = other._length;
_ram_id = other._ram_id;
_index_initialized = other._index_initialized;
Expand All @@ -161,6 +179,7 @@ template <typename Builder> ram_table<Builder>& ram_table<Builder>::operator=(ra
_ram_table_generated_in_builder = other._ram_table_generated_in_builder;
_all_entries_written_to_with_constant_index = other._all_entries_written_to_with_constant_index;
_context = other._context;
_tags = other._tags;
return *this;
}

Expand All @@ -176,8 +195,8 @@ template <typename Builder> field_t<Builder> ram_table<Builder>::read(const fiel
if (_context == nullptr) {
_context = index.get_context();
}

if (uint256_t(index.get_value()) >= _length) {
const auto native_index = uint256_t(index.get_value());
if (native_index >= _length) {
// TODO: what's best practise here? We are assuming that this action will generate failing constraints,
// and we set failure message here so that it better describes the point of failure.
// However, we are not *ensuring* that failing constraints are generated at the point that `failure()` is
Expand All @@ -197,8 +216,15 @@ template <typename Builder> field_t<Builder> ram_table<Builder>::read(const fiel
index_wire = field_pt::from_witness_index(_context, _context->put_constant_variable(index.get_value()));
}

uint32_t output_idx = _context->read_RAM_array(_ram_id, index_wire.normalize().get_witness_index());
return field_pt::from_witness_index(_context, output_idx);
uint32_t output_idx = _context->read_RAM_array(_ram_id, index_wire.get_normalized_witness_index());
auto element = field_pt::from_witness_index(_context, output_idx);

const size_t cast_index = static_cast<size_t>(static_cast<uint64_t>(native_index));
// If the index is legitimate, restore the tag
if (native_index < _length) {
element.set_origin_tag(_tags[cast_index]);
}
return element;
}

/**
Expand All @@ -224,7 +250,7 @@ template <typename Builder> void ram_table<Builder>::write(const field_pt& index

initialize_table();
field_pt index_wire = index;
auto native_index = index.get_value();
const auto native_index = uint256_t(index.get_value());
if (index.is_constant()) {
// need to write every array element at a constant index before doing reads/writes at prover-defined indices
index_wire = field_pt::from_witness_index(_context, _context->put_constant_variable(native_index));
Expand All @@ -247,7 +273,13 @@ template <typename Builder> void ram_table<Builder>::write(const field_pt& index

_index_initialized[cast_index] = true;
} else {
_context->write_RAM_array(_ram_id, index_wire.normalize().get_witness_index(), value_wire.get_witness_index());
_context->write_RAM_array(
_ram_id, index_wire.get_normalized_witness_index(), value_wire.get_normalized_witness_index());
}
// Update the value of the stored tag, if index is legitimate

if (native_index < _length) {
_tags[cast_index] = value.get_origin_tag();
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once
#include "../circuit_builders/circuit_builders_fwd.hpp"
#include "../field/field.hpp"
#include "barretenberg/transcript/origin_tag.hpp"

namespace bb::stdlib {

Expand Down Expand Up @@ -48,6 +49,8 @@ template <typename Builder> class ram_table {

private:
std::vector<field_pt> _raw_entries;
// Origin Tags for detection of dangerous interactions within stdlib primitives
mutable std::vector<OriginTag> _tags;
mutable std::vector<bool> _index_initialized;
size_t _length = 0;
mutable size_t _ram_id = 0; // Builder identifier for this ROM table
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "barretenberg/circuit_checker/circuit_checker.hpp"
#include "barretenberg/numeric/random/engine.hpp"
#include "barretenberg/stdlib_circuit_builders/ultra_circuit_builder.hpp"
#include "barretenberg/transcript/origin_tag.hpp"
#include "ram_table.hpp"

using namespace bb;
Expand All @@ -15,8 +16,58 @@ using ram_table_ct = stdlib::ram_table<Builder>;
namespace {
auto& engine = numeric::get_debug_randomness();
}
STANDARD_TESTING_TAGS

/**
* @brief Check that Origin Tags within the ram table are propagated correctly (when we lookup an element it has the
* same tag as the one inserted originally)
*
*/
TEST(RamTable, TagCorrectness)
{

Builder builder;
std::vector<field_ct> table_values;

// Generate random witnesses
field_ct entry_1 = witness_ct(&builder, bb::fr::random_element());
field_ct entry_2 = witness_ct(&builder, bb::fr::random_element());
field_ct entry_3 = witness_ct(&builder, bb::fr::random_element());

// Tag them with 3 different tags
entry_1.set_origin_tag(submitted_value_origin_tag);
entry_2.set_origin_tag(challenge_origin_tag);
// The last tag is an instant death tag, that triggers a runtime failure if any computation happens on the element
entry_3.set_origin_tag(instant_death_tag);

table_values.emplace_back(entry_1);
table_values.emplace_back(entry_2);
table_values.emplace_back(entry_3);

// Initialize the table
ram_table_ct table(table_values);

// Check that each element has the same tag as original entries
EXPECT_EQ(table.read(field_ct(0)).get_origin_tag(), submitted_value_origin_tag);
EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 0))).get_origin_tag(), submitted_value_origin_tag);
EXPECT_EQ(table.read(field_ct(1)).get_origin_tag(), challenge_origin_tag);
EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 1))).get_origin_tag(), challenge_origin_tag);

// Replace one of the elements in the table with a new one
entry_2.set_origin_tag(next_challenge_tag);
table.write(field_ct(1), entry_2);

// Check that the tag has been updated accordingly
EXPECT_EQ(table.read(field_ct(1)).get_origin_tag(), next_challenge_tag);
EXPECT_EQ(table.read(field_ct(witness_ct(&builder, 1))).get_origin_tag(), next_challenge_tag);

#ifndef NDEBUG
// Check that interacting with the poisoned element causes a runtime error
EXPECT_THROW(table.read(0) + table.read(2), std::runtime_error);
#endif
}

TEST(ram_table, ram_table_init_read_consistency)
TEST(RamTable, RamTableInitReadConsistency)
{
Builder builder;

Expand Down Expand Up @@ -50,7 +101,7 @@ TEST(ram_table, ram_table_init_read_consistency)
EXPECT_EQ(verified, true);
}

TEST(ram_table, ram_table_read_write_consistency)
TEST(RamTable, RamTableReadWriteConsistency)
{
Builder builder;
const size_t table_size = 10;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ template <typename Builder> rom_table<Builder>::rom_table(const std::vector<fiel
// if this is the case we might not have a valid pointer to a Builder
// We get around this, by initializing the table when `operator[]` is called
// with a non-const field element.

// Initialize tags
_tags.resize(raw_entries.size());
for (size_t i = 0; i < length; ++i) {
_tags[i] = raw_entries[i].get_origin_tag();
}
}

// initialize the table once we perform a read. This ensures we always have a valid
Expand All @@ -37,8 +43,11 @@ template <typename Builder> void rom_table<Builder>::initialize_table() const
// populate table. Table entries must be normalized and cannot be constants
for (const auto& entry : raw_entries) {
if (entry.is_constant()) {
entries.emplace_back(
field_pt::from_witness_index(context, context->put_constant_variable(entry.get_value())));
auto fixed_witness =
field_pt::from_witness_index(context, context->put_constant_variable(entry.get_value()));
fixed_witness.set_origin_tag(entry.get_origin_tag());
entries.emplace_back(fixed_witness);

} else {
entries.emplace_back(entry.normalize());
}
Expand All @@ -49,13 +58,19 @@ template <typename Builder> void rom_table<Builder>::initialize_table() const
context->set_ROM_element(rom_id, i, entries[i].get_witness_index());
}

// Preserve tags to restore them in the future lookups
_tags.resize(raw_entries.size());
for (size_t i = 0; i < length; ++i) {
_tags[i] = raw_entries[i].get_origin_tag();
}
initialized = true;
}

template <typename Builder>
rom_table<Builder>::rom_table(const rom_table& other)
: raw_entries(other.raw_entries)
, entries(other.entries)
, _tags(other._tags)
, length(other.length)
, rom_id(other.rom_id)
, initialized(other.initialized)
Expand All @@ -66,6 +81,7 @@ template <typename Builder>
rom_table<Builder>::rom_table(rom_table&& other)
: raw_entries(other.raw_entries)
, entries(other.entries)
, _tags(other._tags)
, length(other.length)
, rom_id(other.rom_id)
, initialized(other.initialized)
Expand All @@ -76,6 +92,7 @@ template <typename Builder> rom_table<Builder>& rom_table<Builder>::operator=(co
{
raw_entries = other.raw_entries;
entries = other.entries;
_tags = other._tags;
length = other.length;
rom_id = other.rom_id;
initialized = other.initialized;
Expand All @@ -87,6 +104,7 @@ template <typename Builder> rom_table<Builder>& rom_table<Builder>::operator=(ro
{
raw_entries = other.raw_entries;
entries = other.entries;
_tags = other._tags;
length = other.length;
rom_id = other.rom_id;
initialized = other.initialized;
Expand All @@ -112,13 +130,24 @@ template <typename Builder> field_t<Builder> rom_table<Builder>::operator[](cons
if (context == nullptr) {
context = index.get_context();
}

initialize_table();
if (uint256_t(index.get_value()) >= length) {
const auto native_index = uint256_t(index.get_value());
if (native_index >= length) {
context->failure("rom_table: ROM array access out of bounds");
}

uint32_t output_idx = context->read_ROM_array(rom_id, index.normalize().get_witness_index());
return field_pt::from_witness_index(context, output_idx);
uint32_t output_idx = context->read_ROM_array(rom_id, index.get_normalized_witness_index());
auto element = field_pt::from_witness_index(context, output_idx);

const size_t cast_index = static_cast<size_t>(static_cast<uint64_t>(native_index));

// If the index is legitimate, restore the tag
if (native_index < length) {

element.set_origin_tag(_tags[cast_index]);
}
return element;
}

template class rom_table<bb::UltraCircuitBuilder>;
Expand Down
Loading

0 comments on commit d77e473

Please sign in to comment.