Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor unconstraining to use deserializer interface #872

Merged
merged 54 commits into from
Aug 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4f87011
WIP; First pass at unconstrain refactor for deserializer
rybern Apr 5, 2021
02b1782
make the transform_inits() API to work with the new impl
SteveBronder Apr 6, 2021
035a665
Use serializer.write instead of var__.emplace_back in transform_inits…
rybern Apr 6, 2021
82e8af9
Add (commented) code to replace vars__.emplace_back with serializer.w…
rybern Apr 6, 2021
acf5c8e
Change initialization of decay_t in transform_inits
rybern Apr 6, 2021
39d22dd
use serialize in write_array()
SteveBronder Apr 6, 2021
1f0aac7
Don't inline deserializer calls in transform_inits
rybern Apr 7, 2021
a4810ba
Code cleanup
rybern Apr 7, 2021
e8a5960
update to master
SteveBronder Apr 7, 2021
f5e6f6d
Do unconstrain in writes rather than reads
rybern Apr 12, 2021
42d541a
Remove dimensions from serializer reads with Identity constraints
rybern Apr 12, 2021
0474b30
Structured reads/writes, but optimization doesn't find exprs
rybern Apr 12, 2021
7848b18
Include dimensions in serializer reads
rybern Apr 13, 2021
25e1ce8
update to master
SteveBronder Apr 21, 2021
dfe7564
Give transformations their own module to avoid circular dependencies
rybern Apr 23, 2021
5f2294b
Merge structured internal functions
rybern Apr 23, 2021
2d41f2e
Use transformations in internal functions; move transform codegen log…
rybern Apr 23, 2021
f602363
Pass transformation to FnCheck to handle in backend instead of ast_to…
rybern Apr 23, 2021
d45927d
Update optimizations' side effecting exprs and to find exprs in inter…
rybern Apr 23, 2021
8b6c9d6
Cleanup constraint generation code
rybern Apr 23, 2021
ba396d7
Merge back into unconstrain branch
rybern Apr 23, 2021
a39d3b0
Merge branch 'unconstrain-refactor' of github.com:rybern/stanc3 into …
rybern Apr 23, 2021
30c4c61
Merge structured internal functions
rybern Apr 23, 2021
21c81b2
Add optimization util test
rybern Apr 24, 2021
4c172c9
Fix optimizations with structured internal_fns
rybern Apr 24, 2021
1caeea6
Change name of Transformations.t to follow convention
rybern Apr 24, 2021
100f353
Adding Transformation module to avoid cyclic dependency
rybern Apr 24, 2021
bf5a3cf
update expected tests
SteveBronder Apr 27, 2021
8381bad
remove dims from write to serializer
SteveBronder Apr 28, 2021
7fd3b74
dune format
SteveBronder Apr 28, 2021
a14d5cc
adds back all dimensions for when transform_inits does the read from …
SteveBronder Apr 28, 2021
7ac6cda
Merge remote-tracking branch 'upstream/master' into HEAD
SteveBronder Apr 28, 2021
92f284e
Remove dimensions from serializer.write calls
rybern Apr 28, 2021
4d4e0c0
Remove dimensions from serializer.write calls
rybern Apr 28, 2021
e7be045
fix non-impl transform_inits body to construct instead of reserve the…
SteveBronder Apr 28, 2021
cfe9238
Initialize vars_vec with size to hold transformed and genquant vars
rybern May 5, 2021
65e1b66
add quantile signatures
adamhaber Apr 5, 2021
b4a208f
add test model
adamhaber Apr 6, 2021
5222f34
dune promote
adamhaber Apr 6, 2021
2f92c01
Revert "Add quantile signatures"
rok-cesnovar May 3, 2021
ba7c5dc
fix stanc copy
rok-cesnovar May 3, 2021
e847427
update name of num_params to num_params_r__
SteveBronder May 5, 2021
0818dc5
fix sizes for std::vector write_array
SteveBronder May 7, 2021
6349501
merge master
rybern May 14, 2021
24b018b
format
SteveBronder May 14, 2021
112c179
Array params now read from row-major to col-major in transform_inits.…
rybern May 20, 2021
7438345
Merge remote-tracking branch 'upstream/master' into HEAD
SteveBronder May 20, 2021
6100eeb
update ode expect tests
SteveBronder May 21, 2021
03fbd4f
update size of write_array serialized vector
SteveBronder May 25, 2021
270a8cc
force clean all before building cmdstan
SteveBronder May 25, 2021
042c13c
turn off precompiled headers for jenkins
SteveBronder May 25, 2021
e9cb92a
update to master
SteveBronder Aug 11, 2021
87a1c3b
fix Jenkinsfile
SteveBronder Aug 11, 2021
39e5826
use loop instead to handle column major to row major in transform_inits
SteveBronder Aug 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ test/*.log

# Mac OS X hidden files
*.DS_Store

# .hpp files in test folder
test/**/*.hpp
19 changes: 12 additions & 7 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type bound_values =
{ lower: [`None | `Nonlit | `Lit of float]
; upper: [`None | `Nonlit | `Lit of float] }

