Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Discrepancies in ONNX Runtime Inference Results on RISC-V #22530

Open
sarmentow opened this issue Oct 22, 2024 · 2 comments
Open

Discrepancies in ONNX Runtime Inference Results on RISC-V #22530

sarmentow opened this issue Oct 22, 2024 · 2 comments
Labels
contributions welcome lower priority issues for the core ORT teams

Comments

@sarmentow
Copy link

Describe the issue

We have successfully built ONNX Runtime Python wheels targeting the RISC-V architecture, using both the cross-compilation process outlined in the documentation and an emulated RISC-V Docker container running Ubuntu 22.04. Both builds completed without errors.

However, after training a model in PyTorch and exporting it to the ONNX format, we observed that the inference results from the ONNX Runtime Python package vary significantly across platforms. Specifically, the results from the RISC-V wheels we built (both the cross-compiled and the emulated versions) do not match the expected outputs seen from running inference in PyTorch before the ONNX export, nor do they match the outputs produced by the ONNX Runtime x64 wheel on the same model.

This leads us to believe that the issue lies in the ONNX Runtime's support for RISC-V.

Example Outputs

To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input [0] when using the pre-built ONNX Runtime wheels for x64:

[array([[ -9.021126,  17.9599  , -18.350208, -11.425449]], dtype=float32)]

In contrast, the output from the RISC-V wheel for the same model and input is:

[array([[  5.5013514 , -13.528254  ,  -8.2745905 ,  -0.89257914]], dtype=float32)]

Both outputs are from the same model, using the same input, highlighting the inconsistency.

Investigation

Through extensive troubleshooting, we have identified that this discrepancy occurs specifically when using torch.nn.Linear layers. Basic arithmetic operators (e.g., +, -, *, /) do not cause any issues. Furthermore, exporting the model using PyTorch's .pth format and running inference in a RISC-V environment works as expected, further reinforcing that the issue may reside within ONNX Runtime's handling of RISC-V architectures.
We are fairly sure this is a problem in ONNX Runtime since we have tested the model export using Pytorch's .pth format and it has worked fine in the RISC-V environment.

Reproduction

We have included the PyTorch training code, the Dockerfile for the build environment, and the scripts used to compare inference results between the platforms below.

Model Training Code

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np

np.random.seed(69)
torch.manual_seed(69)

# Define the dataset class
class CustomDataset(Dataset):
    def __init__(self, inputs, outputs):
        self.inputs = inputs
        self.outputs = outputs

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        input_val = self.inputs[idx]
        output_val = self.outputs[idx]
        return input_val, output_val

# Define the model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 20)  # input layer (1) -> hidden layer (20)
        self.fc2 = nn.Linear(20, 20)  # hidden layer (20) -> hidden layer (20)
        self.fc3 = nn.Linear(20, 4)  # hidden layer (20) -> output layer (4)

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # activation function for hidden layer
        x = torch.relu(self.fc2(x))  # activation function for hidden layer
        x = self.fc3(x)
        return x

# Define the inputs and outputs
inputs = np.array([0, 4, 8, 9, 10, 14, 15])
outputs = np.array([1, 1, 2, 2, 1, 2, 2])

# Create the dataset and data loader
dataset = CustomDataset(inputs.reshape(-1, 1), outputs)
data_loader = DataLoader(dataset, batch_size=7, shuffle=False)

# Initialize the model, loss function, and optimizer
model = Net()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Train the model
for epoch in range(1000):  # loop over the dataset multiple times
    for i, data in enumerate(data_loader, 0):
        inputs, labels = data
        inputs = torch.tensor(inputs, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.long)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    # print statistics
    if epoch % 100 == 0:
        print('Epoch %d, Loss: %.3f' % (epoch+1, loss.item()))

# Evaluate the model
def evaluate_model(model, inputs, expected_outputs):
    inputs = torch.tensor(inputs.reshape(-1, 1), dtype=torch.float32)
    outputs = model(inputs)
    _, predicted = torch.max(outputs, dim=1)
    print('Expected Outputs: ', expected_outputs)
    print('Predicted Outputs: ', predicted.detach().numpy())

# Test the model
test_inputs = np.array([0, 4, 8, 9, 10, 14, 15])
test_outputs = np.array([1, 1, 2, 2, 1, 1, 2])
evaluate_model(model, test_inputs, test_outputs)

# Export the model to ONNX
dummy_input = torch.randn(1, 1)
torch.onnx.export(model, dummy_input, 'winner.onnx', export_params=True, opset_version=17)

Dockerfile for Build Environment

FROM --platform=linux/riscv64 ubuntu:22.04
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && apt-get install -y \
    build-essential \
    git \
    python3 \
    python3-pip \
    python3-dev \
    python3-numpy \
    libssl-dev \
    wget \
    && apt-get clean
WORKDIR /workspace
CMD ["/bin/bash"]

We also built CMake from source in order to obey the version requirements for building ONNX Runtime. I have pushed an image to Docker Hub with CMake 3.30 installed to save you the hassle: docker pull sarmentow/onnxruntime-build-env-with-cmake.

ONNX Runtime Inference Comparison Code

import onnxruntime as ort
import numpy as np
MODEL_INPUT_SHAPE = (1,1)
model_path = "simple_nn.onnx"
session = ort.InferenceSession(model_path)
input_names = [session.get_inputs()[0].name]
output_names = [session.get_outputs()[0].name]

