Skip to content

Commit

Permalink
feat(compression): allocate resource variables in persistent buffer (#…
Browse files Browse the repository at this point in the history
…3013)

Allocate resource variables in a persistent buffer when the input
tensor is compressed. Extend tests to validate operation.

BUG=part of #2636
  • Loading branch information
rkuester authored Dec 16, 2024
1 parent b2f2718 commit 9a32964
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 12 deletions.
51 changes: 49 additions & 2 deletions tensorflow/lite/micro/kernels/assign_variable.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_graph.h"
#include "tensorflow/lite/micro/micro_log.h"
#include "tensorflow/lite/micro/micro_resource_variable.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"

namespace tflite {
Expand All @@ -35,6 +36,20 @@ namespace {
constexpr int kInputVariableId = 0;
constexpr int kInputValue = 1;

#ifdef USE_TFLM_COMPRESSION

struct OpData {
// scratch buffer for compressed input tensor
int scratch_index;
};

void* Init(TfLiteContext* context, const char* buffer, size_t length) {
TFLITE_DCHECK(context->AllocatePersistentBuffer != nullptr);
return context->AllocatePersistentBuffer(context, sizeof(OpData));
}

#endif // USE_TFLM_COMPRESSION

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 0);
Expand Down Expand Up @@ -70,6 +85,17 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
context, input_value));
}

#ifdef USE_TFLM_COMPRESSION

TFLITE_DCHECK(node->user_data != nullptr);
OpData* data = static_cast<OpData*>(node->user_data);
// Compression scratch buffers.
// These will only be allocated if the tensor is compressed.
data->scratch_index =
micro_context->AllocateDecompressionScratchBuffer(node, kInputValue);

#endif // USE_TFLM_COMPRESSION

micro_context->DeallocateTempTfLiteTensor(input_value);
return kTfLiteOk;
}
Expand All @@ -93,15 +119,36 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
"ResourceVariables and pass it to the interpreter.");
return kTfLiteError;
}

#ifdef USE_TFLM_COMPRESSION
OpData* data = static_cast<OpData*>(node->user_data);
const CompressionTensorData* comp_td =
micro_context->GetTensorCompressionData(node, kInputValue);
const void* buffer = tflite::micro::GetTensorData<void>(
micro_context, input_value, comp_td, data->scratch_index);
#else // USE_TFLM_COMPRESSION
const void* buffer = tflite::micro::GetTensorData<void>(input_value);
#endif // USE_TFLM_COMPRESSION

TF_LITE_ENSURE_OK(context,
resources->Assign(input_id->data.i32[0], input_value));
resources->Assign(input_id->data.i32[0],
EvalTensorBytes(input_value), buffer));
return kTfLiteOk;
}

} // namespace.

#ifdef USE_TFLM_COMPRESSION

TFLMRegistration Register_ASSIGN_VARIABLE() {
return tflite::micro::RegisterOp(Init, Prepare, Eval);

#else // USE_TFLM_COMPRESSION

TFLMRegistration Register_ASSIGN_VARIABLE() {
return tflite::micro::RegisterOp(nullptr, Prepare, Eval);

#endif // USE_TFLM_COMPRESSION
}

} // namespace tflite
11 changes: 6 additions & 5 deletions tensorflow/lite/micro/micro_resource_variable.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -113,8 +113,8 @@ TfLiteStatus MicroResourceVariables::Allocate(int id, TfLiteContext* context,
return kTfLiteOk;
}

TfLiteStatus MicroResourceVariables::Assign(int id,
const TfLiteEvalTensor* tensor) {
TfLiteStatus MicroResourceVariables::Assign(int id, size_t count_bytes,
const void* input_buffer) {
if (id < 0 || id >= num_resource_variables_) {
MicroPrintf("Attempting to read non-existent resource variable %d", id);
return kTfLiteError;
Expand All @@ -128,8 +128,9 @@ TfLiteStatus MicroResourceVariables::Assign(int id,
"with a TfLiteTensor first.");
return kTfLiteError;
}
TFLITE_DCHECK(EvalTensorBytes(tensor) == variable.bytes);
memcpy(variable.resource_buffer, tensor->data.raw, variable.bytes);
TFLITE_DCHECK(count_bytes == variable.bytes);
TFLITE_DCHECK(input_buffer != nullptr);
memcpy(variable.resource_buffer, input_buffer, variable.bytes);
return kTfLiteOk;
}

Expand Down
6 changes: 3 additions & 3 deletions tensorflow/lite/micro/micro_resource_variable.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,10 +46,10 @@ class MicroResourceVariables {
TfLiteStatus Allocate(int id, TfLiteContext* context,
const TfLiteTensor* tensor);

// Copies input tensor contents to the resource buffer.
// Copies input_buffer contents to the resource buffer.
// AllocateResourceVariable with a TFLite tensor must have been called first
// in order to allocate the resource buffer.
TfLiteStatus Assign(int id, const TfLiteEvalTensor* tensor);
TfLiteStatus Assign(int id, size_t count_bytes, const void* input_buffer);

// Zeros out all resource buffers.
TfLiteStatus ResetAll();
Expand Down
7 changes: 5 additions & 2 deletions tensorflow/lite/micro/micro_resource_variable_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2024 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/micro/micro_resource_variable.h"

#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/micro/micro_utils.h"
#include "tensorflow/lite/micro/test_helpers.h"
#include "tensorflow/lite/micro/testing/micro_test.h"

Expand Down Expand Up @@ -120,7 +121,9 @@ TF_LITE_MICRO_TEST(VerifyAssignAndReadResourceBuffer) {

.type = kTfLiteFloat32,
};
resource_variables->Assign(id, &assign_tensor);
resource_variables->Assign(
id, tflite::EvalTensorBytes(&assign_tensor),
tflite::micro::GetTensorData<void>(&assign_tensor));

int32_t buffer[32];
TfLiteEvalTensor read_tensor = {
Expand Down

0 comments on commit 9a32964

Please sign in to comment.