Skip to content

Commit

Permalink
Adding support for reading binary files in tool --function_input=. (#…
Browse files Browse the repository at this point in the history
…9328)

Raw binary data read from a file can now be used to initialize buffer
view contents. No interpretation is done on the data. The shape and
element type information is still required.

Example:
```
iree-benchmark-module ... --function_input=4x2xi32=@some/file.bin
```
  • Loading branch information
benvanik authored Jun 4, 2022
1 parent f4189ca commit a30c840
Show file tree
Hide file tree
Showing 7 changed files with 231 additions and 14 deletions.
48 changes: 48 additions & 0 deletions runtime/src/iree/hal/string_util.c
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,54 @@ IREE_API_EXPORT iree_status_t iree_hal_format_element_type(
: iree_ok_status();
}

IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
iree_string_view_t value, iree_host_size_t shape_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
iree_hal_element_type_t* out_element_type) {
*out_shape_rank = 0;
*out_element_type = IREE_HAL_ELEMENT_TYPE_NONE;

// Strip whitespace that may come along (linefeeds/etc).
value = iree_string_view_trim(value);
value = iree_string_view_strip_prefix(value, IREE_SV("\""));
value = iree_string_view_strip_suffix(value, IREE_SV("\""));
if (iree_string_view_is_empty(value)) {
// Empty lines are invalid; need at least the shape/type information.
return iree_make_status(IREE_STATUS_INVALID_ARGUMENT, "empty string input");
}

// The part of the string corresponding to the shape, e.g. 1x2x3.
iree_string_view_t shape_str = iree_string_view_empty();
// The part of the string corresponding to the type, e.g. f32
iree_string_view_t type_str = iree_string_view_empty();
// The part of the string corresponding to the buffer data, e.g. 1 2 3 4 5 6
// We ignore this.
iree_string_view_t data_str = iree_string_view_empty();

iree_string_view_t shape_and_type_str = value;
iree_string_view_split(value, '=', &shape_and_type_str, &data_str);
iree_host_size_t last_x_index = iree_string_view_find_last_of(
shape_and_type_str, IREE_SV("x"), IREE_STRING_VIEW_NPOS);
if (last_x_index == IREE_STRING_VIEW_NPOS) {
// Scalar.
type_str = shape_and_type_str;
} else {
// Has a shape.
shape_str = iree_string_view_substr(shape_and_type_str, 0, last_x_index);
type_str = iree_string_view_substr(shape_and_type_str, last_x_index + 1,
IREE_STRING_VIEW_NPOS);
}

// AxBxC...
IREE_RETURN_IF_ERROR(iree_hal_parse_shape(shape_str, shape_capacity,
out_shape, out_shape_rank));

// f32, i32, etc
IREE_RETURN_IF_ERROR(iree_hal_parse_element_type(type_str, out_element_type));

return iree_ok_status();
}