let trans_bounds_values (trans : Expr.Typed.t transformation) : bound_values =
let trans_bounds_values (trans : Expr.Typed.t Transformation.t) : bound_values
=
let bound_value e =
match num_expr_value e with None -> `Nonlit | Some (f, _) -> `Lit f
in
Expand Down Expand Up @@ -132,8 +133,7 @@ let map_rec_expr_state f state e =

let rec map_rec_stmt_loc f stmt =
let recurse = map_rec_stmt_loc f in
Stmt.Fixed.
{stmt with pattern= f (Pattern.map (fun x -> x) recurse stmt.pattern)}
Stmt.Fixed.{stmt with pattern= f (Pattern.map Fn.id recurse stmt.pattern)}

let rec top_down_map_rec_stmt_loc f stmt =
let recurse = top_down_map_rec_stmt_loc f in
Expand Down Expand Up @@ -250,7 +250,7 @@ let rec expr_var_set Expr.Fixed.({pattern; meta}) =
match pattern with
| Var s -> Set.Poly.singleton (VVar s, meta)
| Lit _ -> Set.Poly.empty
| FunApp (_, exprs) -> union_recur exprs
| FunApp (kind, exprs) -> union_recur (exprs @ Fun_kind.collect_exprs kind)
| TernaryIf (expr1, expr2, expr3) -> union_recur [expr1; expr2; expr3]
| Indexed (expr, ix) ->
Set.Poly.union_list (expr_var_set expr :: List.map ix ~f:index_var_set)
Expand All @@ -268,7 +268,8 @@ and index_var_set ix =
let stmt_rhs stmt =
match stmt with
| Stmt.Fixed.Pattern.For vars -> Set.Poly.of_list [vars.lower; vars.upper]
| NRFunApp (_, exprs) -> Set.Poly.of_list exprs
| NRFunApp (kind, exprs) ->
Set.Poly.of_list (exprs @ Fun_kind.collect_exprs kind)
| IfElse (rhs, _, _)
|While (rhs, _)
|Assignment (_, rhs)
Expand Down Expand Up @@ -354,7 +355,8 @@ let expr_subst_stmt m = map_rec_stmt_loc (expr_subst_stmt_base m)
let rec expr_depth Expr.Fixed.({pattern; _}) =
match pattern with
| Var _ | Lit (_, _) -> 0
| FunApp (_, l) ->
| FunApp (kind, args) ->
let l = args @ Fun_kind.collect_exprs kind in
1
+ Option.value ~default:0
(List.max_elt ~compare:compare_int (List.map ~f:expr_depth l))
Expand Down Expand Up @@ -393,8 +395,11 @@ let rec update_expr_ad_levels autodiffable_variables
else {e with meta= {e.meta with adlevel= DataOnly}}
| Lit (_, _) -> {e with meta= {e.meta with adlevel= DataOnly}}
| FunApp (kind, l) ->
let kind' =
Fun_kind.map (update_expr_ad_levels autodiffable_variables) kind
in
let l = List.map ~f:(update_expr_ad_levels autodiffable_variables) l in
{pattern= FunApp (kind, l); meta= {e.meta with adlevel= ad_level_sup l}}
{pattern= FunApp (kind', l); meta= {e.meta with adlevel= ad_level_sup l}}
| TernaryIf (e1, e2, e3) ->
let e1 = update_expr_ad_levels autodiffable_variables e1 in
let e2 = update_expr_ad_levels autodiffable_variables e2 in
Expand Down
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Mir_utils.mli
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ type bound_values =
{ lower: [`None | `Nonlit | `Lit of float]
; upper: [`None | `Nonlit | `Lit of float] }

val trans_bounds_values : Expr.Typed.t Program.transformation -> bound_values
val trans_bounds_values : Expr.Typed.t Transformation.t -> bound_values
val chop_dist_name : string -> string Option.t
val top_var_declarations : Stmt.Located.t -> string Set.Poly.t

Expand All @@ -22,7 +22,7 @@ val data_set :
val parameter_set :
?include_transformed:bool
-> Program.Typed.t
-> (string * Expr.Typed.t Program.transformation) Set.Poly.t
-> (string * Expr.Typed.t Transformation.t) Set.Poly.t

val parameter_names_set :
?include_transformed:bool -> Program.Typed.t -> string Set.Poly.t
Expand Down
15 changes: 10 additions & 5 deletions src/analysis_and_optimization/Monotone_framework.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ and free_vars_idx (i : Expr.Typed.t Index.t) =
| Between (e1, e2) -> Set.Poly.union (free_vars_expr e1) (free_vars_expr e2)

and free_vars_fnapp kind l =
let arg_vars = List.map ~f:free_vars_expr l in
let arg_vars =
List.map ~f:free_vars_expr (l @ Fun_kind.collect_exprs kind)
in
match kind with
| Fun_kind.UserDefined (f, _) ->
Set.Poly.union_list (Set.Poly.singleton f :: List.map ~f:free_vars_expr l)
Expand Down Expand Up @@ -547,8 +549,9 @@ let rec used_subexpressions_expr (e : Expr.Typed.t) =
(Expr.Typed.Set.singleton e)
( match e.pattern with
| Var _ | Lit (_, _) -> Expr.Typed.Set.empty
| FunApp (_, l) ->
Expr.Typed.Set.union_list (List.map ~f:used_subexpressions_expr l)
| FunApp (k, l) ->
Expr.Typed.Set.union_list
(List.map ~f:used_subexpressions_expr (l @ Fun_kind.collect_exprs k))
| TernaryIf (e1, e2, e3) ->
Expr.Typed.Set.union_list
[ used_subexpressions_expr e1
Expand Down Expand Up @@ -585,7 +588,8 @@ let rec used_expressions_stmt_help f
[ f e
; used_expressions_stmt_help f b1.pattern
; used_expressions_stmt_help f b2.pattern ]
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (k, l) ->
Expr.Typed.Set.union_list (List.map ~f (l @ Fun_kind.collect_exprs k))
| Decl _ | Return None | Break | Continue | Skip -> Expr.Typed.Set.empty
| IfElse (e, b, None) | While (e, b) ->
Expr.Typed.Set.union (f e) (used_expressions_stmt_help f b.pattern)
Expand Down Expand Up @@ -619,7 +623,8 @@ let top_used_expressions_stmt_help f
(Expr.Typed.Set.union_list
(List.map ~f:(used_expressions_idx_help f) l))
| While (e, _) | IfElse (e, _, _) -> f e
| NRFunApp (_, l) -> Expr.Typed.Set.union_list (List.map ~f l)
| NRFunApp (k, l) ->
Expr.Typed.Set.union_list (List.map ~f (l @ Fun_kind.collect_exprs k))
| Profile _ | Block _ | SList _ | Decl _
|Return None
|Break | Continue | Skip ->
Expand Down
9 changes: 2 additions & 7 deletions src/analysis_and_optimization/Optimize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -690,13 +690,8 @@ and accum_any pred b e = b || expr_any pred e
let can_side_effect_top_expr (e : Expr.Typed.t) =
match e.pattern with
| FunApp ((UserDefined (_, FnTarget) | StanLib (_, FnTarget)), _) -> true
| FunApp
( CompilerInternal
( FnReadParam _ | FnReadData | FnWriteParam | FnConstrain _
| FnValidateSize | FnValidateSizeSimplex | FnValidateSizeUnitVector
| FnUnconstrain _ )
, _ ) ->
true
| FunApp (CompilerInternal internal_fn, _) ->
Internal_fun.can_side_effect internal_fn
| _ -> false

let cannot_duplicate_expr (e : Expr.Typed.t) =
Expand Down
2 changes: 1 addition & 1 deletion src/analysis_and_optimization/Pedantic_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ let list_non_one_priors (fg : factor_graph) (mir : Program.Typed.t) :
(* Collect useful information about an expression that's available at
compile-time into a convenient form. *)
let compiletime_value_of_expr
(params : (string * Expr.Typed.t Program.transformation) Set.Poly.t)
(params : (string * Expr.Typed.t Transformation.t) Set.Poly.t)
(data : string Set.Poly.t) (expr : Expr.Typed.t) :
compiletime_val * Expr.Typed.Meta.t =
let v =
Expand Down
21 changes: 10 additions & 11 deletions src/analysis_and_optimization/Pedantic_dist_warnings.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ open Mir_utils
type compiletime_val =
| Opaque
| Number of (float * string)
| Param of (string * Expr.Typed.t Program.transformation)
| Param of (string * Expr.Typed.t Transformation.t)
| Data of string

(* Info about a distribution occurrences that's useful for checking that
Expand Down Expand Up @@ -79,18 +79,17 @@ let bounds_out_of_range (range : range) (bounds : bound_values) : bool =
(* Check for inconsistency between a distribution argument's constraint and the
constraint transformation of a variable *)
let transform_mismatch_constraint (constr : var_constraint)
(trans : Expr.Typed.t Program.transformation) : bool =
(trans : Expr.Typed.t Transformation.t) : bool =
match constr with
| Range range -> bounds_out_of_range range (trans_bounds_values trans)
| Ordered -> trans <> Program.Ordered
| PositiveOrdered -> trans <> Program.PositiveOrdered
| Simplex -> trans <> Program.Simplex
| UnitVector -> trans <> Program.UnitVector
| CholeskyCorr -> trans <> Program.CholeskyCorr
| CholeskyCov ->
trans <> Program.CholeskyCov && trans <> Program.CholeskyCorr
| Correlation -> trans <> Program.Correlation
| Covariance -> trans <> Program.Covariance && trans <> Program.Correlation
| Ordered -> trans <> Transformation.Ordered
| PositiveOrdered -> trans <> PositiveOrdered
| Simplex -> trans <> Simplex
| UnitVector -> trans <> UnitVector
| CholeskyCorr -> trans <> CholeskyCorr
| CholeskyCov -> trans <> CholeskyCov && trans <> CholeskyCorr
| Correlation -> trans <> Correlation
| Covariance -> trans <> Covariance && trans <> Correlation

(* Check for inconsistency between a distribution argument's range and
a literal value *)
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ type ('e, 's, 'l, 'f) statement =
| Block of 's list
| VarDecl of
{ decl_type: 'e Middle.Type.t
; transformation: 'e Middle.Program.transformation
; transformation: 'e Transformation.t
; identifier: identifier
; initial_value: 'e option
; is_global: bool }
Expand Down
Loading