Skip to content

Commit

Permalink
This is PR #12130.
Browse files Browse the repository at this point in the history
  • Loading branch information
leandron committed Jul 25, 2022
1 parent 15cf56e commit 8ae520b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
8 changes: 5 additions & 3 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,11 @@ def _convert_pooling(
_op.nn.global_max_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
if pool_type == "GlobalAveragePooling2D":
return _convert_flatten(
_op.nn.global_avg_pool2d(inexpr, **global_pool_params), keras_layer, etab, data_layout
)
global_avg_pool2d = _op.nn.global_avg_pool2d(inexpr, **global_pool_params)
keep_dims = len(keras_layer.input.shape) == len(keras_layer.output.shape)
if keep_dims:
return global_avg_pool2d
return _convert_flatten(global_avg_pool2d, keras_layer, etab, data_layout)
pool_h, pool_w = keras_layer.pool_size
stride_h, stride_w = keras_layer.strides
params = {
Expand Down
35 changes: 26 additions & 9 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,11 @@ def representative_data_gen():
)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, data_in.name.replace(":0", ""))
if tf.__version__ < LooseVersion("2.9"):
input_node = data_in.name.replace(":0", "")
else:
input_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, input_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-2, atol=1e-2
)
Expand Down Expand Up @@ -1934,10 +1938,12 @@ def _test_abs(data, quantized, int_quant_dtype=tf.int8):
# TFLite 2.6.x upgrade support
if tf.__version__ < LooseVersion("2.6.1"):
in_node = ["serving_default_input_int8"]
else:
elif tf.__version__ < LooseVersion("2.9"):
in_node = (
["serving_default_input_int16"] if int_quant_dtype == tf.int16 else ["tfl.quantize"]
)
else:
in_node = "serving_default_input"

tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
Expand Down Expand Up @@ -1965,8 +1971,10 @@ def _test_rsqrt(data, quantized, int_quant_dtype=tf.int8):
tf.math.rsqrt, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
in_node = ["tfl.quantize"]

if tf.__version__ < LooseVersion("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
Expand Down Expand Up @@ -2047,7 +2055,10 @@ def _test_cos(data, quantized, int_quant_dtype=tf.int8):
tf.math.cos, data, int_quant_dtype=int_quant_dtype
)
tflite_output = run_tflite_graph(tflite_model_quant, data)
in_node = ["tfl.quantize"]
if tf.__version__ < LooseVersion("2.9"):
in_node = ["tfl.quantize"]
else:
in_node = "serving_default_input"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
Expand Down Expand Up @@ -2955,7 +2966,6 @@ def _test_quantize_dequantize(data):
add = tf.keras.layers.Add()([data_in, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -2965,7 +2975,11 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
if tf.__version__ < LooseVersion("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)
Expand All @@ -2982,7 +2996,6 @@ def _test_quantize_dequantize_const(data):
add = tf.keras.layers.Add()([data, relu])
concat = tf.keras.layers.Concatenate(axis=0)([relu, add])
keras_model = tf.keras.models.Model(inputs=data_in, outputs=concat)
input_name = data_in.name.split(":")[0]

# To create quantized values with dynamic range of activations, needs representative dataset
def representative_data_gen():
Expand All @@ -2992,7 +3005,11 @@ def representative_data_gen():
tflite_model_quant = _quantize_keras_model(keras_model, representative_data_gen, True, True)

tflite_output = run_tflite_graph(tflite_model_quant, data)
tvm_output = run_tvm_graph(tflite_model_quant, data, input_name)
if tf.__version__ < LooseVersion("2.9"):
in_node = data_in.name.split(":")[0]
else:
in_node = "serving_default_" + data_in.name + ":0"
tvm_output = run_tvm_graph(tflite_model_quant, data, in_node)
tvm.testing.assert_allclose(
np.squeeze(tvm_output[0]), np.squeeze(tflite_output[0]), rtol=1e-5, atol=1e-2
)
Expand Down

0 comments on commit 8ae520b

Please sign in to comment.