Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel reduce_sum #451

Merged
merged 80 commits into from
Apr 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
55bf838
hierachiral reduce
rok-cesnovar Jan 30, 2020
818349d
first silly prototype
rok-cesnovar Jan 30, 2020
a45c715
fixed order of arguments
rok-cesnovar Jan 30, 2020
3d2c15f
added normal
rok-cesnovar Jan 30, 2020
b18bd88
promote
rok-cesnovar Jan 31, 2020
4c7064d
basic version
rok-cesnovar Jan 31, 2020
975fad6
parallel_reduce testing
rok-cesnovar Jan 31, 2020
ad67e35
promote test
rok-cesnovar Jan 31, 2020
d842ca9
brute force version one done
rok-cesnovar Jan 31, 2020
977b695
cleanup
rok-cesnovar Jan 31, 2020
ced0170
after cleanup
rok-cesnovar Jan 31, 2020
516d7f2
format code
rok-cesnovar Jan 31, 2020
d406f58
test 0 to 3 real args
rok-cesnovar Jan 31, 2020
c98ff7d
fix the original functor args order
rok-cesnovar Jan 31, 2020
549f72a
dune promote
rok-cesnovar Feb 1, 2020
f3edf9c
slighly nicer creation of duplicate functors
rok-cesnovar Feb 1, 2020
9ae8d04
comment
rok-cesnovar Feb 1, 2020
110c0c6
remove commented out code
rok-cesnovar Feb 1, 2020
b0f1871
fix unwanted change in order
rok-cesnovar Feb 1, 2020
83da2e3
changed to rsfunctor
rok-cesnovar Feb 1, 2020
28ec23a
dune promote
rok-cesnovar Feb 1, 2020
e9fec03
remove duplicated signature, nicer code
rok-cesnovar Feb 1, 2020
5b124be
dune promote
rok-cesnovar Feb 1, 2020
1d3f47f
try to squash the indexing bug
rok-cesnovar Feb 1, 2020
ba4e0ee
remove hpp files
rok-cesnovar Feb 1, 2020
949a696
remove unused expected files
rok-cesnovar Feb 1, 2020
b27fcc8
concat with ^
rok-cesnovar Feb 2, 2020
fa37548
added reduce sum semantic error
rok-cesnovar Feb 4, 2020
6405cd3
Merge remote-tracking branch 'origin/master' into parallel_reduce
rok-cesnovar Feb 4, 2020
77a9f2e
Revert "added reduce sum semantic error"
rok-cesnovar Feb 5, 2020
cc3b774
only add the rsfunctor if first two arguments are ints
rok-cesnovar Feb 5, 2020
8997183
Refactored parallel reduce code for scalability
rybern Feb 6, 2020
5312a59
Merge pull request #1 from rybern/parallel_reduce
rok-cesnovar Feb 7, 2020
4ef08cf
support up to 5 arguments
rok-cesnovar Feb 7, 2020
f712de6
move back to only 3 arguments to speedup dune runtest
rok-cesnovar Feb 7, 2020
f49d710
proper handling of creating rsfunctors for funs used in reduce_sum
rok-cesnovar Feb 7, 2020
34134ae
try adding Any type
rok-cesnovar Feb 7, 2020
8f9f3d7
added for up to 8 arguments
rok-cesnovar Feb 7, 2020
7e5b369
expand test
rok-cesnovar Feb 7, 2020
ff8b373
Merge branch 'master' into parallel_reduce
rok-cesnovar Feb 10, 2020
2874578
remove the code that used Any
rok-cesnovar Feb 10, 2020
d795755
format
rok-cesnovar Feb 10, 2020
fdcf2c5
more cleanup
rok-cesnovar Feb 10, 2020
e166c5b
skip semantic check on reduce_sum
rok-cesnovar Feb 10, 2020
21cd5cc
Merge remote-tracking branch 'bstatcomp/master' into parallel_reduce
rok-cesnovar Mar 11, 2020
0b414a4
fix the bug where functions only used inside functions were not label…
rok-cesnovar Mar 11, 2020
97ed338
cleanup
rok-cesnovar Mar 11, 2020
928118a
add missing slice type
rok-cesnovar Mar 11, 2020
454eca0
cleanup semantic_check function
rok-cesnovar Mar 11, 2020
80e150e
test two deep use of reduce_sum
rok-cesnovar Mar 11, 2020
eef5a5d
dune promote
rok-cesnovar Mar 11, 2020
69586e6
semantic check (with the same error message) done, added tests for ba…
rok-cesnovar Mar 14, 2020
f2b16d3
Merge branch 'master' into parallel_reduce
rok-cesnovar Mar 14, 2020
825e3f3
Merge branch 'master' into parallel_reduce
rok-cesnovar Apr 4, 2020
144210a
add reduce_sum_static
rok-cesnovar Apr 4, 2020
eecc0e5
add tests for reduce_sum_static and a new way of printing errors
rok-cesnovar Apr 5, 2020
871696d
handle return type separately
rok-cesnovar Apr 5, 2020
971a7a7
split the generic error message case
rok-cesnovar Apr 6, 2020
7fd279c
formatting
rok-cesnovar Apr 6, 2020
4af3ae4
cleanup
rok-cesnovar Apr 6, 2020
590e6f6
add error message for additiona arg mismatch
rok-cesnovar Apr 6, 2020
84e2722
add dots to the error message
rok-cesnovar Apr 6, 2020
bf77a47
add more thorough test models
rok-cesnovar Apr 9, 2020
fd5120c
add support for recursion and sliced types
rok-cesnovar Apr 9, 2020
6590fae
format
rok-cesnovar Apr 9, 2020
6114033
update test models
rok-cesnovar Apr 9, 2020
c59bcdb
cleanup matching, change allowed_slice_type name
rok-cesnovar Apr 10, 2020
33b61b0
changed arg match check, changed test, simplified pattern to match
rok-cesnovar Apr 10, 2020
9b4c1bc
format
rok-cesnovar Apr 10, 2020
0076b2a
dont allow user defined functions named reduce_sum or reduce_sum_stat…
rok-cesnovar Apr 10, 2020
9a05825
change the error message
rok-cesnovar Apr 10, 2020
50d3e36
extend the test model so a function is only used in compound functions
rok-cesnovar Apr 10, 2020
5697210
dune promote
rok-cesnovar Apr 10, 2020
f206b2c
add recursive call for all patterns in find_functors
rok-cesnovar Apr 10, 2020
9a0f0e7
nhuures fix for find_functors
rok-cesnovar Apr 10, 2020
f377675
remove the use of string literals
rok-cesnovar Apr 10, 2020
315f7db
move to sets and make sure functor declarations have a trailing newline
rok-cesnovar Apr 10, 2020
cc7b140
cleanup
rok-cesnovar Apr 10, 2020
8ec4f4a
fix off-by-one error in the error message
rok-cesnovar Apr 11, 2020
7a0e718
consistent naming, is_reduce_sum_fn helper, clearer names, cleanup an…
rok-cesnovar Apr 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/frontend/Debug_data_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ let rec pp_value_json ppf e =
Fmt.(pf ppf "[@[<hov 1>%a@]]" (list ~sep:comma pp_value_json) l)
| _ -> failwith "This should never happen."


let var_decl_id d =
match d.stmt with
| VarDecl {identifier; _} -> identifier.name
Expand Down
41 changes: 41 additions & 0 deletions src/frontend/Semantic_check.ml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ let check_fresh_variable_basic id is_nullary_function =
Stan_math_signatures.is_stan_math_function_name id.name
&& ( is_nullary_function
|| Stan_math_signatures.stan_math_returntype id.name [] = None )
|| Stan_math_signatures.is_reduce_sum_fn id.name
then Semantic_error.ident_is_stanmath_name id.id_loc id.name |> error
else
match Symbol_table.look vm id.name with
Expand Down Expand Up @@ -314,6 +315,44 @@ let semantic_check_fn_stan_math ~is_cond_dist ~loc id es =
|> Semantic_error.illtyped_stanlib_fn_app loc id.name
|> Validate.error

let semantic_check_reduce_sum ~is_cond_dist ~loc id es =
let arg_match (x_ad, x_t) y =
UnsizedType.check_of_same_type_mod_conv "" x_t y.emeta.type_
&& UnsizedType.autodifftype_can_convert x_ad y.emeta.ad_level
in
let args_match a b =
List.length a = List.length b && List.for_all2_exn ~f:arg_match a b
in
match es with
| { emeta=
{ type_=
UnsizedType.UFun
( (_, UInt)
:: (_, UInt)
:: ((_, sliced_arg_fun_type) as sliced_arg_fun) :: fun_args
, ReturnType UReal ); _ }; _ }
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
:: sliced :: {emeta= {type_= UInt; _}; _} :: args
when arg_match sliced_arg_fun sliced
&& List.mem Stan_math_signatures.reduce_sum_slice_types
sliced.emeta.type_ ~equal:( = )
&& List.mem Stan_math_signatures.reduce_sum_slice_types
sliced_arg_fun_type ~equal:( = ) ->
if args_match fun_args args then
mk_typed_expression
~expr:(mk_fun_app ~is_cond_dist (StanLib, id, es))
~ad_level:(lub_ad_e es) ~type_:UnsizedType.UReal ~loc
|> Validate.ok
else
Semantic_error.illtyped_reduce_sum loc id.name
(List.map ~f:type_of_expr_typed es)
(sliced_arg_fun :: fun_args)
|> Validate.error
| _ ->
es
|> List.map ~f:type_of_expr_typed
|> Semantic_error.illtyped_reduce_sum_generic loc id.name
|> Validate.error

let fn_kind_from_application id es =
(* We need to check an application here, rather than a mere name of the
function because, technically, user defined functions can shadow
Expand All @@ -332,6 +371,8 @@ let fn_kind_from_application id es =
*)
let semantic_check_fn ~is_cond_dist ~loc id es =
match fn_kind_from_application id es with
| StanLib when Stan_math_signatures.is_reduce_sum_fn id.name ->
semantic_check_reduce_sum ~is_cond_dist ~loc id es
| StanLib -> semantic_check_fn_stan_math ~is_cond_dist ~loc id es
| UserDefined -> semantic_check_fn_normal ~is_cond_dist ~loc id es

Expand Down
52 changes: 52 additions & 0 deletions src/frontend/Semantic_error.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ module TypeError = struct
| IllTypedAssignment of
Ast.assignmentoperator * UnsizedType.t * UnsizedType.t
| IllTypedTernaryIf of UnsizedType.t * UnsizedType.t * UnsizedType.t
| IllTypedReduceSum of
string
* UnsizedType.t list
* (UnsizedType.autodifftype * UnsizedType.t) list
| IllTypedReduceSumGeneric of string * UnsizedType.t list
| ReturningFnExpectedNonReturningFound of string
| ReturningFnExpectedNonFnFound of string
| ReturningFnExpectedUndeclaredIdentFound of string
Expand Down Expand Up @@ -88,6 +93,47 @@ module TypeError = struct
Fmt.pf ppf
"Condition in ternary expression must be primitive int; found type=%a"
UnsizedType.pp ut1
| IllTypedReduceSum (name, arg_tys, args) ->
let arg_types = List.map ~f:(fun (_, t) -> t) args in
let first, rest = List.split_n arg_types 1 in
let generate_reduce_sum_sig =
List.concat
[ [ UnsizedType.UFun
( (AutoDiffable, UInt) :: (AutoDiffable, UInt) :: args
, ReturnType UReal ) ]
; first; [UInt]; rest ]
in
Fmt.pf ppf
"Ill-typed arguments supplied to function '%s'. Expected \
arguments:@[<h>%a@]\n\
@[<h>Instead supplied arguments of incompatible type: %a@]"
name
Fmt.(list UnsizedType.pp ~sep:comma)
generate_reduce_sum_sig
Fmt.(list UnsizedType.pp ~sep:comma)
arg_tys
| IllTypedReduceSumGeneric (name, arg_tys) ->
let rec n_commas n = if n = 0 then "" else "," ^ n_commas (n - 1) in
let type_string (a, b, c, d) i =
Fmt.strf "(%a, %a, T[%s], ...) => %a, %a, T[%s], ...\n"
Pretty_printing.pp_unsizedtype a Pretty_printing.pp_unsizedtype b
(n_commas (i - 1))
Pretty_printing.pp_unsizedtype c Pretty_printing.pp_unsizedtype d
(n_commas i)
in
let lines =
List.map
~f:(fun i -> type_string (UInt, UInt, UReal, UInt) i)
Stan_math_signatures.reduce_sum_allowed_dimensionalities
in
Fmt.pf ppf
"Ill-typed arguments supplied to function '%s'. Available arguments:\n\
%sWhere T is any one of int, real, vector, row_vector or \
matrix.@[<h>Instead supplied arguments of incompatible type: %a@]"
name
(String.concat ~sep:"" lines)
Fmt.(list UnsizedType.pp ~sep:comma)
arg_tys
| NotIndexable ut ->
Fmt.pf ppf
"Only expressions of array, matrix, row_vector and vector type may \
Expand Down Expand Up @@ -401,6 +447,12 @@ let illtyped_ternary_if loc predt lt rt =
let returning_fn_expected_nonreturning_found loc name =
TypeError (loc, TypeError.ReturningFnExpectedNonReturningFound name)

