Skip to content

Commit

Permalink
add more setters (#115)
Browse files Browse the repository at this point in the history
* @set SVector(tup) = ...

* using instead of import

* more general StaticArray signatures

* remove redundant definitions of set() for invertible functions

* splat(atan)

* reorder

* delete(first/last) on ranges

* generalize norm setter

* add hypot

* add Pair setter

* fix tests

* insert and delete on cartesianindex

* minor improvements

* add Dates.value setter
  • Loading branch information
aplavin authored Oct 1, 2023
1 parent 869fb75 commit d0418a3
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 19 deletions.
17 changes: 9 additions & 8 deletions ext/AccessorsStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
module AccessorsStaticArraysExt
isdefined(Base, :get_extension) ? (import StaticArrays) : (import ..StaticArrays)
isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays)
using Accessors
import Accessors: setindex, delete, insert

@inline setindex(a::StaticArrays.StaticArray, args...) = Base.setindex(a, args...)
@inline delete(obj::StaticArrays.SVector, l::IndexLens) = StaticArrays.deleteat(obj, only(l.indices))
@inline insert(obj::StaticArrays.SVector, l::IndexLens, val) = StaticArrays.insert(obj, only(l.indices), val)
@inline setindex(a::StaticArray, args...) = Base.setindex(a, args...)
@inline delete(obj::StaticVector, l::IndexLens) = StaticArrays.deleteat(obj, only(l.indices))
@inline insert(obj::StaticVector, l::IndexLens, val) = StaticArrays.insert(obj, only(l.indices), val)

Accessors.set(obj::StaticArrays.SVector, ::Type{Tuple}, val::Tuple) = StaticArrays.SVector(val)
Accessors.set(obj::StaticVector, ::Type{Tuple}, val::Tuple) = constructorof(typeof(obj))(val...)
Accessors.set(obj::Tuple, ::Type{<:StaticVector}, val::StaticVector) = Tuple(val)

Accessors.getall(obj::StaticArrays.StaticArray, ::Elements) = Tuple(obj)
Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs::AbstractArray) = constructorof(typeof(obj))(vs...) # just for disambiguation
Accessors.setall(obj::StaticArrays.StaticArray, ::Elements, vs) = constructorof(typeof(obj))(vs...)
Accessors.getall(obj::StaticArray, ::Elements) = Tuple(obj)
Accessors.setall(obj::StaticArray, ::Elements, vs::AbstractArray) = constructorof(typeof(obj))(vs...) # just for disambiguation
Accessors.setall(obj::StaticArray, ::Elements, vs) = constructorof(typeof(obj))(vs...)

end
30 changes: 23 additions & 7 deletions src/functionlenses.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
using LinearAlgebra: norm, normalize
using Dates

set(obj, ::typeof(last), val) = @set obj[lastindex(obj)] = val
# first and last on general indexable collections
set(obj, ::typeof(first), val) = @set obj[firstindex(obj)] = val
delete(obj, ::typeof(last)) = delete(obj, IndexLens((lastindex(obj),)))
set(obj, ::typeof(last), val) = @set obj[lastindex(obj)] = val
delete(obj, ::typeof(first)) = delete(obj, IndexLens((firstindex(obj),)))
insert(obj, ::typeof(last), val) = insert(obj, IndexLens((lastindex(obj) + 1,)), val)
delete(obj, ::typeof(last)) = delete(obj, IndexLens((lastindex(obj),)))
insert(obj, ::typeof(first), val) = insert(obj, IndexLens((firstindex(obj),)), val)
insert(obj, ::typeof(last), val) = insert(obj, IndexLens((lastindex(obj) + 1,)), val)