// Parses a string of two character pairs representing hex numbers into bytes.
static void iree_hal_hex_string_to_bytes(const char* from, uint8_t* to,
ptrdiff_t num) {
Expand Down
8 changes: 8 additions & 0 deletions runtime/src/iree/hal/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ IREE_API_EXPORT iree_status_t iree_hal_format_element_type(
iree_hal_element_type_t element_type, iree_host_size_t buffer_capacity,
char* buffer, iree_host_size_t* out_buffer_length);

// Parses a shape and type from a `[shape]x[type]` string |value|.
// Behaves the same as calling iree_hal_parse_shape and
// iree_hal_parse_element_type. Ignores any training `=`.
IREE_API_EXPORT iree_status_t iree_hal_parse_shape_and_element_type(
iree_string_view_t value, iree_host_size_t shape_capacity,
iree_hal_dim_t* out_shape, iree_host_size_t* out_shape_rank,
iree_hal_element_type_t* out_element_type);

// Parses a serialized element of |element_type| to its in-memory form.
// |data_ptr| must be at least large enough to contain the bytes of the element.
// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
Expand Down
78 changes: 75 additions & 3 deletions runtime/src/iree/hal/string_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ using Shape = std::vector<iree_hal_dim_t>;
StatusOr<Shape> ParseShape(const std::string& value) {
Shape shape(6);
iree_host_size_t actual_rank = 0;
iree_status_t status;
iree_status_t status = iree_ok_status();
do {
status =
iree_hal_parse_shape(iree_string_view_t{value.data(), value.size()},
Expand All @@ -50,7 +50,7 @@ StatusOr<Shape> ParseShape(const std::string& value) {
StatusOr<std::string> FormatShape(iree::span<const iree_hal_dim_t> value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
iree_status_t status;
iree_status_t status = iree_ok_status();
do {
status =
iree_hal_format_shape(value.data(), value.size(), buffer.size() + 1,
Expand All @@ -77,7 +77,7 @@ StatusOr<iree_hal_element_type_t> ParseElementType(const std::string& value) {
StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
std::string buffer(16, '\0');
iree_host_size_t actual_length = 0;
iree_status_t status;
iree_status_t status = iree_ok_status();
do {
status = iree_hal_format_element_type(value, buffer.size() + 1, &buffer[0],
&actual_length);
Expand All @@ -87,6 +87,34 @@ StatusOr<std::string> FormatElementType(iree_hal_element_type_t value) {
return std::move(buffer);
}

struct ShapeAndType {
Shape shape;
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
ShapeAndType() = default;
ShapeAndType(Shape shape, iree_hal_element_type_t element_type)
: shape(std::move(shape)), element_type(element_type) {}
};
static bool operator==(const ShapeAndType& lhs,
const ShapeAndType& rhs) noexcept {
return lhs.shape == rhs.shape && lhs.element_type == rhs.element_type;
}

// Parses a serialized set of shape dimensions and an element type.
StatusOr<ShapeAndType> ParseShapeAndElementType(const std::string& value) {
Shape shape(6);
iree_host_size_t actual_rank = 0;
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_status_t status = iree_ok_status();
do {
status = iree_hal_parse_shape_and_element_type(
iree_string_view_t{value.data(), value.size()}, shape.size(),
shape.data(), &actual_rank, &element_type);
shape.resize(actual_rank);
} while (iree_status_is_out_of_range(status));
IREE_RETURN_IF_ERROR(std::move(status));
return ShapeAndType(std::move(shape), element_type);
}

// Parses a serialized element of |element_type| to its in-memory form.
// |buffer| be at least large enough to contain the bytes of the element.
// For example, "1.2" of type IREE_HAL_ELEMENT_TYPE_FLOAT32 will write the 4
Expand Down Expand Up @@ -593,6 +621,50 @@ TEST(ElementTypeStringUtilTest, FormatElementType) {
IsOkAndHolds(Eq("f4")));
}

TEST(StringUtilTest, ParseShapeAndElementType) {
EXPECT_THAT(
ParseShapeAndElementType("1xi8"),
IsOkAndHolds(Eq(ShapeAndType(Shape{1}, IREE_HAL_ELEMENT_TYPE_INT_8))));
EXPECT_THAT(ParseShapeAndElementType("1x2xi16"),
IsOkAndHolds(
Eq(ShapeAndType(Shape{1, 2}, IREE_HAL_ELEMENT_TYPE_INT_16))));
EXPECT_THAT(
ParseShapeAndElementType("1x2x3x4x5x6x7x8x9xi32=invalid stuff here"),
IsOkAndHolds(Eq(ShapeAndType(Shape{1, 2, 3, 4, 5, 6, 7, 8, 9},
IREE_HAL_ELEMENT_TYPE_INT_32))));
}

TEST(StringUtilTest, ParseShapeAndElementTypeInvalid) {
EXPECT_THAT(ParseShapeAndElementType(""),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("0"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("="),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("abc"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1xf"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1xff23"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1xn3"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("x"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("x1"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1x"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("x1x2="),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1xx2"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("1x2x"),
StatusIs(StatusCode::kInvalidArgument));
EXPECT_THAT(ParseShapeAndElementType("0x-1"),
StatusIs(StatusCode::kInvalidArgument));
}

TEST(ElementStringUtilTest, ParseElement) {
EXPECT_THAT(ParseElement<int8_t>("-128"), IsOkAndHolds(Eq(INT8_MIN)));
EXPECT_THAT(ParseElement<int8_t>("127"), IsOkAndHolds(Eq(INT8_MAX)));
Expand Down
101 changes: 94 additions & 7 deletions runtime/src/iree/tools/utils/vm_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,88 @@

namespace iree {

Status ParseToVariantList(iree_hal_allocator_t* allocator,
// Creates a HAL buffer view with the given |metadata| and reads the contents
// from the file at |file_path|.
//
// The file contents are directly read in to memory with no processing.
static iree_status_t CreateBufferViewFromFile(
iree_string_view_t metadata, iree_string_view_t file_path,
iree_hal_allocator_t* device_allocator,
iree_hal_buffer_view_t** out_buffer_view) {
*out_buffer_view = NULL;

// Parse shape and element type used to allocate the buffer view.
iree_hal_element_type_t element_type = IREE_HAL_ELEMENT_TYPE_NONE;
iree_host_size_t shape_rank = 0;
iree_status_t shape_result = iree_hal_parse_shape_and_element_type(
metadata, 0, NULL, &shape_rank, &element_type);
if (!iree_status_is_ok(shape_result) &&
!iree_status_is_out_of_range(shape_result)) {
return shape_result;
} else if (shape_rank > 128) {
return iree_make_status(
IREE_STATUS_RESOURCE_EXHAUSTED,
"a shape rank of %zu is just a little bit excessive, eh?", shape_rank);
}
shape_result = iree_status_ignore(shape_result);
iree_hal_dim_t* shape =
(iree_hal_dim_t*)iree_alloca(shape_rank * sizeof(iree_hal_dim_t));
IREE_RETURN_IF_ERROR(iree_hal_parse_shape_and_element_type(
metadata, shape_rank, shape, &shape_rank, &element_type));

// TODO(benvanik): allow specifying the encoding.
iree_hal_encoding_type_t encoding_type =
IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR;

// Open the file for reading.
std::string file_path_str(file_path.data, file_path.size);
FILE* file = std::fopen(file_path_str.c_str(), "rb");
if (!file) {
return iree_make_status(iree_status_code_from_errno(errno),
"failed to open file '%.*s'", (int)file_path.size,
file_path.data);
}

iree_hal_buffer_params_t buffer_params = {0};
buffer_params.type = IREE_HAL_MEMORY_TYPE_DEVICE_LOCAL;
buffer_params.usage =
IREE_HAL_BUFFER_USAGE_DISPATCH | IREE_HAL_BUFFER_USAGE_TRANSFER;
struct read_params_t {
FILE* file;
} read_params = {
file,
};
iree_status_t status = iree_hal_buffer_view_generate_buffer(
device_allocator, shape, shape_rank, element_type, encoding_type,
buffer_params,
+[](iree_hal_buffer_mapping_t* mapping, void* user_data) {
auto* read_params = reinterpret_cast<read_params_t*>(user_data);
size_t bytes_read =
std::fread(mapping->contents.data, 1, mapping->contents.data_length,
read_params->file);
if (bytes_read != mapping->contents.data_length) {
return iree_make_status(IREE_STATUS_OUT_OF_RANGE,
"file contents truncated; expected %zu bytes "
"based on buffer view size",
mapping->contents.data_length);
}
return iree_ok_status();
},
&read_params, out_buffer_view);

std::fclose(file);

return status;
}

Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
iree::span<const std::string> input_strings,
iree_vm_list_t** out_list) {
*out_list = NULL;
vm::ref<iree_vm_list_t> variant_list;
IREE_RETURN_IF_ERROR(
iree_vm_list_create(/*element_type=*/nullptr, input_strings.size(),
iree_allocator_system(), &variant_list));
IREE_RETURN_IF_ERROR(iree_vm_list_create(
/*element_type=*/nullptr, input_strings.size(),
iree_hal_allocator_host_allocator(device_allocator), &variant_list));
for (size_t i = 0; i < input_strings.size(); ++i) {
iree_string_view_t input_view = iree_string_view_trim(iree_make_string_view(
input_strings[i].data(), input_strings[i].size()));
Expand All @@ -43,9 +117,22 @@ Status ParseToVariantList(iree_hal_allocator_t* allocator,
bool is_storage_reference = iree_string_view_consume_prefix(
&input_view, iree_make_cstring_view("&"));
iree_hal_buffer_view_t* buffer_view = nullptr;
IREE_RETURN_IF_ERROR(
iree_hal_buffer_view_parse(input_view, allocator, &buffer_view),
"parsing value '%.*s'", (int)input_view.size, input_view.data);
bool has_at = iree_string_view_find_char(input_view, '@', 0) !=
IREE_STRING_VIEW_NPOS;
if (has_at) {
// Referencing an external file; split into the portion used to
// initialize the buffer view and the file contents.
iree_string_view_t metadata, file_path;
iree_string_view_split(input_view, '@', &metadata, &file_path);
iree_string_view_consume_suffix(&metadata, iree_make_cstring_view("="));
IREE_RETURN_IF_ERROR(CreateBufferViewFromFile(
metadata, file_path, device_allocator, &buffer_view));
} else {
IREE_RETURN_IF_ERROR(iree_hal_buffer_view_parse(
input_view, device_allocator, &buffer_view),
"parsing value '%.*s'", (int)input_view.size,
input_view.data);
}
if (is_storage_reference) {
// Storage buffer reference; just take the storage for the buffer view -
// it'll still have whatever contents were specified (or 0) but we'll
Expand Down
6 changes: 2 additions & 4 deletions runtime/src/iree/tools/utils/vm_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@ namespace iree {
// Buffers should be in the IREE standard shaped buffer format:
// [shape]xtype=[value]
// described in iree/hal/api.h
// Uses |allocator| to allocate the buffers.
// Uses descriptors in |descs| for type information and validation.
// Uses |device_allocator| to allocate the buffers.
// The returned variant list must be freed by the caller.
Status ParseToVariantList(iree_hal_allocator_t* allocator,
Status ParseToVariantList(iree_hal_allocator_t* device_allocator,
iree::span<const std::string> input_strings,
iree_vm_list_t** out_list);

Expand All @@ -43,7 +42,6 @@ Status ParseToVariantList(iree_hal_allocator_t* allocator,
// [shape]xtype=[value]
// described in
// https://github.com/google/iree/tree/main/iree/hal/api.h
// Uses descriptors in |descs| for type information and validation.
Status PrintVariantList(iree_vm_list_t* variant_list, size_t max_element_count,
std::ostream* os);
inline Status PrintVariantList(iree_vm_list_t* variant_list, std::ostream* os) {
Expand Down
2 changes: 2 additions & 0 deletions tools/iree-benchmark-module-main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ IREE_FLAG_CALLBACK(
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Raw binary files can be read to provide buffer contents:\n"
" 2x2xi32=@some/file.bin\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");

Expand Down
2 changes: 2 additions & 0 deletions tools/iree-run-module-main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ IREE_FLAG_CALLBACK(
" 2x2xi32=1 2 3 4\n"
"Optionally, brackets may be used to separate the element values:\n"
" 2x2xi32=[[1 2][3 4]]\n"
"Raw binary files can be read to provide buffer contents:\n"
" 2x2xi32=@some/file.bin\n"
"Each occurrence of the flag indicates an input in the order they were\n"
"specified on the command line.");

Expand Down

0 comments on commit a30c840

Please sign in to comment.