let illtyped_reduce_sum loc name arg_tys args =
TypeError (loc, TypeError.IllTypedReduceSum (name, arg_tys, args))

let illtyped_reduce_sum_generic loc name arg_tys =
TypeError (loc, TypeError.IllTypedReduceSumGeneric (name, arg_tys))

let returning_fn_expected_nonfn_found loc name =
TypeError (loc, TypeError.ReturningFnExpectedNonFnFound name)

Expand Down
10 changes: 10 additions & 0 deletions src/frontend/Semantic_error.mli
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ val returning_fn_expected_nonfn_found : Location_span.t -> string -> t
val returning_fn_expected_undeclaredident_found :
Location_span.t -> string -> t

val illtyped_reduce_sum :
Location_span.t
-> string
-> UnsizedType.t list
-> (UnsizedType.autodifftype * UnsizedType.t) list
-> t

val illtyped_reduce_sum_generic :
Location_span.t -> string -> UnsizedType.t list -> t

val nonreturning_fn_expected_returning_found : Location_span.t -> string -> t
val nonreturning_fn_expected_nonfn_found : Location_span.t -> string -> t

Expand Down
34 changes: 26 additions & 8 deletions src/middle/Stan_math_signatures.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
(** The signatures of the Stan Math library, which are used for type checking *)
open Core_kernel

(** The "dimensionality" (bad name?) is supposed to help us represent the
vectorized nature of many Stan functions. It allows us to represent when
a function argument can be just a real or matrix, or some common forms of
Expand Down Expand Up @@ -82,6 +81,18 @@ let rec ints_to_real = function
| UArray t -> UArray (ints_to_real t)
| x -> x

let reduce_sum_allowed_dimensionalities = [1; 2; 3; 4; 5; 6; 7]

let reduce_sum_slice_types =
let base_slice_type i =
[ bare_array_type (UnsizedType.UReal, i)
; bare_array_type (UnsizedType.UInt, i)
; bare_array_type (UnsizedType.UMatrix, i)
; bare_array_type (UnsizedType.UVector, i)
; bare_array_type (UnsizedType.URowVector, i) ]
in
List.concat (List.map ~f:base_slice_type reduce_sum_allowed_dimensionalities)

let mk_declarative_sig (fnkinds, name, args) =
let sfxes = function
| Lpmf -> ["_lpmf"; "_log"]
Expand Down Expand Up @@ -128,6 +139,10 @@ let mk_declarative_sig (fnkinds, name, args) =
let full_lpdf = [Lpdf; Rng; Ccdf; Cdf]
let full_lpmf = [Lpmf; Rng; Ccdf; Cdf]

let reduce_sum_functions = ["reduce_sum"; "reduce_sum_static"]
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved

let is_reduce_sum_fn f = List.mem ~equal:String.equal reduce_sum_functions f

let distributions =
[ (full_lpmf, "beta_binomial", [DVInt; DVInt; DVReal; DVReal])
; (full_lpdf, "beta", [DVReal; DVReal; DVReal])
Expand Down Expand Up @@ -251,13 +266,16 @@ let stan_math_returntype name args =
UnsizedType.check_compatible_arguments_mod_conv name (snd x) args )
namematches
in
if List.length filteredmatches = 0 then None
(* Return the least return type in case there are multiple options (due to implicit UInt-UReal conversion), where UInt<UReal *)
else
Some
(List.hd_exn
(List.sort ~compare:UnsizedType.compare_returntype
(List.map ~f:fst filteredmatches)))
match name with
| x when is_reduce_sum_fn x -> Some (UnsizedType.ReturnType UReal)
| _ ->
if List.length filteredmatches = 0 then None
(* Return the least return type in case there are multiple options (due to implicit UInt-UReal conversion), where UInt<UReal *)
else
Some
(List.hd_exn
(List.sort ~compare:UnsizedType.compare_returntype
(List.map ~f:fst filteredmatches)))

let is_stan_math_function_name name =
let name = Utils.stdlib_distribution_name name in
Expand Down
12 changes: 11 additions & 1 deletion src/stan_math_backend/Expression_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ let fn_renames =
*)
let map_rect_counter = ref 0
let functor_suffix = "_functor__"
let reduce_sum_functor_suffix = "_rsfunctor__"

