Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for categorical types #121

Merged
merged 3 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 31 additions & 25 deletions guide/expressions.ml
Original file line number Diff line number Diff line change
Expand Up @@ -561,11 +561,11 @@ let%expect_test "Casting" =
let%expect_test "Aggregation" =
let schema =
Schema.create
[ "first_name", Utf8
; "gender", Utf8
; "type", Utf8
; "state", Utf8
; "party", Utf8
[ "first_name", Categorical None
; "gender", Categorical None
; "type", Categorical None
; "state", Categorical None
; "party", Categorical None
; "birthday", Date
]
in
Expand All @@ -582,7 +582,7 @@ let%expect_test "Aggregation" =
┌────────────┬────────────┬────────────┬────────┬───┬───────────┬───────────┬──────────┬───────────┐
│ last_name ┆ first_name ┆ middle_nam ┆ suffix ┆ … ┆ ballotped ┆ washingto ┆ icpsr_id ┆ wikipedia │
│ --- ┆ --- ┆ e ┆ --- ┆ ┆ ia_id ┆ n_post_id ┆ --- ┆ _id │
│ str ┆ str ┆ --- ┆ str ┆ ┆ --- ┆ --- ┆ i64 ┆ --- │
│ str ┆ cat ┆ --- ┆ str ┆ ┆ --- ┆ --- ┆ i64 ┆ --- │
│ ┆ ┆ str ┆ ┆ ┆ str ┆ str ┆ ┆ str │
╞════════════╪════════════╪════════════╪════════╪═══╪═══════════╪═══════════╪══════════╪═══════════╡
│ Bassett ┆ Richard ┆ null ┆ null ┆ … ┆ null ┆ null ┆ 507 ┆ Richard │
Expand Down Expand Up @@ -627,7 +627,7 @@ let%expect_test "Aggregation" =
┌────────────┬───────┬───────────────────┬───────────┐
│ first_name ┆ count ┆ gender ┆ last_name │
│ --- ┆ --- ┆ --- ┆ --- │
str ┆ u32 ┆ list[str] ┆ str │
cat ┆ u32 ┆ list[cat] ┆ str │
╞════════════╪═══════╪═══════════════════╪═══════════╡
│ John ┆ 1256 ┆ ["M", "M", … "M"] ┆ Walker │
│ William ┆ 1022 ┆ ["M", "M", … "M"] ┆ Few │
Expand Down Expand Up @@ -655,7 +655,7 @@ let%expect_test "Aggregation" =
┌───────┬──────┬─────┐
│ state ┆ anti ┆ pro │
│ --- ┆ --- ┆ --- │
str ┆ u32 ┆ u32 │
cat ┆ u32 ┆ u32 │
╞═══════╪══════╪═════╡
│ NJ ┆ 0 ┆ 3 │
│ CT ┆ 0 ┆ 3 │
Expand Down Expand Up @@ -684,7 +684,7 @@ let%expect_test "Aggregation" =
┌───────┬─────────────────────┬───────┐
│ state ┆ party ┆ count │
│ --- ┆ --- ┆ --- │
strstr ┆ u32 │
catcat ┆ u32 │
╞═══════╪═════════════════════╪═══════╡
│ NJ ┆ Pro-Administration ┆ 3 │
│ VA ┆ Anti-Administration ┆ 3 │
Expand Down Expand Up @@ -720,7 +720,7 @@ let%expect_test "Aggregation" =
┌───────┬────────────────┬────────────────┬────────┬──────────┐
│ state ┆ avg M birthday ┆ avg F birthday ┆ # male ┆ # female │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
str ┆ f64 ┆ f64 ┆ u32 ┆ u32 │
cat ┆ f64 ┆ f64 ┆ u32 ┆ u32 │
╞═══════╪════════════════╪════════════════╪════════╪══════════╡
│ DE ┆ 182.593407 ┆ null ┆ 97 ┆ 0 │
│ VA ┆ 192.542781 ┆ 66.2 ┆ 430 ┆ 5 │
Expand Down Expand Up @@ -749,7 +749,7 @@ let%expect_test "Aggregation" =
┌───────┬──────────────────┬───────────────────────┐
│ state ┆ youngest ┆ oldest │
│ --- ┆ --- ┆ --- │
str ┆ str ┆ str │
cat ┆ str ┆ str │
╞═══════╪══════════════════╪═══════════════════════╡
│ NC ┆ Madison Cawthorn ┆ John Ashe │
│ IA ┆ Abby Finkenauer ┆ Bernhart Henn │
Expand Down Expand Up @@ -778,7 +778,7 @@ let%expect_test "Aggregation" =
┌───────┬──────────────────┬───────────────────────┬────────────────────┐
│ state ┆ youngest ┆ oldest ┆ alphabetical_first │
│ --- ┆ --- ┆ --- ┆ --- │
str ┆ str ┆ str ┆ str │
cat ┆ str ┆ str ┆ str │
╞═══════╪══════════════════╪═══════════════════════╪════════════════════╡
│ NC ┆ Madison Cawthorn ┆ John Ashe ┆ Abraham Rencher │
│ IA ┆ Abby Finkenauer ┆ Bernhart Henn ┆ Abby Finkenauer │
Expand All @@ -797,29 +797,35 @@ let%expect_test "Aggregation" =
; get_person |> last |> alias ~name:"oldest"
; get_person |> sort |> first |> alias ~name:"alphabetical_first"
; col "gender"
|> sort_by ~by:[ col "first_name" ]
|> sort_by
~by:
[ (* The guide uses "first_name" to sort by, but I'm guessing
there's an nondeterminism bug causing output to be unstable
if we have multiple sorts or something, so I suspect
[get_person] is what we actually want *)
get_person
]
|> first
|> alias ~name:"gender"
]
|> Lazy_frame.sort ~by_column:"state"
|> Lazy_frame.limit ~n:5
|> Lazy_frame.collect_exn
in
Data_frame.print df;
[%expect
{|
shape: (5, 5)
┌───────┬──────────────────┬────────────────┬────────────────────┬────────┐
│ state ┆ youngest ┆ oldest ┆ alphabetical_first ┆ gender │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
str ┆ str ┆ str ┆ str ┆ str
╞═══════╪══════════════════╪════════════════╪════════════════════╪════════╡
AKMark Begich ┆ Thomas Cale ┆ Anthony Dimond ┆ M │
ALMartha Roby ┆ John McKeeAlbert Goodwyn ┆ M
ARTim Griffin ┆ Archibald Yell ┆ Albert Rust ┆ M │
ASEni Faleomavaega ┆ Fofó Sunia ┆ Eni Faleomavaega ┆ M │
AZBen Quayle ┆ Coles Bashford ┆ Ann Kirkpatrick ┆ F
└───────┴──────────────────┴────────────────┴────────────────────┴────────┘ |}]
┌───────┬──────────────────┬───────────────────────┬────────────────────┬────────┐
│ state ┆ youngest ┆ oldest ┆ alphabetical_first ┆ gender │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- │
cat ┆ str ┆ str ┆ str ┆ cat
╞═══════╪══════════════════╪═══════════════════════╪════════════════════╪════════╡
NCMadison Cawthorn ┆ John Ashe ┆ Abraham Rencher ┆ M │
IAAbby Finkenauer ┆ Bernhart Henn Abby Finkenauer ┆ F
MIPeter Meijer ┆ Edward Bradley ┆ Aaron Bliss ┆ M │
CAKatie Hill ┆ Edward Gilbert ┆ Aaron Sargent ┆ M │
NYMondaire Jones ┆ Cornelius Schoonmaker ┆ A. Foster ┆ M
└───────┴──────────────────┴───────────────────────┴────────────────────┴────────┘ |}]
;;

(* Examples from https://pola-rs.github.io/polars-book/user-guide/expressions/null/ *)
Expand Down
66 changes: 32 additions & 34 deletions lib/data_type.ml
Original file line number Diff line number Diff line change
@@ -1,40 +1,38 @@
open! Core

module T = struct
type t =
| Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Utf8
| Binary
| Date
| Datetime of Time_unit.t * Tz.t option
| Duration of Time_unit.t
| Time
| List of t
(* We want this branch to be tested very well, since code dealing with
this recursive case is usually the most non-trivial portion of the
logic. *)
[@quickcheck.weight 10.]
| Null
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp, quickcheck]
end

include T
include Sexpable.To_stringable (T)
type t =
| Boolean
| UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float32
| Float64
| Utf8
| Binary
| Date
| Datetime of Time_unit.t * Tz.t option
| Duration of Time_unit.t
| Time
| List of t
(* We want this branch to be tested very well, since code dealing with
this recursive case is usually the most non-trivial portion of the
logic. *)
[@quickcheck.weight 10.]
| Null
| Categorical of (Rev_mapping.t[@compare.ignore]) option [@quickcheck.do_not_generate]
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp_of, quickcheck]

let to_string t = Sexp.to_string ([%sexp_of: t] t)

module Typed = struct
type untyped = t [@@deriving compare, sexp, quickcheck]
type untyped = t [@@deriving compare, sexp_of, quickcheck]

