You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We have successfully built ONNX Runtime Python wheels targeting the RISC-V architecture, using both the cross-compilation process outlined in the documentation and an emulated RISC-V Docker container running Ubuntu 22.04. Both builds completed without errors.
However, after training a model in PyTorch and exporting it to the ONNX format, we observed that the inference results from the ONNX Runtime Python package vary significantly across platforms. Specifically, the results from the RISC-V wheels we built (both the cross-compiled and the emulated versions) do not match the expected outputs seen from running inference in PyTorch before the ONNX export, nor do they match the outputs produced by the ONNX Runtime x64 wheel on the same model.
This leads us to believe that the issue lies in the ONNX Runtime's support for RISC-V.
Example Outputs
To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input [0] when using the pre-built ONNX Runtime wheels for x64:
Both outputs are from the same model, using the same input, highlighting the inconsistency.
Investigation
Through extensive troubleshooting, we have identified that this discrepancy occurs specifically when using torch.nn.Linear layers. Basic arithmetic operators (e.g., +, -, *, /) do not cause any issues. Furthermore, exporting the model using PyTorch's .pth format and running inference in a RISC-V environment works as expected, further reinforcing that the issue may reside within ONNX Runtime's handling of RISC-V architectures.
We are fairly sure this is a problem in ONNX Runtime since we have tested the model export using Pytorch's .pth format and it has worked fine in the RISC-V environment.
Reproduction
We have included the PyTorch training code, the Dockerfile for the build environment, and the scripts used to compare inference results between the platforms below.
Model Training Code
importtorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset, DataLoaderimportnumpyasnpnp.random.seed(69)
torch.manual_seed(69)
# Define the dataset classclassCustomDataset(Dataset):
def__init__(self, inputs, outputs):
self.inputs=inputsself.outputs=outputsdef__len__(self):
returnlen(self.inputs)
def__getitem__(self, idx):
input_val=self.inputs[idx]
output_val=self.outputs[idx]
returninput_val, output_val# Define the modelclassNet(nn.Module):
def__init__(self):
super(Net, self).__init__()
self.fc1=nn.Linear(1, 20) # input layer (1) -> hidden layer (20)self.fc2=nn.Linear(20, 20) # hidden layer (20) -> hidden layer (20)self.fc3=nn.Linear(20, 4) # hidden layer (20) -> output layer (4)defforward(self, x):
x=torch.relu(self.fc1(x)) # activation function for hidden layerx=torch.relu(self.fc2(x)) # activation function for hidden layerx=self.fc3(x)
returnx# Define the inputs and outputsinputs=np.array([0, 4, 8, 9, 10, 14, 15])
outputs=np.array([1, 1, 2, 2, 1, 2, 2])
# Create the dataset and data loaderdataset=CustomDataset(inputs.reshape(-1, 1), outputs)
data_loader=DataLoader(dataset, batch_size=7, shuffle=False)
# Initialize the model, loss function, and optimizermodel=Net()
criterion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(), lr=0.01)
# Train the modelforepochinrange(1000): # loop over the dataset multiple timesfori, datainenumerate(data_loader, 0):
inputs, labels=datainputs=torch.tensor(inputs, dtype=torch.float32)
labels=torch.tensor(labels, dtype=torch.long)
# zero the parameter gradientsoptimizer.zero_grad()
# forward + backward + optimizeoutputs=model(inputs)
loss=criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statisticsifepoch%100==0:
print('Epoch %d, Loss: %.3f'% (epoch+1, loss.item()))
# Evaluate the modeldefevaluate_model(model, inputs, expected_outputs):
inputs=torch.tensor(inputs.reshape(-1, 1), dtype=torch.float32)
outputs=model(inputs)
_, predicted=torch.max(outputs, dim=1)
print('Expected Outputs: ', expected_outputs)
print('Predicted Outputs: ', predicted.detach().numpy())
# Test the modeltest_inputs=np.array([0, 4, 8, 9, 10, 14, 15])
test_outputs=np.array([1, 1, 2, 2, 1, 1, 2])
evaluate_model(model, test_inputs, test_outputs)
# Export the model to ONNXdummy_input=torch.randn(1, 1)
torch.onnx.export(model, dummy_input, 'winner.onnx', export_params=True, opset_version=17)
We also built CMake from source in order to obey the version requirements for building ONNX Runtime. I have pushed an image to Docker Hub with CMake 3.30 installed to save you the hassle: docker pull sarmentow/onnxruntime-build-env-with-cmake.
We used an alternative pip index to install Numpy RISC-V wheels, which we believe are not causing the issue.
Conclusion
Based on our testing, it seems that the issue is specific to the ONNX Runtime's support for RISC-V, particularly when using certain layers such as torch.nn.Linear. One linear layer is enough to see the discrepancies between platforms. We hope this information helps in diagnosing the problem, and we are happy to assist further if needed.
Thank you for your attention to this matter. We look forward to your insights.
Urgency
The issue is urgent as my team depends on this functionality to ship a project this week. We'd be extremely grateful for some attention on this.
Target platform
RISC-V
Build script
We used the following command to build the ONNX Runtime wheel for RISC-V (the build.py file at tool/ci_build/build.py:
Inside a container running the Docker image at sarmentow/onnxruntime-build-env-with-cmake
Error / output
To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input [0] when using the pre-built ONNX Runtime wheels for x64:
sarmentow
changed the title
[Build] Discrepancies in ONNX Runtime Inference Results on RISC-V
Discrepancies in ONNX Runtime Inference Results on RISC-V
Oct 22, 2024
Okay, just ran the tests in onnxruntime/test/python/onnxruntime_test_python.py and we're failing test_memory_arena_shrinkage, test_run_model2, test_run_model2_contiguous, and test_run_model_symbolic_input.
All of the test_run tests fail with the same error which appears to corroborate what the issue mentions about getting very different results:
AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-08
Mismatched elements: 3 / 3 (100%)
Max absolute difference: 12.
Max relative difference: 4.
Describe the issue
We have successfully built ONNX Runtime Python wheels targeting the RISC-V architecture, using both the cross-compilation process outlined in the documentation and an emulated RISC-V Docker container running Ubuntu 22.04. Both builds completed without errors.
However, after training a model in PyTorch and exporting it to the ONNX format, we observed that the inference results from the ONNX Runtime Python package vary significantly across platforms. Specifically, the results from the RISC-V wheels we built (both the cross-compiled and the emulated versions) do not match the expected outputs seen from running inference in PyTorch before the ONNX export, nor do they match the outputs produced by the ONNX Runtime x64 wheel on the same model.
This leads us to believe that the issue lies in the ONNX Runtime's support for RISC-V.
Example Outputs
To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input
[0]
when using the pre-built ONNX Runtime wheels for x64:In contrast, the output from the RISC-V wheel for the same model and input is:
Both outputs are from the same model, using the same input, highlighting the inconsistency.
Investigation
Through extensive troubleshooting, we have identified that this discrepancy occurs specifically when using
torch.nn.Linear
layers. Basic arithmetic operators (e.g.,+
,-
,*
,/
) do not cause any issues. Furthermore, exporting the model using PyTorch's.pth
format and running inference in a RISC-V environment works as expected, further reinforcing that the issue may reside within ONNX Runtime's handling of RISC-V architectures.We are fairly sure this is a problem in ONNX Runtime since we have tested the model export using Pytorch's
.pth
format and it has worked fine in the RISC-V environment.Reproduction
We have included the PyTorch training code, the Dockerfile for the build environment, and the scripts used to compare inference results between the platforms below.
Model Training Code
Dockerfile for Build Environment
We also built CMake from source in order to obey the version requirements for building ONNX Runtime. I have pushed an image to Docker Hub with CMake 3.30 installed to save you the hassle:
docker pull sarmentow/onnxruntime-build-env-with-cmake
.ONNX Runtime Inference Comparison Code
Build Process
We used the following command to build the ONNX Runtime wheel for RISC-V (the build.py file at
tool/ci_build/build.py
:Testing Environment
We utilized the following Dockerfile for the testing environment:
Dependencies
We installed ONNX Runtime and other necessary packages as listed in the
requirements.txt
file:We used an alternative pip index to install Numpy RISC-V wheels, which we believe are not causing the issue.
Conclusion
Based on our testing, it seems that the issue is specific to the ONNX Runtime's support for RISC-V, particularly when using certain layers such as
torch.nn.Linear
. One linear layer is enough to see the discrepancies between platforms. We hope this information helps in diagnosing the problem, and we are happy to assist further if needed.Thank you for your attention to this matter. We look forward to your insights.
Urgency
The issue is urgent as my team depends on this functionality to ship a project this week. We'd be extremely grateful for some attention on this.
Target platform
RISC-V
Build script
We used the following command to build the ONNX Runtime wheel for RISC-V (the build.py file at
tool/ci_build/build.py
:Inside a container running the Docker image at
sarmentow/onnxruntime-build-env-with-cmake
Error / output
To illustrate the discrepancy, after training a PyTorch model, we get the following outputs for the input
[0]
when using the pre-built ONNX Runtime wheels for x64:In contrast, the output from the RISC-V wheel for the same model and input is:
Both outputs are from the same model, using the same input, highlighting the inconsistency.
Visual Studio Version
No response
GCC / Compiler Version
11.4.0
The text was updated successfully, but these errors were encountered: