Skip to content

Commit

Permalink
Support multiple indices in clad::gradient calls.
Browse files Browse the repository at this point in the history
Currently, the user can provide a string as the second argument of ``clad::gradient`` to specify independent parameters as a list of comma-separated names. This commit allows users to specify indices alongside with names. e.g.
```
clad::gradient(fn, "0");
clad::gradient(fn, "1, z");
...
```
Previously, it was possible to provide a single index as an integer literal. e.g.
```
clad::gradient(fn, 0);
```
Fixes #46.
  • Loading branch information
PetroZarytskyi committed Aug 5, 2024
1 parent 7cec7c8 commit 432e478
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ Reverse-mode AD allows computing the gradient of `f` using *at most* a constant
1. `f` is a pointer to a function or a method to be differentiated
2. `ARGS` is either:
* not provided, then `f` is differentiated w.r.t. its every argument
* a string literal with comma-separated names of independent variables (e.g. `"x"` or `"y"` or `"x, y"` or `"y, x"`)
* a string literal with comma-separated names/indices of independent variables (e.g. `"x"`, `"y"`, `"x, y"`, `"y, x"`, "0, 1", "0, y", etc.)
* a SINGLE number representing the index of the independent variable
Since a vector of derivatives must be returned from a function generated by the reverse mode, its signature is slightly different. The generated function has `void` return type and same input arguments. The function has additional `n` arguments (where `n` refers to the number of arguments whose gradient was requested) of type `T*`, where `T` is the type of the corresponding original variable. Each of these variables stores the derivative of the elements as they appear in the orignal function signature. *The caller is responsible for allocating and zeroing-out the gradient storage*. Example:
```cpp
Expand Down
15 changes: 15 additions & 0 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,21 @@ namespace clad {
DiffInputVarInfo dVarInfo;

dVarInfo.source = diffSpec.str();
// Check if diffSpec represents an index of an independent variable.
if ('0' <= diffSpec[0] && diffSpec[0] <= '9') {
unsigned idx = std::stoi(dVarInfo.source);
// Fail if the specified index is invalid.
if ((idx < 0) || (idx >= FD->getNumParams())) {
utils::EmitDiag(
semaRef, DiagnosticsEngine::Error, diffArgs->getEndLoc(),
"Invalid argument index '%0' of '%1' argument(s)",
{std::to_string(idx), std::to_string(FD->getNumParams())});
return;
}
dVarInfo.param = FD->getParamDecl(idx);
DVI.push_back(dVarInfo);
continue;
}
llvm::StringRef pName = computeParamName(diffSpec);
auto it = std::find_if(std::begin(candidates), std::end(candidates),
[&pName](
Expand Down
10 changes: 10 additions & 0 deletions test/Gradient/DiffInterface.C
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,22 @@ int main () {

auto f1_grad_y = clad::gradient(f_1, "y");
TEST(f1_grad_y, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_0 = clad::gradient(f_1, "1");
TEST(f1_grad_0, &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_z = clad::gradient(f_1, "z");
TEST(f1_grad_z, &result[2]); // CHECK-EXEC: {0.00, 0.00, 2.00}

auto f1_grad_xy = clad::gradient(f_1, "x, y");
TEST(f1_grad_xy, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_0y = clad::gradient(f_1, "0, y");
TEST(f1_grad_0y, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_10 = clad::gradient(f_1, "1, 0");
TEST(f1_grad_10, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

auto f1_grad_yx = clad::gradient(f_1, "y, x");
TEST(f1_grad_yx, &result[0], &result[1]); // CHECK-EXEC: {0.00, 1.00, 0.00}

Expand Down

0 comments on commit 432e478

Please sign in to comment.