Skip to content

Commit

Permalink
Add support for categorical types (#121)
Browse files Browse the repository at this point in the history
* add support for categorical types

* fix some clippy issues

* fix issue with guide
  • Loading branch information
mt-caret authored Feb 21, 2024
1 parent 7a929b6 commit 596d61c
Show file tree
Hide file tree
Showing 17 changed files with 257 additions and 72 deletions.
56 changes: 31 additions & 25 deletions guide/expressions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,11 @@ let%expect_test "Casting" =
let%expect_test "Aggregation" =
let schema =
Schema.create
[ "first_name", Utf8
; "gender", Utf8
; "type", Utf8
; "state", Utf8
; "party", Utf8
[ "first_name", Categorical None
; "gender", Categorical None
; "type", Categorical None
; "state", Categorical None
; "party", Categorical None
; "birthday", Date
]
in
Expand All @@ -582,7 +582,7 @@ let%expect_test "Aggregation" =
┌────────────┬────────────┬────────────┬────────┬───┬───────────┬───────────┬──────────┬───────────┐
last_namefirst_namemiddle_namsuffix ┆ … ┆ ballotpedwashingtoicpsr_idwikipedia
------e--- ┆ ┆ ia_idn_post_id---_id
strstr---str ┆ ┆ ------i64---
strcat---str ┆ ┆ ------i64---
│ ┆ ┆ str ┆ ┆ ┆ strstr ┆ ┆ str
╞════════════╪════════════╪════════════╪════════╪═══╪═══════════╪═══════════╪══════════╪═══════════╡
BassettRichardnullnull ┆ … ┆ nullnull507Richard
Expand Down Expand Up @@ -627,7 +627,7 @@ let%expect_test "Aggregation" =
┌────────────┬───────┬───────────────────┬───────────┐
first_namecountgenderlast_name
------------
stru32list[str] ┆ str
catu32list[cat] ┆ str
╞════════════╪═══════╪═══════════════════╪═══════════╡
John1256 ┆ ["M", "M", … "M"] ┆ Walker
William1022 ┆ ["M", "M", … "M"] ┆ Few
Expand Down Expand Up @@ -655,7 +655,7 @@ let%expect_test "Aggregation" =
┌───────┬──────┬─────┐
stateantipro
---------
stru32u32
catu32u32
╞═══════╪══════╪═════╡
NJ03
CT03
Expand Down Expand Up @@ -684,7 +684,7 @@ let%expect_test "Aggregation" =
┌───────┬─────────────────────┬───────┐
statepartycount
---------
strstru32
catcatu32
╞═══════╪═════════════════════╪═══════╡
NJPro-Administration3
VAAnti-Administration3
Expand Down Expand Up @@ -720,7 +720,7 @@ let%expect_test "Aggregation" =
┌───────┬────────────────┬────────────────┬────────┬──────────┐
stateavg M birthdayavg F birthday ┆ # male ┆ # female
---------------
strf64f64u32u32
catf64f64u32u32
╞═══════╪════════════════╪════════════════╪════════╪══════════╡
DE182.593407null970
VA192.54278166.24305
Expand Down Expand Up @@ -749,7 +749,7 @@ let%expect_test "Aggregation" =
┌───────┬──────────────────┬───────────────────────┐
stateyoungestoldest
---------
strstrstr
catstrstr
╞═══════╪══════════════════╪═══════════════════════╡
NCMadison CawthornJohn Ashe
IAAbby FinkenauerBernhart Henn
Expand Down Expand Up @@ -778,7 +778,7 @@ let%expect_test "Aggregation" =
┌───────┬──────────────────┬───────────────────────┬────────────────────┐
stateyoungestoldestalphabetical_first
------------
strstrstrstr
catstrstrstr
╞═══════╪══════════════════╪═══════════════════════╪════════════════════╡
NCMadison CawthornJohn AsheAbraham Rencher
IAAbby FinkenauerBernhart HennAbby Finkenauer
Expand All @@ -797,29 +797,35 @@ let%expect_test "Aggregation" =
; get_person |> last |> alias ~name:"oldest"
; get_person |> sort |> first |> alias ~name:"alphabetical_first"
; col "gender"
|> sort_by ~by:[ col "first_name" ]
|> sort_by
~by:
[ (* The guide uses "first_name" to sort by, but I'm guessing
there's an nondeterminism bug causing output to be unstable
if we have multiple sorts or something, so I suspect
[get_person] is what we actually want *)
get_person
]
|> first
|> alias ~name:"gender"
]
|> Lazy_frame.sort ~by_column:"state"
|> Lazy_frame.limit ~n:5
|> Lazy_frame.collect_exn
in
Data_frame.print df;
[%expect
{|
shape: (5, 5)
┌───────┬──────────────────┬────────────────┬────────────────────┬────────┐
stateyoungestoldestalphabetical_firstgender
---------------
strstrstrstrstr
╞═══════╪══════════════════╪════════════════╪════════════════════╪════════╡
AKMark Begich Thomas Cale Anthony Dimond M
ALMartha RobyJohn McKeeAlbert Goodwyn M
ARTim Griffin Archibald YellAlbert RustM
ASEni FaleomavaegaFofó Sunia Eni FaleomavaegaM
AZBen Quayle Coles BashfordAnn Kirkpatrick F
└───────┴──────────────────┴────────────────┴────────────────────┴────────┘ |}]
┌───────┬──────────────────┬───────────────────────┬────────────────────┬────────┐
stateyoungestoldest alphabetical_firstgender
--------- ------
catstrstr strcat
╞═══════╪══════════════════╪═══════════════════════╪════════════════════╪════════╡
NCMadison CawthornJohn Ashe Abraham RencherM
IAAbby FinkenauerBernhart Henn Abby Finkenauer F
MIPeter Meijer Edward BradleyAaron BlissM
CAKatie HillEdward Gilbert Aaron Sargent M
NYMondaire Jones Cornelius SchoonmakerA. Foster M
└───────┴──────────────────┴───────────────────────┴────────────────────┴────────┘ |}]
;;

(* Examples from https://pola-rs.github.io/polars-book/user-guide/expressions/null/ *)
Expand Down
66 changes: 32 additions & 34 deletions lib/data_type.ml
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
open! Core

module T = struct
type t =
| Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Utf8
| Binary
| Date
| Datetime of Time_unit.t * Tz.t option
| Duration of Time_unit.t
| Time
| List of t
(* We want this branch to be tested very well, since code dealing with
this recursive case is usually the most non-trivial portion of the
logic. *)
[@quickcheck.weight 10.]
| Null
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp, quickcheck]
end

include T
include Sexpable.To_stringable (T)
type t =
| Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Utf8
| Binary
| Date
| Datetime of Time_unit.t * Tz.t option
| Duration of Time_unit.t
| Time
| List of t
(* We want this branch to be tested very well, since code dealing with
this recursive case is usually the most non-trivial portion of the
logic. *)
[@quickcheck.weight 10.]
| Null
| Categorical of (Rev_mapping.t[@compare.ignore]) option [@quickcheck.do_not_generate]
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp_of, quickcheck]

let to_string t = Sexp.to_string ([%sexp_of: t] t)

module Typed = struct
type untyped = t [@@deriving compare, sexp, quickcheck]
type untyped = t [@@deriving compare, sexp_of, quickcheck]

(* TODO: Consider mapping to smaller OCaml values like Int8, Float32, etc instead of
casting up *)
Expand Down Expand Up @@ -157,7 +155,7 @@ module Typed = struct
| Duration time_unit -> Some (T (Duration time_unit))
| Time -> Some (T Time)
| List t -> of_untyped t |> Option.map ~f:(fun (T t) -> T (List t))
| Null | Struct _ | Unknown -> None
| Null | Categorical _ | Struct _ | Unknown -> None
;;

let rec sexp_of_packed (T t) =
Expand Down
5 changes: 3 additions & 2 deletions lib/data_type.mli
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ type t =
| Time
| List of t
| Null
| Categorical of Rev_mapping.t option
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp, quickcheck]
[@@deriving compare, sexp_of, quickcheck]

include Stringable.S with type t := t
val to_string : t -> string

module Typed : sig
type untyped
Expand Down
1 change: 1 addition & 0 deletions lib/polars.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ module Naive_time = Naive_time
module Schema = Schema
module Series = Series
module Sql_context = Sql_context
module Rev_mapping = Rev_mapping
module Time_unit = Time_unit
module Tz = Tz
19 changes: 19 additions & 0 deletions lib/rev_mapping.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
open! Core

type t

let quickcheck_shrinker = Base_quickcheck.Shrinker.atomic

let quickcheck_observer =
Base_quickcheck.Observer.of_hash_fold (fun hash_state _t -> hash_state)
;;

external get_categories : t -> string list = "rust_rev_mapping_get_categories"

let sexp_of_t t =
get_categories t
|> (* It's difficult to get tests to be deterministic wrt the ordering of
categories, so we force them to be sorted here as a hacky workaround. *)
List.sort ~compare:String.compare
|> [%sexp_of: string list]
;;
7 changes: 7 additions & 0 deletions lib/rev_mapping.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
open! Core

type t [@@deriving sexp_of]

val quickcheck_shrinker : t Quickcheck.Shrinker.t
val quickcheck_observer : t Quickcheck.Observer.t
val get_categories : t -> string list
1 change: 0 additions & 1 deletion lib/schema.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ external create : (string * Data_type.t) list -> t = "rust_schema_create"
external to_fields : t -> (string * Data_type.t) list = "rust_schema_to_fields"

let sexp_of_t t = to_fields t |> [%sexp_of: (string * Data_type.t) list]
let t_of_sexp sexp = [%of_sexp: (string * Data_type.t) list] sexp |> create
2 changes: 1 addition & 1 deletion lib/schema.mli
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
open! Core

type t [@@deriving sexp]
type t [@@deriving sexp_of]

val create : (string * Data_type.t) list -> t
val to_fields : t -> (string * Data_type.t) list
9 changes: 9 additions & 0 deletions lib/series.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ module T = struct
map input_data_type output_data_type t ~f |> Result.ok_exn
;;

external cast
: t
-> to_:Data_type.t
-> strict:bool
-> (t, string) result
= "rust_series_cast"

let cast ?(strict = true) t ~to_ = cast t ~to_ ~strict |> Utils.string_result_ok_exn

external name : t -> string = "rust_series_name"
external rename : t -> name:string -> unit = "rust_series_rename"
external dtype : t -> Data_type.t = "rust_series_dtype"
Expand Down
1 change: 1 addition & 0 deletions lib/series.mli
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ val map
-> f:('a option -> 'b option)
-> t

val cast : ?strict:bool -> t -> to_:Data_type.t -> t
val name : t -> string
val rename : t -> name:string -> unit
val dtype : t -> Data_type.t
Expand Down
1 change: 1 addition & 0 deletions rust/polars-ocaml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ features = [
"dtype-i8",
"dtype-u16",
"dtype-u8",
"dtype-categorical",
"dynamic_groupby",
"horizontal_concat",
"interpolate",
Expand Down
1 change: 0 additions & 1 deletion rust/polars-ocaml/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime};
use ocaml_interop::{
DynBox, OCaml, OCamlBytes, OCamlFloat, OCamlInt, OCamlList, OCamlRef, OCamlRuntime, ToOCaml,
};
use polars::lazy::dsl::GetOutput;
use polars::prelude::*;
use polars::series::IsSorted;
use polars_ocaml_macros::ocaml_interop_export;
Expand Down
11 changes: 11 additions & 0 deletions rust/polars-ocaml/src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ fn rust_schema_to_fields(
fields.to_ocaml(cr)
}

#[ocaml_interop_export]
fn rust_rev_mapping_get_categories(
cr: &mut &mut OCamlRuntime,
rev_mapping: OCamlRef<DynBox<Arc<RevMapping>>>,
) -> OCaml<OCamlList<String>> {
let Abstract(rev_mapping) = rev_mapping.to_rust(cr);
let rev_mapping: Vec<_> = rev_mapping.get_categories().values_iter().collect();

rev_mapping.to_ocaml(cr)
}

#[ocaml_interop_export]
fn rust_test_panic(cr: &mut &mut OCamlRuntime, error_message: OCamlRef<String>) -> OCaml<()> {
let error_message: String = error_message.to_rust(cr);
Expand Down
12 changes: 10 additions & 2 deletions rust/polars-ocaml/src/polars_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use ocaml_interop::{
polymorphic_variant_tag_hash, DynBox, FromOCaml, OCaml, OCamlInt, OCamlList, OCamlRuntime,
ToOCaml,
};
use polars::prelude::*;
use polars::series::IsSorted;
use polars::{lazy::dsl::WindowMapping, prelude::*};
use smartstring::{LazyCompact, SmartString};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -71,6 +71,10 @@ unsafe impl FromOCaml<DataType> for PolarsDataType {
DataType::List(Box::new(datatype))
},
DataType::Null,
DataType::Categorical(local_rev_mapping_opt: Option<DynBox<Arc<RevMapping>>>) => {
let local_rev_mapping_opt: Option<Abstract<Arc<RevMapping>>> = local_rev_mapping_opt;
DataType::Categorical(local_rev_mapping_opt.map(Abstract::get))
},
DataType::Struct(fields: OCamlList<(String, DataType)>) => {
let fields_: Vec<(String, PolarsDataType)> = fields;
let fields: Vec<Field> =
Expand Down Expand Up @@ -134,12 +138,16 @@ unsafe impl ToOCaml<DataType> for PolarsDataType {
ocaml_alloc_tagged_block!(cr, 2, datatype: DataType)
}
DataType::Null => ocaml_value(cr, 15),
DataType::Categorical(local_rev_mapping_opt) => {
let local_rev_mapping_opt = local_rev_mapping_opt.clone().map(Abstract);
ocaml_alloc_tagged_block!(cr, 3, local_rev_mapping_opt: Option<DynBox<Arc<RevMapping>>>)
}
DataType::Struct(fields) => {
let fields: Vec<(String, PolarsDataType)> = fields
.iter()
.map(|field| (field.name.to_string(), PolarsDataType(field.dtype.clone())))
.collect();
ocaml_alloc_tagged_block!(cr, 3, fields: OCamlList<(String, DataType)>)
ocaml_alloc_tagged_block!(cr, 4, fields: OCamlList<(String, DataType)>)
}
DataType::Unknown => ocaml_value(cr, 16),
}
Expand Down
Loading

0 comments on commit 596d61c

Please sign in to comment.