diff --git a/Project.toml b/Project.toml index febafd4..b415ac6 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8.2" +version = "0.8.3" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" diff --git a/src/tapedtask.jl b/src/tapedtask.jl index c5e5d22..c96f472 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -63,10 +63,17 @@ function TapedTask(tf::TapedFunction, args...) return t end +BASE_COPY_TYPES = Union{Array, Ref} + # NOTE: evaluating model without a trace, see # https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329 -function TapedTask(f, args...; deepcopy_types=Union{Array, Ref}) # deepcoy Array and Ref by default. - tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy_types) +function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default. + if isnothing(deepcopy_types) + deepcopy = BASE_COPY_TYPES + else + deepcopy = Union{BASE_COPY_TYPES, deepcopy_types} + end + tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy) TapedTask(tf, args...) end diff --git a/test/tape_copy.jl b/test/tape_copy.jl index 5c3af3b..6edc3f7 100644 --- a/test/tape_copy.jl +++ b/test/tape_copy.jl @@ -171,4 +171,25 @@ y[][2] = 19 @test y[][2] == 19 end + + @testset "override deepcopy_types #57" begin + struct DummyType end + + function f(start::Int) + t = [start] + while true + produce(t[1]) + t[1] = 1 + t[1] + end + end + + ttask = TapedTask(f, 0; deepcopy_types=DummyType) + consume(ttask) + + ttask2 = copy(ttask) + consume(ttask2) + + @test consume(ttask) == 1 + @test consume(ttask2) == 2 + end end