-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Function request: logerfc, logerfcx special functions #31945
Comments
Seems reasonable especially if all those other libraries have it. |
@zou3519 I think I could implement this taking inspiration from the algorithms used in the above libraries. If I wanted to implement add this myself, do you have any suggestions on what I should do? Maybe there are some PRs I can look at where people implement special functions in PyTorch? |
I did a Python based implementation that calls in to scipy: class ErfcxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input_np = input.detach().numpy()
result_np = scipy.special.erfcx(input_np)
result = input.new(result_np)
ctx.save_for_backward(input, result)
return result
@staticmethod
def backward(ctx, grad_output):
input, result = ctx.saved_tensors
g = -2 / np.sqrt(np.pi) + 2 * input * result
return g * grad_output
erfcx = ErfcxFunction.apply
def logerfc(x):
return torch.where(x > 0, erfcx(x).log() - x**2, x.erfc().log())
def logerfcx(x):
return torch.where(x < 0, x.erfc().log() + x**2, erfcx(x.log())) Any comments are welcome. With some guidance into the pytorch source code I can turn this into a PR. |
I'm not sure what erfcx does under the hood, but PyTorch does not have a numpy/scipy dependency so we cannot take a pull request if it is implemented as shown. Given that I'm happy to review any pull requests on this subject. |
I wrote a dummy implementation for log(erfc(x)) in pytorch, with polynomial approximations for large positive x. There's still a numpy dependency, but just to get some math constant in my project. And that one should be trivial to implement in c++ and cuda, as both implement erfc function already. If it interests someone I can try to implement a proper c++ version. I just don't know if someone who is better at math than me would have some reserve for my approximations.
|
@french-paragon It would be nice to have this! |
cc @kshitij12345, more requests to track in #50345. fyi @cossio, we now have the torch.special namespace in nightlies and @kshitij12345 is adding ops to it for the 1.9 release. We won't cover every torch.special op in the 1.9 release (or maybe ever), but these requests are helpful in prioritizing which ops we do add. |
Thanks @mruberry! Hope this one gets added. |
🚀 Feature
Implement the
erfcx(x)
special function, which computesexp(x^2) * erfc(x)
in a numerically stable way. Also for convenience, addlogerfc(x) = log(erfc(x))
andlogerfcx(x) = log(erfcx(x))
.erfcx
is available in many numerical packages, such as Matlab, Julia, SciPy R, and others.From
erfcx
it is easy to obtainlogerfc
andlogerfcx
, but this involves a conditional which can be slow in pure Python code. So I recommend addinglogerfc
andlogerfcx
as well, which can be implemented as:Motivation
These special functions are very useful whenever we have to work with truncated normal distributions.
Related: #2129, #32293
cc @mruberry @rgommers @heitorschueroff
The text was updated successfully, but these errors were encountered: