Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort out frule API #129

Merged
merged 14 commits into from
Feb 21, 2020
Merged

Sort out frule API #129

merged 14 commits into from
Feb 21, 2020

Conversation

willtebbutt
Copy link
Member

Implements option 2 from #128 . Adds API regression tests to ensure that the considerations made there don't slip through the cracks in any future API revisions.

Currently got a weird type instability thing @oxinabox any thoughts?

@oxinabox
Copy link
Member

Tell me more of your type-stability thing?

@willtebbutt
Copy link
Member Author

Tell me more of your type-stability thing?

see e.g. here

The inference test here is failing to properly infer, I think. Unclear to me why it should be doing so.

For example:

using ChainRulesCore, StaticArrays
sx = @SVector [1, 2]
sy = @SVector [3, 4]

very_nice(x, y) = x + y
@scalar_rule(very_nice(x, y), (One(), One()))

@code_warntype frule((Zero(), sx, sy), very_nice, 1, 2)

Variables
  #self#::Core.Compiler.Const(ChainRulesCore.frule, false)
  @_2::Tuple{Zero,SArray{Tuple{2},Int64,1,2},SArray{Tuple{2},Int64,1,2}}
  #unused#::Core.Compiler.Const(very_nice, false)
  x::Int64
  y::Int64
  @_6::Int64
  Ω::Int64

Body::Tuple{Int64,Any}
1%1  = Base.indexed_iterate(@_2, 1)::Core.Compiler.Const((Zero(), 2), false)
│   %2  = Core.getfield(%1, 1)::Core.Compiler.Const(Zero(), false)
│         (ChainRulesCore._ = %2)
│         (@_6 = Core.getfield(%1, 2))
│   %5  = Base.indexed_iterate(@_2, 2, @_6::Core.Compiler.Const(2, false))::Core.Compiler.PartialStruct(Tuple{SArray{Tuple{2},Int64,1,2},Int64}, Any[SArray{Tuple{2},Int64,1,2}, Core.Compiler.Const(3, false)])
│   %6  = Core.getfield(%5, 1)::SArray{Tuple{2},Int64,1,2}
│         (ChainRulesCore.Δ1 = %6)
│         (@_6 = Core.getfield(%5, 2))
│   %9  = Base.indexed_iterate(@_2, 3, @_6::Core.Compiler.Const(3, false))::Core.Compiler.PartialStruct(Tuple{SArray{Tuple{2},Int64,1,2},Int64}, Any[SArray{Tuple{2},Int64,1,2}, Core.Compiler.Const(4, false)])
│   %10 = Core.getfield(%9, 1)::SArray{Tuple{2},Int64,1,2}
│         (ChainRulesCore.Δ2 = %10)
│         (Ω = Main.very_nice(x, y))
│   %13 = Ω::Int64%14 = Base.broadcasted(Main.One)::Core.Compiler.Const(One(), false)
│   %15 = Base.broadcasted(Main.One)::Core.Compiler.Const(One(), false)
│   %16 = Base.broadcasted(ChainRulesCore.:*, %15, ChainRulesCore.Δ1)::Any%17 = Base.broadcasted(muladd, %14, ChainRulesCore.Δ2, %16)::Any%18 = Base.materialize(%17)::Any%19 = Core.tuple(%13, %18)::Tuple{Int64,Any}
└──       return %19

@willtebbutt willtebbutt mentioned this pull request Feb 16, 2020
@willtebbutt
Copy link
Member Author

This is generally passing now. @oxinabox @nickrobinson251 what would you like to happen with the nightly failures?

@simeonschaub
Copy link
Member

There's something weird going on with @allocated since JuliaLang/julia#33717. We might need to adjust our tests eventually, but I think it's okay to ignore failures on nightly for now

@willtebbutt
Copy link
Member Author

@shashi @YingboMa does this change work for you?

test/rules.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
@oxinabox
Copy link
Member

Main documentation in /docs/src need to updated

@willtebbutt
Copy link
Member Author

@nickrobinson251 @oxinabox shall we merge?

@nickrobinson251
Copy link
Contributor

Would be good to see the partner PR to ChainRules

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems good to me -- we just need to decide if we're releasing this as 0.7 (i don't seee why not) or keeping it at 0.7-DEV until the macros added too

Also, tests need to pass before merge

Project.toml Outdated Show resolved Hide resolved
@willtebbutt
Copy link
Member Author

Have tweaked the CI slightly to require that tests pass on 1.3. Will merge provided that everything seems fine.

@willtebbutt
Copy link
Member Author

Before merging, @oxinabox are you happy with just calling this merge 0.7.0? Adding the convenience macros isn't breaking, so 🤷‍♂

@shashi
Copy link
Collaborator

shashi commented Feb 20, 2020

Yeah it should work for ForwardDiff2. cc @YingboMa

.travis.yml Show resolved Hide resolved
test/rules.jl Outdated Show resolved Hide resolved
Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fix to not use import and then LGTM

willtebbutt and others added 2 commits February 21, 2020 17:09
Co-Authored-By: Lyndon White <oxinabox@ucc.asn.au>
@willtebbutt willtebbutt merged commit 2ef3f20 into master Feb 21, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants