Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BroadcastThunk #615

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRulesCore"
uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
version = "1.15.7"
version = "1.16.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
4 changes: 2 additions & 2 deletions src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ export frule, rrule # core function
export RuleConfig, HasReverseMode, NoReverseMode, HasForwardsMode, NoForwardsMode
export frule_via_ad, rrule_via_ad
# definition helper macros
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented, @bc_thunk
export ProjectTo, canonicalize, unthunk # tangent operations
export add!!, is_inplaceable_destination # gradient accumulation operations
export ignore_derivatives, @ignore_derivatives
# tangents
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk, BroadcastThunk

include("debug_mode.jl")

Expand Down
12 changes: 12 additions & 0 deletions src/accumulation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ function add!!(x, t::InplaceableThunk)
end
end

function add!!(x, t::BroadcastThunk)
return if is_inplaceable_destination(x)
if !debug_mode()
x .+= t.bc
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
else
debug_add!(x, t)
end
else
x .+ t.bc
end
end

add!!(x::AbstractArray, y::Thunk) = add!!(x, unthunk(y))

function add!!(x::AbstractArray{<:Any,N}, y::AbstractArray{<:Any,N}) where {N}
Expand Down
40 changes: 40 additions & 0 deletions src/projection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,41 @@ function _projection_mismatch(axes_x::Tuple, size_dx::Tuple)
)
end

# Broadcasted shows up in forward pass of lazy broadcasting.
# Choose not to reshape until we get back to an array etc, just handle eltype:
function ProjectTo(x::Broadcasted)
T = Base.@default_eltype x
T <: Number || return identity
element = _eltype_projectto(T)
element isa ProjectTo{NoTangent} && return element
return ProjectTo{Broadcasted}(; element)
end

function (project::ProjectTo{Broadcasted})(dx::AbstractArray{S,M}) where {S,M}
T = project_type(project.element)
return S <: T ? dx : @bc_thunk project.element(dx)
end
function (project::ProjectTo{Broadcasted})(dx::BroadcastThunk{S}) where {S}
T = project_type(project.element)
return S <: T ? dx : @bc_thunk project.element(dx)
end

# We can allow BroadcastThunk as a gradient of an Array, but collect it to reshape etc.
function (project::ProjectTo{AbstractArray})(dx::BroadcastThunk{S}) where {S}
if axes(dx.bc) !== project.axes
return project(unthunk(dx))
end
dz = if hasproperty(project, :element)
T = project_type(project.element)
S <: T ? dx : @bc_thunk project.element(dx)
else
@bc_thunk (|>)(dx, project.elements)
end
return dz
end
# Also collect BroadcastThunk for any structured array:
(project::ProjectTo{<:AbstractArray})(dx::BroadcastThunk) = project(unthunk(dx))

#####
##### `Base`, part II: return of the Tangent
#####
Expand Down Expand Up @@ -385,6 +420,11 @@ function (project::ProjectTo{<:Tangent{<:Tuple}})(dx::AbstractArray)
end
end

# Since Tuples participate in broadcasting, we may get a BroadcastThunk.
# function (project::ProjectTo{<:Tangent{<:Union{Tuple,NamedTuple}}})(dx::BroadcastThunk)
# return project(unthunk(dx))
# end


#####
##### `LinearAlgebra`
Expand Down
120 changes: 119 additions & 1 deletion src/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ struct Thunk{F} <: AbstractThunk
f::F
end

@inline unthunk(x::Thunk) = x.f()
@inline unthunk(x::Thunk) = unthunk(x.f())

function Base.show(io::IO, x::Thunk)
print(io, "Thunk(")
Expand Down Expand Up @@ -249,3 +249,121 @@ function Base.show(io::IO, x::InplaceableThunk)
show(io, x.val)
print(io, ")")
end


