Skip to content

Commit

Permalink
backport basic refactor from PR#100 (#103)
Browse files Browse the repository at this point in the history
* back port basic refactor from PR#100

* remove unused code, and some minor changes

* move increase_counter

* minor update
  • Loading branch information
KDr2 authored Jan 6, 2022
1 parent 375e2f8 commit d27401a
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 66 deletions.
68 changes: 32 additions & 36 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,37 @@
mutable struct Instruction{F}
fun::F
input::Tuple
output
tape
end

abstract type AbstractInstruction end

mutable struct Tape
tape::Vector{Instruction}
tape::Vector{<:AbstractInstruction}
counter::Int
owner
end

Tape() = Tape(Vector{Instruction}(), nothing)
Tape(owner) = Tape(Vector{Instruction}(), owner)
mutable struct Instruction{F} <: AbstractInstruction
fun::F
input::Tuple
output
tape::Tape
end

Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing)
Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner)
MacroTools.@forward Tape.tape Base.iterate, Base.length
MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex
const NULL_TAPE = Tape()

function setowner!(tape::Tape, owner)
tape.owner = owner
return tape
end

mutable struct Box{T}
val::T
end

val(x) = x
val(x::Box) = x.val
box(x) = Box(x)
any_box(x) = Box{Any}(x)
box(x::Box) = x

gettape(x) = nothing
gettape(x::Instruction) = x.tape
Expand Down Expand Up @@ -63,11 +70,21 @@ function (instr::Instruction{F})() where F
instr.output.val = output
end

function increase_counter!(t::Tape)
t.counter > length(t) && return
# instr = t[t.counter]
t.counter += 1
return t
end

function run(tape::Tape, args...)
input = map(box, args)
tape[1].input = input
if length(args) > 0
input = map(box, args)
tape[1].input = input
end
for instruction in tape
instruction()
increase_counter!(tape)
end
end

Expand All @@ -77,21 +94,13 @@ function run_and_record!(tape::Tape, f, args...)
box(f(map(val, args)...))
catch e
@warn e
any_box(nothing)
Box{Any}(nothing)
end
ins = Instruction(f, args, output, tape)
push!(tape, ins)
return output
end

function dry_record!(tape::Tape, f, args...)
# We don't know the type of box.val now, so we use Box{Any}
output = any_box(nothing)
ins = Instruction(f, args, output, tape)
push!(tape, ins)
return output
end

function unbox_condition(ir)
for blk in IRTools.blocks(ir)
vars = keys(blk)
Expand Down Expand Up @@ -188,27 +197,14 @@ function (tf::TapedFunction)(args...)
tape = IRTools.evalir(ir, tf.func, args...)
tf.ir = ir
tf.tape = tape
tape.owner = tf
setowner!(tape, tf)
return result(tape)
end
# TODO: use cache
run(tf.tape, args...)
return result(tf.tape)
end

function dry_run(tf::TapedFunction)
isempty(tf.tape) || (return tf)
@assert tf.arity >= 0 "TapedFunction need a fixed arity to dry run."
args = fill(nothing, tf.arity)
ir = IRTools.@code_ir tf.func(args...)
ir = intercept(ir; recorder=:dry_record!)
tape = IRTools.evalir(ir, tf.func, args...)
tf.ir = ir
tf.tape = tape
tape.owner = tf
return tf
end

function Base.show(io::IO, tf::TapedFunction)
buf = IOBuffer()
println(buf, "TapedFunction:")
Expand Down
71 changes: 41 additions & 30 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,32 @@
struct TapedTaskException
exc
exc::Exception
backtrace
end

struct TapedTask
task::Task
tf::TapedFunction
counter::Ref{Int}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, counter, pch, cch, Any[])
t::Task, tf::TapedFunction, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, pch, cch, Any[])
end
end

