Skip to content

Commit

Permalink
Unify rounding code (#1540)
Browse files Browse the repository at this point in the history
Systematically provide methods for ceil, floor, round, trunc with
various source and/or target types we provide. Remove some (accidental?)
type piracy.

For `m` a `Matrix{BigFloat}` instead of `trunc(m)` one now needs to
write `trunc(ZZMatrix, m)` which avoids type piracy. On the upside,
this is now supported for any `Matrix{<:Real}`.

Also add methods `is_positive(::QQFieldElem)` and `is_negative(::QQFieldElem)`,
move some functions to more appropriate places, and "optimize" (?)
`sign(::Type{Int}, a::QQFieldElem)`.
  • Loading branch information
fingolfin authored Oct 8, 2023
1 parent dcad522 commit ffc0d7d
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 103 deletions.
89 changes: 20 additions & 69 deletions src/HeckeMiscInteger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,35 +215,8 @@ function (::ZZRing)(x::Rational{Int})
return ZZRingElem(numerator(x))
end

function ceil(::Type{ZZRingElem}, a::BigFloat)
return ZZRingElem(ceil(BigInt, a))
end

function ceil(::Type{Int}, a::QQFieldElem)
return Int(ceil(ZZRingElem, a))
end

function floor(::Type{ZZRingElem}, a::BigFloat)
return ZZRingElem(floor(BigInt, a))
end

function floor(::Type{Int}, a::QQFieldElem)
return Int(floor(ZZRingElem, a))
end

function round(::Type{ZZRingElem}, a::BigFloat)
return ZZRingElem(round(BigInt, a))
end

function round(::Type{Int}, a::BigFloat)
return Int(round(ZZRingElem, a))
end

/(a::BigFloat, b::ZZRingElem) = a / BigInt(b)

is_negative(n::ZZRingElem) = cmp(n, 0) < 0
is_positive(n::ZZRingElem) = cmp(n, 0) > 0


################################################################################
#
Expand Down Expand Up @@ -278,54 +251,32 @@ end
#
################################################################################

Base.floor(::Type{ZZRingElem}, x::Int) = ZZRingElem(x)

Base.ceil(::Type{ZZRingElem}, x::Int) = ZZRingElem(x)

Base.floor(::Type{ZZRingElem}, x::QQFieldElem) = fdiv(numerator(x), denominator(x))
export trunc, round, ceil, floor

Base.ceil(::Type{ZZRingElem}, x::QQFieldElem) = cdiv(numerator(x), denominator(x))
for sym in (:trunc, :round, :ceil, :floor)
@eval begin
# support `trunc(ZZRingElem, 1.23)` etc. for arbitrary reals
Base.$sym(::Type{ZZRingElem}, a::Real) = ZZRingElem(Base.$sym(BigInt, a))
Base.$sym(::Type{ZZRingElem}, a::Rational) = ZZRingElem(Base.$sym(BigInt, a))

Base.round(x::QQFieldElem, ::RoundingMode{:Up}) = ceil(x)
# for integers we don't need to round in between
Base.$sym(::Type{ZZRingElem}, a::Integer) = ZZRingElem(a)

Base.round(::Type{ZZRingElem}, x::QQFieldElem, ::RoundingMode{:Up}) = ceil(ZZRingElem, x)

Base.round(x::QQFieldElem, ::RoundingMode{:Down}) = floor(x)

Base.round(::Type{ZZRingElem}, x::QQFieldElem, ::RoundingMode{:Down}) = floor(ZZRingElem, x)

function Base.round(x::QQFieldElem, ::RoundingMode{:Nearest})
d = denominator(x)
n = numerator(x)
if d == 2
if mod(n, 4) == 1
if n > 0
return Base.div(n, d)
else
return Base.div(n, d) - 1
end
else
if n > 0
return Base.div(n, d) + 1
else
return Base.div(n, d)
# support `trunc(ZZRingElem, m)` etc. where m is a matrix of reals
function Base.$sym(::Type{ZZMatrix}, a::Matrix{<:Real})
s = Base.size(a)
m = zero_matrix(FlintZZ, s[1], s[2])
for i = 1:s[1], j = 1:s[2]
m[i, j] = Base.$sym(ZZRingElem, a[i, j])
end
return m
end
end

return floor(x + 1 // 2)
end

Base.round(x::QQFieldElem, ::RoundingMode{:NearestTiesAway}) = sign(x) * floor(abs(x) + 1 // 2)

Base.round(::Type{ZZRingElem}, x::QQFieldElem, ::RoundingMode{:NearestTiesAway}) = sign(x) == 1 ? floor(ZZRingElem, abs(x) + 1 // 2) : -floor(ZZRingElem, abs(x) + 1 // 2)

function Base.round(::Type{ZZRingElem}, a::QQFieldElem)
return round(ZZRingElem, a, RoundNearestTiesAway)
end

function Base.round(a::QQFieldElem)
return round(ZZRingElem, a)
# rounding QQFieldElem to integer via ZZRingElem
function Base.$sym(::Type{T}, a::QQFieldElem) where T <: Integer
return T(Base.$sym(ZZRingElem, a))
end
end
end

clog(a::Int, b::Int) = clog(ZZRingElem(a), b)
Expand Down
21 changes: 0 additions & 21 deletions src/HeckeMoreStuff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,6 @@ function evaluate!(z::fqPolyRepFieldElem, f::ZZPolyRingElem, r::fqPolyRepFieldEl
return z
end

export trunc, round, ceil, floor

for (s, f) in ((:trunc, Base.trunc), (:round, Base.round), (:ceil, Base.ceil), (:floor, Base.floor))
@eval begin
function ($s)(a::Matrix{BigFloat})
s = Base.size(a)
m = zero_matrix(FlintZZ, s[1], s[2])
for i = 1:s[1]
for j = 1:s[2]
m[i, j] = FlintZZ(BigInt(($f)(a[i, j])))
end
end
return m
end
end
end

function norm(v::arb_mat)
return sqrt(sum([a^2 for a in v]))
Expand Down Expand Up @@ -158,11 +142,6 @@ order(::ZZRingElem) = FlintZZ

export rem!

function is_negative(x::QQFieldElem)
c = ccall((:fmpq_sgn, libflint), Cint, (Ref{QQFieldElem},), x)
return c < 0
end

function sub!(z::Vector{QQFieldElem}, x::Vector{QQFieldElem}, y::Vector{ZZRingElem})
for i in 1:length(z)
sub!(z[i], x[i], y[i])
Expand Down
64 changes: 59 additions & 5 deletions src/flint/fmpq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ Return the sign of $a$ ($-1$, $0$ or $1$) as a fraction.
"""
sign(a::QQFieldElem) = QQFieldElem(sign(numerator(a)))

sign(::Type{Int}, a::QQFieldElem) = sign(Int, numerator(a))
sign(::Type{Int}, a::QQFieldElem) = Int(ccall((:fmpq_sgn, libflint), Cint, (Ref{QQFieldElem},), a))

Base.signbit(a::QQFieldElem) = signbit(sign(Int, a))

is_negative(n::QQFieldElem) = sign(Int, n) < 0
is_positive(n::QQFieldElem) = sign(Int, n) > 0

function abs(a::QQFieldElem)
z = QQFieldElem()
ccall((:fmpq_abs, libflint), Nothing, (Ref{QQFieldElem}, Ref{QQFieldElem}), z, a)
Expand Down Expand Up @@ -168,15 +171,62 @@ characteristic(::QQField) = 0
Return the greatest integer that is less than or equal to $a$. The result is
returned as a rational with denominator $1$.
"""
Base.floor(a::QQFieldElem) = QQFieldElem(fdiv(numerator(a), denominator(a)), 1)
Base.floor(a::QQFieldElem) = floor(QQFieldElem, a)
Base.floor(::Type{QQFieldElem}, a::QQFieldElem) = QQFieldElem(floor(ZZRingElem, a), 1)
Base.floor(::Type{ZZRingElem}, a::QQFieldElem) = fdiv(numerator(a), denominator(a))

@doc raw"""
ceil(a::QQFieldElem)
Return the least integer that is greater than or equal to $a$. The result is
returned as a rational with denominator $1$.
"""
Base.ceil(a::QQFieldElem) = QQFieldElem(cdiv(numerator(a), denominator(a)), 1)
Base.ceil(a::QQFieldElem) = ceil(QQFieldElem, a)
Base.ceil(::Type{QQFieldElem}, a::QQFieldElem) = QQFieldElem(ceil(ZZRingElem, a), 1)
Base.ceil(::Type{ZZRingElem}, a::QQFieldElem) = cdiv(numerator(a), denominator(a))

Base.trunc(a::QQFieldElem) = trunc(QQFieldElem, a)
Base.trunc(::Type{QQFieldElem}, a::QQFieldElem) = QQFieldElem(trunc(ZZRingElem, a), 1)
Base.trunc(::Type{ZZRingElem}, a::QQFieldElem) = is_positive(a) ? floor(ZZRingElem, a) : ceil(ZZRingElem, a)

Base.round(x::QQFieldElem, ::RoundingMode{:Up}) = ceil(x)
Base.round(::Type{T}, x::QQFieldElem, ::RoundingMode{:Up}) where T = ceil(T, x)

Base.round(x::QQFieldElem, ::RoundingMode{:Down}) = floor(x)
Base.round(::Type{T}, x::QQFieldElem, ::RoundingMode{:Down}) where T = floor(T, x)

Base.round(x::QQFieldElem, ::RoundingMode{:Nearest}) = round(QQFieldElem, x, RoundNearest)
function Base.round(::Type{T}, x::QQFieldElem, ::RoundingMode{:Nearest}) where T
d = denominator(x)
n = numerator(x)
if d == 2
if mod(n, 4) == 1
if n > 0
return Base.div(n, d)
else
return Base.div(n, d) - 1
end
else
if n > 0
return Base.div(n, d) + 1
else
return Base.div(n, d)
end
end
end

return floor(T, x + 1 // 2)
end

Base.round(x::QQFieldElem, ::RoundingMode{:NearestTiesAway}) = sign(x) * floor(abs(x) + 1 // 2)
function Base.round(::Type{T}, x::QQFieldElem, ::RoundingMode{:NearestTiesAway}) where T
tmp = floor(T, abs(x) + 1 // 2)
return is_positive(x) ? tmp : -tmp
end

Base.round(a::QQFieldElem) = round(QQFieldElem, a)
Base.round(::Type{T}, a::QQFieldElem) where T = round(T, a, RoundNearestTiesAway)


nbits(a::QQFieldElem) = nbits(numerator(a)) + nbits(denominator(a))

Expand Down Expand Up @@ -1137,17 +1187,21 @@ end

convert(::Type{Rational{BigInt}}, a::QQFieldElem) = Rational(a)

function Rational(z::QQFieldElem)
function Base.Rational{BigInt}(z::QQFieldElem)
r = Rational{BigInt}(0)
ccall((:fmpq_get_mpz_frac, libflint), Nothing,
(Ref{BigInt}, Ref{BigInt}, Ref{QQFieldElem}), r.num, r.den, z)
return r
end

function Rational(z::ZZRingElem)
Rational(z::QQFieldElem) = Rational{BigInt}(z)

function Base.Rational{BigInt}(z::ZZRingElem)
return Rational{BigInt}(BigInt(z))
end

Rational(z::ZZRingElem) = Rational{BigInt}(z)


###############################################################################
#
Expand Down
9 changes: 7 additions & 2 deletions src/flint/fmpz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ sign(::Type{Int}, a::ZZRingElem) = Int(ccall((:fmpz_sgn, libflint), Cint, (Ref{Z

Base.signbit(a::ZZRingElem) = signbit(sign(Int, a))

is_negative(n::ZZRingElem) = sign(Int, n) < 0
is_positive(n::ZZRingElem) = sign(Int, n) > 0

@doc raw"""
fits(::Type{Int}, a::ZZRingElem)
Expand Down Expand Up @@ -312,12 +315,14 @@ function abs(x::ZZRingElem)
end

floor(x::ZZRingElem) = x

ceil(x::ZZRingElem) = x
trunc(x::ZZRingElem) = x
round(x::ZZRingElem) = x

floor(::Type{ZZRingElem}, x::ZZRingElem) = x

ceil(::Type{ZZRingElem}, x::ZZRingElem) = x
trunc(::Type{ZZRingElem}, x::ZZRingElem) = x
round(::Type{ZZRingElem}, x::ZZRingElem) = x

###############################################################################
#
Expand Down
52 changes: 50 additions & 2 deletions test/flint/fmpq-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ end

@test characteristic(R) == 0

@test nbits(QQFieldElem(12, 1)) == 5
@test nbits(QQFieldElem(1, 3)) == 3
end

@testset "QQFieldElem.rounding" begin
@test floor(QQFieldElem(2, 3)) == 0
@test floor(QQFieldElem(-1, 3)) == -1
@test floor(QQFieldElem(2, 1)) == 2
Expand All @@ -176,8 +181,51 @@ end
@test ceil(QQFieldElem(-1, 3)) == 0
@test ceil(QQFieldElem(2, 1)) == 2

@test nbits(QQFieldElem(12, 1)) == 5
@test nbits(QQFieldElem(1, 3)) == 3
@test trunc(QQFieldElem(2, 3)) == 0
@test trunc(QQFieldElem(-1, 3)) == 0
@test trunc(QQFieldElem(2, 1)) == 2

@testset "$func" for func in (trunc, round, ceil, floor)
for d in -15:15
val = d//3
valQ = QQFieldElem(val)
@test func(valQ) isa QQFieldElem
@test func(valQ) == func(val)

@test func(QQFieldElem, valQ) isa QQFieldElem
@test func(QQFieldElem, valQ) == func(QQFieldElem, val)

@test func(ZZRingElem, valQ) isa ZZRingElem
@test func(ZZRingElem, valQ) == func(ZZRingElem, val)

@test func(BigInt, valQ) isa BigInt
@test func(BigInt, valQ) == func(BigInt, val)

@test func(Int, valQ) isa Int
@test func(Int, valQ) == func(Int, val)
end
end

@testset "$mode" for mode in (RoundUp, RoundDown, RoundNearest, RoundNearestTiesAway)
for d in -5:5
val = d//3
valQ = QQFieldElem(val)
@test round(valQ, mode) isa QQFieldElem
@test round(valQ, mode) == round(val, mode)

@test round(QQFieldElem, valQ, mode) isa QQFieldElem
@test round(QQFieldElem, valQ, mode) == round(val, mode)

@test round(ZZRingElem, valQ, mode) isa ZZRingElem
@test round(ZZRingElem, valQ, mode) == round(val, mode)

@test round(BigInt, valQ, mode) isa BigInt
@test round(BigInt, valQ, mode) == round(val, mode)

@test round(Int, valQ, mode) isa Int
@test round(Int, valQ, mode) == round(val, mode)
end
end
end

@testset "QQFieldElem.unary_ops" begin
Expand Down
26 changes: 22 additions & 4 deletions test/flint/fmpz-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,6 @@ end

@test denominator(ZZRingElem(12)) == ZZRingElem(1)

@test floor(ZZRingElem(12)) == ZZRingElem(12)

@test ceil(ZZRingElem(12)) == ZZRingElem(12)

@test iseven(ZZRingElem(12))
@test isodd(ZZRingElem(13))
b = big(2)
Expand All @@ -213,6 +209,28 @@ end
@test characteristic(ZZ) == 0
end

@testset "QQFieldElem.rounding" begin
@test floor(ZZRingElem(12)) == ZZRingElem(12)
@test ceil(ZZRingElem(12)) == ZZRingElem(12)
@test trunc(ZZRingElem(12)) == ZZRingElem(12)

@test floor(ZZRingElem, ZZRingElem(12)) == ZZRingElem(12)
@test ceil(ZZRingElem, ZZRingElem(12)) == ZZRingElem(12)
@test trunc(ZZRingElem, ZZRingElem(12)) == ZZRingElem(12)


@testset "$func" for func in (trunc, round, ceil, floor)
for val in -5:5
valZ = ZZRingElem(val)
@test func(valZ) isa ZZRingElem
@test func(valZ) == func(val)
@test func(ZZRingElem, valZ) isa ZZRingElem
@test func(ZZRingElem, valZ) == func(ZZRingElem, val)
end
end

end

@testset "ZZRingElem.binary_ops" begin
a = ZZRingElem(12)
b = ZZRingElem(26)
Expand Down
Loading

0 comments on commit ffc0d7d

Please sign in to comment.