Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

models/export.py setting model.train() may be changed to false in onnx #3346

Closed
ChaofWang opened this issue May 26, 2021 · 6 comments · Fixed by #3362
Closed

models/export.py setting model.train() may be changed to false in onnx #3346

ChaofWang opened this issue May 26, 2021 · 6 comments · Fixed by #3362
Labels
bug Something isn't working

Comments

@ChaofWang
Copy link
Contributor

🐛 Bug

I think you may want to have no grid construction in Detect layer by setting model.train() when export model to onnx with --train. But in "torch.onnx.export", model.training seems to be reset to mode.training=False by default.

In fact, the param 'training' in the "torch.onnx.export" can be set to "training=torch.onnx.TrainingMode.TRAINING", so that the model can be set to train for conversion instead of using model.train(). But it doesn't seem to be a good way to be recommended. And I found that this part was correct before version 4.0, but after version 5.0, the "export" was deleted in Detect.

To Reproduce (REQUIRED)

Input:

python models/export.py --weights yolov5s.pt --img 640 --batch 1 --train

Output:

TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:

Expected behavior

This part should be skipped.

yolov5/models/yolo.py

Lines 50 to 62 in aad99b6

if not self.training: # inference
if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic:
self.grid[i] = self._make_grid(nx, ny).to(x[i].device)
y = x[i].sigmoid()
if self.inplace:
y[..., 0:2] = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
xy = (y[..., 0:2] * 2. - 0.5 + self.grid[i]) * self.stride[i] # xy
wh = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i].view(1, self.na, 1, 1, 2) # wh
y = torch.cat((xy, wh, y[..., 4:]), -1)
z.append(y.view(bs, -1, self.no))

Environment

If applicable, add screenshots to help explain your problem.

  • OS: Ubuntu18.04
  • GPU: 3080
  • PyTorch 1.8.0+cu111
  • onnx 1.9.0
@ChaofWang ChaofWang added the bug Something isn't working label May 26, 2021
@github-actions
Copy link
Contributor

github-actions bot commented May 26, 2021

👋 Hello @ChaofWang, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@glenn-jocher
Copy link
Member

@ChaofWang hi, thanks for the bug report! You are correct, ONNX export is not operating in train mode for some reason. It seems the ONNX exporter is forcing it back into eval mode.

@glenn-jocher
Copy link
Member

glenn-jocher commented May 26, 2021

@ChaofWang I've implemented a solution per your recommendation by adding two arguments to the onnx export function in export.py L100:

training=torch.onnx.TrainingMode.TRAINING if opt.train else torch.onnx.TrainingMode.EVAL,
do_constant_folding=not opt.train,

Results are here:
Screenshot 2021-05-26 at 12 27 07

If this solution works for you please submit a PR with this update, thank you!

@glenn-jocher
Copy link
Member

TODO: ONNX export in .train() mode fix

@ChaofWang
Copy link
Contributor Author

@glenn-jocher hi, this solution works for me. I have submitted PR for this update

@glenn-jocher glenn-jocher linked a pull request May 27, 2021 that will close this issue
@glenn-jocher glenn-jocher removed the TODO label May 27, 2021
@glenn-jocher
Copy link
Member

@ChaofWang good news 😃! Your original issue may now be fixed ✅ in PR #3362. To receive this update:

  • Gitgit pull from within your yolov5/ directory or git clone https://github.com/ultralytics/yolov5 again
  • PyTorch Hub – Force-reload with model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
  • Notebooks – View updated notebooks Open In Colab Open In Kaggle
  • Dockersudo docker pull ultralytics/yolov5:latest to update your image Docker Pulls

Thank you for spotting this issue and informing us of the problem. Please let us know if this update resolves the issue for you, and feel free to inform us of any other issues you discover or feature requests that come to mind. Happy trainings with YOLOv5 🚀!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants