diff --git a/src/primitives.jl b/src/primitives.jl index bdfb048..3917cb9 100644 --- a/src/primitives.jl +++ b/src/primitives.jl @@ -134,6 +134,30 @@ Returns the last entry of `X`. """ -> Base.endof(X::NullableArray) = endof(X.values) # -> Int +@doc """ +`==(A::NullableArray, B::NullableArray)` + +When none of the arrays contain missing values, returns `Nullable(true)` +if all elements of the two arrays are equal according to `==`, and +`Nullable(false)` otherwise. Returns `Nullable{Bool}()` if a missing +value is present. +""" -> +function Base.(:(==))(A::NullableArray, B::NullableArray) + if size(A) != size(B) + return false + end + # Short-circuit is only possible after finding a missing element + ret = true + for i in eachindex(A,B) + if A.isnull[i] || B.isnull[i] + return Nullable{Bool}() + elseif A.values[i] != B.values[i] + ret = false + end + end + return Nullable(ret) +end + @doc """ """ -> diff --git a/test/primitives.jl b/test/primitives.jl index f13d340..392c00e 100644 --- a/test/primitives.jl +++ b/test/primitives.jl @@ -149,6 +149,17 @@ module TestPrimitives @test endof(NullableArray(collect(1:10))) == 10 @test endof(NullableArray([1, 2, nothing, 4, nothing])) == 5 +# ----- test Base.== ------------------------------------------------------# + x = NullableArray(collect(1:3)) + @test get(x == NullableArray([1.0, 2.0, 3.0])) + @test get(x != NullableArray([1.1, 2.0, 3.0])) + y = NullableArray([1.0, 2.0, 3.0], [false, false, true]) + z = NullableArray([1.1, 2.0, 3.0], [false, true, false]) + @test isnull(x == y) + @test isnull(x != y) + @test isnull(x == z) + @test isnull(x != z) + # ----- test Base.find -------------------------------------------------------# z = NullableArray(rand(Bool, 10))