let functor_suffix_select hof =
if Stan_math_signatures.is_reduce_sum_fn hof then reduce_sum_functor_suffix
else functor_suffix

let rec pp_index ppf = function
| Index.All -> pf ppf "index_omni()"
Expand Down Expand Up @@ -269,7 +274,9 @@ and gen_fun_app ppf fname es =
let convert_hof_vars = function
| {Expr.Fixed.pattern= Var name; meta= {Expr.Typed.Meta.type_= UFun _; _}}
as e ->
{e with pattern= FunApp (StanLib, name ^ functor_suffix, [])}
{ e with
pattern= FunApp (StanLib, name ^ functor_suffix_select fname, [])
}
| e -> e
in
let converted_es = List.map ~f:convert_hof_vars es in
Expand Down Expand Up @@ -300,6 +307,9 @@ and gen_fun_app ppf fname es =
, "integrate_ode_rk45"
, f :: y0 :: t0 :: ts :: theta :: x :: x_int :: tl ) ->
(fname, f :: y0 :: t0 :: ts :: theta :: x :: x_int :: msgs :: tl)
| true, x, {pattern= FunApp (_, f, _); _} :: grainsize :: container :: tl
when Stan_math_signatures.is_reduce_sum_fn x ->
(strf "%s<%s>" fname f, grainsize :: container :: msgs :: tl)
| true, "map_rect", {pattern= FunApp (_, f, _); _} :: tl ->
incr map_rect_counter ;
(strf "%s<%d, %s>" fname !map_rect_counter f, tl @ [msgs])
Expand Down
56 changes: 52 additions & 4 deletions src/stan_math_backend/Stan_math_code_gen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ let pp_template_decorator ppf = function
(* XXX refactor this please - one idea might be to have different functions for
printing user defined distributions vs rngs vs regular functions.
*)
let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _}) =
let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _})
funs_used_in_reduce_sum =
let is_lp = is_user_lp fdname in
let is_dist = is_user_dist fdname in
let is_rng = String.is_suffix fdname ~suffix:"_rng" in
Expand Down Expand Up @@ -164,6 +165,18 @@ let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _}) =
in
pf ppf "%s(@[<hov>%a@]) " name (list ~sep:comma string) arg_strs
in
let pp_sig_rs ppf name =
pp_template_decorator ppf templates ;
pp_returntype ppf fdargs fdrt ;
let first_three, rest = List.split_n args 3 in
let arg_strs =
first_three
@ ["std::ostream* pstream__"]
@ rest
@ mk_extra_args extra_templates extra
in
pf ppf "%s(@[<hov>%a@]) " name (list ~sep:comma string) arg_strs
in
pp_sig ppf fdname ;
match Stmt.Fixed.(fdbody.pattern) with
| Skip -> pf ppf ";@ "
Expand All @@ -173,7 +186,17 @@ let pp_fun_def ppf Program.({fdrt; fdname; fdargs; fdbody; _}) =
functor_suffix pp_sig "operator()" pp_call_str
( fdname
, List.map ~f:(fun (_, name, _) -> name) fdargs @ extra @ ["pstream__"]
)
) ;
if String.Set.mem funs_used_in_reduce_sum fdname then
(* Produces the reduce_sum functors that has the pstream argument
as the third and not last argument *)
let first_two, rest_fdargs = List.split_n fdargs 2 in
pf ppf "@,@,struct %s%s {@,%a const @,{@,return %a;@,}@,};@," fdname
reduce_sum_functor_suffix pp_sig_rs "operator()" pp_call_str
( fdname
, List.map ~f:(fun (_, name, _) -> name ^ " + 1") first_two
@ List.map ~f:(fun (_, name, _) -> name) rest_fdargs
@ extra @ ["pstream__"] )