(* TODO: Consider mapping to smaller OCaml values like Int8, Float32, etc instead of
casting up *)
Expand Down Expand Up @@ -157,7 +155,7 @@ module Typed = struct
| Duration time_unit -> Some (T (Duration time_unit))
| Time -> Some (T Time)
| List t -> of_untyped t |> Option.map ~f:(fun (T t) -> T (List t))
| Null | Struct _ | Unknown -> None
| Null | Categorical _ | Struct _ | Unknown -> None
;;

let rec sexp_of_packed (T t) =
Expand Down
5 changes: 3 additions & 2 deletions lib/data_type.mli
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ type t =
| Time
| List of t
| Null
| Categorical of Rev_mapping.t option
| Struct of (string * t) list
| Unknown
[@@deriving compare, sexp, quickcheck]
[@@deriving compare, sexp_of, quickcheck]

include Stringable.S with type t := t
val to_string : t -> string

module Typed : sig
type untyped
Expand Down
1 change: 1 addition & 0 deletions lib/polars.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ module Naive_time = Naive_time
module Schema = Schema
module Series = Series
module Sql_context = Sql_context
module Rev_mapping = Rev_mapping
module Time_unit = Time_unit
module Tz = Tz
19 changes: 19 additions & 0 deletions lib/rev_mapping.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
open! Core

type t

let quickcheck_shrinker = Base_quickcheck.Shrinker.atomic

let quickcheck_observer =
Base_quickcheck.Observer.of_hash_fold (fun hash_state _t -> hash_state)
;;

external get_categories : t -> string list = "rust_rev_mapping_get_categories"

let sexp_of_t t =
get_categories t
|> (* It's difficult to get tests to be deterministic wrt the ordering of
categories, so we force them to be sorted here as a hacky workaround. *)
List.sort ~compare:String.compare
|> [%sexp_of: string list]
;;
7 changes: 7 additions & 0 deletions lib/rev_mapping.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
open! Core

type t [@@deriving sexp_of]

val quickcheck_shrinker : t Quickcheck.Shrinker.t
val quickcheck_observer : t Quickcheck.Observer.t
val get_categories : t -> string list
1 change: 0 additions & 1 deletion lib/schema.ml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,3 @@ external create : (string * Data_type.t) list -> t = "rust_schema_create"
external to_fields : t -> (string * Data_type.t) list = "rust_schema_to_fields"

let sexp_of_t t = to_fields t |> [%sexp_of: (string * Data_type.t) list]
let t_of_sexp sexp = [%of_sexp: (string * Data_type.t) list] sexp |> create
2 changes: 1 addition & 1 deletion lib/schema.mli
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
open! Core

type t [@@deriving sexp]
type t [@@deriving sexp_of]

val create : (string * Data_type.t) list -> t
val to_fields : t -> (string * Data_type.t) list
9 changes: 9 additions & 0 deletions lib/series.ml
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ module T = struct
map input_data_type output_data_type t ~f |> Result.ok_exn
;;

external cast
: t
-> to_:Data_type.t
-> strict:bool
-> (t, string) result
= "rust_series_cast"

let cast ?(strict = true) t ~to_ = cast t ~to_ ~strict |> Utils.string_result_ok_exn

external name : t -> string = "rust_series_name"
external rename : t -> name:string -> unit = "rust_series_rename"
external dtype : t -> Data_type.t = "rust_series_dtype"
Expand Down
1 change: 1 addition & 0 deletions lib/series.mli
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ val map
-> f:('a option -> 'b option)
-> t

val cast : ?strict:bool -> t -> to_:Data_type.t -> t
val name : t -> string
val rename : t -> name:string -> unit
val dtype : t -> Data_type.t
Expand Down
1 change: 1 addition & 0 deletions rust/polars-ocaml/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ features = [
"dtype-i8",
"dtype-u16",
"dtype-u8",
"dtype-categorical",
"dynamic_groupby",
"horizontal_concat",
"interpolate",
Expand Down
1 change: 0 additions & 1 deletion rust/polars-ocaml/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use chrono::{Duration, NaiveDate, NaiveDateTime, NaiveTime};
use ocaml_interop::{
DynBox, OCaml, OCamlBytes, OCamlFloat, OCamlInt, OCamlList, OCamlRef, OCamlRuntime, ToOCaml,
};
use polars::lazy::dsl::GetOutput;
use polars::prelude::*;
use polars::series::IsSorted;
use polars_ocaml_macros::ocaml_interop_export;
Expand Down
11 changes: 11 additions & 0 deletions rust/polars-ocaml/src/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@ fn rust_schema_to_fields(
fields.to_ocaml(cr)
}

#[ocaml_interop_export]
fn rust_rev_mapping_get_categories(
cr: &mut &mut OCamlRuntime,
rev_mapping: OCamlRef<DynBox<Arc<RevMapping>>>,
) -> OCaml<OCamlList<String>> {
let Abstract(rev_mapping) = rev_mapping.to_rust(cr);
let rev_mapping: Vec<_> = rev_mapping.get_categories().values_iter().collect();

rev_mapping.to_ocaml(cr)
}

#[ocaml_interop_export]
fn rust_test_panic(cr: &mut &mut OCamlRuntime, error_message: OCamlRef<String>) -> OCaml<()> {
let error_message: String = error_message.to_rust(cr);
Expand Down
12 changes: 10 additions & 2 deletions rust/polars-ocaml/src/polars_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use ocaml_interop::{
polymorphic_variant_tag_hash, DynBox, FromOCaml, OCaml, OCamlInt, OCamlList, OCamlRuntime,
ToOCaml,
};
use polars::prelude::*;
use polars::series::IsSorted;
use polars::{lazy::dsl::WindowMapping, prelude::*};
use smartstring::{LazyCompact, SmartString};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -71,6 +71,10 @@ unsafe impl FromOCaml<DataType> for PolarsDataType {
DataType::List(Box::new(datatype))
},
DataType::Null,
DataType::Categorical(local_rev_mapping_opt: Option<DynBox<Arc<RevMapping>>>) => {
let local_rev_mapping_opt: Option<Abstract<Arc<RevMapping>>> = local_rev_mapping_opt;
DataType::Categorical(local_rev_mapping_opt.map(Abstract::get))
},
DataType::Struct(fields: OCamlList<(String, DataType)>) => {
let fields_: Vec<(String, PolarsDataType)> = fields;
let fields: Vec<Field> =
Expand Down Expand Up @@ -134,12 +138,16 @@ unsafe impl ToOCaml<DataType> for PolarsDataType {
ocaml_alloc_tagged_block!(cr, 2, datatype: DataType)
}
DataType::Null => ocaml_value(cr, 15),
DataType::Categorical(local_rev_mapping_opt) => {
let local_rev_mapping_opt = local_rev_mapping_opt.clone().map(Abstract);
ocaml_alloc_tagged_block!(cr, 3, local_rev_mapping_opt: Option<DynBox<Arc<RevMapping>>>)
}
DataType::Struct(fields) => {
let fields: Vec<(String, PolarsDataType)> = fields
.iter()
.map(|field| (field.name.to_string(), PolarsDataType(field.dtype.clone())))
.collect();
ocaml_alloc_tagged_block!(cr, 3, fields: OCamlList<(String, DataType)>)
ocaml_alloc_tagged_block!(cr, 4, fields: OCamlList<(String, DataType)>)
}
DataType::Unknown => ocaml_value(cr, 16),
}
Expand Down
Loading
Loading