Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect derivative for if-statements with assignments in forward mode #914

Open
gojakuch opened this issue May 29, 2024 · 5 comments
Open

Comments

@gojakuch
Copy link
Collaborator

In the following example, the derivative of fn should always evaluate to 2*x. Yet, Clad computes its derivative to be 1.

double fn(double x) {
    if ((x = x*x)) {}
    return x;
} // = x^2

int main() {
    auto f_dx = clad::differentiate(fn, 0);

    std::cout << f_dx.execute(5) << '\n';
}

Expected output: 10

Output: 1

@vgvassilev
Copy link
Owner

We should be careful here though. Having conditional statements to depend on the independent variable generally means discontinuity in the mathematical meaning. That example looks artificially constructed but it would translate into:

double fn(double x) {
    x = x*x;
    if (x) {}
    return x;
} // = x^2

In general we cannot really reason about that and should issue a warning... Can we reproduce the issue with a different example that makes more sense in its mathematical meaning?

@gojakuch
Copy link
Collaborator Author

gojakuch commented May 29, 2024

That example looks artificially constructed but it would translate into:

double fn(double x) {
    x = x*x;
    if (x) {}
    return x;
} // = x^2

In general we cannot really reason about that and should issue a warning

Well, but if Clad can handle this potential discontinuity (or rather, non-differentiability) in its expanded form, shouldn't it also be able to produce the same output if someone decides to shorten it and put the assignment into the condition? I mean, Clad supports ifs in the following situation:

double fn(double x) {
    x = x*x;
    if (x > 2) {
        return 2;
    }
    return x;
}

Which is a continuous function with, and it's differentiable in all points other than ±sqrt(2). So I think by the same logic the shorter form

double fn(double x) {
    if ((x = x*x) > 2) {
        return 2;
    }
    return x;
}

should be differentiated accordingly. it's not that different. we can also construct an example that is differentiable everywhere but contains a functional if-statement:

double fn(double x) { // supported
    x = x*x-1;
    if (x > 0) {
        return exp(x)-1;
    }
    return x;
}

And again, it can be abbreviated to:

double fn(double x) {
    if ((x = x*x-1) > 0) {
        return exp(x)-1;
    }
    return x;
}

which is not supported, but should be the same thing. I think this "assignment inside a condition" (for if statements at least) is mostly just a shorter form for a thing already supported by Clad, so I don't see why this shouldn't be supported.

@vgvassilev
Copy link
Owner

double fn(double x) {
    x = x*x;
    if (x > 2) {
        return 2;
    }
    return x;
}

Technically this should not be supported or at least we should diagnose this as suspicious code from AD perspective.

@gojakuch
Copy link
Collaborator Author

gojakuch commented May 29, 2024

But I really don't understand the point of this code not being supported. this issue is basically an equivalent of #894 which is the same issue but for the reverse mode. I think it's quite beneficial to support differentiation of functions that have maybe a couple of non-differentiable points. a user might only evaluate the derivative elsewhere.

in any case, it's hard to distinguish between a case where a condition creates a differentiability problem and a case like

double fn(double x) { // supported
    x = x*x-1;
    if (x > 0) {
        return exp(x)-1;
    }
    return x;
}

that I mentioned above (which is a fully differentiable function).

Issues like #273, #911, #367, #746 and many more I have seen basically address some sort of conditions dependent on differentiable function arguments, so the same applies there. I thought this is one of the places where Clad could shine because it treats most functions as differentiable, and in cases where there is no derivative at a point mathematically, it would simply be weird for the user to evaluate that function's derivative in places where it shouldn't be differentiated, that would imply a mistake in the logic on the application side (this can simply be clarified and stated explicitly in the documentation). which is why I thought Clad just ignores this and as I think should continue to do so.

@vgvassilev
Copy link
Owner

Yes, your reasoning is right. All I am saying is that we will need to make a choice very soon in what we select to conform to. The standard ad theory tells us we should not do what we are doing. I am fine to not obey but we need to be very careful on what we allow to.

My suggestion was to issue a warning, not an error.

@PetroZarytskyi has generally a take on the differentiable/non-differentiable support. Petro what do you think in general we should be doing here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants