Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the derivative function's signature #1091

Open
kchristin22 opened this issue Sep 12, 2024 · 0 comments
Open

Update the derivative function's signature #1091

kchristin22 opened this issue Sep 12, 2024 · 0 comments

Comments

@kchristin22
Copy link
Collaborator

Clad currently supports the differentiation of a function with regards to a specific argument. This results in not being able to guess the derivative function's signature at compile time to assign the DerivedFnType template. Hence, a wrapper function is being produced with the signature: void (*)(Args..., OutputParamType_t<Args, void>...), where OutputParamType_t appends an extra arg of type void * for each arg of the function. This function refers to the Overload function, as mentioned in the source, but the user assumes that the function being returned is the one with a signature: void (*)(Args..., $_d_args), where _d_args are the derivatives of the function with respect to the args the user specified in the differentiation function. For instance:

double foo(double a, double b);

// clad::gradient(foo, "b");

void foo_grad(double a, double b, void *_d_a, void *_d_b); // overload function

void foo_grad(double a, double b, double *_d_b); // derivative function

The overload function includes a call to the derivative function and performs the necessary typecasting:

double foo(double a, double b) {
  a *= b;
  return a;
}

void foo_grad_1(double a, double b, double *_d_b) { // derivative function
    double _d_a = 0.;
    double _t0 = a;
    a *= b;
    _d_a += 1;
    {
        a = _t0;
        double _r_d0 = _d_a;
        _d_a = 0.;
        _d_a += _r_d0 * b;
        *_d_b += a * _r_d0;
    }
}

void foo_grad_1(double a, double b, void *_temp__d_b0, void *_d_1) { // overloaded function
    double *_d_b = (double *)_temp__d_b0;
    foo_grad_1(a, b, _d_b);
}

The overload function is the one actually returned to the user, so this is the one executed every time. The user still provides the args thinking they call the internal derivative function, so Clad appends nullptr to the derivative args not used which are specified in the end of the function's signature.

One issue with this approach arose when handling global CUDA kernels. In this case, the derivative function has to be made a device function and the overload is the actual global kernel executed. However, if the user utilized shared memory inside the original kernel, this can not be cloned into the device function as shared memory can only be declared inside a global function in CUDA.

Updating the templates after they're initialized is pretty complicated especially since a constructor of a templated class is being called inside clad::gradient.

An alternative to the overloaded function could look like this:

// clad::gradient(foo, "b");
void foo_grad_1_b(double a, double b, double *_d_a, double *_d_b) { // derivative function
    double _d_a0 = 0.;
    double _t0 = a;
    a *= b;
    _d_a0 += 1;
    {
        a = _t0;
        double _r_d0 = _d_a0;
        _d_a0 = 0.;
        _d_a0 += _r_d0 * b;
        *_d_b += a * _r_d0;
    }
}

This way:

  • We know the function signature at compile time: void (*)(Args..., $Args_pointers)
  • The users don't have to create all the adjoint variables, they can still use nullptr when they want to
  • The derivative function has a different name according to the arg it's being derived by, so the same function can be derived using every possible arg combination without any conflict

There's also the argument of losing the capability of differentiating based on certain args. Hence, the user must provide every adjoint.

This issue aims to gather every argument and idea for the way forward and arrive to a conclusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant