From a765f8f37e2bc20305e584b121feba4cfe65d428 Mon Sep 17 00:00:00 2001 From: Alexander Plavin Date: Sun, 21 Nov 2021 18:51:30 +0300 Subject: [PATCH 1/2] fix weighted computations for non-real arrays --- src/weights.jl | 4 +--- test/weights.jl | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 0d03317fc..1781e4990 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -380,9 +380,7 @@ Base.:(==)(x::AbstractWeights, y::AbstractWeights) = false Compute the weighted sum of an array `v` with weights `w`, optionally over the dimension `dim`. """ -wsum(v::AbstractVector, w::AbstractVector) = dot(v, w) -wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w) -wsum(v::AbstractArray, w::AbstractVector, dims::Colon) = wsum(v, w) +wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = w' * vec(v) ## wsum along dimension # diff --git a/test/weights.jl b/test/weights.jl index 61d68320a..7fed9ef97 100644 --- a/test/weights.jl +++ b/test/weights.jl @@ -239,6 +239,8 @@ a = reshape(1.0:27.0, 3, 3, 3) @testset "Sum $f" for f in weight_funcs @test sum([1.0, 2.0, 3.0], f([1.0, 0.5, 0.5])) ≈ 3.5 @test sum(1:3, f([1.0, 1.0, 0.5])) ≈ 4.5 + @test sum([1 + 2im, 2 + 3im], f([1.0, 0.5])) ≈ 2 + 3.5im + @test sum([[1, 2], [3, 4]], f([2, 3])) == [11, 16] for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) @test sum(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1) @@ -250,6 +252,7 @@ end @testset "Mean $f" for f in weight_funcs @test mean([1:3;], f([1.0, 1.0, 0.5])) ≈ 1.8 @test mean(1:3, f([1.0, 1.0, 0.5])) ≈ 1.8 + @test mean([1 + 2im, 4 + 5im], f([1.0, 0.5])) ≈ 2 + 3im for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) @test mean(a, f(wt), dims=1) ≈ sum(a.*reshape(wt, length(wt), 1, 1), dims=1)/sum(wt) From 5617097c38f4b6a1bbda81d66beb33778f16a353 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sun, 21 Nov 2021 21:10:58 +0300 Subject: [PATCH 2/2] Update src/weights.jl Co-authored-by: Milan Bouchet-Valat --- src/weights.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/weights.jl b/src/weights.jl index 1781e4990..3ff2ce1a0 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -380,7 +380,7 @@ Base.:(==)(x::AbstractWeights, y::AbstractWeights) = false Compute the weighted sum of an array `v` with weights `w`, optionally over the dimension `dim`. """ -wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = w' * vec(v) +wsum(v::AbstractArray, w::AbstractVector, dims::Colon=:) = transpose(w) * vec(v) ## wsum along dimension #