Skip to content

Commit

Permalink
fix a TArray bug
Browse files Browse the repository at this point in the history
  • Loading branch information
KDr2 committed Jan 11, 2022
1 parent dd211e9 commit af672b5
Showing 1 changed file with 47 additions and 45 deletions.
92 changes: 47 additions & 45 deletions src/tarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ TArray{T,N}(::UndefInitializer, d::Vararg{<:Integer,N}) where {T,N} = TArray{T,N
TArray{T,N}(dim::NTuple{N,Int}) where {T,N} = TArray(T, dim)
TArray(T::Type, dim) = TArray(Array{T}(undef, dim))

localize(x) = x
localize(x::AbstractArray) = TArray(x)
getdata(x::TArray) = x.data
tape_copy(x::TArray) = TArray(deepcopy(x.data))

Expand Down Expand Up @@ -166,70 +168,70 @@ end
# Other methods from stdlib

Base.view(x::TArray, inds...; kwargs...) =
Base.view(getdata(x), inds...; kwargs...) |> TArray
Base.:-(x::TArray) = (-getdata(x)) |> TArray
Base.transpose(x::TArray) = transpose(getdata(x)) |> TArray
Base.adjoint(x::TArray) = adjoint(getdata(x)) |> TArray
Base.repeat(x::TArray; kw...) = repeat(getdata(x); kw...) |> TArray
Base.view(getdata(x), inds...; kwargs...) |> localize
Base.:-(x::TArray) = (-getdata(x)) |> localize
Base.transpose(x::TArray) = transpose(getdata(x)) |> localize
Base.adjoint(x::TArray) = adjoint(getdata(x)) |> localize
Base.repeat(x::TArray; kw...) = repeat(getdata(x); kw...) |> localize

Base.hcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
hcat(getdata.(xs)...) |> TArray
hcat(getdata.(xs)...) |> localize
Base.vcat(xs::Union{TArray{T,1}, TArray{T,2}}...) where T =
vcat(getdata.(xs)...) |> TArray
vcat(getdata.(xs)...) |> localize
Base.cat(xs::Union{TArray{T,1}, TArray{T,2}}...; dims) where T =
cat(getdata.(xs)...; dims = dims) |> TArray
cat(getdata.(xs)...; dims = dims) |> localize


Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(getdata(x), dims) |> TArray
Base.reshape(x::TArray, dims::Union{Colon,Int}...) = reshape(getdata(x), dims) |> localize
Base.reshape(x::TArray, dims::Tuple{Vararg{Union{Int,Colon}}}) =
reshape(getdata(x), Base._reshape_uncolon(getdata(x), dims)) |> TArray
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(getdata(x), dims) |> TArray

Base.permutedims(x::TArray, perm) = permutedims(getdata(x), perm) |> TArray
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(getdata(x), perm) |> TArray
Base.reverse(x::TArray; dims) = reverse(getdata(x), dims = dims) |> TArray

Base.sum(x::TArray; dims = :) = sum(getdata(x), dims = dims) |> TArray
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(getdata(x))) |> TArray
Base.prod(x::TArray; dims=:) = prod(getdata(x); dims=dims) |> TArray
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(getdata(x))) |> TArray

Base.findfirst(x::TArray, args...) = findfirst(getdata(x), args...) |> TArray
Base.maximum(x::TArray; dims = :) = maximum(getdata(x), dims = dims) |> TArray
Base.minimum(x::TArray; dims = :) = minimum(getdata(x), dims = dims) |> TArray

Base.:/(x::TArray, y::TArray) = getdata(x) / getdata(y) |> TArray
Base.:/(x::AbstractArray, y::TArray) = x / getdata(y) |> TArray
Base.:/(x::TArray, y::AbstractArray) = getdata(x) / y |> TArray
Base.:\(x::TArray, y::TArray) = getdata(x) \ getdata(y) |> TArray
Base.:\(x::AbstractArray, y::TArray) = x \ getdata(y) |> TArray
Base.:\(x::TArray, y::AbstractArray) = getdata(x) \ y |> TArray
Base.:*(x::TArray, y::TArray) = getdata(x) * getdata(y) |> TArray
Base.:*(x::AbstractArray, y::TArray) = x * getdata(y) |> TArray
Base.:*(x::TArray, y::AbstractArray) = getdata(x) * y |> TArray
reshape(getdata(x), Base._reshape_uncolon(getdata(x), dims)) |> localize
Base.reshape(x::TArray, dims::Tuple{Vararg{Int}}) = reshape(getdata(x), dims) |> localize

