Skip to content

frskplis/sumcumprod_jax

Repository files navigation

Extending JAX with custom C++ and CUDA code

This is sumcumprod extension for CUDA and CPU in C++ for JAX that is equivalent to this python code:

def grand_true_function(x,y):
    assert x.ndim == y.ndim == 1
    assert x.shape == y.shape
    i = jnp.arange(x.shape[0])
    mask = i[None, :] < i[:, None]
    cumprod = jnp.where(mask, 1, 1 / (1 + x[None, :] * y[:, None])).cumprod(1)
    return jnp.where(mask, 0, cumprod).sum(1)

but should be much faster in execution due to smaller memory requirements.# finrisklib_cuda

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •