diff --git a/src/rollouts.jl b/src/rollouts.jl index 2bfe0ff0..ddaf77fd 100644 --- a/src/rollouts.jl +++ b/src/rollouts.jl @@ -45,6 +45,13 @@ function rollout( return Ψ̃ end +function rollout( + ψ̃₁s::AbstractVector{AbstractVector}, args...; kwargs... +) + return vcat([rollout(ψ̃₁, args...; kwargs...) for ψ̃₁ ∈ ψ̃₁s]...) +end + + function unitary_rollout( Ũ⃗₁::AbstractVector{<:Real}, controls::AbstractMatrix{Float64},