Base.permutedims(x::TArray, perm) = permutedims(getdata(x), perm) |> localize
Base.PermutedDimsArray(x::TArray, perm) = PermutedDimsArray(getdata(x), perm) |> localize
Base.reverse(x::TArray; dims) = reverse(getdata(x), dims = dims) |> localize

Base.sum(x::TArray; dims = :) = sum(getdata(x), dims = dims) |> localize
Base.sum(f::Union{Function,Type},x::TArray) = sum(f.(getdata(x))) |> localize
Base.prod(x::TArray; dims=:) = prod(getdata(x); dims=dims) |> localize
Base.prod(f::Union{Function, Type}, x::TArray) = prod(f.(getdata(x))) |> localize

Base.findfirst(x::TArray, args...) = findfirst(getdata(x), args...) |> localize
Base.maximum(x::TArray; dims = :) = maximum(getdata(x), dims = dims) |> localize
Base.minimum(x::TArray; dims = :) = minimum(getdata(x), dims = dims) |> localize

Base.:/(x::TArray, y::TArray) = getdata(x) / getdata(y) |> localize
Base.:/(x::AbstractArray, y::TArray) = x / getdata(y) |> localize
Base.:/(x::TArray, y::AbstractArray) = getdata(x) / y |> localize
Base.:\(x::TArray, y::TArray) = getdata(x) \ getdata(y) |> localize
Base.:\(x::AbstractArray, y::TArray) = x \ getdata(y) |> localize
Base.:\(x::TArray, y::AbstractArray) = getdata(x) \ y |> localize
Base.:*(x::TArray, y::TArray) = getdata(x) * getdata(y) |> localize
Base.:*(x::AbstractArray, y::TArray) = x * getdata(y) |> localize
Base.:*(x::TArray, y::AbstractArray) = getdata(x) * y |> localize

# broadcast
Base.BroadcastStyle(::Type{<:TArray}) = Broadcast.ArrayStyle{TArray}()
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(getdata.(args)...) |> TArray
Broadcast.broadcasted(::Broadcast.ArrayStyle{TArray}, f, args...) = f.(getdata.(args)...) |> localize

import LinearAlgebra
import LinearAlgebra: \, /, inv, det, logdet, logabsdet, norm

LinearAlgebra.inv(x::TArray) = inv(getdata(x)) |> TArray
LinearAlgebra.det(x::TArray) = det(getdata(x)) |> TArray
LinearAlgebra.logdet(x::TArray) = logdet(getdata(x)) |> TArray
LinearAlgebra.logabsdet(x::TArray) = logabsdet(getdata(x)) |> TArray
LinearAlgebra.inv(x::TArray) = inv(getdata(x)) |> localize
LinearAlgebra.det(x::TArray) = det(getdata(x)) |> localize
LinearAlgebra.logdet(x::TArray) = logdet(getdata(x)) |> localize
LinearAlgebra.logabsdet(x::TArray) = logabsdet(getdata(x)) |> localize
LinearAlgebra.norm(x::TArray, p::Real = 2) =
LinearAlgebra.norm(getdata(x), p) |> TArray
LinearAlgebra.norm(getdata(x), p) |> localize

import LinearAlgebra: dot
dot(x::TArray, ys::TArray) = dot(getdata(x), getdata(ys)) |> TArray
dot(x::AbstractArray, ys::TArray) = dot(x, getdata(ys)) |> TArray
dot(x::TArray, ys::AbstractArray) = dot(getdata(x), ys) |> TArray
dot(x::TArray, ys::TArray) = dot(getdata(x), getdata(ys)) |> localize
dot(x::AbstractArray, ys::TArray) = dot(x, getdata(ys)) |> localize
dot(x::TArray, ys::AbstractArray) = dot(getdata(x), ys) |> localize

using Statistics
Statistics.mean(x::TArray; dims = :) = mean(getdata(x), dims = dims) |> TArray
Statistics.std(x::TArray; kw...) = std(getdata(x), kw...) |> TArray
Statistics.mean(x::TArray; dims = :) = mean(getdata(x), dims = dims) |> localize
Statistics.std(x::TArray; kw...) = std(getdata(x), kw...) |> localize

# TODO
# * NNlib

0 comments on commit af672b5

Please sign in to comment.