Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
Fix TFLite 2.9 tests (apache#12130)
Browse files Browse the repository at this point in the history
This pr fixes the tests that will be broken when we will update TFLite to
the 2.9 version.

We will update TensorFlow and TFLite versions to 2.9 so that we can
benefit from improvements in packaging to support multiple platforms
and Operating Systems.
  • Loading branch information
Nicola Lancellotti authored and xinetzone committed Nov 25, 2022
1 parent 09bc192 commit 1ebbec8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 11 deletions.
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,11 @@ def _convert_pooling(
_op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
if pool_type == "GlobalAveragePooling2D":
return _convert_flatten(
_op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params)
keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape)
if keep_dims:
return global_avg_pool2d
return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {
Expand Down
33 changes: 25 additions & 8 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,10 @@ def representative_data_gen():
input_node = subgraph.Tensors(model_input).Name().decode("utf-8")

tflite_output = run_tflite_graph(tflite_model_quant, data)
if tf.__version__ < LooseVersion("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
Expand Down Expand Up @@ -1997,10 +2001,12 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
# TFLite 2.6.x upgrade support
if tf.__version__ < LooseVersion("2.6.1"):
in_node = ["serving_default_input_int8"]
else:
elif tf.__version__ < LooseVersion("2.9"):
in_node = (
["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"]
)
else:
in_node = "serving_default_input"

tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
Expand Down Expand Up @@ -2028,8 +2034,10 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
in_node = ["tfl.quantize"]

if tf.__version__ < LooseVersion("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
Expand Down Expand Up @@ -2110,7 +2118,10 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
tf.math.cos, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
in_node = ["tfl.quantize"]
if tf.__version__ < LooseVersion("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
Expand Down Expand Up @@ -3024,7 +3035,6 @@ def _test_quantize_dequantize(data):
add = tf.keras.layers.Add()([data_in, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -3034,7 +3044,11 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
if tf.__version__ < LooseVersion("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)
Expand All @@ -3051,7 +3065,6 @@ def _test_quantize_dequantize_const(data):
add = tf.keras.layers.Add()([data, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -3061,7 +3074,11 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
if tf.__version__ < LooseVersion("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)
Expand Down

0 comments on commit 1ebbec8

Please sign in to comment.