for i in [0, 4, 8, 9, 10, 14, 15]:
    print("Input:", i)
    outputs = session.run(output_names, {input_names[0]: np.array([i], dtype=np.float32).reshape(1, 1)})
    print(outputs)
    print(np.argmax(outputs))

Build Process

We used the following command to build the ONNX Runtime wheel for RISC-V (the build.py file at tool/ci_build/build.py:

python3 build.py --parallel 6 --config MinSizeRel --skip_tests --cmake_extra_defines CMAKE_CXX_FLAGS="-pthread -latomic ${CMAKE_CXX_FLAGS}" CMAKE_C_FLAGS="-pthread -latomic ${CMAKE_C_FLAGS}" CMAKE_EXE_LINKER_FLAGS="-pthread -latomic" --enable_pybind --build_wheel --wheel_name_suffix=riscv --build_dir=/workspace/onnxruntime/build/older --compile_no_warning_as_error --allow_running_as_root

Testing Environment

We utilized the following Dockerfile for the testing environment:

FROM --platform=linux/riscv64 ubuntu:22.04

ARG DEBIAN_FRONTEND=noninteractive
RUN <<EOF
set -e
apt-get update
apt-get install -y --no-install-recommends \
  busybox-static=1:1.30.1-7ubuntu3 \
  python3 python3-pip python3-dev \
  libatomic1 \
  libopenblas-dev
EOF

WORKDIR /workspace
COPY ./requirements.txt .
COPY ./winner.onnx .

RUN <<EOF
set -e
pip install -r requirements.txt
EOF


COPY ./simple_inference.py .

ENTRYPOINT ["bash"]

Dependencies

We installed ONNX Runtime and other necessary packages as listed in the requirements.txt file:

https://github.com/sarmentow/riscv-wheels/raw/refs/heads/main/onnxruntime_riscv-1.19.2-cp310-cp310-linux_riscv64.whl
-i https://think-and-dev.github.io/riscv-python-wheels/pip-index/
numpy == 1.26.2

We used an alternative pip index to install Numpy RISC-V wheels, which we believe are not causing the issue.

Conclusion

Based on our testing, it seems that the issue is specific to the ONNX Runtime's support for RISC-V, particularly when using certain layers such as torch.nn.Linear. One linear layer is enough to see the discrepancies between platforms. We hope this information helps in diagnosing the problem, and we are happy to assist further if needed.

Thank you for your attention to this matter. We look forward to your insights.

Urgency

The issue is urgent as my team depends on this functionality to ship a project this week. We'd be extremely grateful for some attention on this.

Target platform

RISC-V

Build script

We used the following command to build the ONNX Runtime wheel for RISC-V (the build.py file at tool/ci_build/build.py:

python3 build.py --parallel 6 --config MinSizeRel --skip_tests --cmake_extra_defines CMAKE_CXX_FLAGS="-pthread -latomic ${CMAKE_CXX_FLAGS}" CMAKE_C_FLAGS="-pthread -latomic ${CMAKE_C_FLAGS}" CMAKE_EXE_LINKER_FLAGS="-pthread -latomic" --enable_pybind --build_wheel --wheel_name_suffix=riscv --build_dir=/workspace/onnxruntime/build/older --compile_no_warning_as_error --allow_running_as_root

Inside a container running the Docker image at sarmentow/onnxruntime-build-env-with-cmake

Error / output

To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input [0] when using the pre-built ONNX Runtime wheels for x64:

[array([[ -9.021126,  17.9599  , -18.350208, -11.425449]], dtype=float32)]

In contrast, the output from the RISC-V wheel for the same model and input is:

[array([[  5.5013514 , -13.528254  ,  -8.2745905 ,  -0.89257914]], dtype=float32)]

Both outputs are from the same model, using the same input, highlighting the inconsistency.

Visual Studio Version

No response

GCC / Compiler Version

11.4.0

@sarmentow sarmentow added the build build issues; typically submitted using template label Oct 22, 2024
@snnn snnn added contributions welcome lower priority issues for the core ORT teams and removed build build issues; typically submitted using template labels Oct 22, 2024
@sarmentow sarmentow changed the title [Build] Discrepancies in ONNX Runtime Inference Results on RISC-V Discrepancies in ONNX Runtime Inference Results on RISC-V Oct 22, 2024
@sarmentow
Copy link
Author

Okay, just ran the tests in onnxruntime/test/python/onnxruntime_test_python.py and we're failing test_memory_arena_shrinkage, test_run_model2, test_run_model2_contiguous, and test_run_model_symbolic_input.

All of the test_run tests fail with the same error which appears to corroborate what the issue mentions about getting very different results:

AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-08

Mismatched elements: 3 / 3 (100%)
Max absolute difference: 12.
Max relative difference: 4.

Will look into it later.

@pttuan
Copy link

pttuan commented Oct 25, 2024

'm experiencing the same issue when trying to use the exported ONNX model from: https://github.com/IBM/ai-on-z-fraud-detection/blob/main/ccf_220_keras_lstm_static-OS.ipynb

The model runs fine on x86 architecture, but on RISC-V, it produces completely incorrect outputs using the same code and input data.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributions welcome lower priority issues for the core ORT teams
Projects
None yet
Development

No branches or pull requests

3 participants