Skip to content

implements optimal transport algorithms in pytorch

Notifications You must be signed in to change notification settings

rythei/PyTorchOT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

17 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorchOT

UPDATE: POT now supports GPU usage, and has way more algorithms implemented! Please use their library instead: https://github.com/rflamary/POT/.

Implements sinkhorn optimal transport algorithms in PyTorch. Currrently there are two versions of the Sinkhorn algorithm implemented: the original and the log-stabilized version. This code essentially just reworks a couple of the implementations from the awesome POT library (https://github.com/rflamary/POT/) in PyTorch.

Example usage:

from ot_pytorch import sink

M = pairwise_distance_matrix()
dist = sink(M, reg=5, cuda=False)

Setting cuda=True enables cuda use.

The examples.py file contains two basic examples.

Example 1:

Let Zi ~ Uniform[0,1], and define the data Xi = (0,Zi), Yi = (θ, Zi), for i=1,...,N and some parameter θ which is varied over [-1,1]. The true optimal transport distance is |θ|. The algorithm yields:

alt text

About

implements optimal transport algorithms in pytorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages