From 4ff467de060cf9ca2f6e4a4012625d8c42ffb029 Mon Sep 17 00:00:00 2001 From: Shane O'Brien Date: Fri, 6 Oct 2023 13:31:08 +0100 Subject: [PATCH] Add ordered set aggregation functions --- .../20231009_170616_shane.obrien_mode.md | 3 + src/Rel8.hs | 9 +- src/Rel8/Expr/Aggregate.hs | 147 +++++++++++++++++- src/Rel8/Expr/Num.hs | 2 +- src/Rel8/Query/Aggregate.hs | 15 -- src/Rel8/Schema/HTable.hs | 24 ++- src/Rel8/Table/Opaleye.hs | 17 +- src/Rel8/Type/Eq.hs | 5 +- src/Rel8/Type/Num.hs | 4 + src/Rel8/Type/Ord.hs | 5 + src/Rel8/Type/Sum.hs | 3 + 11 files changed, 207 insertions(+), 27 deletions(-) create mode 100644 changelog.d/20231009_170616_shane.obrien_mode.md diff --git a/changelog.d/20231009_170616_shane.obrien_mode.md b/changelog.d/20231009_170616_shane.obrien_mode.md new file mode 100644 index 00000000..158dea7b --- /dev/null +++ b/changelog.d/20231009_170616_shane.obrien_mode.md @@ -0,0 +1,3 @@ +### Added + +- Add support for ordered-set aggregation functions, including `mode`, `percentile`, `percentileContinuous`, `hypotheticalRank`, `hypotheticalDenseRank`, `hypotheticalPercentRank` and `hypotheticalCumeDist`. diff --git a/src/Rel8.hs b/src/Rel8.hs index 38cd1f3a..722ffa07 100644 --- a/src/Rel8.hs +++ b/src/Rel8.hs @@ -272,7 +272,6 @@ module Rel8 , groupBy, groupByOn , listAgg, listAggOn, listAggExpr, listAggExprOn , listCat, listCatOn, listCatExpr, listCatExprOn - , mode , nonEmptyAgg, nonEmptyAggOn, nonEmptyAggExpr, nonEmptyAggExprOn , nonEmptyCat, nonEmptyCatOn, nonEmptyCatExpr, nonEmptyCatExprOn , DBMax, max, maxOn @@ -286,6 +285,14 @@ module Rel8 , and, andOn , or, orOn + , mode, modeOn + , percentile, percentileOn + , percentileContinuous, percentileContinuousOn + , hypotheticalRank + , hypotheticalDenseRank + , hypotheticalPercentRank + , hypotheticalCumeDist + -- ** Ordering , orderBy , Order diff --git a/src/Rel8/Expr/Aggregate.hs b/src/Rel8/Expr/Aggregate.hs index a9aaeeaf..849cb12c 100644 --- a/src/Rel8/Expr/Aggregate.hs +++ b/src/Rel8/Expr/Aggregate.hs @@ -4,6 +4,7 @@ {-# language NamedFieldPuns #-} {-# language OverloadedStrings #-} {-# language ScopedTypeVariables #-} +{-# language TypeApplications #-} {-# language TypeFamilies #-} {-# options_ghc -fno-warn-redundant-constraints #-} @@ -17,6 +18,13 @@ module Rel8.Expr.Aggregate , sum, sumOn, sumWhere , avg, avgOn , stringAgg, stringAggOn + , mode, modeOn + , percentile, percentileOn + , percentileContinuous, percentileContinuousOn + , hypotheticalRank + , hypotheticalDenseRank + , hypotheticalPercentRank + , hypotheticalCumeDist , groupByExpr, groupByExprOn , distinctAggregate , filterWhereExplicit @@ -28,6 +36,7 @@ module Rel8.Expr.Aggregate where -- base +import Data.Functor.Contravariant ((>$<)) import Data.Int ( Int64 ) import Data.List.NonEmpty ( NonEmpty ) import Data.String (IsString) @@ -36,6 +45,7 @@ import Prelude hiding (and, max, min, null, or, show, sum) -- opaleye import qualified Opaleye.Aggregate as Opaleye import qualified Opaleye.Internal.Aggregate as Opaleye +import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye import qualified Opaleye.Internal.Operators as Opaleye -- profunctors @@ -59,17 +69,22 @@ import Rel8.Expr.Opaleye , fromPrimExpr , toColumn , toPrimExpr + , unsafeCastExpr ) +import Rel8.Expr.Order (asc) import Rel8.Expr.Read (sread) import Rel8.Expr.Show (show) import qualified Rel8.Expr.Text as Text +import Rel8.Order (Order (Order)) import Rel8.Schema.Null ( Sql, Unnullify ) +import Rel8.Table.Opaleye (fromOrder, unpackspec) +import Rel8.Table.Order (ascTable) import Rel8.Type ( DBType, typeInformation ) import Rel8.Type.Array (arrayTypeName, encodeArrayElement) import Rel8.Type.Eq ( DBEq ) import Rel8.Type.Information (TypeInformation) -import Rel8.Type.Num ( DBNum ) -import Rel8.Type.Ord ( DBMax, DBMin ) +import Rel8.Type.Num (DBFractional, DBNum) +import Rel8.Type.Ord (DBMax, DBMin, DBOrd) import Rel8.Type.String ( DBString ) import Rel8.Type.Sum ( DBSum ) @@ -239,6 +254,132 @@ stringAggOn :: (Sql IsString a, Sql DBString a) stringAggOn delimiter f = lmap f (stringAgg delimiter) +-- | Corresponds to @mode() WITHIN GROUP (ORDER BY _)@. +mode :: Sql DBOrd a => Aggregator1 (Expr a) (Expr a) +mode = + unsafeMakeAggregator + id + (fromPrimExpr . fromColumn) + Empty + (Opaleye.withinGroup ((\(Order o) -> o) ascTable) + (Opaleye.makeAggrExplicit (pure ()) (Opaleye.AggrOther "mode"))) + + +-- | Applies 'mode' to the column selected by the given function. +modeOn :: Sql DBOrd a => (i -> Expr a) -> Aggregator1 i (Expr a) +modeOn f = lmap f mode + + +-- | Corresponds to @percentile_disc(_) WITHIN GROUP (ORDER BY _)@. +percentile :: Sql DBOrd a => Expr Double -> Aggregator1 (Expr a) (Expr a) +percentile fraction = + unsafeMakeAggregator + (\a -> (fraction, a)) + (castExpr . fromPrimExpr . fromColumn) + Empty + (Opaleye.withinGroup ((\(Order o) -> o) (snd >$< ascTable)) + (Opaleye.makeAggrExplicit + (lmap fst unpackspec) + (Opaleye.AggrOther "percentile_disc"))) + + +-- | Applies 'percentile' to the column selected by the given function. +percentileOn :: + Sql DBOrd a => + Expr Double -> + (i -> Expr a) -> + Aggregator1 i (Expr a) +percentileOn fraction f = lmap f (percentile fraction) + + +-- | Corresponds to @percentile_cont(_) WITHIN GROUP (ORDER BY _)@. +percentileContinuous :: + Sql DBFractional a => + Expr Double -> + Aggregator1 (Expr a) (Expr a) +percentileContinuous fraction = + unsafeMakeAggregator + (\a -> (fraction, a)) + (castExpr . fromPrimExpr . fromColumn) + Empty + (Opaleye.withinGroup ((\(Order o) -> o) (unsafeCastExpr @Double . snd >$< asc)) + (Opaleye.makeAggrExplicit + (lmap fst unpackspec) + (Opaleye.AggrOther "percentile_disc"))) + + + +-- | Applies 'percentileContinuous' to the column selected by the given +-- function. +percentileContinuousOn :: + Sql DBFractional a => + Expr Double -> + (i -> Expr a) -> + Aggregator1 i (Expr a) +percentileContinuousOn fraction f = lmap f (percentileContinuous fraction) + + +-- | Corresponds to @rank(_) WITHIN GROUP (ORDER BY _)@. +hypotheticalRank :: + Order a -> + a -> + Aggregator' fold a (Expr Int64) +hypotheticalRank (Order order) args = + unsafeMakeAggregator + (\a -> (args, a)) + (castExpr . fromPrimExpr . fromColumn) + (Fallback 1) + (Opaleye.withinGroup (snd >$< order) + (Opaleye.makeAggrExplicit + (fromOrder (fst >$< order)) + (Opaleye.AggrOther "rank"))) + + +-- | Corresponds to @dense_rank(_) WITHIN GROUP (ORDER BY _)@. +hypotheticalDenseRank :: + Order a -> + a -> + Aggregator' fold a (Expr Int64) +hypotheticalDenseRank (Order order) args = + unsafeMakeAggregator + (const args) + (castExpr . fromPrimExpr . fromColumn) + (Fallback 1) + (Opaleye.withinGroup order + (Opaleye.makeAggrExplicit (fromOrder order) + (Opaleye.AggrOther "dense_rank"))) + + +-- | Corresponds to @percent_rank(_) WITHIN GROUP (ORDER BY _)@. +hypotheticalPercentRank :: + Order a -> + a -> + Aggregator' fold a (Expr Double) +hypotheticalPercentRank (Order order) args = + unsafeMakeAggregator + (const args) + (castExpr . fromPrimExpr . fromColumn) + (Fallback 0) + (Opaleye.withinGroup order + (Opaleye.makeAggrExplicit (fromOrder order) + (Opaleye.AggrOther "percent_rank"))) + + +-- | Corresponds to @cume_dist(_) WITHIN GROUP (ORDER BY _)@. +hypotheticalCumeDist :: + Order a -> + a -> + Aggregator' fold a (Expr Double) +hypotheticalCumeDist (Order order) args = + unsafeMakeAggregator + (const args) + (castExpr . fromPrimExpr . fromColumn) + (Fallback 1) + (Opaleye.withinGroup order + (Opaleye.makeAggrExplicit (fromOrder order) + (Opaleye.AggrOther "cume_dist"))) + + -- | Aggregate a value by grouping by it. groupByExpr :: Sql DBEq a => Aggregator1 (Expr a) (Expr a) groupByExpr = @@ -249,7 +390,7 @@ groupByExpr = Opaleye.groupBy --- | Applies 'groupByExprOn' to the column selected by the given function. +-- | Applies 'groupByExpr' to the column selected by the given function. groupByExprOn :: Sql DBEq a => (i -> Expr a) -> Aggregator1 i (Expr a) groupByExprOn f = lmap f groupByExpr diff --git a/src/Rel8/Expr/Num.hs b/src/Rel8/Expr/Num.hs index 32c06c66..f1f58767 100644 --- a/src/Rel8/Expr/Num.hs +++ b/src/Rel8/Expr/Num.hs @@ -33,7 +33,7 @@ fromIntegral :: (Sql DBIntegral a, Sql DBNum b, Homonullable a b) fromIntegral (Expr a) = castExpr (Expr a) --- | Cast 'DBNum' types to 'DBFractional' types. For example, his can be useful +-- | Cast 'DBNum' types to 'DBFractional' types. For example, this can be useful -- to convert @Expr Float@ to @Expr Double@. realToFrac :: (Sql DBNum a, Sql DBFractional b, Homonullable a b) => Expr a -> Expr b diff --git a/src/Rel8/Query/Aggregate.hs b/src/Rel8/Query/Aggregate.hs index 2e3a215d..1a5348a5 100644 --- a/src/Rel8/Query/Aggregate.hs +++ b/src/Rel8/Query/Aggregate.hs @@ -6,13 +6,11 @@ module Rel8.Query.Aggregate ( aggregate , aggregate1 , countRows - , mode ) where -- base import Control.Applicative (liftA2) -import Data.Functor.Contravariant ( (>$<) ) import Data.Int ( Int64 ) import Prelude @@ -24,15 +22,10 @@ import Rel8.Aggregate (Aggregator' (Aggregator), Aggregator) import Rel8.Aggregate.Fold (Fallback (Fallback)) import Rel8.Expr ( Expr ) import Rel8.Expr.Aggregate ( countStar ) -import Rel8.Expr.Order ( desc ) import Rel8.Query ( Query ) -import Rel8.Query.Limit ( limit ) import Rel8.Query.Maybe ( optional ) import Rel8.Query.Opaleye ( mapOpaleye ) -import Rel8.Query.Order ( orderBy ) import Rel8.Table (Table) -import Rel8.Table.Aggregate (groupBy) -import Rel8.Table.Eq (EqTable) import Rel8.Table.Maybe (fromMaybeTable) @@ -55,11 +48,3 @@ aggregate1 (Aggregator _ aggregator) = mapOpaleye (Opaleye.aggregate aggregator) -- will return @0@. countRows :: Query a -> Query (Expr Int64) countRows = aggregate countStar - - --- | Return the most common row in a query. -mode :: forall a. EqTable a => Query a -> Query a -mode rows = - limit 1 $ fmap snd $ - orderBy (fst >$< desc) $ do - aggregate1 (liftA2 (,) countStar groupBy) rows diff --git a/src/Rel8/Schema/HTable.hs b/src/Rel8/Schema/HTable.hs index 477fad14..2642f205 100644 --- a/src/Rel8/Schema/HTable.hs +++ b/src/Rel8/Schema/HTable.hs @@ -15,15 +15,17 @@ module Rel8.Schema.HTable ( HTable (HField, HConstrainTable) - , hfield, htabulate, htraverse, hdicts, hspecs - , hfoldMap, hmap, htabulateA, htabulateP, htraverseP, htraversePWithField + , hfield, htabulate, hdicts, hspecs + , hfoldMap, hmap, htabulateA, htabulateP + , htraverse, htraverse_, htraverseP, htraversePWithField ) where -- base +import Data.Functor (void) +import Data.Functor.Compose ( Compose( Compose ), getCompose ) import Data.Functor.Const ( Const( Const ), getConst ) import Data.Kind ( Constraint, Type ) -import Data.Functor.Compose ( Compose( Compose ), getCompose ) import Data.Proxy ( Proxy ) import GHC.Generics ( (:*:)( (:*:) ) @@ -46,7 +48,7 @@ import Rel8.Schema.HTable.Product ( HProduct( HProduct ) ) import qualified Rel8.Schema.Kind as K -- semigroupoids -import Data.Functor.Apply ( Apply, (<.>) ) +import Data.Functor.Apply (Apply, (<.>), liftF2) -- | A @HTable@ is a functor-indexed/higher-kinded data type that is -- representable ('htabulate'/'hfield'), constrainable ('hdicts'), and @@ -130,6 +132,20 @@ hmap :: HTable t hmap f a = htabulate $ \field -> f (hfield a field) +newtype Ap f a = Ap + { getAp :: f a + } + + +instance (Apply f, Semigroup a) => Semigroup (Ap f a) where + Ap a <> Ap b = Ap (liftF2 (<>) a b) + + +htraverse_ :: (HTable t, Apply f) + => (forall a. context a -> f b) -> t context -> f () +htraverse_ f a = getAp $ hfoldMap (Ap . void . f) a + + htabulateA :: (HTable t, Apply m) => (forall a. HField t a -> m (context a)) -> m (t context) htabulateA f = htraverse getCompose $ htabulate $ Compose . f diff --git a/src/Rel8/Table/Opaleye.hs b/src/Rel8/Table/Opaleye.hs index 8a168962..3f468b26 100644 --- a/src/Rel8/Table/Opaleye.hs +++ b/src/Rel8/Table/Opaleye.hs @@ -23,10 +23,12 @@ module Rel8.Table.Opaleye , valuesspec , view , castTable + , fromOrder ) where -- base +import Data.Foldable (traverse_) import Data.Functor.Const ( Const( Const ), getConst ) import Data.List.NonEmpty ( NonEmpty ) import Prelude @@ -36,6 +38,9 @@ import qualified Opaleye.Adaptors as Opaleye import qualified Opaleye.Field as Opaleye ( Field_ ) import qualified Opaleye.Internal.HaskellDB.PrimQuery as Opaleye import qualified Opaleye.Internal.Operators as Opaleye +import qualified Opaleye.Internal.Order as Opaleye +import qualified Opaleye.Internal.PackMap as Opaleye +import qualified Opaleye.Internal.Unpackspec as Opaleye import qualified Opaleye.Internal.Values as Opaleye import qualified Opaleye.Table as Opaleye @@ -48,8 +53,10 @@ import Rel8.Expr.Opaleye ( fromPrimExpr, toPrimExpr , scastExpr, traverseFieldP ) -import Rel8.Schema.HTable ( htabulateA, hfield, hspecs, htabulate, - htraverseP, htraversePWithField ) +import Rel8.Schema.HTable + ( htabulateA, hfield, hspecs, htabulate + , htraverseP, htraversePWithField + ) import Rel8.Schema.Name ( Name( Name ), Selects, ppColumn ) import Rel8.Schema.QualifiedName (QualifiedName (QualifiedName)) import Rel8.Schema.Spec ( Spec(..) ) @@ -153,3 +160,9 @@ castTable (toColumns -> as) = fromColumns $ htabulate \field -> case hfield hspecs field of Spec {info} -> case hfield as field of expr -> scastExpr info expr + + +fromOrder :: Opaleye.Order a -> Opaleye.Unpackspec a a +fromOrder (Opaleye.Order o) = + Opaleye.Unpackspec $ Opaleye.PackMap $ \f a -> + a <$ traverse_ (f . snd) (o a) diff --git a/src/Rel8/Type/Eq.hs b/src/Rel8/Type/Eq.hs index 02e11273..426f8c65 100644 --- a/src/Rel8/Type/Eq.hs +++ b/src/Rel8/Type/Eq.hs @@ -14,9 +14,10 @@ where import Data.Aeson ( Value ) -- base -import Data.List.NonEmpty ( NonEmpty ) +import Data.Fixed (Fixed) import Data.Int ( Int16, Int32, Int64 ) import Data.Kind ( Constraint, Type ) +import Data.List.NonEmpty ( NonEmpty ) import Prelude -- bytestring @@ -29,6 +30,7 @@ import Data.CaseInsensitive ( CI ) -- rel8 import Rel8.Schema.Null ( Sql ) import Rel8.Type ( DBType ) +import Rel8.Type.Decimal (PowerOf10) -- scientific import Data.Scientific ( Scientific ) @@ -58,6 +60,7 @@ instance DBEq Char instance DBEq Int16 instance DBEq Int32 instance DBEq Int64 +instance PowerOf10 n => DBEq (Fixed n) instance DBEq Float instance DBEq Double instance DBEq Scientific diff --git a/src/Rel8/Type/Num.hs b/src/Rel8/Type/Num.hs index da13ef09..98c04e89 100644 --- a/src/Rel8/Type/Num.hs +++ b/src/Rel8/Type/Num.hs @@ -12,12 +12,14 @@ module Rel8.Type.Num where -- base +import Data.Fixed (Fixed) import Data.Int ( Int16, Int32, Int64 ) import Data.Kind ( Constraint, Type ) import Prelude -- rel8 import Rel8.Type ( DBType ) +import Rel8.Type.Decimal (PowerOf10) import Rel8.Type.Ord ( DBOrd ) -- scientific @@ -31,6 +33,7 @@ class DBType a => DBNum a instance DBNum Int16 instance DBNum Int32 instance DBNum Int64 +instance PowerOf10 n => DBNum (Fixed n) instance DBNum Float instance DBNum Double instance DBNum Scientific @@ -49,6 +52,7 @@ instance DBIntegral Int64 -- | The class of database types that support the @/@ operator. type DBFractional :: Type -> Constraint class DBNum a => DBFractional a +instance PowerOf10 n => DBFractional (Fixed n) instance DBFractional Float instance DBFractional Double instance DBFractional Scientific diff --git a/src/Rel8/Type/Ord.hs b/src/Rel8/Type/Ord.hs index a89c40fa..5b67d8eb 100644 --- a/src/Rel8/Type/Ord.hs +++ b/src/Rel8/Type/Ord.hs @@ -12,6 +12,7 @@ module Rel8.Type.Ord where -- base +import Data.Fixed (Fixed) import Data.Int ( Int16, Int32, Int64 ) import Data.Kind ( Constraint, Type ) import Data.List.NonEmpty ( NonEmpty ) @@ -26,6 +27,7 @@ import Data.CaseInsensitive ( CI ) -- rel8 import Rel8.Schema.Null ( Sql ) +import Rel8.Type.Decimal (PowerOf10) import Rel8.Type.Eq ( DBEq ) -- scientific @@ -53,6 +55,7 @@ instance DBOrd Char instance DBOrd Int16 instance DBOrd Int32 instance DBOrd Int64 +instance PowerOf10 n => DBOrd (Fixed n) instance DBOrd Float instance DBOrd Double instance DBOrd Scientific @@ -79,6 +82,7 @@ instance DBMax Char instance DBMax Int16 instance DBMax Int32 instance DBMax Int64 +instance PowerOf10 n => DBMax (Fixed n) instance DBMax Float instance DBMax Double instance DBMax Scientific @@ -104,6 +108,7 @@ instance DBMin Char instance DBMin Int16 instance DBMin Int32 instance DBMin Int64 +instance PowerOf10 n => DBMin (Fixed n) instance DBMin Float instance DBMin Double instance DBMin Scientific diff --git a/src/Rel8/Type/Sum.hs b/src/Rel8/Type/Sum.hs index 3808b321..3daeb360 100644 --- a/src/Rel8/Type/Sum.hs +++ b/src/Rel8/Type/Sum.hs @@ -12,12 +12,14 @@ module Rel8.Type.Sum where -- base +import Data.Fixed (Fixed) import Data.Int ( Int16, Int32, Int64 ) import Data.Kind ( Constraint, Type ) import Prelude -- rel8 import Rel8.Type ( DBType ) +import Rel8.Type.Decimal (PowerOf10) -- scientific import Data.Scientific ( Scientific ) @@ -32,6 +34,7 @@ class DBType a => DBSum a instance DBSum Int16 instance DBSum Int32 instance DBSum Int64 +instance PowerOf10 n => DBSum (Fixed n) instance DBSum Float instance DBSum Double instance DBSum Scientific