Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Composable transforms #947

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
31fc2ed
First pass at composable transforms
WardBrian Aug 17, 2021
5b762ae
Merge branch 'stan-dev:master' into composable-transforms
WardBrian Aug 17, 2021
3441287
Checkpointing
WardBrian Aug 18, 2021
175794a
Merge branch 'master' into composable-transforms
WardBrian Aug 18, 2021
3a6f275
Checkpointing
WardBrian Aug 18, 2021
1923d34
Commenting
WardBrian Aug 18, 2021
e1cf057
update tests
SteveBronder Aug 20, 2021
724ae4f
Mark non-parameter uses as error states
WardBrian Aug 30, 2021
0bc2984
First pass at dimensions for reader
WardBrian Aug 30, 2021
404d9a8
Merge branch 'master' of github.com:stan-dev/stanc3 into composable-t…
WardBrian Aug 30, 2021
644eac7
Update test output
WardBrian Aug 30, 2021
727f115
Switch to composing function calls
WardBrian Aug 30, 2021
6db5276
Change write to match
WardBrian Aug 30, 2021
8a3bef8
Matrix constraints read in a vector
WardBrian Aug 31, 2021
30147a5
Fix formatting of semicolon
WardBrian Aug 31, 2021
3eed96c
Fix function naming
WardBrian Aug 31, 2021
21ce330
TODO work
WardBrian Aug 31, 2021
0f954d9
Fix dimensionality in nested types
WardBrian Aug 31, 2021
4cf7040
Add extra args to matrix constrain -- ugly
WardBrian Aug 31, 2021
8ec5be3
Change how special matrix types are handled internally
WardBrian Aug 31, 2021
ecfc6c5
Merge branch 'master' of github.com:stan-dev/stanc3 into composable-t…
WardBrian Sep 1, 2021
0c4906d
Merge branch 'master' into composable-transforms
WardBrian Sep 9, 2021
f384436
Merge branch 'master' of github.com:stan-dev/stanc3 into composable-t…
WardBrian Sep 13, 2021
43bc0f4
Work on semantic checks
WardBrian Sep 15, 2021
23428c1
Semantic check work
WardBrian Sep 15, 2021
16feb64
Merge branch 'master' of github.com:WardBrian/stanc3 into composable-…
WardBrian Sep 20, 2021
6c6b0c0
Merge branch 'master' of github.com:stan-dev/stanc3 into composable-t…
WardBrian Sep 20, 2021
61bbf29
Fix formatting
WardBrian Sep 20, 2021
05b13fb
update to use explicit namespaces for constraints
SteveBronder Sep 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 26 additions & 14 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,34 @@ type bound_values =
{ lower: [`None | `Nonlit | `Lit of float]
; upper: [`None | `Nonlit | `Lit of float] }

