UPDATE: POT now supports GPU usage, and has way more algorithms implemented! Please use their library instead: https://github.com/rflamary/POT/.
Implements sinkhorn optimal transport algorithms in PyTorch. Currrently there are two versions of the Sinkhorn algorithm implemented: the original and the log-stabilized version. This code essentially just reworks a couple of the implementations from the awesome POT library (https://github.com/rflamary/POT/) in PyTorch.
Example usage:
from ot_pytorch import sink
M = pairwise_distance_matrix()
dist = sink(M, reg=5, cuda=False)
Setting cuda=True enables cuda use.
The examples.py file contains two basic examples.
Example 1:
Let Zi ~ Uniform[0,1], and define the data Xi = (0,Zi), Yi = (θ, Zi), for i=1,...,N and some parameter θ which is varied over [-1,1]. The true optimal transport distance is |θ|. The algorithm yields: