Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function request: logerfc, logerfcx special functions #31945

Open
cossio opened this issue Jan 8, 2020 · 8 comments
Open

Function request: logerfc, logerfcx special functions #31945

cossio opened this issue Jan 8, 2020 · 8 comments
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: numerical-stability Problems related to numerical stability of operations module: numpy Related to numpy support, and also numpy compatibility of our operators module: special Functions with no exact solutions, analogous to those in scipy.special triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@cossio
Copy link

cossio commented Jan 8, 2020

🚀 Feature

Implement the erfcx(x) special function, which computes exp(x^2) * erfc(x) in a numerically stable way. Also for convenience, add logerfc(x) = log(erfc(x)) and logerfcx(x) = log(erfcx(x)).

erfcx is available in many numerical packages, such as Matlab, Julia, SciPy R, and others.

From erfcx it is easy to obtain logerfc and logerfcx, but this involves a conditional which can be slow in pure Python code. So I recommend adding logerfc and logerfcx as well, which can be implemented as:

def logerfc(x): 
    if x > 0.0:
        return log(erfcx(x)) - x**2
    else:
        return log(erfc(x))

def logerfcx(x):
    if x < 0.0:
        return log(erfc(x)) + x^2
    else:
        return log(erfcx(x))

Motivation

These special functions are very useful whenever we have to work with truncated normal distributions.

Related: #2129, #32293

cc @mruberry @rgommers @heitorschueroff

@zou3519 zou3519 added feature A request for a proper, new feature. module: operators module: numerical-stability Problems related to numerical stability of operations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 9, 2020
@zou3519
Copy link
Contributor

zou3519 commented Jan 9, 2020

Seems reasonable especially if all those other libraries have it.

@cossio
Copy link
Author

cossio commented Jan 20, 2020

@zou3519 I think I could implement this taking inspiration from the algorithms used in the above libraries.

If I wanted to implement add this myself, do you have any suggestions on what I should do? Maybe there are some PRs I can look at where people implement special functions in PyTorch?
Would it have to be in C++ or it could be pure Python code?

@cossio cossio changed the title Implement erfcx special function Implement erfcx, logerfc, logerfcx special functions Jan 23, 2020
@cossio
Copy link
Author

cossio commented Jan 23, 2020

I did a Python based implementation that calls in to scipy:

class ErfcxFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        input_np = input.detach().numpy()
        result_np = scipy.special.erfcx(input_np)
        result = input.new(result_np)
        ctx.save_for_backward(input, result)
        return result
    
    @staticmethod
    def backward(ctx, grad_output):
        input, result = ctx.saved_tensors
        g = -2 / np.sqrt(np.pi) + 2 * input * result
        return g * grad_output

erfcx = ErfcxFunction.apply

def logerfc(x):
    return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())

def logerfcx(x):
    return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log()))

Any comments are welcome. With some guidance into the pytorch source code I can turn this into a PR.

@zou3519
Copy link
Contributor

zou3519 commented Feb 3, 2020

@cossio,

I'm not sure what erfcx does under the hood, but PyTorch does not have a numpy/scipy dependency so we cannot take a pull request if it is implemented as shown.

Given that erfcx is a pointwise (it operates on elements in an element-wise fashion) function, we should put it in here: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/UnaryOps.cpp , in C++. It's possible to implement it in python but we prefer C++ implementations unless the implementation is trivial.

I'm happy to review any pull requests on this subject.

@french-paragon
Copy link

french-paragon commented Mar 27, 2020

I wrote a dummy implementation for log(erfc(x)) in pytorch, with polynomial approximations for large positive x. There's still a numpy dependency, but just to get some math constant in my project. And that one should be trivial to implement in c++ and cuda, as both implement erfc function already. If it interests someone I can try to implement a proper c++ version. I just don't know if someone who is better at math than me would have some reserve for my approximations.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Mar 27 14:01:28 2020

@author: laurent
"""

import torch
from torch.autograd import Function

import numpy as np
from scipy.special import erfc

class logErfc(Function) :
	"""
	An implementation of log(erfc(x)) with polynomial approximations for x leading to infs or nans. This should work for a large range of float32 values.
	"""
	
	@staticmethod
	def forward(ctx, x) :
		
		ret = torch.log(torch.erfc(x))
		mask = torch.isinf(ret) | torch.isnan(ret)
		ret[mask] = torch.tensor(np.log(erfc(10.)), dtype = x.dtype, device = x.device) - x[mask]**2 + 100
		
		ctx.save_for_backward(x)
		
		return ret 

	@staticmethod
	def backward(ctx, grad_output) :
		
		x, = ctx.saved_tensors
		
		delta = -2./np.sqrt(np.pi) * 1/torch.erfc(x) * torch.exp(-x**2)
		mask = torch.isinf(delta) | torch.isnan(delta)
		delta[mask] = -2*x[mask]
		
		return grad_output*delta

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. module: numpy Related to numpy support, and also numpy compatibility of our operators and removed feature A request for a proper, new feature. module: operators (deprecated) labels Oct 10, 2020
@cossio
Copy link
Author

cossio commented Mar 17, 2021

@french-paragon It would be nice to have this!

@mruberry mruberry added the module: special Functions with no exact solutions, analogous to those in scipy.special label Mar 17, 2021
@mruberry
Copy link
Collaborator

cc @kshitij12345, more requests to track in #50345.

fyi @cossio, we now have the torch.special namespace in nightlies and @kshitij12345 is adding ops to it for the 1.9 release. We won't cover every torch.special op in the 1.9 release (or maybe ever), but these requests are helpful in prioritizing which ops we do add.

@cossio
Copy link
Author

cossio commented Mar 17, 2021

Thanks @mruberry! Hope this one gets added.

facebook-github-bot pushed a commit that referenced this issue Jun 22, 2021
Summary:
Implement erfcx() #31945

Reference: #50345

Pull Request resolved: #58194

Reviewed By: ngimel

Differential Revision: D29285979

Pulled By: mruberry

fbshipit-source-id: 5bcfe77fddfabbeb8c8068658ba6d9fec6430399
@mruberry mruberry changed the title Implement erfcx, logerfc, logerfcx special functions Function request: logerfc, logerfcx special functions Jun 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: numerical-stability Problems related to numerical stability of operations module: numpy Related to numpy support, and also numpy compatibility of our operators module: special Functions with no exact solutions, analogous to those in scipy.special triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants