Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow scale to compose with transformations #971

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ let trans_bounds_values (trans : Expr.Typed.t Transformation.t) : bound_values =
| Simplex -> {lower= `Lit 0.; upper= `Lit 1.}
| PositiveOrdered -> {lower= `Lit 0.; upper= `None}
| UnitVector -> {lower= `Lit (-1.); upper= `Lit 1.}
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered | Offset _
|Multiplier _ | OffsetMultiplier _ | Identity ->
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered | Identity
->
{lower= `None; upper= `None}

let chop_dist_name (fname : string) : string Option.t =
Expand Down
22 changes: 13 additions & 9 deletions src/frontend/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ type ('e, 's, 'l, 'f) statement =
| Block of 's list
| VarDecl of
{ decl_type: 'e Middle.Type.t
; scale: 'e Scale.t
; transformation: 'e Transformation.t
; identifier: identifier
; initial_value: 'e option
Expand Down Expand Up @@ -321,15 +322,15 @@ let get_loc_dt (t : untyped_expression Type.t) =

let get_loc_tf (t : untyped_expression Transformation.t) =
match t with
| Lower e
|Upper e
|LowerUpper (e, _)
|Offset e
|Multiplier e
|OffsetMultiplier (e, _) ->
Some e.emeta.loc.begin_loc
| Lower e | Upper e | LowerUpper (e, _) -> Some e.emeta.loc.begin_loc
| _ -> None

let get_loc_scale (s : untyped_expression Scale.t) =
match s with
| Offset e | Multiplier e | OffsetMultiplier (e, _) ->
Some e.emeta.loc.begin_loc
| Native -> None

let get_first_loc (s : untyped_statement) =
match s.stmt with
| Assignment {assign_lhs; _} -> assign_lhs.lmeta.loc.end_loc
Expand All @@ -348,10 +349,13 @@ let get_first_loc (s : untyped_statement) =
| Tilde {arg; _} -> get_loc_expr arg
| Break | Continue | ReturnVoid | Print _ | Reject _ | Skip ->
s.smeta.loc.end_loc
| VarDecl {decl_type; transformation; identifier; _} -> (
| VarDecl {decl_type; scale; transformation; identifier; _} -> (
match get_loc_dt decl_type with
| Some loc -> loc
| None -> (
match get_loc_tf transformation with
| Some loc -> loc
| None -> identifier.id_loc.begin_loc ) )
| None -> (
match get_loc_scale scale with
| Some loc -> loc
| None -> identifier.id_loc.begin_loc ) ) )
39 changes: 22 additions & 17 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,7 @@ type decl_context =
{transform_action: transform_action; dadlevel: UnsizedType.autodifftype}

let constraint_forl = function
| Transformation.Identity | Offset _ | Multiplier _ | OffsetMultiplier _
|Lower _ | Upper _ | LowerUpper _ ->
| Transformation.Identity | Lower _ | Upper _ | LowerUpper _ ->
Stmt.Helpers.for_scalar
| Ordered | PositiveOrdered | Simplex | UnitVector | CholeskyCorr
|CholeskyCov | Correlation | Covariance ->
Expand All @@ -207,20 +206,23 @@ let same_shape decl_id decl_var id var meta =
; meta } ]

let check_transform_shape decl_id decl_var meta = function
| Transformation.Offset e -> same_shape decl_id decl_var "offset" e meta
| Multiplier e -> same_shape decl_id decl_var "multiplier" e meta
| Lower e -> same_shape decl_id decl_var "lower" e meta
| Transformation.Lower e -> same_shape decl_id decl_var "lower" e meta
| Upper e -> same_shape decl_id decl_var "upper" e meta
| OffsetMultiplier (e1, e2) ->
same_shape decl_id decl_var "offset" e1 meta
@ same_shape decl_id decl_var "multiplier" e2 meta
| LowerUpper (e1, e2) ->
same_shape decl_id decl_var "lower" e1 meta
@ same_shape decl_id decl_var "upper" e2 meta
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity ->
[]

let check_scale_shape decl_id decl_var meta = function
| Scale.Native -> []
| Offset e -> same_shape decl_id decl_var "offset" e meta
| Multiplier e -> same_shape decl_id decl_var "multiplier" e meta
| OffsetMultiplier (e1, e2) ->
same_shape decl_id decl_var "offset" e1 meta
@ same_shape decl_id decl_var "multiplier" e2 meta

let copy_indices indexed (var : Expr.Typed.t) =
if UnsizedType.is_scalar_type var.meta.type_ then var
else
Expand All @@ -236,10 +238,7 @@ let copy_indices indexed (var : Expr.Typed.t) =

