Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: A simpler and more flexible IDifferentiable system #4865

Merged
merged 2 commits into from
Aug 21, 2024

Conversation

saipraveenb25
Copy link
Collaborator

Designed to work for both value types as well as ref/ptr/buffer types.

Resolves #4034

Designed to work for both value types as well as ref/ptr/buffer types
// Force arithmetic types to be a differential pair of the same two types.
// Make it simple to define derivatives of arithmetic operations.
//
associatedtype Differential : This;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is meaning what you want it to mean. This is saying that the differnetial type of a cofnormance type must be inheriting from itself. I think you really want a type-equality constraint that

associatedtype Differential : IDifferentiableValueType where Differential == This;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is about time we add a where clause and implement type equality constraints.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this part is confusing.. what we want is the equality constraints.
(We could implement ISelfDifferentiableValueType as a special case for now, in case the where clause implementation is too complex)

interface IDifferentiableValueType : __IDifferentiableBase
{
associatedtype Differential : IDifferentiableValueType & IDefaultInitializable;
[Differentiable] This use(); // auto-synthesized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also need use_bwd and use_fwd?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that using [Differentiable] in the method signature should automatically require the two derivatives to be provided?

[BackwardDerivative(get_fwd)]
T get() { return this.buffer[offset]; }

[deriv_method] DifferentialPair<T> get_fwd()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_fwd?

Copy link
Collaborator Author

@saipraveenb25 saipraveenb25 Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is different from use. the IDifferentiablePtrType explicitly doesn't require a use (because we don't need any type-specific handling of bwd_diff(use). use is invoked internally by the compiler for any extra uses of an value-typed object (if used multiple times)

get() is a normal getter/setter method implemented on a buffer or pointer type (that a user can use) and is not required in the differentiable interface.
It's effectively just derefencing or indexing in this scenario. (Maybe I should have used load() instead?). The core idea is that on these 'pointer-style' types all differentiation happens on the methods used to access the data, so we'll rely on custom derivatives, and simply make sure the pair is passed as-is through all the generated code so the buffer references are available to the custom derivative code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for introducing IDifferentiablePtrType is to make this explicit. It explicitly has no use and no way to initialize it, which means that this type will not get 'transposed'. They will behave similar to how DiffTensorView<T> works, with the main difference being that auto-diff itself will create the pair of buffers instead of us having to side-step the differentiable type system.

return diffPair(this.p.buffer[offset], this.d().buffer[offset]);
}

[deriv_method] void get_bwd(Differential d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_bwd?

// After fwd-mode:
DiffPair<float> f_fwd(DiffPair<float> dpa)
{
DiffPair<float> dpa_extra = dpa.use_fwd();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does DiffPair<float>.use_fwd do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is effectively a noop 'passthrough' in this case. It only exists because when converting to reverse-mode, we will replace use_fwd into use_bwd which is the method that has non-trivial logic. Both use and use_fwd are always intended to be passthroughs.
(We could do away with use entirely and simply require the implementation of the backward part, which is just dadd but not static)

@@ -0,0 +1,283 @@
## Problem
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should start with a title
SP #10: New differentiable type system

Unfortunately, this is a rather hacky workaround with several drawbacks:
1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these.
2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked)
3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does higher order diff work with the new system?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this scenario, we can have a DTensorView<T> : IDifferentiablePtrType which can have a single buffer in primal code.
It becomes DifferentialPair<DTensorView<T>> under the first derivative pass (2 buffers), and then becomes DifferentialPair<DifferentialPair<<DTensorView<T>>> under the second derivative pass (4 buffers).

// ----

// backward pass.
float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are the argument to set?


// trivial auto-synthesis (maybe we even prevent the user from overriding this).
[ForwardDerivativeOf(use)]
[deriv_method] void use_fwd() { return this; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't this be just a static method where the first parameter is DiffPair<T>? It seems that [deriv_method] isn't necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csyonghe: It's syntactic sugar for a static method that operates on pairs. It's mostly to keep code clarity. If a primal function is a member method that operates on this, then the derivative is also a member method that operates on a this-pair.
But we can certainly avoid adding a new decorator, if it doesn't really help with clarity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather us not adding new decorations if it can be achieved with existing mechanisms to avoid increasing maintenance costs down the road.

@csyonghe csyonghe merged commit 359e96c into shader-slang:master Aug 21, 2024
aroidzap added a commit to aroidzap/slang that referenced this pull request Aug 21, 2024
Slang v2024.10

This release brings support for tuple types, variadic generics and depedent generic constraints.

No breaking changes.

Changes:

e97e7e5 Revert "Fetch slang-llvm.so from correct release (shader-slang#4847)" (shader-slang#4893)
359e96c Proposal: A simpler and more flexible `IDifferentiable` system (shader-slang#4865)
f9f6a28 Support dependent generic constraints. (shader-slang#4870)
03e1e17 Fix `tests\autodiff\reverse-while-loop-3.slang` test (shader-slang#4886)
bcb5391 Exclude synthesized code from code auto documentation system (shader-slang#4889)
6b1b243 Track uninitialized values of `Ptr<Specialize<T>>` inside type `T` without hang (shader-slang#4885)
77e6c64 Fixes shader-slang#4879 (shader-slang#4881)
579d59c Fetch slang-llvm.so from correct release (shader-slang#4847)
d286ff5 Implement Path::createDirectoryRecursive (shader-slang#4871)
f77a5ac Remove using SpvStorageClass values casted into AddressSpace values (shader-slang#4861)
453683b Tuple swizzling, concat, comparison and `countof`. (shader-slang#4856)
ecf85df Variadic Generics Part 2: IR lowering and specialization. (shader-slang#4849)
ca5d303 Make sure to resolve overloaded expr for call args. (shader-slang#4864)
25bc5a3 Avoiding the use of the global AST builder in DeclRefType::create (shader-slang#4866)
b411c05 Include inout cast operation as an aliasing instruction (shader-slang#4859)
9bf5dc9 Design proposal: IFunc interface. (shader-slang#4851)
f447b74 Update documentation for #include to indicate it is for legacy code and new code should use modules (shader-slang#4862)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rethink IDifferentiable system for buffer/ptr types
2 participants