Skip to content

Commit

Permalink
Review changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
rofinn committed Jul 9, 2019
1 parent 5f1f4d8 commit 6043da8
Show file tree
Hide file tree
Showing 9 changed files with 90 additions and 75 deletions.
40 changes: 18 additions & 22 deletions src/Impute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function __init__()

@warn(
"""
All matrix imputation methods will be switching to the JuliaStats column-major convention
All matrix imputation methods will be switching to the column-major convention
(e.g., each column corresponds to an observation, and each row corresponds to a variable).
"""
)
Expand All @@ -51,31 +51,27 @@ Base.showerror(io::IO, err::ImputeError) = println(io, "ImputeError: $(err.msg)"
include("context.jl")
include("imputors.jl")

const global imputation_methods = Dict{Symbol, Type}(
:drop => DropObs,
:dropobs => DropObs,
:dropvars => DropVars,
:interp => Interpolate,
:fill => Fill,
:locf => LOCF,
:nocb => NOCB,
const global imputation_methods = (
drop = DropObs,
dropobs = DropObs,
dropvars = DropVars,
interp = Interpolate,
fill = Fill,
locf = LOCF,
nocb = NOCB,
)

include("deprecated.jl")

let
for (k, v) in imputation_methods
local typename = nameof(v)
local f = k
local f! = Symbol(k, :!)

# NOTE: The
@eval begin
$f(data; kwargs...) = impute($typename(; _extract_context_kwargs(kwargs...)...), data)
$f!(data; kwargs...) = impute!($typename(; _extract_context_kwargs(kwargs...)...), data)
$f(; kwargs...) = data -> impute($typename(; _extract_context_kwargs(kwargs...)...), data)
$f!(; kwargs...) = data -> impute!($typename(; _extract_context_kwargs(kwargs...)...), data)
end
for (f, v) in pairs(imputation_methods)
typename = nameof(v)
f! = Symbol(f, :!)

@eval begin
$f(data; kwargs...) = impute($typename(; _extract_context_kwargs(kwargs...)...), data)
$f!(data; kwargs...) = impute!($typename(; _extract_context_kwargs(kwargs...)...), data)
$f(; kwargs...) = data -> impute($typename(; _extract_context_kwargs(kwargs...)...), data)
$f!(; kwargs...) = data -> impute!($typename(; _extract_context_kwargs(kwargs...)...), data)
end
end

Expand Down
71 changes: 39 additions & 32 deletions src/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ An imputation context records summary information about missing data for an impu
"""
abstract type AbstractContext end

# We implement a version of copy for all contexts which reconstructs the context from the
# raw fields.
Base.copy(ctx::T) where {T <: AbstractContext} = T(fieldvalues(ctx)...)

"""
ismissing(ctx::AbstractContext, x) -> Bool
Expand All @@ -18,23 +22,23 @@ exceeds our `ctx.limit` we throw an `ImputeError`
* `x`: the value to check (may be an single values, abstract array or row)
"""
function Base.ismissing(ctx::AbstractContext, x)
missing = if isa(x, NamedTuple)
any(entry -> ctx.is_missing(entry[2]), pairs(x))
was_missing = if isa(x, NamedTuple)
any(ctx.is_missing, Tuple(x))
elseif isa(x, AbstractArray)
any(ctx.is_missing, x)
else
ctx.is_missing(x)
end

missing_update!(ctx, missing)
missing_update!(ctx, was_missing)

return missing
return was_missing
end

"""
findfirst(ctx::AbstractContext, data::AbstractVector) -> Int
Returns the first not missing index in `data`.
Returns the first non-missing index in `data`.
# Arguments
* `ctx::AbstractContext`: the context to pass into `ismissing`
Expand All @@ -50,7 +54,7 @@ end
"""
findlast(ctx::AbstractContext, data::AbstractVector) -> Int
Returns the last not missing index in `data`.
Returns the last non-missing index in `data`.
# Arguments
* `ctx::AbstractContext`: the context to pass into `ismissing`
Expand All @@ -66,7 +70,7 @@ end
"""
findnext(ctx::AbstractContext, data::AbstractVector) -> Int
Returns the next not missing index in `data`.
Returns the next non-missing index in `data`.
# Arguments
* `ctx::AbstractContext`: the context to pass into `ismissing`
Expand All @@ -88,7 +92,7 @@ weighted.
# Fields
* `n::Int`: number of observations
* `count::Int`: number of missing values found
* `limit::Float64`: allowable limit for missing values to impute
* `limit::Float64`: allowable portion of total values allowed to be imputed (should be between 0.0 and 1.0).
* `is_missing::Function`: returns a Bool if the value counts as missing
* `on_complete::Function`: a function to run when imputation is complete
"""
Expand All @@ -105,37 +109,35 @@ function Context(;
is_missing::Function=ismissing,
on_complete::Function=complete
)
Context(0, 0, limit, is_missing, on_complete)
return Context(0, 0, limit, is_missing, on_complete)
end

function (ctx::Context)(f::Function)
function Base.empty(ctx::Context)
_ctx = copy(ctx)
_ctx.num = 0
_ctx.count = 0

result = f(_ctx)
ctx.on_complete(_ctx)
return result
return _ctx
end

Base.copy(x::Context) = Context(x.num, x.count, x.limit, x.is_missing, x.on_complete)

function missing_update!(ctx::Context, miss)
function missing_update!(ctx::Context, was_missing)
ctx.num += 1

if miss
if was_missing
ctx.count += 1
end
end

function complete(ctx::Context)
function complete(ctx::Context, data)
missing_ratio = ctx.count / ctx.num

if missing_ratio > ctx.limit
throw(ImputeError(
"More than $(ctx.limit * 100)% of values were missing ($missing_ratio)."
))
end

return data
end


Expand All @@ -151,9 +153,9 @@ This context type can be useful if some missing observation are more important t
* `s::Float64`: sum of missing values weights
* `limit::Float64`: allowable limit for missing values to impute
* `is_missing::Function`: returns a Bool if the value counts as missing
* `on_complete::Function`: a function to run when imputation is complete
* `on_complete::Function`: allowable portion of total values allowed to be imputed (should be between 0.0 and 1.0).
* `wv::AbstractWeights`: a set of statistical weights to use when evaluating the importance
of each observation
of each observation. Will be accumulated during imputation.
"""
mutable struct WeightedContext <: AbstractContext
num::Int
Expand All @@ -170,37 +172,42 @@ function WeightedContext(
is_missing::Function=ismissing,
on_complete::Function=complete
)
WeightedContext(0, 0.0, limit, is_missing, on_complete, wv)
return WeightedContext(0, 0.0, limit, is_missing, on_complete, wv)
end

function (ctx::WeightedContext)(f::Function)
function Base.empty(ctx::WeightedContext)
_ctx = copy(ctx)
_ctx.num = 0
_ctx.s = 0.0

result = f(_ctx)
ctx.on_complete(_ctx)
return result
return _ctx
end

function Base.copy(x::WeightedContext)
WeightedContext(x.num, x.s, x.limit, x.is_missing, x.on_complete, x.wv)
end

function missing_update!(ctx::WeightedContext, miss)
function missing_update!(ctx::WeightedContext, was_missing)
ctx.num += 1

if miss
if was_missing
ctx.s += ctx.wv[ctx.num]
end
end

function complete(ctx::WeightedContext)
function complete(ctx::WeightedContext, data)
missing_ratio = ctx.s / sum(ctx.wv)

if missing_ratio > ctx.limit
throw(ImputeError(
"More than $(ctx.limit * 100)% of weighted values were missing ($missing_ratio)."
))
end

return data
end

for T in [Context, WeightedContext]
@eval begin
function (ctx::$T)(f::Function)
_ctx = empty(ctx)
return ctx.on_complete(_ctx, f(_ctx))
end
end
end
6 changes: 4 additions & 2 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Base.@deprecate(
impute!(imp::Imputor, context::AbstractContext, data; kwargs...),
impute!(typeof(imp)(; context=context), data; kwargs...)
)

#####################################################################
# Deprecate all impute calls where the first argument is an Imputor #
#####################################################################
Expand Down Expand Up @@ -88,6 +89,7 @@ function impute(data, args...; kwargs...)
""",
:impute
)
# Call `deepcopy` because we can trust that it's available for all types.
return impute!(deepcopy(data), args...; kwargs...)
end

Expand Down Expand Up @@ -148,8 +150,8 @@ function chain(data, args...; kwargs...)
""",
:chain
)
result = deepcopy(data)
return chain!(data, args...; kwargs...)
# Call `deepcopy` because we can trust that it's available for all types.
return chain!(deepcopy(data), args...; kwargs...)
end

#####################
Expand Down
7 changes: 4 additions & 3 deletions src/imputors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ abstract type Imputor end
"""
impute(imp::Imputor, data)
Copies the `data` before calling the corresponding `impute!(imp, ...)` call.
Returns a new copy of the `data` with the missing data imputed by the imputor `imp`.
"""
function impute(imp::Imputor, data)
impute!(imp, deepcopy(data))
# Call `deepcopy` because we can trust that it's available for all types.
return impute!(imp, deepcopy(data))
end

"""
Expand All @@ -33,7 +34,7 @@ if this is not the desired behaviour custom imputor methods should overload this
* `AbstractMatrix`: the input `data` with values imputed
"""
function impute!(imp::Imputor, data::AbstractMatrix)
for i in 1:size(data, 2)
for i in axes(data, 2)
impute!(imp, view(data, :, i))
end
return data
Expand Down
10 changes: 5 additions & 5 deletions src/imputors/drop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ struct DropObs <: Imputor
context::AbstractContext
end

"""DropObs(; context=Context()) -> DropObs"""
DropObs(; context=Context()) = DropObs(context)

"""
Expand Down Expand Up @@ -54,7 +53,9 @@ NOTES (or premature optimizations):
"""
function impute!(imp::DropObs, data::AbstractMatrix)
imp.context() do c
mask = map(i -> !ismissing(c, data[i, :]), 1:size(data, 1))
mask = map(axes(data, 1)) do i
!ismissing(c, view(data, i, :))
end
return data[mask, :]
end
end
Expand Down Expand Up @@ -101,7 +102,6 @@ struct DropVars <: Imputor
context::AbstractContext
end

"""DropVars(; context=Context()) -> DropVars"""
DropVars(; context=Context()) = DropVars(context)

"""
Expand All @@ -119,10 +119,10 @@ requires copying the matrix.
* `AbstractMatrix`: a new matrix with missing columns removed
"""
function impute!(imp::DropVars, data::AbstractMatrix)
mask = map(1:size(data, 2)) do i
mask = map(axes(data, 2)) do i
try
imp.context() do c
for j in 1:size(data, 1)
for j in axes(data, 1)
ismissing(c, data[j, i])
end
end
Expand Down
8 changes: 5 additions & 3 deletions src/imputors/fill.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ Fill(; value=mean, context=Context()) = Fill(value, context)
"""
impute!(imp::Fill, data::AbstractVector)
Computes the fill value if `imp.value` is a `Function` (i.e., `imp.value(drop(copy(data)))`)
and replaces all missing values in the `data` with that value.
Fill in missing values with a values determined by `imp.value`.
If `imp.value` is a function then the fill values calculated by invoking that function on
the collection of all nonmissing values.
"""
function impute!(imp::Fill, data::AbstractVector)
imp.context() do c
fill_val = if isa(imp.value, Function)
imp.value(Iterators.drop(copy(data); context=c))
# Call `deepcopy` because we can trust that it's available for all types.
imp.value(Iterators.drop(deepcopy(data); context=c))
else
imp.value
end
Expand Down
2 changes: 1 addition & 1 deletion src/imputors/interp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function impute!(imp::Interpolate, data::AbstractVector{<:Union{T, Missing}}) wh
imp.context() do c
i = findfirst(c, data) + 1

while i < length(data)
while i < lastindex(data)
if ismissing(c, data[i])
prev_idx = i - 1
next_idx = findnext(c, data, i + 1)
Expand Down
16 changes: 10 additions & 6 deletions src/imputors/locf.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
LOCF <: Imputor
Last observation carried forward. Fill in missing values with the most recent
nonmissing value.
See also:
- [NOCB](@ref): Next Observation Carried Backward
"""
struct LOCF <: Imputor
context::AbstractContext
end
Expand All @@ -14,16 +23,11 @@ existing observation.
WARNING: missing elements at the head of the array may not be imputed if there is no
existing observation to carry forward. As a result, this method does not guarantee
that all missing values will be imputed.
# Usage
```
```
"""
function impute!(imp::LOCF, data::AbstractVector)
imp.context() do c
start_idx = findfirst(c, data) + 1
for i in start_idx:length(data)
for i in start_idx:lastindex(data)
if ismissing(c, data[i])
data[i] = data[i-1]
end
Expand Down
Loading

0 comments on commit 6043da8

Please sign in to comment.