-
Notifications
You must be signed in to change notification settings - Fork 180
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal: A simpler and more flexible IDifferentiable
system
#4865
Proposal: A simpler and more flexible IDifferentiable
system
#4865
Conversation
Designed to work for both value types as well as ref/ptr/buffer types
// Force arithmetic types to be a differential pair of the same two types. | ||
// Make it simple to define derivatives of arithmetic operations. | ||
// | ||
associatedtype Differential : This; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is meaning what you want it to mean. This is saying that the differnetial type of a cofnormance type must be inheriting from itself. I think you really want a type-equality constraint that
associatedtype Differential : IDifferentiableValueType where Differential == This;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is about time we add a where
clause and implement type equality constraints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this part is confusing.. what we want is the equality constraints.
(We could implement ISelfDifferentiableValueType
as a special case for now, in case the where
clause implementation is too complex)
interface IDifferentiableValueType : __IDifferentiableBase | ||
{ | ||
associatedtype Differential : IDifferentiableValueType & IDefaultInitializable; | ||
[Differentiable] This use(); // auto-synthesized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also need use_bwd
and use_fwd
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that using [Differentiable]
in the method signature should automatically require the two derivatives to be provided?
[BackwardDerivative(get_fwd)] | ||
T get() { return this.buffer[offset]; } | ||
|
||
[deriv_method] DifferentialPair<T> get_fwd() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_fwd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is different from use
. the IDifferentiablePtrType
explicitly doesn't require a use
(because we don't need any type-specific handling of bwd_diff(use)
. use
is invoked internally by the compiler for any extra uses of an value-typed object (if used multiple times)
get()
is a normal getter/setter method implemented on a buffer or pointer type (that a user can use) and is not required in the differentiable interface.
It's effectively just derefencing or indexing in this scenario. (Maybe I should have used load() instead?). The core idea is that on these 'pointer-style' types all differentiation happens on the methods used to access the data, so we'll rely on custom derivatives, and simply make sure the pair is passed as-is through all the generated code so the buffer references are available to the custom derivative code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main reason for introducing IDifferentiablePtrType
is to make this explicit. It explicitly has no use
and no way to initialize it, which means that this type will not get 'transposed'. They will behave similar to how DiffTensorView<T>
works, with the main difference being that auto-diff itself will create the pair of buffers instead of us having to side-step the differentiable type system.
return diffPair(this.p.buffer[offset], this.d().buffer[offset]); | ||
} | ||
|
||
[deriv_method] void get_bwd(Differential d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use_bwd?
// After fwd-mode: | ||
DiffPair<float> f_fwd(DiffPair<float> dpa) | ||
{ | ||
DiffPair<float> dpa_extra = dpa.use_fwd(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does DiffPair<float>.use_fwd
do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is effectively a noop 'passthrough' in this case. It only exists because when converting to reverse-mode, we will replace use_fwd
into use_bwd
which is the method that has non-trivial logic. Both use
and use_fwd
are always intended to be passthroughs.
(We could do away with use
entirely and simply require the implementation of the backward part, which is just dadd
but not static)
@@ -0,0 +1,283 @@ | |||
## Problem |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should start with a title
SP #10: New differentiable type system
Unfortunately, this is a rather hacky workaround with several drawbacks: | ||
1. `DiffTensorView<T>` does not conform to `IDifferentiable`, but is used for derivatives. This makes our type system less useful as checks for `is_subtype` from applications using reflection need workarounds to account for corner cases like these. | ||
2. `DiffTensorView<T>` always has two buffer pointers even when used in non-differentiable methods. This is extra data in the struct, and potentially extra tensor allocations (we explicitly handle this case in `slangtorch` by leaving the diff part uninitialized if a primal method is invoked) | ||
3. Higher-order derivatives don't work well with this workaround. Differentiating a method twice needs a set of 4 pointers, but we need to account for this ahead of time by using new types like `DiffDiffTensorView` that worsens the problem of carrying around extra data where its not required. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does higher order diff work with the new system?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this scenario, we can have a DTensorView<T> : IDifferentiablePtrType
which can have a single buffer in primal code.
It becomes DifferentialPair<DTensorView<T>>
under the first derivative pass (2 buffers), and then becomes DifferentialPair<DifferentialPair<<DTensorView<T>>>
under the second derivative pass (4 buffers).
// ---- | ||
|
||
// backward pass. | ||
float.Differential d_val = DPtr<float>::set_bwd(output); // set_bwd works on the entire pair. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where are the argument to set
?
|
||
// trivial auto-synthesis (maybe we even prevent the user from overriding this). | ||
[ForwardDerivativeOf(use)] | ||
[deriv_method] void use_fwd() { return this; } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why can't this be just a static method where the first parameter is DiffPair<T>
? It seems that [deriv_method]
isn't necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@csyonghe: It's syntactic sugar for a static method that operates on pairs. It's mostly to keep code clarity. If a primal function is a member method that operates on this
, then the derivative is also a member method that operates on a this-pair.
But we can certainly avoid adding a new decorator, if it doesn't really help with clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would rather us not adding new decorations if it can be achieved with existing mechanisms to avoid increasing maintenance costs down the road.
Slang v2024.10 This release brings support for tuple types, variadic generics and depedent generic constraints. No breaking changes. Changes: e97e7e5 Revert "Fetch slang-llvm.so from correct release (shader-slang#4847)" (shader-slang#4893) 359e96c Proposal: A simpler and more flexible `IDifferentiable` system (shader-slang#4865) f9f6a28 Support dependent generic constraints. (shader-slang#4870) 03e1e17 Fix `tests\autodiff\reverse-while-loop-3.slang` test (shader-slang#4886) bcb5391 Exclude synthesized code from code auto documentation system (shader-slang#4889) 6b1b243 Track uninitialized values of `Ptr<Specialize<T>>` inside type `T` without hang (shader-slang#4885) 77e6c64 Fixes shader-slang#4879 (shader-slang#4881) 579d59c Fetch slang-llvm.so from correct release (shader-slang#4847) d286ff5 Implement Path::createDirectoryRecursive (shader-slang#4871) f77a5ac Remove using SpvStorageClass values casted into AddressSpace values (shader-slang#4861) 453683b Tuple swizzling, concat, comparison and `countof`. (shader-slang#4856) ecf85df Variadic Generics Part 2: IR lowering and specialization. (shader-slang#4849) ca5d303 Make sure to resolve overloaded expr for call args. (shader-slang#4864) 25bc5a3 Avoiding the use of the global AST builder in DeclRefType::create (shader-slang#4866) b411c05 Include inout cast operation as an aliasing instruction (shader-slang#4859) 9bf5dc9 Design proposal: IFunc interface. (shader-slang#4851) f447b74 Update documentation for #include to indicate it is for legacy code and new code should use modules (shader-slang#4862)
Designed to work for both value types as well as ref/ptr/buffer types.
Resolves #4034