[Converter] Support GRU operator conversion with separated_rnn_gate_calc=False #1013
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: Unit tests | |
on: | |
push: | |
branches: | |
- main | |
pull_request: | |
branches: | |
- main | |
jobs: | |
unit-test: | |
name: Unit tests for ${{ matrix.os }}, PyTorch ${{ matrix.vers.pt_ver }} | |
runs-on: ${{ matrix.os }} | |
strategy: | |
fail-fast: false | |
matrix: | |
os: ["ubuntu-latest", "macos-13", "windows-latest"] | |
vers: [ {pt_ver: "1.6.0", tv_ver: "0.7.0"}, {pt_ver: "latest", tv_ver: "latest"} ] | |
include: | |
- os: macos-latest | |
vers: | |
pt_ver: latest | |
tv_ver: latest | |
defaults: | |
run: | |
shell: bash -l {0} | |
steps: | |
- uses: actions/checkout@v4 | |
- uses: conda-incubator/setup-miniconda@v3 | |
with: | |
auto-update-conda: true | |
python-version: 3.8 | |
- name: Install PyTorch | |
env: | |
PYTORCH_VER: ${{ matrix.vers.pt_ver }} | |
TORCHVISION_VER: ${{ matrix.vers.tv_ver }} | |
run: | | |
if [ "$RUNNER_OS" == "macOS" ]; then | |
if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then | |
conda install pytorch::pytorch torchvision -c pytorch | |
elif [[ "$TORCHVISION_VER" == "0.9."* || "$TORCHVISION_VER" == "0.10."* ]]; then | |
conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER pillow=6 -c pytorch | |
else | |
conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER -c pytorch | |
fi | |
elif [ "$RUNNER_OS" == "Windows" ]; then | |
if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then | |
conda install pytorch torchvision cpuonly pillow=6 -c pytorch | |
else | |
conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly pillow=6 -c pytorch | |
fi | |
else | |
if [[ "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then | |
conda install pytorch torchvision cpuonly -c pytorch | |
else | |
conda install pytorch=$PYTORCH_VER torchvision=$TORCHVISION_VER cpuonly -c pytorch | |
fi | |
fi | |
- name: Install TinyNeuralNetwork | |
run: pip install . | |
- name: Install test dependencies | |
env: | |
PYTORCH_VER: ${{ matrix.vers.pt_ver }} | |
TORCHVISION_VER: ${{ matrix.vers.tv_ver }} | |
run: | | |
if [[ "$RUNNER_OS" == "Linux" && "$TORCHVISION_VER" == "latest" && "$PYTORCH_VER" == "latest" ]]; then | |
pip install black 'ruff<0.0.234' 'tensorflow<2.12' scipy interval | |
else | |
pip install black 'ruff<0.0.234' tensorflow scipy interval | |
fi | |
- name: Lint checks | |
run: python -m ruff . | |
- name: Style checks | |
run: python -m black . | |
- name: Run unit tests | |
run: | | |
cd tests | |
python import_test.py -v | |
python converter_op_test.py -v | |
python converter_optimizer_test.py -v | |
python quantizer_test.py -v | |
python modifier_prune_test.py -v | |
python modifier_test.py -v | |
python cross_layer_equalization_test.py -v | |
- name: Run integration tests | |
run: | | |
source activate.sh | |
cd tests | |
rm -rf out || true | |
python converter_test.py -v -k mobilenet | |
if [ "$RUNNER_OS" != "Windows" ]; then | |
rm -rf out || true | |
python converter_qat_test.py -v -k mobilenet | |
fi | |
rm -rf out || true | |
python tracer_test.py -v -k mobilenet | |
misc-check: | |
name: Misc checks | |
runs-on: ubuntu-latest | |
defaults: | |
run: | |
shell: bash -l {0} | |
steps: | |
- uses: actions/checkout@v2 | |
- uses: conda-incubator/setup-miniconda@v2 | |
with: | |
auto-update-conda: true | |
python-version: 3.8 | |
- name: Install PyTorch | |
run: conda install pytorch torchvision cpuonly -c pytorch | |
- name: Install TinyNeuralNetwork | |
run: pip install . | |
- name: Check op matrix | |
run: | | |
python scripts/gen_op_docs.py | |
if [[ $(git diff --stat) != '' ]]; then | |
echo "OP matrix is not updated. Please run python scripts/gen_op_docs.py and update your commit" | |
echo "Differences:" | |
git --no-pager diff | |
exit 1 | |
fi | |
- name: Check quantization support | |
run: | | |
python scripts/gen_quantized_docs.py | |
if [[ $(git diff --stat) != '' ]]; then | |
echo "Quantization support doc is not updated. Please run python scripts/gen_quantized_docs.py and update your commit" | |
echo "Differences:" | |
git --no-pager diff | |
exit 1 | |
fi |