diff --git a/stdlib/LinearAlgebra/src/generic.jl b/stdlib/LinearAlgebra/src/generic.jl index a93325ed5b883..df4112134bc81 100644 --- a/stdlib/LinearAlgebra/src/generic.jl +++ b/stdlib/LinearAlgebra/src/generic.jl @@ -2014,19 +2014,24 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar) BLAS.chkuplo(uplo) m,n = size(A) m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) A = Base.unalias(B, A) if uplo == 'U' - for j=1:n - for i=1:min(j,m) - @inbounds B[i,j] = A[i,j] - end + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) end - else # uplo == 'L' - for j=1:n - for i=j:m - @inbounds B[i,j] = A[i,j] - end + for j in 1:n, i in 1:min(j,m) + @inbounds B[i,j] = A[i,j] + end + else # uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + for j in 1:n, i in j:m + @inbounds B[i,j] = A[i,j] end end return B diff --git a/stdlib/LinearAlgebra/src/lapack.jl b/stdlib/LinearAlgebra/src/lapack.jl index b7c624dbb494d..c9f018e21175a 100644 --- a/stdlib/LinearAlgebra/src/lapack.jl +++ b/stdlib/LinearAlgebra/src/lapack.jl @@ -7157,9 +7157,23 @@ for (fn, elty) in ((:dlacpy_, :Float64), function lacpy!(B::AbstractMatrix{$elty}, A::AbstractMatrix{$elty}, uplo::AbstractChar) require_one_based_indexing(A, B) chkstride1(A, B) - m,n = size(A) - m1,n1 = size(B) - (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)")) + m, n = size(A) + m1, n1 = size(B) + if uplo == 'U' + if n < m + (m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + elseif uplo == 'L' + if m < n + (m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)")) + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end + else + (m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)")) + end lda = max(1, stride(A, 2)) ldb = max(1, stride(B, 2)) ccall((@blasfunc($fn), libblastrampoline), Cvoid, diff --git a/stdlib/LinearAlgebra/test/generic.jl b/stdlib/LinearAlgebra/test/generic.jl index f4eff3b2e355f..e0a1704913f78 100644 --- a/stdlib/LinearAlgebra/test/generic.jl +++ b/stdlib/LinearAlgebra/test/generic.jl @@ -654,12 +654,54 @@ end @testset "copytrito!" begin n = 10 - for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U') - for AA in (A, view(A, reverse.(axes(A))...)) - for B in (zeros(n, n), zeros(n+1, n+2)) - copytrito!(B, AA, uplo) + @testset "square" begin + for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U') + for AA in (A, view(A, reverse.(axes(A))...)) C = uplo == 'L' ? tril(AA) : triu(AA) - @test view(B, 1:n, 1:n) == C + for B in (zeros(n, n), zeros(n+1, n+2)) + copytrito!(B, AA, uplo) + @test view(B, 1:n, 1:n) == C + end + end + end + end + @testset "wide" begin + for A in (rand(n, 2n), rand(Int8, n, 2n)) + for AA in (A, view(A, reverse.(axes(A))...)) + C = tril(AA) + for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'L') + @test view(B, 1:n, 1:n) == view(C, 1:n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, 2n), AA, 'L') + C = triu(AA) + for (M, N) in ((n, 2n), (n+1, 2n), (n, 2n+1), (n+1, 2n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'U') + @test view(B, 1:n, 1:2n) == view(C, 1:n, 1:2n) + end + @test_throws DimensionMismatch copytrito!(zeros(n+1, 2n-1), AA, 'U') + end + end + end + @testset "tall" begin + for A in (rand(2n, n), rand(Int8, 2n, n)) + for AA in (A, view(A, reverse.(axes(A))...)) + C = triu(AA) + for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'U') + @test view(B, 1:n, 1:n) == view(C, 1:n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'U') + C = tril(AA) + for (M, N) in ((2n, n), (2n, n+1), (2n+1, n), (2n+1, n+1)) + B = zeros(M, N) + copytrito!(B, AA, 'L') + @test view(B, 1:2n, 1:n) == view(C, 1:2n, 1:n) + end + @test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'L') end end end diff --git a/stdlib/LinearAlgebra/test/lapack.jl b/stdlib/LinearAlgebra/test/lapack.jl index fe0dd92d0a69a..fd14dad4634a8 100644 --- a/stdlib/LinearAlgebra/test/lapack.jl +++ b/stdlib/LinearAlgebra/test/lapack.jl @@ -805,8 +805,26 @@ end B = zeros(elty, n, n) LinearAlgebra.LAPACK.lacpy!(B, A, uplo) C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A) - @test B ≈ C + @test B == C + B = zeros(elty, n+1, n+1) + LinearAlgebra.LAPACK.lacpy!(B, A, uplo) + C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A) + @test view(B, 1:n, 1:n) == C end + A = rand(elty, n, n+1) + B = zeros(elty, n, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'L') + @test B == view(tril(A), 1:n, 1:n) + B = zeros(elty, n, n+1) + LinearAlgebra.LAPACK.lacpy!(B, A, 'U') + @test B == triu(A) + A = rand(elty, n+1, n) + B = zeros(elty, n, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'U') + @test B == view(triu(A), 1:n, 1:n) + B = zeros(elty, n+1, n) + LinearAlgebra.LAPACK.lacpy!(B, A, 'L') + @test B == tril(A) end end