Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sum_to_zero transform #1443

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/analysis_and_optimization/Mir_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
| PositiveOrdered -> {lower= `Lit 0.; upper= `None}
| UnitVector -> {lower= `Lit (-1.); upper= `Lit 1.}
| CholeskyCorr | CholeskyCov | Correlation | Covariance | Ordered | Offset _
|Multiplier _ | OffsetMultiplier _
|Identity
|Multiplier _ | OffsetMultiplier _ | Identity
|SumToZero

Check warning on line 60 in src/analysis_and_optimization/Mir_utils.ml

View check run for this annotation

Codecov / codecov/patch

src/analysis_and_optimization/Mir_utils.ml#L59-L60

Added lines #L59 - L60 were not covered by tests
(* This is a stub, but,
until we define a distribution which accepts a tuple,
this doesn't matter.
Expand Down
10 changes: 5 additions & 5 deletions src/frontend/Ast_to_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,8 @@
same_shape decl_id decl_var "lower" e1 meta
@ same_shape decl_id decl_var "upper" e2 meta
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _
|StochasticRow | StochasticColumn ->
|PositiveOrdered | Simplex | UnitVector | SumToZero | Identity
|TupleTransformation _ | StochasticRow | StochasticColumn ->
[]

let copy_indices indexed (var : Expr.Typed.t) =
Expand All @@ -295,8 +295,8 @@
| LowerUpper (a1, a2) | OffsetMultiplier (a1, a2) ->
[copy_indices var a1; copy_indices var a2]
| Covariance | Correlation | CholeskyCov | CholeskyCorr | Ordered
|PositiveOrdered | Simplex | UnitVector | Identity | TupleTransformation _
|StochasticRow | StochasticColumn ->
|PositiveOrdered | Simplex | UnitVector | SumToZero | Identity
|TupleTransformation _ | StochasticRow | StochasticColumn ->

Check warning on line 299 in src/frontend/Ast_to_Mir.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Ast_to_Mir.ml#L298-L299

Added lines #L298 - L299 were not covered by tests
[]

let rec param_size transform sizedtype =
Expand Down Expand Up @@ -349,7 +349,7 @@
(SizedType.STuple
(List.map subtypes_transforms ~f:(fun (st, trans) ->
param_size trans st)))
| Simplex ->
| Simplex | SumToZero ->
shrink_eigen (fun d -> Expr.Helpers.(binop d Minus (int 1))) sizedtype
| CholeskyCorr | Correlation -> shrink_eigen k_choose_2 sizedtype
| StochasticRow -> stoch_size Fn.id min_one sizedtype
Expand Down
7 changes: 4 additions & 3 deletions src/frontend/Pretty_printing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@
| Multiplier e -> pf ppf "<@[multiplier=%a@]>" pp_expression e
| OffsetMultiplier (e1, e2) ->
pf ppf "<@[offset=%a,@ multiplier=%a@]>" pp_expression e1 pp_expression e2
| Identity | Ordered | PositiveOrdered | Simplex | UnitVector | CholeskyCorr
|CholeskyCov | Correlation | Covariance | TupleTransformation _
|StochasticColumn
| Identity | Ordered | PositiveOrdered | Simplex | UnitVector | SumToZero
|CholeskyCorr | CholeskyCov | Correlation | Covariance
|TupleTransformation _ | StochasticColumn

Check warning on line 325 in src/frontend/Pretty_printing.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Pretty_printing.ml#L323-L325

Added lines #L323 - L325 were not covered by tests
|StochasticRow (* tuple transformations are handled in pp_transformed_type *)
->
()
Expand Down Expand Up @@ -360,6 +360,7 @@
| PositiveOrdered -> pf ppf "positive_ordered%a" sizes_fmt ()
| Simplex -> pf ppf "simplex%a" sizes_fmt ()
| UnitVector -> pf ppf "unit_vector%a" sizes_fmt ()
| SumToZero -> pf ppf "sum_to_zero_vector%a" sizes_fmt ()
| CholeskyCorr -> pf ppf "cholesky_factor_corr%a" cov_sizes_fmt ()
| CholeskyCov -> pf ppf "cholesky_factor_cov%a" cov_sizes_fmt ()
| Correlation -> pf ppf "corr_matrix%a" cov_sizes_fmt ()
Expand Down
1 change: 1 addition & 0 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,7 @@ and check_transformation cf tenv ut trans =
| PositiveOrdered -> PositiveOrdered
| Simplex -> Simplex
| UnitVector -> UnitVector
| SumToZero -> SumToZero
| CholeskyCorr -> CholeskyCorr
| CholeskyCov -> CholeskyCov
| Correlation -> Correlation
Expand Down
1 change: 1 addition & 0 deletions src/frontend/lexer.mll
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ rule token = parse
Parser.POSITIVEORDERED }
| "simplex" { lexer_logger "simplex" ; Parser.SIMPLEX }
| "unit_vector" { lexer_logger "unit_vector" ; Parser.UNITVECTOR }
| "sum_to_zero_vector" { lexer_logger "sum_to_zero_vector" ; Parser.SUMTOZERO }
| "cholesky_factor_corr" { lexer_logger "cholesky_factor_corr" ;
Parser.CHOLESKYFACTORCORR }
| "cholesky_factor_cov" { lexer_logger "cholesky_factor_cov" ;
Expand Down
692 changes: 355 additions & 337 deletions src/frontend/parser.messages

Large diffs are not rendered by default.

10 changes: 7 additions & 3 deletions src/frontend/parser.mly
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,10 @@ let nest_unsized_array basic_type n =
ROWVECTOR "row_vector" ARRAY "array" TUPLE "tuple" MATRIX "matrix" ORDERED "ordered"
COMPLEXVECTOR "complex_vector" COMPLEXROWVECTOR "complex_row_vector"
POSITIVEORDERED "positive_ordered" SIMPLEX "simplex" UNITVECTOR "unit_vector"
CHOLESKYFACTORCORR "cholesky_factor_corr" CHOLESKYFACTORCOV "cholesky_factor_cov"
CORRMATRIX "corr_matrix" COVMATRIX "cov_matrix" COMPLEXMATRIX "complex_matrix"
STOCHASTICCOLUMNMATRIX "column_stochastic_matrix" STOCHASTICROWMATRIX "row_stochastic_matrix"
SUMTOZERO "sum_to_zero_vector" CHOLESKYFACTORCORR "cholesky_factor_corr"
CHOLESKYFACTORCOV "cholesky_factor_cov" CORRMATRIX "corr_matrix" COVMATRIX "cov_matrix"
COMPLEXMATRIX "complex_matrix" STOCHASTICCOLUMNMATRIX "column_stochastic_matrix"
STOCHASTICROWMATRIX "row_stochastic_matrix"
%token LOWER "lower" UPPER "upper" OFFSET "offset" MULTIPLIER "multiplier"
%token JACOBIAN "jacobian"
%token <string> INTNUMERAL "24"
Expand Down Expand Up @@ -254,6 +255,7 @@ reserved_word:
| POSITIVEORDERED { "positive_ordered", $loc, true }
| SIMPLEX { "simplex", $loc, true }
| UNITVECTOR { "unit_vector", $loc, true }
| SUMTOZERO { "sum_to_zero_vector", $loc, true }
| CHOLESKYFACTORCORR { "cholesky_factor_corr", $loc, true }
| CHOLESKYFACTORCOV { "cholesky_factor_cov", $loc, true }
| CORRMATRIX { "corr_matrix", $loc, true }
Expand Down Expand Up @@ -520,6 +522,8 @@ top_var_type:
{ grammar_logger "SIMPLEX_top_var_type" ; (SVector (AoS, e), Simplex) }
| UNITVECTOR LBRACK e=expression RBRACK
{ grammar_logger "UNITVECTOR_top_var_type" ; (SVector (AoS, e), UnitVector) }
| SUMTOZERO LBRACK e=expression RBRACK
{ grammar_logger "SUMTOZERO_top_var_type" ; (SVector (AoS, e), SumToZero) }
| CHOLESKYFACTORCORR LBRACK e=expression RBRACK
{
grammar_logger "CHOLESKYFACTORCORR_top_var_type" ;
Expand Down
1 change: 1 addition & 0 deletions src/middle/Transformation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
| PositiveOrdered
| Simplex
| UnitVector
| SumToZero

Check warning on line 19 in src/middle/Transformation.ml

View check run for this annotation

Codecov / codecov/patch

src/middle/Transformation.ml#L19

Added line #L19 was not covered by tests
| CholeskyCorr
| CholeskyCov
| Correlation
Expand Down
1 change: 1 addition & 0 deletions src/stan_math_backend/Lower_expr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ let constraint_to_string = function
| PositiveOrdered -> Some "positive_ordered"
| Simplex -> Some "simplex"
| UnitVector -> Some "unit_vector"
| SumToZero -> Some "sum_to_zero"
| CholeskyCorr -> Some "cholesky_factor_corr"
| CholeskyCov -> Some "cholesky_factor_cov"
| Correlation -> Some "corr_matrix"
Expand Down
10 changes: 10 additions & 0 deletions test/integration/bad/reserved/stanc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,16 @@ Semantic error in 'struct.stan', line 2, column 7 to column 13:
-------------------------------------------------

Identifier 'struct' clashes with reserved keyword.
$ ../../../../../install/default/bin/stanc sum_to_zero.stan
Syntax error in 'sum_to_zero.stan', line 2, column 7 to column 25, parsing error:
-------------------------------------------------
1: data {
2: real sum_to_zero_vector;
^
3: }
-------------------------------------------------

Expected a new identifier but found reserved keyword 'sum_to_zero_vector'.
$ ../../../../../install/default/bin/stanc then.stan
Semantic error in 'then.stan', line 2, column 7 to column 11:
-------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions test/integration/bad/reserved/sum_to_zero.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
data {
real sum_to_zero_vector;
}
6 changes: 3 additions & 3 deletions test/integration/bad/stanc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Syntax error in 'array-expr-decl-bad2.stan', line 2, column 2 to column 10, pars

Invalid type in declaration. Valid types:
int, real, vector, row_vector, matrix,
unit_vector, simplex, ordered, positive_ordered,
unit_vector, simplex, sum_to_zero_vector, ordered, positive_ordered,
corr_matrix, cov_matrix, cholesky_factor_corr, cholesky_factor_cov,
row_stochastic_matrix, column_stochastic_matrix, tuple(...)
optionally preceded by a single array[...]
Expand All @@ -35,7 +35,7 @@ Syntax error in 'array-expr-decl-bad3.stan', line 2, column 2 to column 6, parsi

Invalid type in declaration. Valid types:
int, real, vector, row_vector, matrix,
unit_vector, simplex, ordered, positive_ordered,
unit_vector, simplex, sum_to_zero_vector, ordered, positive_ordered,
corr_matrix, cov_matrix, cholesky_factor_corr, cholesky_factor_cov,
row_stochastic_matrix, column_stochastic_matrix, tuple(...)
optionally preceded by a single array[...]
Expand Down Expand Up @@ -782,7 +782,7 @@ Syntax error in 'err-decl-double.stan', line 2, column 2 to column 8, parsing er

Invalid type in declaration. Valid types:
int, real, vector, row_vector, matrix,
unit_vector, simplex, ordered, positive_ordered,
unit_vector, simplex, sum_to_zero_vector, ordered, positive_ordered,
corr_matrix, cov_matrix, cholesky_factor_corr, cholesky_factor_cov,
row_stochastic_matrix, column_stochastic_matrix, tuple(...)
optionally preceded by a single array[...]
Expand Down
Loading