let trans_bounds_values (trans : Expr.Typed.t Transformation.t) : bound_values
=
let bound_value e =
match num_expr_value e with None -> `Nonlit | Some (f, _) -> `Lit f
let trans_bounds_values (trans : Expr.Typed.t Transformation.t) :
bound_values list =
let single_bound (t : 'e Transformation.primitive) =
let bound_value e =
match num_expr_value e with None -> `Nonlit | Some (f, _) -> `Lit f
in
match t with
| Lower lower -> {lower= bound_value lower; upper= `None}
| Upper upper -> {lower= `None; upper= bound_value upper}
| LowerUpper (lower, upper) ->
{lower= bound_value lower; upper= bound_value upper}
| Simplex -> {lower= `Lit 0.; upper= `Lit 1.}
| PositiveOrdered -> {lower= `Lit 0.; upper= `None}
| UnitVector -> {lower= `Lit (-1.); upper= `Lit 1.}
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered
|Offset _ | Multiplier _ | OffsetMultiplier _ ->
{lower= `None; upper= `None}
in
match trans with
| Lower lower -> {lower= bound_value lower; upper= `None}
| Upper upper -> {lower= `None; upper= bound_value upper}
| LowerUpper (lower, upper) ->
{lower= bound_value lower; upper= bound_value upper}
| Simplex -> {lower= `Lit 0.; upper= `Lit 1.}
| PositiveOrdered -> {lower= `Lit 0.; upper= `None}
| UnitVector -> {lower= `Lit (-1.); upper= `Lit 1.}
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered
|Offset _ | Multiplier _ | OffsetMultiplier _ | Identity ->
{lower= `None; upper= `None}
| Identity -> [{lower= `None; upper= `None}]
| Single t -> [single_bound t]
| Chain ts -> List.map ~f:single_bound ts

let trans_domain_bounds trans : bound_values =
(* in distribution constraints, we only care that the final constraint
applied satisfies
*)
List.last_exn (trans_bounds_values trans)

let chop_dist_name (fname : string) : string Option.t =
(* Slightly inefficient, would be better to short-circuit *)
Expand Down
3 changes: 2 additions & 1 deletion src/analysis_and_optimization/Mir_utils.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ type bound_values =
{ lower: [`None | `Nonlit | `Lit of float]
; upper: [`None | `Nonlit | `Lit of float] }

val trans_bounds_values : Expr.Typed.t Transformation.t -> bound_values
val trans_bounds_values : Expr.Typed.t Transformation.t -> bound_values list
val trans_domain_bounds : Expr.Typed.t Transformation.t -> bound_values
val chop_dist_name : string -> string Option.t
val top_var_declarations : Stmt.Located.t -> string Set.Poly.t

Expand Down
12 changes: 6 additions & 6 deletions src/analysis_and_optimization/Pedantic_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ let list_hard_constrained (mir : Program.Typed.t) :
| {lower= `Lit _; upper= `Lit _} -> Some `HardConstraint
| _ -> None
in
Set.Poly.filter_map
~f:(fun (name, trans) ->
Option.map
~f:(fun c -> (name, c))
(constrained (trans_bounds_values trans)) )
(parameter_set mir)
parameter_set mir |> Set.Poly.to_list
|> List.concat_map ~f:(fun (name, trans) ->
trans_bounds_values trans
|> List.filter_map ~f:constrained
|> List.map ~f:(fun el -> (name, el)) )
|> Set.Poly.of_list

let list_multi_twiddles (mir : Program.Typed.t) :
(string * Location_span.t Set.Poly.t) Set.Poly.t =
Expand Down
25 changes: 15 additions & 10 deletions src/analysis_and_optimization/Pedantic_dist_warnings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,21 @@ let bounds_out_of_range (range : range) (bounds : bound_values) : bool =
constraint transformation of a variable *)
let transform_mismatch_constraint (constr : var_constraint)
(trans : Expr.Typed.t Transformation.t) : bool =
let open Transformation in
match constr with
| Range range -> bounds_out_of_range range (trans_bounds_values trans)
| Ordered -> trans <> Transformation.Ordered
| PositiveOrdered -> trans <> PositiveOrdered
| Simplex -> trans <> Simplex
| UnitVector -> trans <> UnitVector
| CholeskyCorr -> trans <> CholeskyCorr
| CholeskyCov -> trans <> CholeskyCov && trans <> CholeskyCorr
| Correlation -> trans <> Correlation
| Covariance -> trans <> Covariance && trans <> Correlation
| Range range -> bounds_out_of_range range (trans_domain_bounds trans)
| Ordered -> not (domains_match trans Ordered)
| PositiveOrdered -> not (domains_match trans PositiveOrdered)
| Simplex -> not (domains_match trans Simplex)
| UnitVector -> not (domains_match trans UnitVector)
| CholeskyCorr -> not (domains_match trans CholeskyCorr)
| CholeskyCov ->
(not (domains_match trans CholeskyCov))
&& not (domains_match trans CholeskyCorr)
| Correlation -> not (domains_match trans Correlation)
| Covariance ->
(not (domains_match trans Covariance))
&& not (domains_match trans Correlation)

(* Check for inconsistency between a distribution argument's range and
a literal value *)
Expand Down Expand Up @@ -204,7 +209,7 @@ let uniform_dist_warning (dist_info : dist_info) :
match dist_info with
| {args= (Param (pname, trans), _) :: (arg1, _) :: (arg2, _) :: _; _} -> (
let warning = Some (dist_info.loc, uniform_dist_message pname) in
match (arg1, arg2, trans_bounds_values trans) with
match (arg1, arg2, trans_domain_bounds trans) with
| _, _, {upper= `None; _} | _, _, {lower= `None; _} ->
(* the variate is unbounded *)
warning
Expand Down
25 changes: 17 additions & 8 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,24 @@ let get_loc_dt (t : untyped_expression Type.t) =
Some e.emeta.loc.begin_loc

let get_loc_tf (t : untyped_expression Transformation.t) =
let get_loc_prim (t : untyped_expression Transformation.primitive) =
match t with
| Lower e
|Upper e
|LowerUpper (e, _)
|Offset e
|Multiplier e
|OffsetMultiplier (e, _) ->
Some e.emeta.loc.begin_loc
| _ -> None
in
match t with
| Lower e
|Upper e
|LowerUpper (e, _)
|Offset e
|Multiplier e
|OffsetMultiplier (e, _) ->
Some e.emeta.loc.begin_loc
| _ -> None
| Identity -> None
| Single t -> get_loc_prim t
| Chain ts ->
let locs = List.map ~f:get_loc_prim ts in
(* return first location encountered *)
List.fold ~f:(Option.merge ~f:(fun x _ -> x)) ~init:None locs

let get_first_loc (s : untyped_statement) =
match s.stmt with
Expand Down
132 changes: 86 additions & 46 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ type decl_context =
{transform_action: transform_action; dadlevel: UnsizedType.autodifftype}

let constraint_forl = function
| Transformation.Identity | Offset _ | Multiplier _ | OffsetMultiplier _
|Lower _ | Upper _ | LowerUpper _ ->
| Transformation.Offset _ | Multiplier _ | OffsetMultiplier _ | Lower _
|Upper _ | LowerUpper _ ->
Stmt.Helpers.for_scalar
| Ordered | PositiveOrdered | Simplex | UnitVector | CholeskyCorr
|CholeskyCov | Correlation | Covariance ->
Expand All @@ -216,7 +216,7 @@ let same_shape decl_id decl_var id var meta =
[str "constraint"; str decl_id; decl_var; str id; var] )
; meta } ]

let check_transform_shape decl_id decl_var meta = function
let check_transform_shape_prim decl_id decl_var meta = function
| Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta
| Multiplier e -> same_shape decl_id decl_var "multiplier" e meta
| Lower e -> same_shape decl_id decl_var "lower" e meta
Expand All @@ -228,9 +228,15 @@ let check_transform_shape decl_id decl_var meta = function
same_shape decl_id decl_var "lower" e1 meta
@ same_shape decl_id decl_var "upper" e2 meta
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity ->
|PositiveOrdered | Simplex | UnitVector ->
[]

let check_transform_shape decl_id decl_var meta = function
| Transformation.Identity -> []
| Single t -> check_transform_shape_prim decl_id decl_var meta t
| Chain ts ->
List.concat_map ~f:(check_transform_shape_prim decl_id decl_var meta) ts

let copy_indices indexed (var : Expr.Typed.t) =
if UnsizedType.is_scalar_type var.meta.type_ then var
else
Expand All @@ -253,7 +259,7 @@ let extract_transform_args var = function
| LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) ->
[copy_indices var a1; copy_indices var a2]
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity ->
|PositiveOrdered | Simplex | UnitVector ->
[]

let param_size transform sizedtype =
Expand All @@ -278,30 +284,48 @@ let param_size transform sizedtype =
let k_choose_2 k =
Expr.Helpers.(binop (binop k Times (binop k Minus (int 1))) Divide (int 2))
in
match transform with
| Transformation.Identity | Lower _ | Upper _
|LowerUpper (_, _)
|Offset _ | Multiplier _
|OffsetMultiplier (_, _)
|Ordered | PositiveOrdered | UnitVector ->
sizedtype
| Simplex ->
shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype
| CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype
| CholeskyCov ->
(* (N * (N + 1)) / 2 + (M - N) * N *)
shrink_eigen_mat
(fun m n ->
Expr.Helpers.(
binop
(binop (k_choose_2 n) Plus n)
Plus
(binop (binop m Minus n) Times n)) )
sizedtype
| Covariance ->
shrink_eigen
(fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k)))
let resize sizedtype trans =
match trans with
| Transformation.Lower _ | Upper _
|LowerUpper (_, _)
|Offset _ | Multiplier _
|OffsetMultiplier (_, _)
|Ordered | PositiveOrdered | UnitVector ->
sizedtype
| Simplex ->
shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype
| CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype
| CholeskyCov ->
(* (N * (N + 1)) / 2 + (M - N) * N *)
shrink_eigen_mat
(fun m n ->
Expr.Helpers.(
binop
(binop (k_choose_2 n) Plus n)
Plus
(binop (binop m Minus n) Times n)) )
sizedtype
| Covariance ->
shrink_eigen
(fun k -> Expr.Helpers.(binop k Plus (k_choose_2 k)))
sizedtype
in
match transform with
| Transformation.Identity -> sizedtype
| Single t -> resize sizedtype t
| Chain ts ->
(* Finds the first constraint which changes sizes and goes with that
* Similar to Transform_Mir.read_constrain_dims, this won't work in all cases,
* like if you ever did a simplex corr_matrix. In our standard library
* This is fine, may need more complicated logic for user defined transforms.
*)
let rec loop st = function
| [] -> st
| t :: ts ->
let st' = resize st t in
if st' = st then loop st ts else st'
in
loop sizedtype ts

let remove_possibly_exn pst action loc =
match pst with
Expand All @@ -313,20 +337,36 @@ let remove_possibly_exn pst action loc =

let rec check_decl var decl_type' decl_id decl_trans smeta adlevel =
let decl_type = remove_possibly_exn decl_type' "check" smeta in
let check_single trans =
match trans with
| Transformation.LowerUpper (lb, ub) ->
check_decl var decl_type' decl_id (Transformation.Single (Lower lb))
smeta adlevel
@ check_decl var decl_type' decl_id (Single (Upper ub)) smeta adlevel
| _ when Transformation.primitive_has_check trans ->
let check_id id =
let var_name = Fmt.strf "%a" Expr.Typed.pp id in
let args = extract_transform_args id trans in
Stmt.Helpers.internal_nrfunapp
(FnCheck {trans; var_name; var= id})
args smeta
in
[(constraint_forl trans) decl_type check_id var smeta]
| _ -> []
in
match decl_trans with
| Transformation.LowerUpper (lb, ub) ->
check_decl var decl_type' decl_id (Lower lb) smeta adlevel
@ check_decl var decl_type' decl_id (Upper ub) smeta adlevel
| _ when Transformation.has_check decl_trans ->
let check_id id =
let var_name = Fmt.strf "%a" Expr.Typed.pp id in
let args = extract_transform_args id decl_trans in
Stmt.Helpers.internal_nrfunapp
(FnCheck {trans= decl_trans; var_name; var= id})
args smeta
in
[(constraint_forl decl_trans) decl_type check_id var smeta]
| _ -> []
| Transformation.Identity -> []
| Single t -> check_single t
(* NB: We only allow chain transforms in parameters, which are never checked.
* REM: A naive attempt at doing this would be `List.concat_map ~f:check_single ts`
* This currently concatinates all checks in a way that is invalid,
* e.g. lower(0), upper(10), will not yield values in [0,10],
* one must perform unconstraining transform between sucessive checks in general
*)
| Chain _ ->
raise_s
[%message
"Attempting to check a chained transform. This should never happen"]

let check_sizedtype name =
let check x = function
Expand Down Expand Up @@ -362,8 +402,8 @@ let check_sizedtype name =
(ll, Type.Sized st)
| Unsized ut -> ([], Unsized ut)

let trans_decl {transform_action; dadlevel} smeta decl_type transform
identifier initial_value =
let trans_decl {transform_action; dadlevel} smeta decl_type
(transform : Expr.Typed.t Transformation.t) identifier initial_value =
let decl_id = identifier.Ast.name in
let rhs = Option.map ~f:trans_expr initial_value in
let size_checks, dt = check_sizedtype identifier.name decl_type in
Expand Down Expand Up @@ -641,9 +681,9 @@ let trans_sizedtype_decl declc tr name =
| SVector (mem_pattern, s) ->
let fn =
match (declc.transform_action, tr) with
| Constrain, Transformation.Simplex ->
| Constrain, Transformation.Single Simplex ->
Internal_fun.FnValidateSizeSimplex
| Constrain, UnitVector -> FnValidateSizeUnitVector
| Constrain, Single UnitVector -> FnValidateSizeUnitVector
| _ -> FnValidateSize
in
let l, s = grab_size fn n s in
Expand All @@ -656,7 +696,7 @@ let trans_sizedtype_decl declc tr name =
let l2, c = grab_size FnValidateSize (n + 1) c in
let cf_cov =
match (declc.transform_action, tr) with
| Constrain, CholeskyCov ->
| Constrain, Single CholeskyCov ->
[ { Stmt.Fixed.pattern=
NRFunApp
( StanLib ("check_greater_or_equal", FnPlain, AoS)
Expand Down
Loading