set(obj, o::Base.Fix2{typeof(first)}, val) = @set obj[firstindex(obj):(firstindex(obj) + o.x - 1)] = val
set(obj, o::Base.Fix2{typeof(last)}, val) = @set obj[(lastindex(obj) - o.x + 1):lastindex(obj)] = val
Expand All @@ -15,12 +16,17 @@ delete(obj, o::Base.Fix2{typeof(last)}) = @delete obj[(lastindex(obj) - o.x + 1)
insert(obj, o::Base.Fix2{typeof(first)}, val) = @insert obj[firstindex(obj):(firstindex(obj) + o.x - 1)] = val
insert(obj, o::Base.Fix2{typeof(last)}, val) = @insert obj[(lastindex(obj) + 1):(lastindex(obj) + o.x)] = val

# first and last on ranges
# they don't support delete() with arbitrary index, so special casing is needed
delete(obj::AbstractRange, ::typeof(first)) = obj[begin+1:end]
delete(obj::AbstractRange, ::typeof(last)) = obj[begin:end-1]
delete(obj::AbstractRange, o::Base.Fix2{typeof(first)}) = obj[begin+o.x:end]
delete(obj::AbstractRange, o::Base.Fix2{typeof(last)}) = obj[begin:end-o.x]


set(obj::Tuple, ::typeof(Base.front), val::Tuple) = (val..., last(obj))
set(obj::Tuple, ::typeof(Base.tail), val::Tuple) = (first(obj), val...)

set(obj, ::typeof(identity), val) = val
set(obj, ::typeof(inv), new_inv) = inv(new_inv)

function set(obj, ::typeof(only), val)
only(obj) # error check
set(obj, first, val)
Expand All @@ -42,6 +48,8 @@ function set(obj::NamedTuple, ::Type{NamedTuple{KS}}, val::NamedTuple) where {KS
setproperties(obj, NamedTuple{KS}(val))
end

set(obj, ::typeof(Base.splat(=>)), val::Pair) = @set Tuple(obj) = Tuple(val)

set(obj, ::typeof(getproperties), val::NamedTuple) = setproperties(obj, val)

################################################################################
Expand Down Expand Up @@ -125,17 +133,25 @@ set(x, f::Base.Fix2{typeof(rem)}, y) = set(x, @optic(last(divrem(_, f.x))), y)
set(x::AbstractString, f::Base.Fix1{typeof(parse), Type{T}}, y::T) where {T} = string(y)

set(arr, ::typeof(normalize), val) = norm(arr) * val
set(arr, ::typeof(norm), val) = val/norm(arr) * arr # should we check val is positive?
set(arr, ::typeof(norm), val) = map(Base.Fix2(*, val / norm(arr)), arr) # should we check val is positive?

set(f, ::typeof(inverse), invf) = setinverse(f, invf)

set(obj, ::typeof(Base.splat(atan)), val) = @set Tuple(obj) = norm(obj) .* sincos(val)
set(obj, ::typeof(Base.splat(hypot)), val) = @set norm(obj) = val

################################################################################
##### dates
################################################################################
set(x::DateTime, ::Type{Date}, y) = DateTime(y, Time(x))
set(x::DateTime, ::Type{Time}, y) = DateTime(Date(x), y)
set(x::T, ::Type{T}, y) where {T <: Union{Date, Time}} = y

# directly mirrors Dates.value implementation in stdlib
set(x::Date, ::typeof(Dates.value), y) = @set x.instant.periods.value = y
set(x::DateTime, ::typeof(Dates.value), y) = @set x.instant.periods.value = y
set(x::Time, ::typeof(Dates.value), y) = @set x.instant.value = y

set(x::Date, ::typeof(year), y) = Date(y, month(x), day(x))
set(x::Date, ::typeof(month), y) = Date(year(x), y, day(x))
set(x::Date, ::typeof(day), y) = Date(year(x), month(x), y)
Expand Down
3 changes: 3 additions & 0 deletions src/optics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,9 @@ end
@inline insert(obj::NamedTuple, l::IndexLens{Tuple{Symbol}}, val) = merge(obj, NamedTuple{l.indices}((val,)))
@inline insert(obj::NamedTuple, l::IndexLens{<:Tuple{Tuple{Vararg{Symbol}}}}, vals) = merge(obj, NamedTuple{only(l.indices)}(vals))

@inline delete(obj::CartesianIndex, l::IndexLens{Tuple{Int}}) = delete(obj, l Tuple)
@inline insert(obj::CartesianIndex, l::IndexLens{Tuple{Int}}, val) = insert(obj, l Tuple, val)

struct DynamicIndexLens{F}
f::F
end
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
IntervalSets = "8197267c-284f-5f27-9208-e0e47529a953"
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
PerformanceTestTools = "dc46b164-d16f-48ec-a853-60448fc869fe"
QuickTypes = "ae2dfa86-617c-530c-b392-ef20fdad97bb"
Expand Down
5 changes: 5 additions & 0 deletions test/test_delete.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ using StaticArrays
@test @inferred(delete( [1, 2, 3], @optic(first(_, 2)))) == [3]
@test @inferred(delete( [1, 2, 3], @optic(last(_, 2)))) == [1]

@test @inferred(delete(CartesianIndex(1, 2, 3), @optic(_[1]))) == CartesianIndex(2, 3)

@test @inferred(delete(1:4, last)) === 1:3
@test @inferred(delete(1:4, (@optic first(_, 2)))) === 3:4

l = @optic first(_, 2)
@test l((1,2,3)) == [1,2]
@test delete((1,2,3), l) === (3,)
Expand Down
16 changes: 16 additions & 0 deletions test/test_extensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ end
# requires ConstructionBase extension:
VERSION >= v"1.9-" && @test (@set v.x = 10) === @SVector [10.,2,3]

v = @MVector [1.,2,3]
@test (@set v[1] = 10)::MVector == @MVector [10.,2,3]

@testset "Multi-dynamic indexing" begin
two = 2
plusone(x) = x + 1
Expand All @@ -106,6 +109,19 @@ end
v = @set StaticArrays.normalize(@SVector [10, 0,0]) = @SVector[0,1,0]
@test v @SVector[0,10,0]
@test @set(StaticArrays.norm([1,0]) = 20) [20, 0]

cmp(a::NamedTuple, b::NamedTuple) = Set(keys(a)) == Set(keys(b)) && NamedTuple{keys(b)}(a) === b
cmp(a::T, b::T) where {T} = a == b

if VERSION >= v"1.9-"
# require ConstructionBase extension
test_getset_laws(Tuple, SVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp)
test_getset_laws(Tuple, MVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (x='x', y='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (y='x', x='y'), (x=1, y=2); cmp=cmp)
end
test_getset_laws(SVector, (0, 1), SVector('x', 'y'), SVector(1, 2); cmp=cmp)
test_getset_laws(MVector, (0, 1), MVector('x', 'y'), MVector(1, 2); cmp=cmp)
end


Expand Down
22 changes: 18 additions & 4 deletions test/test_functionlenses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ module TestFunctionLenses
using Test
using Dates
using Unitful
using LinearAlgebra: norm
using InverseFunctions: inverse
using Accessors: test_getset_laws, test_modify_law
using Accessors
using StaticArrays: SVector


@testset "os" begin
Expand Down Expand Up @@ -86,11 +86,14 @@ end

cmp(a::NamedTuple, b::NamedTuple) = Set(keys(a)) == Set(keys(b)) && NamedTuple{keys(b)}(a) === b
cmp(a::T, b::T) where {T} = a == b

test_getset_laws(Base.splat(=>), (1, 'a'), 'b' => 2, 3 => 'c'; cmp=cmp)
test_getset_laws(Base.splat(Pair), (1, 'a'), 'b' => 2, 3 => 'c'; cmp=cmp)
test_getset_laws(Base.splat(=>), [1, 2], 3 => 2, 3 => 4; cmp=cmp)

test_getset_laws(Tuple, (1, 'a'), ('x', 'y'), (1, 2))
test_getset_laws(Tuple, (a=1, b='a'), ('x', 'y'), (1, 2))
test_getset_laws(Tuple, [0, 1], ('x', 'y'), (1, 2); cmp=cmp)
test_getset_laws(Tuple, SVector(0, 1), ('x', 'y'), (1, 2); cmp=cmp)
test_getset_laws(Tuple, CartesianIndex(1, 2), (3, 4), (5, 6))

test_getset_laws(NamedTuple{(:x, :y)}, (1, 'a'), (x='x', y='y'), (x=1, y=2); cmp=cmp)
Expand All @@ -101,8 +104,6 @@ end
test_getset_laws(NamedTuple{(:x, :y)}, (y=1, z=10, x='a'), (y='x', x='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, [0, 1], (x='x', y='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, [0, 1], (y='x', x='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (x='x', y='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, SVector(0, 1), (y='x', x='y'), (x=1, y=2); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, CartesianIndex(1, 2), (x=3, y=4), (x=5, y=6); cmp=cmp)
test_getset_laws(NamedTuple{(:x, :y)}, CartesianIndex(1, 2), (y=3, x=4), (x=5, y=6); cmp=cmp)

Expand Down Expand Up @@ -198,6 +199,9 @@ end
test_getset_laws(mod2pi, 5.3, 1, 2; cmp=isapprox)
test_getset_laws(mod2pi, -5.3, 1, 2; cmp=isapprox)

test_getset_laws(Base.splat(atan), (3, 4), 1, 2)
test_getset_laws(Base.splat(atan), (a=3, b=4), 1, 2)

test_getset_laws(!, true, true, false)
@testset for o in [
# invertible lenses below: no need for extensive testing, simply forwarded to InverseFunctions
Expand Down Expand Up @@ -241,6 +245,12 @@ end
f = @set inverse(sin) = myasin
@test f(2) == sin(2)
@test inverse(f)(0.5) == asin(0.5) + 2π

@test set([3, 4], norm, 10) == [6, 8]
@test set((3, 4), norm, 10) === (6., 8.)
@test set((a=3, b=4), norm, 10) === (a=6., b=8.)
test_getset_laws(norm, (3, 4), 10, 12)
test_getset_laws(Base.splat(hypot), (3, 4), 10, 12)
end

@testset "dates" begin
Expand All @@ -264,6 +274,10 @@ end
test_getset_laws(yearmonthday, x, (rand(1:5000), rand(1:12), rand(1:28)), (rand(1:5000), rand(1:12), rand(1:28)))
end

@testset for x in [DateTime(2020, 1, 2, 3, 4, 5, 6), Date(2020, 1, 2), Time(1, 2, 3, 4, 5, 6)]
test_getset_laws(Dates.value, x, 123, 456)
end

l = @optic DateTime(_, dateformat"yyyy_mm_dd")
@test @inferred(set("2020_03_04", month l, 10)) == "2020_10_04"
test_getset_laws(month l, "2020_03_04", 10, 11)
Expand Down
1 change: 1 addition & 0 deletions test/test_insert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ using Accessors: insert
@test insert(A, @optic(last(_, 2)), [3, 4]) == [1, 2, 3, 4]
@test A == [1, 2] # not changed
end
@test @inferred(insert(CartesianIndex(1, 2, 3), @optic(_[2]), 4)) == CartesianIndex(1, 4, 2, 3)
@test insert((1,2), last, 3) == (1, 2, 3)
@inferred(insert((1,2), last, 3))
@test @inferred(insert(SVector(1,2), @optic(_[1]), 3)) == SVector(3, 1, 2)
Expand Down

0 comments on commit d0418a3

Please sign in to comment.