let version = "// Code generated by %%NAME%% %%VERSION%%"
let includes = "#include <stan/model/model_header.hpp>"
Expand Down Expand Up @@ -657,12 +680,37 @@ let pp_register_map_rect_functors ppf p =
(list ~sep:cut pp_register_functor)
(List.mapi ~f:(fun i f -> (i + 1, f)) functors)

let fun_used_in_reduce_sum p =
let rec find_functors_expr accum Expr.Fixed.({pattern; _}) =
String.Set.union accum
( match pattern with
| FunApp (StanLib, x, {pattern= Var f; _} :: _)
when Stan_math_signatures.is_reduce_sum_fn x ->
String.Set.of_list [f]
| x -> Expr.Fixed.Pattern.fold find_functors_expr accum x )
in
let rec find_functors_stmt accum stmt =
Stmt.Fixed.(
Pattern.fold find_functors_expr find_functors_stmt accum stmt.pattern)
in
Program.fold find_functors_expr find_functors_stmt String.Set.empty p

let pp_prog ppf (p : Program.Typed.t) =
(* First, do some transformations on the MIR itself before we begin printing it.*)
let p, s = Locations.prepare_prog p in
pf ppf "@[<v>@ %s@ %s@ namespace %s {@ %s@ %s@ %a@ %a@ %a@ }@ @]" version
let pp_fun_def_with_rs_list ppf fblock =
pp_fun_def ppf fblock (fun_used_in_reduce_sum p)
in
let reduce_sum_struct_decl =
String.Set.map
~f:(fun x -> "struct " ^ x ^ reduce_sum_functor_suffix ^ ";")
(fun_used_in_reduce_sum p)
in
pf ppf "@[<v>@ %s@ %s@ namespace %s {@ %s@ %s@ %a@ %s@ %a@ %a@ }@ @]" version
includes (namespace p) custom_functions usings Locations.pp_globals s
(list ~sep:cut pp_fun_def) p.functions_block pp_model p ;
(String.concat ~sep:"\n" (String.Set.elements reduce_sum_struct_decl))
(list ~sep:cut pp_fun_def_with_rs_list)
p.functions_block pp_model p ;
pf ppf "@,typedef %s_namespace::%s stan_model;@," p.prog_name p.prog_name ;
pf ppf
{|
Expand Down
4 changes: 3 additions & 1 deletion src/stan_math_backend/Transform_Mir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ let rec switch_expr_to_opencl available_cl_vars (Expr.Fixed.({pattern; _}) as e)
| true -> List.mapi args ~f:(move_cl_args cl_args)
| false -> args
in
let trim_propto f = String.substr_replace_all ~pattern:"_propto_" ~with_:"_" f in
let trim_propto f =
String.substr_replace_all ~pattern:"_propto_" ~with_:"_" f
in
match pattern with
| FunApp (StanLib, f, args) when Map.mem opencl_triggers (trim_propto f) ->
let trigger = Map.find_exn opencl_triggers (trim_propto f) in
Expand Down
13 changes: 13 additions & 0 deletions test/integration/bad/reduce_sum/bad_args_length_mismatch.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
functions {
real my_func(int start, int end, real[] y_slice, real mu, real sigma) {
return normal_lpdf(y_slice | mu, sigma);
}
}

parameters {
real a[5];
}

model {
target += reduce_sum(my_func, a, 1, 0.0);
}
13 changes: 13 additions & 0 deletions test/integration/bad/reduce_sum/bad_args_length_mismatch2.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
functions {
real my_func(int start, int end, real[] y_slice, real mu) {
return normal_lpdf(y_slice | mu, 0.0);
}
}

parameters {
real a[5];
}

model {
target += reduce_sum(my_func, a, 1, 0.0, 0.0);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
functions {
real my_func(int start, int end, real[] y_slice, real mu) {
return normal_lpdf(y_slice | mu, 0.0);
}
}

parameters {
real a[5];
}

model {
target += reduce_sum_static(my_func, a, 1, 0.0, 0.0);
}
Loading