Skip to content

PyTorch 1.0 implementation of the approximate Earth Mover's Distance

Notifications You must be signed in to change notification settings

meder411/PyTorch-EMDLoss

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch EMDLoss

PyTorch 1.0 implementation of the approximate Earth Mover's Distance

This is a PyTorch wrapper of CUDA code for computing an approximation to the Earth Mover's Distance loss.

Original source code can be found here. This repository updates the code to be compatible with PyTorch 1.0 and extends the implementation to handle arbitrary dimensions of data.

Installation should be as simple as running python setup.py install.

Limitations and Known Bugs:

  • Double tensors must have <=11 dimensions while float tensors must have <=23 dimensions. This is due to the use of CUDA shared memory in the computation. This shared memory is limited by the hardware to 48kB.
  • When handling larger point sets (M, N > ~2000), the CUDA kernel will fail. I think this is due to an overflow error in computing the approximate matching kernel. Any suggestions to fix this would be greatly appreciated. I have pinpointed the source of the bug here.

About

PyTorch 1.0 implementation of the approximate Earth Mover's Distance

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published