"""
BroadcastThunk{T}(bc::Broadcasted)

A new kind of thunk, which wraps Base's lazy broadcasting. `T` is the eltype.
mcabbott marked this conversation as resolved.
Show resolved Hide resolved

Calling `unthunk` will materialise the array. But inserting it directly
into a broadcast will fuse the old and the new, which is the whole point.

That means that rules accepting thunks should not always `unthunk` them
before use. Instead, they may rely on `broadcastable` to do so, as long as
they only do this in one place.

Does that mean we should use `BroadcastThunk` even for expensive operations,
and rely on downstreadm rules to use it just once? Or would a better policy be
to only use this for cheap operations, and downstreadm rules may then dispatch
on `::BroadcastThunk` & fuse it into two different places?
"""
struct BroadcastThunk{T, B<:Broadcast.Broadcasted} <: AbstractThunk
bc::B
end
function BroadcastThunk(bc::Broadcast.Broadcasted)
T = Base.@default_eltype(bc)
return if T <: Number # applicable(zero, T) # will SVector work?
BroadcastThunk{T, typeof(bc)}(bc)
else
# We need init=zero(T) for unbroadcast to work.
# For things like arrays of arrays, we just don't thunk?
# copy(bc)
# Or perhaps we make a boring thunk?
InplaceableThunk(dx -> dx .+= bc, @thunk copy(bc))
end
end
function BroadcastThunk(x::AbstractArray{T}) where {T}
return if T <: Number # applicable(zero, T)
bc = Broadcast.instantiate(Broadcast.broadcasted(identity, x))
BroadcastThunk{T,typeof(bc)}(bc)
else
x
end
end
Base.eltype(x::BroadcastThunk{T}) where {T} = T

@inline unthunk(x::BroadcastThunk) = copy(x.bc)

# This is the whole point:
Base.Broadcast.broadcastable(x::BroadcastThunk) = x.bc

function Base.show(io::IO, x::BroadcastThunk)
print(io, "BroadcastThunk{")
show(io, eltype(x))
print(io, "}(")
str = sprint(show, x.bc, context = io)
if length(str) < 80
printstyled(io, str, color=:light_black)
else
printstyled(io, str[1:70], "...", color=:light_black)
end
print(io, ")")
end

"""
@bc_thunk f(a, g(b, c))

This works like `@.` to produce something like `@thunk f.(a, g.(b, c))`.
Except that instead of a `Thunk`, it's a `BroadcastThunk`.
"""
macro bc_thunk(ex)
bc = esc(Broadcast.__dot__(ex))
:($_lazy_bc.($bc))
end

function _lazy_bc end
Broadcast.broadcasted(::typeof(_lazy_bc), x) = _Lazy_BC(x)
struct _Lazy_BC{T}; bc::T; end
Broadcast.materialize(x::_Lazy_BC) = BroadcastThunk(Broadcast.instantiate(x.bc))

macro bc_thunk(s::Symbol)
error("cannot apply @bc_thunk to one symbol, there is nothing to broadcast!")
end

# These prove useful for writing rules:

Base.:(-)(x::BroadcastThunk) = @bc_thunk -(x.bc)

for fun in [:conj, :real, :imag, :complex]
@eval Base.$fun(x::BroadcastThunk) = BroadcastThunk(Broadcast.instantiate(Broadcast.broadcasted($fun, x.bc)))
end

Base.sum(x::BroadcastThunk) = sum(x.bc)
Base.sum(x::BroadcastThunk; dims=:) = sum(x.bc; dims, init=zero(eltype(x)))
Base.sum(f, x::BroadcastThunk) = sum(f, x.bc)

LinearAlgebra.dot(x::Base.AbstractArrayOrBroadcasted, y::BroadcastThunk{<:Number}) = sum(@bc_thunk conj(x) * y.bc)

