Skip to content

Commit

Permalink
Introduce flintify helper for "optimal" dispatch on integer and rat…
Browse files Browse the repository at this point in the history
…ional inputs (#1867)
  • Loading branch information
fingolfin authored Sep 26, 2024
1 parent 1b6961c commit 7894c21
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 67 deletions.
63 changes: 63 additions & 0 deletions src/flint/FlintTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,8 @@ mutable struct ZZMPolyRingElem <: MPolyRingElem{ZZRingElem}
z.parent = ctx
return z
end

ZZMPolyRingElem(ctx::ZZMPolyRing, a::Integer) = ZZMPolyRingElem(ctx, flintify(a))
end

function _fmpz_mpoly_clear_fn(a::ZZMPolyRingElem)
Expand Down Expand Up @@ -6453,8 +6455,69 @@ end
#
################################################################################

"""
IntegerUnion = Union{Integer, ZZRingElem}
The `IntegerUnion` type union allows convenient and compact declaration
of methods that accept both Julia and Nemo integers.
"""
const IntegerUnion = Union{Integer, ZZRingElem}

"""
RationalUnion = Union{Integer, ZZRingElem, Rational, QQFieldElem}
The `RationalUnion` type union allows convenient and compact declaration
of methods that accept both Julia and Nemo integers or rationals.
"""
const RationalUnion = Union{Integer, ZZRingElem, Rational, QQFieldElem}

"""
flintify(x::RationalUnion)
Return either an `Int`, `ZZRingElem` or `QQFieldElem` equal to `x`.
This internal helper allow us to cleanly and compactly implement efficient
dispatch for arithmetics that involve native Nemo objects plus a Julia
integer.
Indeed, many internal arithmetics functions in FLINT have optimize variants
for the case when one operand is an `ZZRingElem` or an `Int` (sometimes also
`UInt` is supported). E.g. there are special methods for adding one of these
to a `ZZRingPolyElem`.
In order to handling adding an arbitrary Julia integer to a `ZZRingPolyElem`,
further dispatch is needed. The easiest is to provide a method
+(a::ZZRingPolyElem, b::Integer) = a + ZZ(b)
However this is inefficient when `b` is e.g. an `UInt16`. So we could
do this (at least on a 64 bit machine):
+(a::ZZRingPolyElem, b::Integer) = a + ZZ(b)
+(a::ZZRingPolyElem, b::{Int64,Int32,Int16,Int8,UInt32,UInt16,UInt8}) = a + Int(b)
Doing this repeatedly is cumbersome and error prone. This can be avoided by
using `flintify`, which allows us to write
+(a::ZZRingPolyElem, b::Integer) = a + flintify(b)
to get optimal dispatch.
This also works for Nemo types that also have special handlers for `UInt`,
as their method for `b::UInt` takes precedence over the fallback method.
"""
flintify(x::ZZRingElem) = x
flintify(x::QQFieldElem) = x
flintify(x::Int) = x
flintify(x::Integer) = ZZRingElem(x)::ZZRingElem
flintify(x::Rational) = QQFieldElem(x)::QQFieldElem
@static if Int === Int64
flintify(x::Union{Int64,Int32,Int16,Int8,UInt32,UInt16,UInt8}) = Int(x)
else
flintify(x::Union{Int32,Int16,Int8,UInt16,UInt8}) = Int(x)
end


const ZmodNFmpzPolyRing = Union{ZZModPolyRing, FpPolyRing}

const Zmodn_poly = Union{zzModPolyRingElem, fpPolyRingElem}
Expand Down
41 changes: 10 additions & 31 deletions src/flint/fmpq_mpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,35 +369,35 @@ for (jT, cN, cT) in ((QQFieldElem, :fmpq, Ref{QQFieldElem}), (ZZRingElem, :fmpz,
end
end

+(a::QQMPolyRingElem, b::Integer) = a + ZZRingElem(b)
+(a::QQMPolyRingElem, b::Integer) = a + flintify(b)

+(a::Integer, b::QQMPolyRingElem) = b + a

-(a::QQMPolyRingElem, b::Integer) = a - ZZRingElem(b)
-(a::QQMPolyRingElem, b::Integer) = a - flintify(b)

-(a::Integer, b::QQMPolyRingElem) = -(b - a)
-(a::Integer, b::QQMPolyRingElem) = neg!(b - a)

+(a::QQMPolyRingElem, b::Rational{<:Integer}) = a + QQFieldElem(b)

+(a::Rational{<:Integer}, b::QQMPolyRingElem) = b + a

-(a::QQMPolyRingElem, b::Rational{<:Integer}) = a - QQFieldElem(b)

-(a::Rational{<:Integer}, b::QQMPolyRingElem) = -(b - a)
-(a::Rational{<:Integer}, b::QQMPolyRingElem) = neg!(b - a)

*(a::QQMPolyRingElem, b::Integer) = a * ZZRingElem(b)
*(a::QQMPolyRingElem, b::Integer) = a * flintify(b)

*(a::Integer, b::QQMPolyRingElem) = b * a

*(a::QQMPolyRingElem, b::Rational{<:Integer}) = a * QQFieldElem(b)

*(a::Rational{<:Integer}, b::QQMPolyRingElem) = b * a

divexact(a::QQMPolyRingElem, b::Integer; check::Bool=true) = divexact(a, ZZRingElem(b); check=check)
divexact(a::QQMPolyRingElem, b::Integer; check::Bool=true) = divexact(a, flintify(b); check=check)

divexact(a::QQMPolyRingElem, b::Rational{<:Integer}; check::Bool=true) = divexact(a, QQFieldElem(b); check=check)

//(a::QQMPolyRingElem, b::Integer) = //(a, ZZRingElem(b))
//(a::QQMPolyRingElem, b::Integer) = //(a, flintify(b))

//(a::QQMPolyRingElem, b::Rational{<:Integer}) = //(a, QQFieldElem(b))

Expand Down Expand Up @@ -578,7 +578,7 @@ end

==(a::Int, b::QQMPolyRingElem) = b == a

==(a::QQMPolyRingElem, b::Integer) = a == ZZRingElem(b)
==(a::QQMPolyRingElem, b::Integer) = a == flintify(b)

==(a::Integer, b::QQMPolyRingElem) = b == a

Expand Down Expand Up @@ -1186,33 +1186,12 @@ function (R::QQMPolyRing)()
return z
end

function (R::QQMPolyRing)(b::QQFieldElem)
function (R::QQMPolyRing)(b::RationalUnion)
z = QQMPolyRingElem(R, b)
return z
end

function (R::QQMPolyRing)(b::ZZRingElem)
z = QQMPolyRingElem(R, b)
return z
end

function (R::QQMPolyRing)(b::Int)
z = QQMPolyRingElem(R, b)
return z
end

function (R::QQMPolyRing)(b::UInt)
z = QQMPolyRingElem(R, b)
return z
end

function (R::QQMPolyRing)(b::Integer)
return R(ZZRingElem(b))
end

function (R::QQMPolyRing)(b::Rational{<:Integer})
return R(QQFieldElem(b))
end
QQMPolyRingElem(ctx::QQMPolyRing, a::RationalUnion) = QQMPolyRingElem(ctx, flintify(a))

function (R::QQMPolyRing)(a::QQMPolyRingElem)
parent(a) != R && error("Unable to coerce polynomial")
Expand Down
17 changes: 1 addition & 16 deletions src/flint/fmpz_mpoly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,26 +1066,11 @@ function (R::ZZMPolyRing)()
return z
end

function (R::ZZMPolyRing)(b::ZZRingElem)
function (R::ZZMPolyRing)(b::IntegerUnion)
z = ZZMPolyRingElem(R, b)
return z
end

function (R::ZZMPolyRing)(b::Int)
z = ZZMPolyRingElem(R, b)
return z
end

function (R::ZZMPolyRing)(b::UInt)
z = ZZMPolyRingElem(R, b)
return z
end

function (R::ZZMPolyRing)(b::Integer)
return R(ZZRingElem(b))
end


function (R::ZZMPolyRing)(a::ZZMPolyRingElem)
parent(a) != R && error("Unable to coerce polynomial")
return a
Expand Down
28 changes: 8 additions & 20 deletions src/flint/fmpz_poly.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,17 @@ end

*(x::ZZPolyRingElem, y::ZZRingElem) = y*x

+(x::Integer, y::ZZPolyRingElem) = y + ZZRingElem(x)
+(x::Integer, y::ZZPolyRingElem) = y + flintify(x)

-(x::Integer, y::ZZPolyRingElem) = ZZRingElem(x) - y
-(x::Integer, y::ZZPolyRingElem) = flintify(x) - y

*(x::Integer, y::ZZPolyRingElem) = ZZRingElem(x)*y
*(x::Integer, y::ZZPolyRingElem) = flintify(x)*y

+(x::ZZPolyRingElem, y::Integer) = x + ZZRingElem(y)
+(x::ZZPolyRingElem, y::Integer) = x + flintify(y)

-(x::ZZPolyRingElem, y::Integer) = x - ZZRingElem(y)
-(x::ZZPolyRingElem, y::Integer) = x - flintify(y)

*(x::ZZPolyRingElem, y::Integer) = ZZRingElem(y)*x
*(x::ZZPolyRingElem, y::Integer) = flintify(y)*x

###############################################################################
#
Expand Down Expand Up @@ -942,20 +942,8 @@ function (a::ZZPolyRing)()
return z
end

function (a::ZZPolyRing)(b::Int)
z = ZZPolyRingElem(b)
z.parent = a
return z
end

function (a::ZZPolyRing)(b::Integer)
z = ZZPolyRingElem(ZZRingElem(b))
z.parent = a
return z
end

function (a::ZZPolyRing)(b::ZZRingElem)
z = ZZPolyRingElem(b)
function (a::ZZPolyRing)(b::IntegerUnion)
z = ZZPolyRingElem(flintify(b))
z.parent = a
return z
end
Expand Down

0 comments on commit 7894c21

Please sign in to comment.