function TapedTask(tf::TapedFunction, args...)
tf.owner != nothing && error("TapedFunction is owned to another task.")
# dry_run(tf)
isempty(tf.tape) && tf(args...)
counter = Ref{Int}(1)
produce_ch = Channel()
consume_ch = Channel{Int}()
task = @task try
step_in(tf, counter, args)
step_in(tf.tape, args)
catch e
put!(produce_ch, TapedTaskException(e))
# @error "TapedTask Error: " exception=(e, catch_backtrace())
bt = catch_backtrace()
put!(produce_ch, TapedTaskException(e, bt))
# @error "TapedTask Error: " exception=(e, bt)
rethrow()
finally
@static if VERSION >= v"1.4"
Expand All @@ -40,7 +39,7 @@ function TapedTask(tf::TapedFunction, args...)
close(produce_ch)
close(consume_ch)
end
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
t = TapedTask(task, tf, produce_ch, consume_ch)
task.storage === nothing && (task.storage = IdDict())
task.storage[:tapedtask] = t
tf.owner = t
Expand All @@ -53,25 +52,31 @@ TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
func(t::TapedTask) = t.tf.func

function step_in(tf::TapedFunction, counter::Ref{Int}, args)
len = length(tf.tape)
if(counter[] <= 1 && length(args) > 0)

function step_in(t::Tape, args)
len = length(t)
if(t.counter <= 1 && length(args) > 0)
input = map(box, args)
tf.tape[1].input = input
t[1].input = input
end
while counter[] <= len
tf.tape[counter[]]()
while t.counter <= len
t[t.counter]()
# produce and wait after an instruction is done
ttask = tf.owner
ttask = t.owner.owner
if length(ttask.produced_val) > 0
val = pop!(ttask.produced_val)
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
end
counter[] += 1
increase_counter!(t)
end
end

function next_step!(t::TapedTask)
increase_counter!(t.tf.tape)
return t
end

#=
# ** Approach (A) to implement `produce`:
# Make`produce` a standalone instturction. This approach does NOT
Expand Down Expand Up @@ -186,18 +191,21 @@ function copy_box(old_box::Box{T}, roster::Dict{UInt64, Any}) where T
end
copy_box(o, roster::Dict{UInt64, Any}) = o

function Base.copy(t::Tape)
function Base.copy(x::Instruction, on_tape::Tape, roster::Dict{UInt64, Any})
input = map(x.input) do ob
copy_box(ob, roster)
end
output = copy_box(x.output, roster)
Instruction(x.fun, input, output, on_tape)
end

function Base.copy(t::Tape, roster::Dict{UInt64, Any})
old_data = t.tape
new_data = Vector{Instruction}()
new_tape = Tape(new_data, t.owner)
new_data = Vector{AbstractInstruction}()
new_tape = Tape(new_data, t.counter, t.owner)

roster = Dict{UInt64, Any}()
for x in old_data
input = map(x.input) do ob
copy_box(ob, roster)
end
output = copy_box(x.output, roster)
new_ins = Instruction(x.fun, input, output, new_tape)
new_ins = copy(x, new_tape, roster)
push!(new_data, new_ins)
end

Expand All @@ -207,8 +215,9 @@ end
function Base.copy(tf::TapedFunction)
new_tf = TapedFunction(tf.func; arity=tf.arity)
new_tf.ir = tf.ir
new_tape = copy(tf.tape)
new_tape.owner = new_tf
roster = Dict{UInt64, Any}()
new_tape = copy(tf.tape, roster)
setowner!(new_tape, new_tf)
new_tf.tape = new_tape
return new_tf
end
Expand All @@ -217,6 +226,8 @@ function Base.copy(t::TapedTask)
# t.counter[] <= 1 && error("Can't copy a TapedTask which is not running.")
tf = copy(t.tf)
new_t = TapedTask(tf)
new_t.counter[] = t.counter[] + 1
new_t.task.storage = copy(t.task.storage)
new_t.task.storage[:tapedtask] = new_t
next_step!(new_t)
return new_t
end

0 comments on commit d27401a

Please sign in to comment.