"""
unthunk_or_bc(dx)

This removes most thunks, but turns a `BroadcastThunk` into a `Broadcasted`.
For use in rrules which can handle the latter.
"""
unthunk_or_bc(x) = unthunk(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe unthunk should do this on BroadcastedThunks?
Rather than a seperate function?
idk maybe not since Broadcasted isn't a linear algrebra type.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If your rule used to call unthunk and get an array, then it may fail if a new kind of thunk doesn't respect that.

Agree it adds complexity to have two unwrap-this functions.

unthunk_or_bc(x::BroadcastThunk) = x.bc
unthunk_or_bc(x::Thunk) = unthunk_or_bc(x.f()) # bct = @bc_thunk 1+[2,3]; unthunk_or_bc(@thunk -bct) isa Broadcasted

# The accumulation of thunks is a mess, no AD actually calls add!!, and + always unthunks.
# But... these should be safe, and make more BroadcastThunks:
Base.:(+)(x::BroadcastThunk, y::BroadcastThunk) = @bc_thunk x.bc + y.bc

Base.:(+)(x::BroadcastThunk, y::AbstractArray) = @bc_thunk x.bc + y
Base.:(+)(x::AbstractArray, y::BroadcastThunk) = @bc_thunk x + y.bc

Base.:(+)(x::BroadcastThunk, y::Thunk) = x + unthunk(y)
Base.:(+)(x::Thunk, y::BroadcastThunk) = unthunk(x) + y

Base.:(+)(x::BroadcastThunk, y::InplaceableThunk) = BroadcastThunk(add!!(unthunk(x), y))
Base.:(+)(x::InplaceableThunk, y::BroadcastThunk) = BroadcastThunk(add!!(unthunk(y), x))
73 changes: 72 additions & 1 deletion test/tangent_types/thunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@

@testset "unthunk" begin
@test unthunk(@thunk(3)) == 3
@test unthunk(@thunk(@thunk(3))) isa Thunk
# @test unthunk(@thunk(@thunk(3))) isa Thunk
@test unthunk(@thunk(@thunk(3))) == 3
# Changed to ensure that un-thunking `@thunk unbroadcast(x, dx::BroadcastThunk)` predictably gives a non-thunk.
# That could be done more narrowly, to allow some double-thunks.
end

@testset "erroring thunks should include the source in the backtrack" begin
Expand Down Expand Up @@ -217,3 +220,71 @@
@test length(findall("...", str)) == 2 # now both halves shortened
end
end

@testset "BroadcastThunk" begin
@testset "basics" begin
bth = @bc_thunk (1, 2) + 3
@test bth isa BroadcastThunk{Int}
@test unthunk(bth) === (4, 5)
@test eltype(bth) === Int

@test BroadcastThunk([1,2]) isa BroadcastThunk{Int}
@test BroadcastThunk([[1,2], [3,4]]) isa Vector{Vector{Int}}

nobc = @bc_thunk [[1,2], [3,4]] * [5,6]
@test !(nobc isa BroadcastThunk)
@test nobc isa InplaceableThunk
@test unthunk(nobc) == [[5, 10], [18, 24]]

@test unthunk(@thunk 1 .+ bth) === (5, 6)
@test ChainRulesCore.unthunk_or_bc(@thunk bth) isa Broadcast.Broadcasted
@test ChainRulesCore.unthunk_or_bc(@thunk 1 .+ bth) === (5, 6)

@test bth .+ 1 === (5, 6) # in fact fused, but this isn't tested
@test Broadcast.broadcastable(bth) isa Broadcast.Broadcasted
@test sum(bth) === 9 # in fact lazy
end

@testset "preservation" begin
bth = @bc_thunk [1, 2] + 3im
@test -bth isa BroadcastThunk
@test unthunk(-bth) == [-1-3im, -2-3im]

for f in [real, imag, conj]
@test f(bth) isa BroadcastThunk
@test unthunk(f(bth)) == f.([1+3im, 2+3im])
end
end

@testset "accumulation" begin
bth = @bc_thunk [1, 2] + 3
bth2 = @bc_thunk [4, 5] - 6
arr = [7, 8]

@test bth + bth2 isa BroadcastThunk
@test bth + arr isa BroadcastThunk
@test arr + bth2 isa BroadcastThunk
end

@testset "ProjectTo" begin
bth = @bc_thunk [1, 2, 3] + 4

# When sizes match, we can be lazy:
@test ProjectTo([1,2,3])(@bc_thunk 4 + [5,6,7]) isa BroadcastThunk
@test ProjectTo(bth.bc)(@bc_thunk 4 + [5,6,7]) isa BroadcastThunk
@test ProjectTo(Float32[1,2,3])(@bc_thunk 4im + [5,6,7]) isa BroadcastThunk{Float32}
@test ProjectTo(bth.bc)(@bc_thunk 4im + [5,6,7]) isa BroadcastThunk{Float64}

# But when we must resize, or make a special matrix, give up & materialise:
@test ProjectTo([1; 2;;])(@bc_thunk 3 + [4, 5]) isa Matrix{Float64}
@test ProjectTo([1, 2]')(@bc_thunk 3 + [4 5]) isa Adjoint{Float64}

# There is also ProjectTo{Broadcasted} for fused forward pass.
# It makes a BroadcastThunk when it has to change eltype, but does not fix shape.
@test ProjectTo(bth.bc)(hcat([1.0, 2.0, 3.0])) isa Matrix{Float64}
@test ProjectTo(bth.bc)(hcat([1, 2, 3])) isa BroadcastThunk{Float64}
@test unthunk(ProjectTo(bth.bc)([4, 5im, 6])) == [4, 0, 6]

@test ProjectTo(@bc_thunk([1; 2;;] + 0).bc)(@bc_thunk 3 + [4, 5]) isa BroadcastThunk{Float64}
end
end