diff --git a/src/Impute.jl b/src/Impute.jl index e72a9fd..6c63747 100644 --- a/src/Impute.jl +++ b/src/Impute.jl @@ -28,7 +28,7 @@ function __init__() @warn( """ - All matrix imputation methods will be switching to the JuliaStats column-major convention + All matrix imputation methods will be switching to the column-major convention (e.g., each column corresponds to an observation, and each row corresponds to a variable). """ ) @@ -51,31 +51,27 @@ Base.showerror(io::IO, err::ImputeError) = println(io, "ImputeError: $(err.msg)" include("context.jl") include("imputors.jl") -const global imputation_methods = Dict{Symbol, Type}( - :drop => DropObs, - :dropobs => DropObs, - :dropvars => DropVars, - :interp => Interpolate, - :fill => Fill, - :locf => LOCF, - :nocb => NOCB, +const global imputation_methods = ( + drop = DropObs, + dropobs = DropObs, + dropvars = DropVars, + interp = Interpolate, + fill = Fill, + locf = LOCF, + nocb = NOCB, ) include("deprecated.jl") -let - for (k, v) in imputation_methods - local typename = nameof(v) - local f = k - local f! = Symbol(k, :!) - - # NOTE: The - @eval begin - $f(data; kwargs...) = impute($typename(; _extract_context_kwargs(kwargs...)...), data) - $f!(data; kwargs...) = impute!($typename(; _extract_context_kwargs(kwargs...)...), data) - $f(; kwargs...) = data -> impute($typename(; _extract_context_kwargs(kwargs...)...), data) - $f!(; kwargs...) = data -> impute!($typename(; _extract_context_kwargs(kwargs...)...), data) - end +for (f, v) in pairs(imputation_methods) + typename = nameof(v) + f! = Symbol(f, :!) + + @eval begin + $f(data; kwargs...) = impute($typename(; _extract_context_kwargs(kwargs...)...), data) + $f!(data; kwargs...) = impute!($typename(; _extract_context_kwargs(kwargs...)...), data) + $f(; kwargs...) = data -> impute($typename(; _extract_context_kwargs(kwargs...)...), data) + $f!(; kwargs...) = data -> impute!($typename(; _extract_context_kwargs(kwargs...)...), data) end end diff --git a/src/context.jl b/src/context.jl index 6645174..2e2df5a 100644 --- a/src/context.jl +++ b/src/context.jl @@ -5,6 +5,10 @@ An imputation context records summary information about missing data for an impu """ abstract type AbstractContext end +# We implement a version of copy for all contexts which reconstructs the context from the +# raw fields. +Base.copy(ctx::T) where {T <: AbstractContext} = T(fieldvalues(ctx)...) + """ ismissing(ctx::AbstractContext, x) -> Bool @@ -18,23 +22,23 @@ exceeds our `ctx.limit` we throw an `ImputeError` * `x`: the value to check (may be an single values, abstract array or row) """ function Base.ismissing(ctx::AbstractContext, x) - missing = if isa(x, NamedTuple) - any(entry -> ctx.is_missing(entry[2]), pairs(x)) + was_missing = if isa(x, NamedTuple) + any(ctx.is_missing, Tuple(x)) elseif isa(x, AbstractArray) any(ctx.is_missing, x) else ctx.is_missing(x) end - missing_update!(ctx, missing) + missing_update!(ctx, was_missing) - return missing + return was_missing end """ findfirst(ctx::AbstractContext, data::AbstractVector) -> Int -Returns the first not missing index in `data`. +Returns the first non-missing index in `data`. # Arguments * `ctx::AbstractContext`: the context to pass into `ismissing` @@ -50,7 +54,7 @@ end """ findlast(ctx::AbstractContext, data::AbstractVector) -> Int -Returns the last not missing index in `data`. +Returns the last non-missing index in `data`. # Arguments * `ctx::AbstractContext`: the context to pass into `ismissing` @@ -66,7 +70,7 @@ end """ findnext(ctx::AbstractContext, data::AbstractVector) -> Int -Returns the next not missing index in `data`. +Returns the next non-missing index in `data`. # Arguments * `ctx::AbstractContext`: the context to pass into `ismissing` @@ -88,7 +92,7 @@ weighted. # Fields * `n::Int`: number of observations * `count::Int`: number of missing values found -* `limit::Float64`: allowable limit for missing values to impute +* `limit::Float64`: allowable portion of total values allowed to be imputed (should be between 0.0 and 1.0). * `is_missing::Function`: returns a Bool if the value counts as missing * `on_complete::Function`: a function to run when imputation is complete """ @@ -105,30 +109,26 @@ function Context(; is_missing::Function=ismissing, on_complete::Function=complete ) - Context(0, 0, limit, is_missing, on_complete) + return Context(0, 0, limit, is_missing, on_complete) end -function (ctx::Context)(f::Function) +function Base.empty(ctx::Context) _ctx = copy(ctx) _ctx.num = 0 _ctx.count = 0 - result = f(_ctx) - ctx.on_complete(_ctx) - return result + return _ctx end -Base.copy(x::Context) = Context(x.num, x.count, x.limit, x.is_missing, x.on_complete) - -function missing_update!(ctx::Context, miss) +function missing_update!(ctx::Context, was_missing) ctx.num += 1 - if miss + if was_missing ctx.count += 1 end end -function complete(ctx::Context) +function complete(ctx::Context, data) missing_ratio = ctx.count / ctx.num if missing_ratio > ctx.limit @@ -136,6 +136,8 @@ function complete(ctx::Context) "More than $(ctx.limit * 100)% of values were missing ($missing_ratio)." )) end + + return data end @@ -149,11 +151,11 @@ This context type can be useful if some missing observation are more important t # Fields * `num::Int`: number of observations * `s::Float64`: sum of missing values weights -* `limit::Float64`: allowable limit for missing values to impute +* `limit::Float64`: allowable portion of total values allowed to be imputed (should be between 0.0 and 1.0). * `is_missing::Function`: returns a Bool if the value counts as missing -* `on_complete::Function`: a function to run when imputation is complete +* `on_complete::Function`: allowable portion of total values allowed to be imputed (should be between 0.0 and 1.0). * `wv::AbstractWeights`: a set of statistical weights to use when evaluating the importance - of each observation + of each observation. Will be accumulated during imputation. """ mutable struct WeightedContext <: AbstractContext num::Int @@ -170,32 +172,26 @@ function WeightedContext( is_missing::Function=ismissing, on_complete::Function=complete ) - WeightedContext(0, 0.0, limit, is_missing, on_complete, wv) + return WeightedContext(0, 0.0, limit, is_missing, on_complete, wv) end -function (ctx::WeightedContext)(f::Function) +function Base.empty(ctx::WeightedContext) _ctx = copy(ctx) _ctx.num = 0 _ctx.s = 0.0 - result = f(_ctx) - ctx.on_complete(_ctx) - return result -end - -function Base.copy(x::WeightedContext) - WeightedContext(x.num, x.s, x.limit, x.is_missing, x.on_complete, x.wv) + return _ctx end -function missing_update!(ctx::WeightedContext, miss) +function missing_update!(ctx::WeightedContext, was_missing) ctx.num += 1 - if miss + if was_missing ctx.s += ctx.wv[ctx.num] end end -function complete(ctx::WeightedContext) +function complete(ctx::WeightedContext, data) missing_ratio = ctx.s / sum(ctx.wv) if missing_ratio > ctx.limit @@ -203,4 +199,15 @@ function complete(ctx::WeightedContext) "More than $(ctx.limit * 100)% of weighted values were missing ($missing_ratio)." )) end + + return data +end + +for T in [Context, WeightedContext] + @eval begin + function (ctx::$T)(f::Function) + _ctx = empty(ctx) + return ctx.on_complete(_ctx, f(_ctx)) + end + end end diff --git a/src/deprecated.jl b/src/deprecated.jl index a5174df..6f62a14 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -10,6 +10,7 @@ Base.@deprecate( impute!(imp::Imputor, context::AbstractContext, data; kwargs...), impute!(typeof(imp)(; context=context), data; kwargs...) ) + ##################################################################### # Deprecate all impute calls where the first argument is an Imputor # ##################################################################### @@ -88,6 +89,7 @@ function impute(data, args...; kwargs...) """, :impute ) + # Call `deepcopy` because we can trust that it's available for all types. return impute!(deepcopy(data), args...; kwargs...) end @@ -148,8 +150,8 @@ function chain(data, args...; kwargs...) """, :chain ) - result = deepcopy(data) - return chain!(data, args...; kwargs...) + # Call `deepcopy` because we can trust that it's available for all types. + return chain!(deepcopy(data), args...; kwargs...) end ##################### diff --git a/src/imputors.jl b/src/imputors.jl index 769926e..1ed2783 100644 --- a/src/imputors.jl +++ b/src/imputors.jl @@ -13,10 +13,11 @@ abstract type Imputor end """ impute(imp::Imputor, data) -Copies the `data` before calling the corresponding `impute!(imp, ...)` call. +Returns a new copy of the `data` with the missing data imputed by the imputor `imp`. """ function impute(imp::Imputor, data) - impute!(imp, deepcopy(data)) + # Call `deepcopy` because we can trust that it's available for all types. + return impute!(imp, deepcopy(data)) end """ @@ -33,7 +34,7 @@ if this is not the desired behaviour custom imputor methods should overload this * `AbstractMatrix`: the input `data` with values imputed """ function impute!(imp::Imputor, data::AbstractMatrix) - for i in 1:size(data, 2) + for i in axes(data, 2) impute!(imp, view(data, :, i)) end return data diff --git a/src/imputors/drop.jl b/src/imputors/drop.jl index 8f02b88..e15f88c 100644 --- a/src/imputors/drop.jl +++ b/src/imputors/drop.jl @@ -11,7 +11,6 @@ struct DropObs <: Imputor context::AbstractContext end -"""DropObs(; context=Context()) -> DropObs""" DropObs(; context=Context()) = DropObs(context) """ @@ -54,7 +53,9 @@ NOTES (or premature optimizations): """ function impute!(imp::DropObs, data::AbstractMatrix) imp.context() do c - mask = map(i -> !ismissing(c, data[i, :]), 1:size(data, 1)) + mask = map(axes(data, 1)) do i + !ismissing(c, view(data, i, :)) + end return data[mask, :] end end @@ -101,7 +102,6 @@ struct DropVars <: Imputor context::AbstractContext end -"""DropVars(; context=Context()) -> DropVars""" DropVars(; context=Context()) = DropVars(context) """ @@ -119,10 +119,10 @@ requires copying the matrix. * `AbstractMatrix`: a new matrix with missing columns removed """ function impute!(imp::DropVars, data::AbstractMatrix) - mask = map(1:size(data, 2)) do i + mask = map(axes(data, 2)) do i try imp.context() do c - for j in 1:size(data, 1) + for j in axes(data, 1) ismissing(c, data[j, i]) end end diff --git a/src/imputors/fill.jl b/src/imputors/fill.jl index fa68d31..ee94955 100644 --- a/src/imputors/fill.jl +++ b/src/imputors/fill.jl @@ -20,13 +20,15 @@ Fill(; value=mean, context=Context()) = Fill(value, context) """ impute!(imp::Fill, data::AbstractVector) -Computes the fill value if `imp.value` is a `Function` (i.e., `imp.value(drop(copy(data)))`) -and replaces all missing values in the `data` with that value. +Fill in missing values with a values determined by `imp.value`. +If `imp.value` is a function then the fill values calculated by invoking that function on +the collection of all nonmissing values. """ function impute!(imp::Fill, data::AbstractVector) imp.context() do c fill_val = if isa(imp.value, Function) - imp.value(Iterators.drop(copy(data); context=c)) + # Call `deepcopy` because we can trust that it's available for all types. + imp.value(Iterators.drop(deepcopy(data); context=c)) else imp.value end diff --git a/src/imputors/interp.jl b/src/imputors/interp.jl index a6c178a..21f81f8 100644 --- a/src/imputors/interp.jl +++ b/src/imputors/interp.jl @@ -23,7 +23,7 @@ function impute!(imp::Interpolate, data::AbstractVector{<:Union{T, Missing}}) wh imp.context() do c i = findfirst(c, data) + 1 - while i < length(data) + while i < lastindex(data) if ismissing(c, data[i]) prev_idx = i - 1 next_idx = findnext(c, data, i + 1) diff --git a/src/imputors/locf.jl b/src/imputors/locf.jl index 10ccb8b..56a7438 100644 --- a/src/imputors/locf.jl +++ b/src/imputors/locf.jl @@ -1,3 +1,12 @@ +""" + LOCF <: Imputor + +Last observation carried forward. Fill in missing values with the most recent +nonmissing value. + +See also: +- [NOCB](@ref): Next Observation Carried Backward +""" struct LOCF <: Imputor context::AbstractContext end @@ -14,16 +23,11 @@ existing observation. WARNING: missing elements at the head of the array may not be imputed if there is no existing observation to carry forward. As a result, this method does not guarantee that all missing values will be imputed. - -# Usage -``` - -``` """ function impute!(imp::LOCF, data::AbstractVector) imp.context() do c start_idx = findfirst(c, data) + 1 - for i in start_idx:length(data) + for i in start_idx:lastindex(data) if ismissing(c, data[i]) data[i] = data[i-1] end diff --git a/src/imputors/nocb.jl b/src/imputors/nocb.jl index 32690a1..9475ac0 100644 --- a/src/imputors/nocb.jl +++ b/src/imputors/nocb.jl @@ -2,6 +2,9 @@ NOCB <: Imputor Fills in missing data using the Next Observation Carried Backward (NOCB) approach. + +See also: +- [LOCF](@ref): Last Observation Carried Forward """ struct NOCB <: Imputor context::AbstractContext @@ -28,7 +31,7 @@ that all missing values will be imputed. function impute!(imp::NOCB, data::AbstractVector) imp.context() do c end_idx = findlast(c, data) - 1 - for i in end_idx:-1:1 + for i in end_idx:-1:firstindex(data) if ismissing(c, data[i]) data[i] = data[i+1] end