let extract_transform_args var = function
| Transformation.Lower a | Upper a -> [copy_indices var a]
| Offset a -> [copy_indices var a; {a with Expr.Fixed.pattern= Lit (Int, "1")}]
| Multiplier a -> [{a with pattern= Lit (Int, "0")}; copy_indices var a]
| LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) ->
[copy_indices var a1; copy_indices var a2]
| LowerUpper (a1, a2) -> [copy_indices var a1; copy_indices var a2]
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity ->
[]
Expand Down Expand Up @@ -269,8 +268,6 @@ let param_size transform sizedtype =
match transform with
| Transformation.Identity | Lower _ | Upper _
|LowerUpper (_, _)
|Offset _ | Multiplier _
|OffsetMultiplier (_, _)
|Ordered | PositiveOrdered | UnitVector ->
sizedtype
| Simplex ->
Expand Down Expand Up @@ -346,8 +343,8 @@ let check_sizedtype name =
(ll, Type.Sized st)
| Unsized ut -> ([], Unsized ut)

let trans_decl {transform_action; dadlevel} smeta decl_type transform identifier
initial_value =
let trans_decl {transform_action; dadlevel} smeta decl_type transform scale
identifier initial_value =
let decl_id = identifier.Ast.name in
let rhs = Option.map ~f:trans_expr initial_value in
let size_checks, dt = check_sizedtype identifier.name decl_type in
Expand Down Expand Up @@ -377,6 +374,7 @@ let trans_decl {transform_action; dadlevel} smeta decl_type transform identifier
| Constrain | Unconstrain -> Common.FatalError.fatal_error ()
| Check ->
check_transform_shape decl_id decl_var smeta transform
@ check_scale_shape decl_id decl_var smeta scale
@ check_decl decl_var dt decl_id transform smeta dadlevel
| IgnoreTransform -> [] in
size_checks @ (decl :: rhs_assignment) @ constrain_checks
Expand Down Expand Up @@ -519,9 +517,11 @@ let rec trans_stmt ud_dists (declc : decl_context) (ts : Ast.typed_statement) =
[%message
"Found function definition statement outside of function block"]
| Ast.VarDecl
{decl_type; transformation; identifier; initial_value; is_global= _} ->
{decl_type; scale; transformation; identifier; initial_value; is_global= _}
->
trans_decl declc smeta decl_type
(Transformation.map trans_expr transformation)
(Scale.map trans_expr scale)
identifier initial_value
| Ast.Block stmts -> Block (List.concat_map ~f:trans_stmt stmts) |> swrap
| Ast.Profile (name, stmts) ->
Expand Down Expand Up @@ -636,11 +636,13 @@ let trans_block ud_dists declc block prog =
VarDecl
{ decl_type= Sized type_
; identifier
; scale
; transformation
; initial_value
; is_global= true }
; smeta } ->
let decl_id = identifier.Ast.name in
let scale = Scale.map trans_expr scale in
let transform = Transformation.map trans_expr transformation in
let rhs = Option.map ~f:trans_expr initial_value in
let size, type_ =
Expand Down Expand Up @@ -676,15 +678,18 @@ let trans_block ud_dists declc block prog =
{ out_constrained_st= type_
; out_unconstrained_st= param_size transform type_
; out_block= block
; out_scale= scale
; out_trans= transform } ) in
let stmts =
if Utils.is_user_ident decl_id then
let constrain_checks =
match declc.transform_action with
| Constrain | Unconstrain ->
check_transform_shape decl_id decl_var smeta.loc transform
@ check_scale_shape decl_id decl_var smeta.loc scale
| Check ->
check_transform_shape decl_id decl_var smeta.loc transform
@ check_scale_shape decl_id decl_var smeta.loc scale
@ check_decl decl_var (Sized type_) decl_id transform
smeta.loc declc.dadlevel
| IgnoreTransform -> [] in
Expand Down
2 changes: 2 additions & 0 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,14 @@ let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
match stmt with
| VarDecl
{ decl_type= d
; scale= s
; transformation= t
; identifier
; initial_value= init
; is_global } ->
VarDecl
{ decl_type= Middle.Type.map no_parens d
; scale= Middle.Scale.map keep_parens s
; transformation= Middle.Transformation.map keep_parens t
; identifier
; initial_value= Option.map ~f:no_parens init
Expand Down
36 changes: 20 additions & 16 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,6 @@ let pp_transformation ppf = function
| Upper e -> Fmt.pf ppf "<@[upper=%a@]>" pp_expression e
| LowerUpper (e1, e2) ->
Fmt.pf ppf "<@[lower=%a,@ upper=%a@]>" pp_expression e1 pp_expression e2
| Offset e -> Fmt.pf ppf "<@[offset=%a@]>" pp_expression e
| Multiplier e -> Fmt.pf ppf "<@[multiplier=%a@]>" pp_expression e
| OffsetMultiplier (e1, e2) ->
Fmt.pf ppf "<@[offset=%a,@ multiplier=%a@]>" pp_expression e1
pp_expression e2
| Ordered -> Fmt.pf ppf ""
| PositiveOrdered -> Fmt.pf ppf ""
| Simplex -> Fmt.pf ppf ""
Expand All @@ -335,7 +330,15 @@ let pp_transformation ppf = function
| Correlation -> Fmt.pf ppf ""
| Covariance -> Fmt.pf ppf ""

let pp_transformed_type ppf (pst, trans) =
let pp_scale ppf = function
| Middle.Scale.Native -> Fmt.pf ppf ""
| Offset e -> Fmt.pf ppf "<@[offset=%a@]>" pp_expression e
| Multiplier e -> Fmt.pf ppf "<@[multiplier=%a@]>" pp_expression e
| OffsetMultiplier (e1, e2) ->
Fmt.pf ppf "<@[offset=%a,@ multiplier=%a@]>" pp_expression e1
pp_expression e2

let pp_transformed_type ppf (pst, trans, scale) =
let rec discard_arrays pst =
match pst with
| Middle.Type.Sized st ->
Expand Down Expand Up @@ -374,15 +377,15 @@ let pp_transformed_type ppf (pst, trans) =
| _ -> Fmt.nop in
match trans with
| Middle.Transformation.Identity ->
Fmt.pf ppf "%a%a" unsizedtype_fmt () sizes_fmt ()
| Lower _ | Upper _ | LowerUpper _ | Offset _ | Multiplier _
|OffsetMultiplier _ ->
Fmt.pf ppf "%a%a%a" unsizedtype_fmt () pp_transformation trans sizes_fmt
()
| Ordered -> Fmt.pf ppf "ordered%a" sizes_fmt ()
| PositiveOrdered -> Fmt.pf ppf "positive_ordered%a" sizes_fmt ()
| Simplex -> Fmt.pf ppf "simplex%a" sizes_fmt ()
| UnitVector -> Fmt.pf ppf "unit_vector%a" sizes_fmt ()
Fmt.pf ppf "%a%a%a" unsizedtype_fmt () pp_scale scale sizes_fmt ()
| Lower _ | Upper _ | LowerUpper _ ->
Fmt.pf ppf "%a%a%a%a" unsizedtype_fmt () pp_transformation trans pp_scale
scale sizes_fmt ()
| Ordered -> Fmt.pf ppf "ordered%a%a" pp_scale scale sizes_fmt ()
| PositiveOrdered ->
Fmt.pf ppf "positive_ordered%a%a" pp_scale scale sizes_fmt ()
| Simplex -> Fmt.pf ppf "simplex%a%a" pp_scale scale sizes_fmt ()
| UnitVector -> Fmt.pf ppf "unit_vector%a%a" pp_scale scale sizes_fmt ()
| CholeskyCorr -> Fmt.pf ppf "cholesky_factor_corr%a" cov_sizes_fmt ()
| CholeskyCov -> Fmt.pf ppf "cholesky_factor_cov%a" cov_sizes_fmt ()
| Correlation -> Fmt.pf ppf "corr_matrix%a" cov_sizes_fmt ()
Expand Down Expand Up @@ -483,6 +486,7 @@ and pp_statement ppf ({stmt= s_content; smeta= {loc}} as ss : untyped_statement)
Fmt.pf ppf "}"
| VarDecl
{ decl_type= pst
; scale
; transformation= trans
; identifier= id
; initial_value= init
Expand All @@ -497,7 +501,7 @@ and pp_statement ppf ({stmt= s_content; smeta= {loc}} as ss : untyped_statement)
| Unsized _ -> [] in
with_hbox ppf (fun () ->
Fmt.pf ppf "%a%a %a%a;" pp_array_dims es pp_transformed_type
(pst, trans) pp_identifier id pp_init init )
(pst, trans, scale) pp_identifier id pp_init init )
| FunDef {returntype= rt; funname= id; arguments= args; body= b} -> (
Fmt.pf ppf "%a %a(" pp_returntype rt pp_identifier id ;
let loc_of (_, _, id) = id.id_loc in
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ module StatementError = struct
Fmt.pf ppf
"Bounds of integer variable must be of type int. Found type real."
| ComplexTransform ->
Fmt.pf ppf "Complex types do not support transformations."
Fmt.pf ppf "Complex types do not support bounds or scales."
| TransformedParamsInt ->
Fmt.pf ppf "(Transformed) Parameters cannot be integers."
| MismatchFunDefDecl (name, Some ut) ->
Expand Down
41 changes: 28 additions & 13 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ and check_profile loc cf tenv name stmts =
mk_typed_statement ~stmt:(Profile (name, stmts)) ~return_type ~loc

(* variable declarations *)
and verify_valid_transformation_for_type loc is_global sized_ty trans =
and verify_valid_transformation_for_type loc is_global sized_ty trans scale =
let is_real {emeta; _} = emeta.type_ = UReal in
let is_real_transformation =
match trans with
Expand All @@ -1127,10 +1127,13 @@ and verify_valid_transformation_for_type loc is_global sized_ty trans =
| _ -> false in
if is_global && sized_ty = SizedType.SInt && is_real_transformation then
Semantic_error.non_int_bounds loc |> error ;
let is_transformation =
let is_transformed =
match trans with Transformation.Identity -> false | _ -> true in
let is_scaled = match scale with Scale.Native -> false | _ -> true in
if
is_global && SizedType.(inner_type sized_ty = SComplex) && is_transformation
is_global
&& SizedType.(inner_type sized_ty = SComplex)
&& (is_transformed || is_scaled)
then Semantic_error.complex_transform loc |> error

and verify_transformed_param_ty loc cf is_global unsized_ty =
Expand Down Expand Up @@ -1187,10 +1190,6 @@ and check_transformation cf tenv ut trans =
| Upper e -> check e "Upper bound" |> Upper
| LowerUpper (e1, e2) ->
(check e1 "Lower bound", check e2 "Upper bound") |> LowerUpper
| Offset e -> check e "Offset" |> Offset
| Multiplier e -> check e "Multiplier" |> Multiplier
| OffsetMultiplier (e1, e2) ->
(check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier
| Ordered -> Ordered
| PositiveOrdered -> PositiveOrdered
| Simplex -> Simplex
Expand All @@ -1200,24 +1199,36 @@ and check_transformation cf tenv ut trans =
| Correlation -> Correlation
| Covariance -> Covariance

and check_var_decl loc cf tenv sized_ty trans id init is_global =
and check_scaling cf tenv ut scale =
let check e msg = check_expression_of_scalar_or_type cf tenv ut e msg in
match scale with
| Scale.Native -> Scale.Native
| Offset e -> check e "Offset" |> Offset
| Multiplier e -> check e "Multiplier" |> Multiplier
| OffsetMultiplier (e1, e2) ->
(check e1 "Offset", check e2 "Multiplier") |> OffsetMultiplier

and check_var_decl loc cf tenv sized_ty trans scale id init is_global =
let checked_type =
check_sizedtype {cf with in_toplevel_decl= is_global} tenv sized_ty in
let unsized_type = SizedType.to_unsized checked_type in
let checked_trans = check_transformation cf tenv unsized_type trans in
let checked_scale = check_scaling cf tenv unsized_type scale in
verify_identifier id ;
verify_name_fresh tenv id ~is_udf:false ;
let tenv =
Env.add tenv id.name unsized_type
(`Variable {origin= cf.current_block; global= is_global; readonly= false})
in
let tinit = check_var_decl_initial_value loc cf tenv id init in
verify_valid_transformation_for_type loc is_global checked_type checked_trans ;
verify_valid_transformation_for_type loc is_global checked_type checked_trans
checked_scale ;
verify_transformed_param_ty loc cf is_global unsized_type ;
let stmt =
VarDecl
{ decl_type= Sized checked_type
; transformation= checked_trans
; scale= checked_scale
; identifier= id
; initial_value= tinit
; is_global } in
Expand Down Expand Up @@ -1394,10 +1405,14 @@ and check_statement (cf : context_flags_record) (tenv : Env.t)
[%message "Don't support unsized declarations yet."]
(* these two are special in that they're allowed to change the type environment *)
| VarDecl
{decl_type= Sized st; transformation; identifier; initial_value; is_global}
->
check_var_decl loc cf tenv st transformation identifier initial_value
is_global
{ decl_type= Sized st
; transformation
; scale
; identifier
; initial_value
; is_global } ->
check_var_decl loc cf tenv st transformation scale identifier
initial_value is_global
| FunDef {returntype; funname; arguments; body} ->
check_fundef loc cf tenv returntype funname arguments body

Expand Down
Loading