Skip to content

Commit

Permalink
[CI] Apply linting rules to caffe tests
Browse files Browse the repository at this point in the history
  • Loading branch information
blackkker committed Jul 7, 2022
1 parent 17b8425 commit 0d6e77d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
4 changes: 2 additions & 2 deletions tests/lint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ python3 -m pylint python/tvm --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint vta/python/vta --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_cmsisnn --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/relay/aot/*.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/ci --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/frontend/caffe/test_forward.py --rcfile="$(dirname "$0")"/pylintrc

25 changes: 11 additions & 14 deletions tests/python/frontend/caffe/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,9 @@
This article is a test script to test Caffe operator with Relay.
"""
import os

os.environ["GLOG_minloglevel"] = "2"
import sys
import logging

logging.basicConfig(level=logging.ERROR)

import numpy as np

from google.protobuf import text_format
import caffe
from caffe import layers as L, params as P
Expand All @@ -37,9 +32,13 @@
import tvm
import tvm.testing
from tvm import relay
from tvm.contrib import utils, graph_executor
from tvm.contrib import graph_executor
from tvm.contrib.download import download_testdata

os.environ["GLOG_minloglevel"] = "2"

logging.basicConfig(level=logging.ERROR)

CURRENT_DIR = os.path.join(os.path.expanduser("~"), ".tvm_test_data", "caffe_test")

#######################################################################
Expand All @@ -57,7 +56,8 @@ def _list_to_str(ll):
"""Convert list or tuple to str, separated by underline."""
if isinstance(ll, (tuple, list)):
tmp = [str(i) for i in ll]
return "_".join(tmp)
res = "_".join(tmp)
return res


def _gen_filename_str(op_name, data_shape, *args, **kwargs):
Expand Down Expand Up @@ -221,11 +221,7 @@ def _run_tvm(data, proto_file, blob_file):


def _compare_caffe_tvm(caffe_out, tvm_out, is_network=False):
<<<<<<< HEAD
for i in range(len(caffe_out)):
=======
for i, _ in enumerate(caffe_out):
>>>>>>> 2b7ec741f... fix error
if is_network:
caffe_out[i] = caffe_out[i][:1]
tvm.testing.assert_allclose(caffe_out[i], tvm_out[i], rtol=1e-5, atol=1e-5)
Expand Down Expand Up @@ -965,8 +961,9 @@ def _test_embed(data, **kwargs):


def test_forward_Embed():
"""Embed"""
k = 20
data = [i for i in range(k)]
data = list(i for i in range(k))
np.random.shuffle(data)
# dimension is 1
data = np.asarray(data)
Expand Down Expand Up @@ -1197,4 +1194,4 @@ def test_forward_Inceptionv1():
test_forward_Mobilenetv2()
test_forward_Alexnet()
test_forward_Resnet50()
test_forward_Inceptionv1()
test_forward_Inceptionv1()

0 comments on commit 0d6e77d

Please sign in to comment.