Skip to content

Commit

Permalink
Added and exported MappedArray
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy Ferris committed Sep 7, 2017
1 parent f8745ec commit fcf5201
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 1 deletion.
3 changes: 3 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ export
LinSpace,
LowerTriangular,
Irrational,
MappedArray,
MappedVector,
MappedMatrix,
Matrix,
MergeSort,
NTuple,
Expand Down
86 changes: 86 additions & 0 deletions base/mappedarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

module MappedArrays

export MappedArray, MappedVector, MappedMatrix

"""
MappedArray(f, a)
MappedArray(f, f_inv, a)
Returns a lazily mapped array where function `f` is applied the elements of array `a`.
`f_inv` is the inverse function to `f`, and should satisfy `f(f_inv(x)) == x`. It is
optional and used to enable `setindex!` on the output array, so that the appropriate values
can be stored in `a`. Some inverse functions are known (e.g. `conj` is its own inverse) and
users may overload the `Base.MappedArrays.inv_func` function with their own definitions
(e.g. `Base.MappedArrays.inv_func(::typeof(conj)) = conj`) so that `f_inv` is created
automatically by the constructor.
# Example
```julia
julia> a = [1, 2, 3]
3-element Array{Int64,1}:
1
2
3
julia> b = MappedArray(x -> x + 10, x -> x - 10, a)
3-element MappedArray{Int64,1,getfield(Main, Symbol("##3#5")),getfield(Main, Symbol("##4#6")),Array{Int64,1}}:
11
12
13
julia> b[2] = 20
20
julia> a[2]
10
```
"""
struct MappedArray{T, N, F, F_inv, A <: AbstractArray{<:Any, N}} <: AbstractArray{T, N}
f::F
f_inv::F_inv
parent::A
end

MappedArray(f, a::AbstractArray) = MappedArray(f, inv_func(f), a)
function MappedArray(f, f_inv, a::AbstractArray)
MappedArray{Base.promote_op(f, eltype(a)), ndims(a), typeof(f), typeof(f_inv), typeof(a)}(f, f_inv, a)
end

noinverse(x) = error("No inverse function defined")

inv_func(f) = noinverse
inv_func(::typeof(identity)) = identity
inv_func(::typeof(conj)) = conj

const MappedVector{T, F, F_inv, A <: AbstractVector} = MappedArray{T, 1, F, F_inv, A}
MappedVector(f, a::AbstractVector) = MappedVector(f, inv_func(f), a)
function MappedVector(f, f_inv, a::AbstractVector)
MappedArray{Base.promote_op(f, eltype(a)), 1, typeof(f), typeof(f_inv), typeof(a)}(f, f_inv, a)
end

const MappedMatrix{T, F, F_inv, A <: AbstractMatrix} = MappedArray{T, 2, F, F_inv, A}
MappedMatrix(f, a::AbstractMatrix) = MappedMatrix(f, inv_func(f), a)
function MappedMatrix(f, f_inv, a::AbstractMatrix)
MappedArray{Base.promote_op(f, eltype(a)), 2, typeof(f), typeof(f_inv), typeof(a)}(f, f_inv, a)
end

Base.parent(a::MappedArray) = a.parent
parent_type(::Type{<:MappedArray{<:Any, <:Any, <:Any, <:Any, A}}) where {A} = A

Base.size(a::MappedArray) = size(parent(a))
Base.indices(a::MappedArray) = indices(parent(a))
Base.IndexStyle(::Type{MA}) where {MA <: MappedArray} = Base.IndexStyle(parent_type(MA))

Base.@propagate_inbounds Base.getindex(a::MappedArray, i::Int...) = a.f(a.parent[i...])
Base.@propagate_inbounds function Base.setindex!(a::MappedArray{T}, v::T, i::Int...) where {T}
a.parent[i...] = a.f_inv(v)
end
Base.@propagate_inbounds function Base.setindex!(a::MappedArray{T}, v, i::Int...) where {T}
a[i...] = convert(T, v)
end

end # module
2 changes: 2 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ using .Cartesian
include("multidimensional.jl")
include("permuteddimsarray.jl")
using .PermutedDimsArrays
include("mappedarray.jl")
using .MappedArrays

# nullable types
include("nullable.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/choosetests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function choosetests(choices = [])
"enums", "cmdlineargs", "i18n", "workspace", "libdl", "int",
"checked", "intset", "floatfuncs", "compile", "distributed", "inline",
"boundscheck", "error", "ambiguous", "cartesian", "asmvariant", "osutils",
"channels", "iostream", "specificity", "codegen", "codevalidation"
"channels", "iostream", "specificity", "codegen", "codevalidation", "mappedarray"
]
profile_skipped = false
if startswith(string(Sys.ARCH), "arm")
Expand Down
15 changes: 15 additions & 0 deletions test/mappedarray.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
@test @inferred(MappedArray(identity, [1,2,3]))::MappedArray == [1,2,3]
@test @inferred(MappedVector(identity, [1,2,3]))::MappedArray == [1,2,3]
@test @inferred(MappedArray(identity, [1 2; 3 4]))::MappedArray == [1 2; 3 4]
@test @inferred(MappedMatrix(identity, [1 2; 3 4]))::MappedArray == [1 2; 3 4]

a = [1,2,3]
b = @inferred(MappedArray(x -> x + 10, a))::MappedArray
@test b[2] === 12
@test_throws ErrorException b[2] = 20
c = @inferred(MappedArray(x -> x + 10, x -> x - 10, a))::MappedArray
@test c[2] === 12
@test (c[2] = 20; c[2] === 10)
@test (c[3] = 30.0; c[3] === 20)

@test Base.IndexStyle(b) === Base.IndexStyle(a)

0 comments on commit fcf5201

Please sign in to comment.