From bbe8b5109752c85c49418ca83884a10812d90679 Mon Sep 17 00:00:00 2001 From: Rudi Grinberg Date: Thu, 30 Nov 2023 18:01:57 -0600 Subject: [PATCH] chore: use janestreet style and update ocamlformat (#33) Signed-off-by: Rudi Grinberg --- .ocamlformat | 13 +- fiber-lwt/fiber_lwt.ml | 22 +- fiber/bench/fiber_bench.ml | 49 ++-- fiber/src/cancel.ml | 57 ++-- fiber/src/core.ml | 190 +++++++------ fiber/src/fiber.ml | 2 +- fiber/src/fiber.mli | 50 ++-- fiber/src/mutex.ml | 6 +- fiber/src/mvar.ml | 35 +-- fiber/src/pool.ml | 39 +-- fiber/src/scheduler.ml | 96 ++++--- fiber/src/stream.ml | 67 +++-- fiber/src/svar.ml | 11 +- fiber/src/throttle.ml | 19 +- fiber/test/fiber_scheduler.ml | 13 +- fiber/test/fiber_tests.ml | 508 +++++++++++++++++++--------------- fiber/test/test_scheduler.ml | 20 +- fiber/test/test_scheduler.mli | 3 - flake.nix | 2 +- 19 files changed, 646 insertions(+), 556 deletions(-) diff --git a/.ocamlformat b/.ocamlformat index b5dbb66..ac1a5cd 100644 --- a/.ocamlformat +++ b/.ocamlformat @@ -1,12 +1,3 @@ -version=0.24.1 -profile=conventional +version=0.26.1 +profile=janestreet ocaml-version=4.08.0 -break-separators=before -dock-collection-brackets=false -doc-comments=before -let-and=sparse -type-decl=sparse -cases-exp-indent=2 -break-cases=fit-or-vertical -parse-docstrings=true -module-item-spacing=sparse diff --git a/fiber-lwt/fiber_lwt.ml b/fiber-lwt/fiber_lwt.ml index e6bd87c..a30b5d8 100644 --- a/fiber-lwt/fiber_lwt.ml +++ b/fiber-lwt/fiber_lwt.ml @@ -8,23 +8,25 @@ module Fiber_inside_lwt = struct | Fiber.Scheduler.Done x -> Lwt.return x | Fiber.Scheduler.Stalled stalled -> Lwt.bind (Lwt_stream.next fills) (fun fill -> - loop (Fiber.Scheduler.advance stalled [ fill ])) + loop (Fiber.Scheduler.advance stalled [ fill ])) in loop (Fiber.Scheduler.start fiber) + ;; let callback_to_lwt f = Fiber.bind (Fiber.Var.get key) ~f:(function - | None -> - failwith "Fiber_lwt.Fiber_inside_lwt.run_lwt: called outside [run]" + | None -> failwith "Fiber_lwt.Fiber_inside_lwt.run_lwt: called outside [run]" | Some push_fill -> let ivar = Fiber.Ivar.create () in Lwt.async (fun () -> - Lwt.bind - (Lwt.try_bind f - (fun x -> Lwt.return (Ok x)) - (fun exn -> Lwt.return (Error exn))) - (fun x -> - push_fill (Some (Fiber.Fill (ivar, x))); - Lwt.return_unit)); + Lwt.bind + (Lwt.try_bind + f + (fun x -> Lwt.return (Ok x)) + (fun exn -> Lwt.return (Error exn))) + (fun x -> + push_fill (Some (Fiber.Fill (ivar, x))); + Lwt.return_unit)); Fiber.Ivar.read ivar) + ;; end diff --git a/fiber/bench/fiber_bench.ml b/fiber/bench/fiber_bench.ml index 21af460..dbf3c18 100644 --- a/fiber/bench/fiber_bench.ml +++ b/fiber/bench/fiber_bench.ml @@ -4,7 +4,7 @@ open Fiber.O let n = 1000 let%bench_fun "bind" = - fun () -> + fun () -> Fiber.run ~iter:(fun () -> assert false) (let rec loop = function @@ -12,14 +12,16 @@ let%bench_fun "bind" = | n -> Fiber.return () >>= fun () -> loop (n - 1) in loop n) +;; let%bench_fun "create ivar and immediately read" = - fun () -> + fun () -> let ivar = Fiber.Ivar.create () in Fiber.run ~iter:(fun () -> [ Fiber.Fill (ivar, ()) ]) (Fiber.Ivar.read ivar) +;; let%bench_fun "ivar" = - fun () -> + fun () -> let ivar = ref (Fiber.Ivar.create ()) in Fiber.run ~iter:(fun () -> [ Fiber.Fill (!ivar, ()) ]) @@ -31,6 +33,7 @@ let%bench_fun "ivar" = loop (n - 1) in loop n) +;; let%bench_fun "Var.set" = let var = Fiber.Var.create () in @@ -42,6 +45,7 @@ let%bench_fun "Var.set" = | n -> Fiber.Var.set var n (fun () -> loop (n - 1)) in loop n) +;; let%bench_fun "Var.get" = let var = Fiber.Var.create () in @@ -55,12 +59,12 @@ let%bench_fun "Var.get" = loop (n - 1) in Fiber.Var.set var 0 (fun () -> loop n)) +;; let exns = List.init n ~f:(fun _ -> - { Exn_with_backtrace.exn = Exit - ; backtrace = Printexc.get_raw_backtrace () - }) + { Exn_with_backtrace.exn = Exit; backtrace = Printexc.get_raw_backtrace () }) +;; let%bench "catching exceptions" = Fiber.run @@ -70,6 +74,7 @@ let%bench "catching exceptions" = ~on_error:(fun _ -> Fiber.return ()) (fun () -> Fiber.reraise_all exns)) |> ignore +;; let%bench "installing handlers" = Fiber.run @@ -87,23 +92,23 @@ let%bench "installing handlers" = in loop n) |> ignore +;; let%bench_fun "Fiber.fork_and_join" = - fun () -> + fun () -> Fiber.run ~iter:(fun () -> assert false) (let rec loop = function | 0 -> Fiber.return () | n -> - let+ (), () = - Fiber.fork_and_join Fiber.return (fun () -> loop (n - 1)) - in + let+ (), () = Fiber.fork_and_join Fiber.return (fun () -> loop (n - 1)) in () in loop 1000) +;; let%bench_fun "Fiber.fork_and_join_unit" = - fun () -> + fun () -> Fiber.run ~iter:(fun () -> assert false) (let rec loop = function @@ -111,6 +116,7 @@ let%bench_fun "Fiber.fork_and_join_unit" = | n -> Fiber.fork_and_join_unit Fiber.return (fun () -> loop (n - 1)) in loop 1000) +;; let%bench_fun "Fiber.parallel_iter" = let l = List.init 1000 ~f:Fun.id in @@ -118,14 +124,14 @@ let%bench_fun "Fiber.parallel_iter" = Fiber.run ~iter:(fun () -> assert false) (Fiber.parallel_iter l ~f:(fun _ -> Fiber.return ())) +;; let%bench_fun "Fiber.parallel_map" = let l = List.init 1000 ~f:Fun.id in fun () -> - Fiber.run - ~iter:(fun () -> assert false) - (Fiber.parallel_map l ~f:Fiber.return) + Fiber.run ~iter:(fun () -> assert false) (Fiber.parallel_map l ~f:Fiber.return) |> ignore +;; let pool_run tasks = Fiber.run @@ -136,36 +142,37 @@ let pool_run tasks = (fun () -> let* () = Fiber.parallel_iter tasks ~f:(fun (_ : int) -> - Fiber.Pool.task pool ~f:Fiber.return) + Fiber.Pool.task pool ~f:Fiber.return) in Fiber.Pool.close pool)) |> ignore +;; (* some pools are used to run many fibers *) let%bench_fun "Fiber.Pool.run - big" = let l = List.init 1000 ~f:Fun.id in fun () -> pool_run l +;; (* other pools are one-off transients that are created and discarded *) let%bench_fun "Fiber.Pool.run - small" = let l = List.init 2 ~f:Fun.id in fun () -> pool_run l +;; module M = Fiber.Make_parallel_map (Int.Map) -let map = - List.init 1000 ~f:Fun.id - |> List.map ~f:(fun i -> (i, i)) - |> Int.Map.of_list_exn +let map = List.init 1000 ~f:Fun.id |> List.map ~f:(fun i -> i, i) |> Int.Map.of_list_exn let%bench "Fiber.parallel_iter_seq" = Fiber.run ~iter:(fun () -> assert false) - (Fiber.parallel_iter_seq (Int.Map.to_seq map) ~f:(fun (_, _) -> - Fiber.return ())) + (Fiber.parallel_iter_seq (Int.Map.to_seq map) ~f:(fun (_, _) -> Fiber.return ())) +;; let%bench "Fiber.Map.parallel_map" = Fiber.run ~iter:(fun () -> assert false) (M.parallel_map map ~f:(fun _ x -> Fiber.return x)) |> ignore +;; diff --git a/fiber/src/cancel.ml b/fiber/src/cancel.ml index 5fbef0b..98eabd0 100644 --- a/fiber/src/cancel.ml +++ b/fiber/src/cancel.ml @@ -28,19 +28,22 @@ let rec invoke_handlers = function let* () = Ivar.fill ivar (Cancelled ()) in invoke_handlers next | End_of_handlers -> return () +;; let fire t = of_thunk (fun () -> - match t.state with - | Cancelled -> return () - | Not_cancelled { handlers } -> - t.state <- Cancelled; - invoke_handlers handlers) + match t.state with + | Cancelled -> return () + | Not_cancelled { handlers } -> + t.state <- Cancelled; + invoke_handlers handlers) +;; let rec fills_of_handlers acc = function | Handler { ivar; next; prev = _ } -> fills_of_handlers (Scheduler.Fill (ivar, Cancelled ()) :: acc) next | End_of_handlers -> List.rev acc +;; let fire' t = match t.state with @@ -48,48 +51,52 @@ let fire' t = | Not_cancelled { handlers } -> t.state <- Cancelled; fills_of_handlers [] handlers +;; let fired t = match t.state with | Cancelled -> true | Not_cancelled _ -> false +;; let with_handler t f ~on_cancel = match t.state with | Cancelled -> let+ x, y = fork_and_join f on_cancel in - (x, Cancelled y) + x, Cancelled y | Not_cancelled h -> let ivar = Ivar.create () in let node = Handler { ivar; next = h.handlers; prev = End_of_handlers } in (match h.handlers with - | End_of_handlers -> () - | Handler first -> first.prev <- node); + | End_of_handlers -> () + | Handler first -> first.prev <- node); h.handlers <- node; fork_and_join (fun () -> let* y = f () in match t.state with | Cancelled -> return y - | Not_cancelled h -> ( - match node with - | End_of_handlers -> - (* We could avoid this [assert false] with GADT sorcery given that - we created [node] just above and we know for sure it is the - [Handler _] case, but it's not worth the code complexity. *) - assert false - | Handler node -> - (match node.prev with - | End_of_handlers -> h.handlers <- node.next - | Handler prev -> prev.next <- node.next); - (match node.next with - | End_of_handlers -> () - | Handler next -> next.prev <- node.prev); - let+ () = Ivar.fill ivar Not_cancelled in - y)) + | Not_cancelled h -> + (match node with + | End_of_handlers -> + (* We could avoid this [assert false] with GADT sorcery given that + we created [node] just above and we know for sure it is the + [Handler _] case, but it's not worth the code complexity. *) + assert false + | Handler node -> + (match node.prev with + | End_of_handlers -> h.handlers <- node.next + | Handler prev -> prev.next <- node.next); + (match node.next with + | End_of_handlers -> () + | Handler next -> next.prev <- node.prev); + let+ () = Ivar.fill ivar Not_cancelled in + y)) (fun () -> - Ivar.read ivar >>= function + Ivar.read ivar + >>= function | Cancelled () -> let+ x = on_cancel () in Cancelled x | Not_cancelled -> return Not_cancelled) +;; diff --git a/fiber/src/core.ml b/fiber/src/core.ml index 917be12..9943c67 100644 --- a/fiber/src/core.ml +++ b/fiber/src/core.ml @@ -10,9 +10,7 @@ and eff = | Get_var : 'a Univ_map.Key.t * ('a option -> eff) -> eff | Set_var : 'a Univ_map.Key.t * 'a * (unit -> eff) -> eff | Unset_var : 'a Univ_map.Key.t * (unit -> eff) -> eff - | With_error_handler : - (Exn_with_backtrace.t -> Nothing.t t) * (unit -> eff) - -> eff + | With_error_handler : (Exn_with_backtrace.t -> Nothing.t t) * (unit -> eff) -> eff | Unwind : ('a -> eff) * 'a -> eff | Map_reduce_errors : (module Monoid with type t = 'a) @@ -66,71 +64,74 @@ and 'a k = } let return x k = k x - let bind t ~f k = t (fun x -> f x k) - let map t ~f k = t (fun x -> k (f x)) let with_error_handler f ~on_error k = With_error_handler (on_error, fun () -> f () (fun x -> Unwind (k, x))) +;; let map_reduce_errors m ~on_error f k = Map_reduce_errors (m, on_error, (fun () -> f () (fun x -> Unwind_map_reduce (k, Ok x))), k) +;; let suspend f k = Suspend (f, k) - let resume suspended x k = Resume (suspended, x, k) - let end_of_fiber = End_of_fiber () - let never _k = Never () let apply f x = - try f x - with exn -> + try f x with + | exn -> let exn = Exn_with_backtrace.capture exn in Reraise exn +;; let apply2 f x y = - try f x y - with exn -> + try f x y with + | exn -> let exn = Exn_with_backtrace.capture exn in Reraise exn +;; let[@inline always] fork a b = match apply a () with | End_of_fiber () -> b () | eff -> Fork (eff, b) +;; let rec nfork x l f = match l with | [] -> f x - | y :: l -> ( + | y :: l -> (* Manually inline [fork] because the compiler is unfortunately not getting rid of the closures. *) - match apply f x with - | End_of_fiber () -> nfork y l f - | eff -> Fork (eff, fun () -> nfork y l f)) + (match apply f x with + | End_of_fiber () -> nfork y l f + | eff -> Fork (eff, fun () -> nfork y l f)) +;; let rec nforki i x l f = match l with | [] -> f i x - | y :: l -> ( - match apply2 f i x with - | End_of_fiber () -> nforki (i + 1) y l f - | eff -> Fork (eff, fun () -> nforki (i + 1) y l f)) + | y :: l -> + (match apply2 f i x with + | End_of_fiber () -> nforki (i + 1) y l f + | eff -> Fork (eff, fun () -> nforki (i + 1) y l f)) +;; let nforki x l f = nforki 0 x l f let rec nfork_seq left_over x (seq : _ Seq.t) f = match seq () with | Nil -> f x - | Cons (y, seq) -> ( + | Cons (y, seq) -> incr left_over; - match apply f x with - | End_of_fiber () -> nfork_seq left_over y seq f - | eff -> Fork (eff, fun () -> nfork_seq left_over y seq f)) + (match apply f x with + | End_of_fiber () -> nfork_seq left_over y seq f + | eff -> Fork (eff, fun () -> nfork_seq left_over y seq f)) +;; let parallel_iter_seq (seq : _ Seq.t) ~f k = match seq () with @@ -139,10 +140,11 @@ let parallel_iter_seq (seq : _ Seq.t) ~f k = let left_over = ref 1 in let f x = f x (fun () -> - decr left_over; - if !left_over = 0 then k () else end_of_fiber) + decr left_over; + if !left_over = 0 then k () else end_of_fiber) in nfork_seq left_over x seq f +;; type ('a, 'b) fork_and_join_state = | Nothing_yet @@ -169,17 +171,18 @@ let fork_and_join fa fb k = match apply2 fa () ka with | End_of_fiber () -> fb () kb | eff -> Fork (eff, fun () -> fb () kb) +;; let fork_and_join_unit fa fb k = let state = ref Nothing_yet in match apply2 fa () (fun () -> - match !state with - | Nothing_yet -> - state := Got_a (); - end_of_fiber - | Got_a _ -> assert false - | Got_b b -> k b) + match !state with + | Nothing_yet -> + state := Got_a (); + end_of_fiber + | Got_a _ -> assert false + | Got_b b -> k b) with | End_of_fiber () -> fb () k | eff -> @@ -187,17 +190,19 @@ let fork_and_join_unit fa fb k = ( eff , fun () -> fb () (fun b -> - match !state with - | Nothing_yet -> - state := Got_b b; - end_of_fiber - | Got_a () -> k b - | Got_b _ -> assert false) ) + match !state with + | Nothing_yet -> + state := Got_b b; + end_of_fiber + | Got_a () -> k b + | Got_b _ -> assert false) ) +;; let rec length_and_rev l len acc = match l with - | [] -> (len, acc) + | [] -> len, acc | x :: l -> length_and_rev l (len + 1) (x :: acc) +;; let length_and_rev l = length_and_rev l 0 [] @@ -206,21 +211,21 @@ let reraise_all l _k = | [] -> Never () | [ exn ] -> Exn_with_backtrace.reraise exn | _ -> Reraise_all l +;; module Ivar = struct type 'a t = 'a ivar let create () = { state = Empty } - let read t k = Read_ivar (t, k) - let fill t x k = Fill_ivar (t, x, k) let peek t k = k (match t.state with - | Empty | Empty_with_readers _ -> None - | Full x -> Some x) + | Empty | Empty_with_readers _ -> None + | Full x -> Some x) + ;; end module Var = struct @@ -232,11 +237,10 @@ module Var = struct map (get var) ~f:(function | None -> failwith "Fiber.Var.get_exn" | Some value -> value) + ;; let set var x f k = Set_var (var, x, fun () -> f () (fun x -> Unwind (k, x))) - let unset var f k = Unset_var (var, fun () -> f () (fun x -> Unwind (k, x))) - let create () = create ~name:"var" (fun _ -> Dyn.string "var") end @@ -244,17 +248,11 @@ let of_thunk f k = f () k module O = struct let ( >>> ) a b k = a (fun () -> b k) - let ( >>= ) t f k = t (fun x -> f x k) - let ( >>| ) t f k = t (fun x -> k (f x)) - let ( let+ ) = ( >>| ) - let ( let* ) = ( >>= ) - let ( and* ) a b = fork_and_join (fun () -> a) (fun () -> b) - let ( and+ ) = ( and* ) end @@ -264,6 +262,7 @@ let both a b = let* x = a in let* y = b in return (x, y) +;; let sequential_map l ~f = let rec loop l acc = @@ -274,6 +273,7 @@ let sequential_map l ~f = loop l (x :: acc) in loop l [] +;; let sequential_iter l ~f = let rec loop l = @@ -284,6 +284,7 @@ let sequential_iter l ~f = loop l in loop l +;; let parallel_iter l ~f k = match l with @@ -294,10 +295,11 @@ let parallel_iter l ~f k = let left_over = ref len in let f x = f x (fun () -> - decr left_over; - if !left_over = 0 then k () else end_of_fiber) + decr left_over; + if !left_over = 0 then k () else end_of_fiber) in nfork x l f +;; let parallel_array_of_list_map' x l ~f k = let len = List.length l + 1 in @@ -305,37 +307,38 @@ let parallel_array_of_list_map' x l ~f k = let results = ref [||] in let f i x = f x (fun y -> - let a = - match !results with - | [||] -> - let a = Array.make len y in - results := a; - a - | a -> - a.(i) <- y; - a - in - decr left_over; - if !left_over = 0 then k a else end_of_fiber) + let a = + match !results with + | [||] -> + let a = Array.make len y in + results := a; + a + | a -> + a.(i) <- y; + a + in + decr left_over; + if !left_over = 0 then k a else end_of_fiber) in nforki x l f +;; let parallel_array_of_list_map l ~f k = match l with | [] -> k [||] | [ x ] -> f x (fun x -> k [| x |]) | x :: l -> parallel_array_of_list_map' x l ~f k +;; let parallel_map l ~f k = match l with | [] -> k [] | [ x ] -> f x (fun x -> k [ x ]) | x :: l -> parallel_array_of_list_map' x l ~f (fun a -> k (Array.to_list a)) +;; let all = sequential_map ~f:Fun.id - let all_concurrently = parallel_map ~f:Fun.id - let all_concurrently_unit l = parallel_iter l ~f:Fun.id let rec sequential_iter_seq (seq : _ Seq.t) ~f = @@ -344,45 +347,48 @@ let rec sequential_iter_seq (seq : _ Seq.t) ~f = | Cons (x, seq) -> let* () = f x in sequential_iter_seq seq ~f - -let parallel_iter_set (type a s) - (module S : Set.S with type elt = a and type t = s) set ~(f : a -> unit t) = +;; + +let parallel_iter_set + (type a s) + (module S : Set.S with type elt = a and type t = s) + set + ~(f : a -> unit t) + = parallel_iter_seq (S.to_seq set) ~f +;; module Make_parallel_map (S : sig - type 'a t - - type key - - val empty : _ t - - val is_empty : _ t -> bool - - val to_list : 'a t -> (key * 'a) list - - val mapi : 'a t -> f:(key -> 'a -> 'b) -> 'b t -end) = + type 'a t + type key + + val empty : _ t + val is_empty : _ t -> bool + val to_list : 'a t -> (key * 'a) list + val mapi : 'a t -> f:(key -> 'a -> 'b) -> 'b t + end) = struct let parallel_map t ~f = - if S.is_empty t then return S.empty + if S.is_empty t + then return S.empty else - let+ a = - parallel_array_of_list_map (S.to_list t) ~f:(fun (k, v) -> f k v) - in + let+ a = parallel_array_of_list_map (S.to_list t) ~f:(fun (k, v) -> f k v) in let pos = ref 0 in S.mapi t ~f:(fun _ _ -> - let i = !pos in - pos := i + 1; - a.(i)) + let i = !pos in + pos := i + 1; + a.(i)) + ;; end [@@inline always] let rec repeat_while : 'a. f:('a -> 'a option t) -> init:'a -> unit t = - fun ~f ~init -> + fun ~f ~init -> let* result = f init in match result with | None -> return () | Some init -> repeat_while ~f ~init +;; module Exns = Monoid.Appendable_list (Exn_with_backtrace) @@ -396,12 +402,13 @@ let collect_errors f = match res with | Ok x -> Ok x | Error l -> Error (Appendable_list.to_list l) +;; let finalize f ~finally = let* res1 = collect_errors f in let* res2 = collect_errors finally in let res = - match (res1, res2) with + match res1, res2 with | Ok x, Ok () -> Ok x | Error l, Ok _ | Ok _, Error l -> Error l | Error l1, Error l2 -> Error (l1 @ l2) @@ -409,3 +416,4 @@ let finalize f ~finally = match res with | Ok x -> return x | Error l -> reraise_all l +;; diff --git a/fiber/src/fiber.ml b/fiber/src/fiber.ml index e5ac1c4..de590c9 100644 --- a/fiber/src/fiber.ml +++ b/fiber/src/fiber.ml @@ -15,6 +15,7 @@ let run = | Stalled w -> loop ~iter (Scheduler.advance w (iter ())) in fun t ~iter -> loop ~iter (Scheduler.start t) +;; type fill = Scheduler.fill = Fill : 'a ivar * 'a -> fill @@ -22,6 +23,5 @@ module Expert = struct type nonrec 'a k = 'a k let suspend f k = suspend f k - let resume a x k = resume a x k end diff --git a/fiber/src/fiber.mli b/fiber/src/fiber.mli index 93b26f7..bd4dbca 100644 --- a/fiber/src/fiber.mli +++ b/fiber/src/fiber.mli @@ -41,7 +41,6 @@ module O : sig val ( >>| ) : 'a t -> ('a -> 'b) -> 'b t val ( let* ) : 'a t -> ('a -> 'b t) -> 'b t - val ( let+ ) : 'a t -> ('a -> 'b) -> 'b t (** Similar to [fork_and_join] *) @@ -51,7 +50,6 @@ module O : sig end val map : 'a t -> f:('a -> 'b) -> 'b t - val bind : 'a t -> f:('a -> 'b t) -> 'b t (** {1 Joining} *) @@ -66,7 +64,6 @@ val both : 'a t -> 'b t -> ('a * 'b) t val all : 'a t list -> 'a list t val sequential_map : 'a list -> f:('a -> 'b t) -> 'b list t - val sequential_iter : 'a list -> f:('a -> unit t) -> unit t (** {1 Forking + joining} *) @@ -102,23 +99,18 @@ val all_concurrently_unit : unit t list -> unit t val parallel_iter : 'a list -> f:('a -> unit t) -> unit t val parallel_iter_seq : 'a Seq.t -> f:('a -> unit t) -> unit t - val sequential_iter_seq : 'a Seq.t -> f:('a -> unit t) -> unit t (** Provide an efficient parallel map function for maps. *) module Make_parallel_map (S : sig - type 'a t - - type key - - val empty : _ t - - val is_empty : _ t -> bool - - val to_list : 'a t -> (key * 'a) list + type 'a t + type key - val mapi : 'a t -> f:(key -> 'a -> 'b) -> 'b t -end) : sig + val empty : _ t + val is_empty : _ t -> bool + val to_list : 'a t -> (key * 'a) list + val mapi : 'a t -> f:(key -> 'a -> 'b) -> 'b t + end) : sig val parallel_map : 'a S.t -> f:(S.key -> 'a -> 'b t) -> 'b S.t t end @@ -159,19 +151,20 @@ end It is guaranteed that after the fiber has returned a value, [on_error] will never be called. *) -val with_error_handler : - (unit -> 'a t) -> on_error:(Exn_with_backtrace.t -> Nothing.t t) -> 'a t +val with_error_handler + : (unit -> 'a t) + -> on_error:(Exn_with_backtrace.t -> Nothing.t t) + -> 'a t -val map_reduce_errors : - (module Monoid with type t = 'a) +val map_reduce_errors + : (module Monoid with type t = 'a) -> on_error:(Exn_with_backtrace.t -> 'a t) -> (unit -> 'b t) -> ('b, 'a) result t (** [collect_errors f] is: - [fold_errors f ~init:\[\] ~on_error:(fun e l -> e :: l)] *) -val collect_errors : - (unit -> 'a t) -> ('a, Exn_with_backtrace.t list) Result.t t + [fold_errors f ~init:[] ~on_error:(fun e l -> e :: l)] *) +val collect_errors : (unit -> 'a t) -> ('a, Exn_with_backtrace.t list) Result.t t (** [finalize f ~finally] runs [finally] after [f ()] has terminated, whether it fails or succeeds. *) @@ -214,11 +207,8 @@ module Mvar : sig type 'a t val create : unit -> 'a t - val create_full : 'a -> 'a t - val read : 'a t -> 'a fiber - val write : 'a t -> 'a -> unit fiber end @@ -260,7 +250,6 @@ module Mutex : sig type t val create : unit -> t - val with_lock : t -> f:(unit -> 'a fiber) -> 'a fiber end @@ -325,13 +314,9 @@ module Stream : sig val read : 'a t -> 'a option fiber val filter_map : 'a t -> f:('a -> 'b option) -> 'b t - val sequential_iter : 'a t -> f:('a -> unit fiber) -> unit fiber - val parallel_iter : 'a t -> f:('a -> unit fiber) -> unit fiber - val append : 'a t -> 'a t -> 'a t - val cons : 'a -> 'a t -> 'a t end @@ -348,7 +333,6 @@ module Stream : sig val create : ('a option -> unit fiber) -> 'a t val write : 'a t -> 'a option -> unit fiber - val null : unit -> 'a t end @@ -462,8 +446,8 @@ module Cancel : sig If [f ()] finished before [t] is fired, then [on_cancel] will never be invoked. *) - val with_handler : - t + val with_handler + : t -> (unit -> 'a fiber) -> on_cancel:(unit -> 'b fiber) -> ('a * 'b outcome) fiber diff --git a/fiber/src/mutex.ml b/fiber/src/mutex.ml index 85d4d74..7678e68 100644 --- a/fiber/src/mutex.ml +++ b/fiber/src/mutex.ml @@ -8,10 +8,12 @@ type t = } let lock t k = - if t.locked then suspend (fun k -> Queue.push t.waiters k) k + if t.locked + then suspend (fun k -> Queue.push t.waiters k) k else ( t.locked <- true; k ()) +;; let unlock t k = assert t.locked; @@ -20,9 +22,11 @@ let unlock t k = t.locked <- false; k () | Some next -> resume next () k +;; let with_lock t ~f = let* () = lock t in finalize f ~finally:(fun () -> unlock t) +;; let create () = { locked = false; waiters = Queue.create () } diff --git a/fiber/src/mvar.ml b/fiber/src/mvar.ml index cf8ef66..9509a36 100644 --- a/fiber/src/mvar.ml +++ b/fiber/src/mvar.ml @@ -13,31 +13,34 @@ let _invariant t = match t.value with | None -> Queue.is_empty t.writers | Some _ -> Queue.is_empty t.readers +;; -let create () = - { value = None; writers = Queue.create (); readers = Queue.create () } +let create () = { value = None; writers = Queue.create (); readers = Queue.create () } let create_full x = { value = Some x; writers = Queue.create (); readers = Queue.create () } +;; let read t k = match t.value with | None -> suspend (fun k -> Queue.push t.readers k) k - | Some v -> ( - match Queue.pop t.writers with - | None -> - t.value <- None; - k v - | Some (v', w) -> - t.value <- Some v'; - resume w () (fun () -> k v)) + | Some v -> + (match Queue.pop t.writers with + | None -> + t.value <- None; + k v + | Some (v', w) -> + t.value <- Some v'; + resume w () (fun () -> k v)) +;; let write t x k = match t.value with | Some _ -> suspend (fun k -> Queue.push t.writers (x, k)) k - | None -> ( - match Queue.pop t.readers with - | None -> - t.value <- Some x; - k () - | Some r -> resume r x (fun () -> k ())) + | None -> + (match Queue.pop t.readers with + | None -> + t.value <- Some x; + k () + | Some r -> resume r x (fun () -> k ())) +;; diff --git a/fiber/src/pool.ml b/fiber/src/pool.ml index 4dc54c9..96eac20 100644 --- a/fiber/src/pool.ml +++ b/fiber/src/pool.ml @@ -21,8 +21,7 @@ type runner = type nonrec t = { tasks : (unit -> unit t) Queue.t (* pending tasks *) - ; mutable runner : runner - (* The continuation to resume the runner set by [run] *) + ; mutable runner : runner (* The continuation to resume the runner set by [run] *) ; mutable status : status } @@ -30,32 +29,33 @@ let running t k = match t.status with | Open -> k true | Closed -> k false +;; -let create () = - { tasks = Queue.create (); runner = Awaiting_run; status = Open } +let create () = { tasks = Queue.create (); runner = Awaiting_run; status = Open } let task t ~f k = match t.status with - | Closed -> - Code_error.raise "pool is closed. new tasks may not be submitted" [] - | Open -> ( + | Closed -> Code_error.raise "pool is closed. new tasks may not be submitted" [] + | Open -> Queue.push t.tasks f; - match t.runner with - | Running | Awaiting_run -> k () - | Awaiting_resume r -> - t.runner <- Running; - resume r () k) + (match t.runner with + | Running | Awaiting_run -> k () + | Awaiting_resume r -> + t.runner <- Running; + resume r () k) +;; let close t k = match t.status with | Closed -> k () - | Open -> ( + | Open -> t.status <- Closed; - match t.runner with - | Running | Awaiting_run -> k () - | Awaiting_resume r -> - t.runner <- Running; - resume r () k) + (match t.runner with + | Running | Awaiting_run -> k () + | Awaiting_resume r -> + t.runner <- Running; + resume r () k) +;; let run t k = match t.runner with @@ -64,7 +64,7 @@ let run t k = | Awaiting_run -> t.runner <- Running; (* The number of currently running fibers in the pool. Only when this - number reaches zero we may call the final continuation [k]. *) + number reaches zero we may call the final continuation [k]. *) let n = ref 1 in let done_fiber () = decr n; @@ -87,3 +87,4 @@ let run t k = | Open -> suspend suspend_k read_delayed in read t +;; diff --git a/fiber/src/scheduler.ml b/fiber/src/scheduler.ml index 3adc5ac..1d38ddf 100644 --- a/fiber/src/scheduler.ml +++ b/fiber/src/scheduler.ml @@ -10,15 +10,17 @@ module Jobs = struct | Concat : t * t -> t let concat a b = - match (a, b) with + match a, b with | Empty, x | x, Empty -> x | _ -> Concat (a, b) + ;; let rec enqueue_readers (readers : (_, [ `Empty ]) ivar_state) x jobs = match readers with | Empty -> jobs | Empty_with_readers (ctx, k, readers) -> enqueue_readers readers x (Job (ctx, k, x, jobs)) + ;; let fill_ivar ivar x jobs = match ivar.state with @@ -30,6 +32,7 @@ module Jobs = struct ivar.state <- Full x; let jobs = Job (ctx, k, x, jobs) in enqueue_readers readers x jobs + ;; let rec exec_fills fills acc = match fills with @@ -37,6 +40,7 @@ module Jobs = struct | Fill (ivar, x) :: fills -> let acc = fill_ivar ivar x acc in exec_fills fills acc + ;; let exec_fills fills = exec_fills (List.rev fills) Empty end @@ -47,7 +51,6 @@ type step' = module type Witness = sig type t - type value += X of t end @@ -69,7 +72,7 @@ and loop2 a b = | Concat (a1, a2) -> loop2 a1 (Jobs.concat a2 b) and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = - fun ctx k x jobs -> + fun ctx k x jobs -> match k x with | exception exn -> let exn = Exn_with_backtrace.capture exn in @@ -77,12 +80,12 @@ and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = | Done v -> Done v | Toplevel_exception exn -> Exn_with_backtrace.reraise exn | Unwind (k, x) -> exec ctx.parent k x jobs - | Read_ivar (ivar, k) -> ( - match ivar.state with - | (Empty | Empty_with_readers _) as readers -> - ivar.state <- Empty_with_readers (ctx, k, readers); - loop jobs - | Full x -> exec ctx k x jobs) + | Read_ivar (ivar, k) -> + (match ivar.state with + | (Empty | Empty_with_readers _) as readers -> + ivar.state <- Empty_with_readers (ctx, k, readers); + loop jobs + | Full x -> exec ctx k x jobs) | Fill_ivar (ivar, x, k) -> let jobs = Jobs.concat jobs (Jobs.fill_ivar ivar x Empty) in exec ctx k () jobs @@ -91,8 +94,7 @@ and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = f k; loop jobs | Resume (suspended, x, k) -> - exec ctx k () - (Jobs.concat jobs (Job (suspended.ctx, suspended.run, x, Empty))) + exec ctx k () (Jobs.concat jobs (Job (suspended.ctx, suspended.run, x, Empty))) | Get_var (key, k) -> exec ctx k (Univ_map.find ctx.vars key) jobs | Set_var (key, x, k) -> let ctx = { ctx with parent = ctx; vars = Univ_map.set ctx.vars key x } in @@ -101,13 +103,10 @@ and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = let ctx = { ctx with parent = ctx; vars = Univ_map.remove ctx.vars key } in exec ctx k () jobs | With_error_handler (on_error, k) -> - let on_error = - { ctx; run = (fun exn -> on_error exn Nothing.unreachable_code) } - in + let on_error = { ctx; run = (fun exn -> on_error exn Nothing.unreachable_code) } in let ctx = { ctx with parent = ctx; on_error } in exec ctx k () jobs - | Map_reduce_errors (m, on_error, f, k) -> - map_reduce_errors ctx m on_error f k jobs + | Map_reduce_errors (m, on_error, f, k) -> map_reduce_errors ctx m on_error f k jobs | End_of_fiber () -> let (Map_reduce_context r) = ctx.map_reduce_context in deref r jobs @@ -117,8 +116,7 @@ and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = r.ref_count <- ref_count; assert (ref_count = 0); exec ctx.parent k x jobs - | End_of_map_reduce_error_handler map_reduce_context -> - deref map_reduce_context jobs + | End_of_map_reduce_error_handler map_reduce_context -> deref map_reduce_context jobs | Never () -> loop jobs | Fork (a, b) -> let (Map_reduce_context r) = ctx.map_reduce_context in @@ -127,21 +125,21 @@ and exec : 'a. context -> ('a -> eff) -> 'a -> Jobs.t -> step' = | Reraise exn -> let { ctx; run } = ctx.on_error in exec ctx run exn jobs - | Reraise_all exns -> ( - match length_and_rev exns with - | 0, _ -> loop jobs - | n, exns -> - let (Map_reduce_context r) = ctx.map_reduce_context in - r.ref_count <- r.ref_count + (n - 1); - let { ctx; run } = ctx.on_error in - let jobs = - List.fold_left exns ~init:jobs ~f:(fun jobs exn -> - Jobs.Job (ctx, run, exn, jobs)) - in - loop jobs) + | Reraise_all exns -> + (match length_and_rev exns with + | 0, _ -> loop jobs + | n, exns -> + let (Map_reduce_context r) = ctx.map_reduce_context in + r.ref_count <- r.ref_count + (n - 1); + let { ctx; run } = ctx.on_error in + let jobs = + List.fold_left exns ~init:jobs ~f:(fun jobs exn -> + Jobs.Job (ctx, run, exn, jobs)) + in + loop jobs) and deref : 'a 'b. ('a, 'b) map_reduce_context' -> Jobs.t -> step' = - fun r jobs -> + fun r jobs -> let ref_count = r.ref_count - 1 in r.ref_count <- ref_count; match ref_count with @@ -150,26 +148,25 @@ and deref : 'a 'b. ('a, 'b) map_reduce_context' -> Jobs.t -> step' = assert (ref_count > 0); loop jobs -and map_reduce_errors : - type errors b. - context +and map_reduce_errors + : type errors b. + context -> (module Monoid with type t = errors) -> (Exn_with_backtrace.t -> errors t) -> (unit -> eff) -> ((b, errors) result -> eff) -> Jobs.t - -> step' = - fun ctx (module M : Monoid with type t = errors) on_error f k jobs -> - let map_reduce_context = - { k = { ctx; run = k }; ref_count = 1; errors = M.empty } - in + -> step' + = + fun ctx (module M : Monoid with type t = errors) on_error f k jobs -> + let map_reduce_context = { k = { ctx; run = k }; ref_count = 1; errors = M.empty } in let on_error = { ctx ; run = (fun exn -> on_error exn (fun m -> - map_reduce_context.errors <- M.combine map_reduce_context.errors m; - End_of_map_reduce_error_handler map_reduce_context)) + map_reduce_context.errors <- M.combine map_reduce_context.errors m; + End_of_map_reduce_error_handler map_reduce_context)) } in let ctx = @@ -180,36 +177,37 @@ and map_reduce_errors : } in exec ctx f () jobs +;; let repack_step (type a) (module W : Witness with type t = a) (step' : step') = match step' with | Done (W.X a) -> Done a | Done _ -> Code_error.raise - "advance: it's illegal to call advance with a fiber created in a \ - different scheduler" + "advance: it's illegal to call advance with a fiber created in a different \ + scheduler" [] | Stalled -> Stalled (module W) +;; let advance (type a) (module W : Witness with type t = a) fill : a step = fill |> Jobs.exec_fills |> loop |> repack_step (module W) +;; let start (type a) (t : a t) = let module W = struct type t = a - type value += X of a - end in + end + in let rec ctx = { parent = ctx ; on_error = { ctx; run = (fun exn -> Toplevel_exception exn) } ; vars = Univ_map.empty ; map_reduce_context = Map_reduce_context - { k = { ctx; run = (fun _ -> assert false) } - ; ref_count = 1 - ; errors = () - } + { k = { ctx; run = (fun _ -> assert false) }; ref_count = 1; errors = () } } in exec ctx t (fun x -> Done (W.X x)) Empty |> repack_step (module W) +;; diff --git a/fiber/src/stream.ml b/fiber/src/stream.ml index ba008e6..a7b0c21 100644 --- a/fiber/src/stream.ml +++ b/fiber/src/stream.ml @@ -20,10 +20,12 @@ module In = struct in t.read <- read; t + ;; let lock t = if t.reading then Code_error.raise "Fiber.Stream.In: already reading" []; t.reading <- true + ;; let unlock t = t.reading <- false @@ -32,6 +34,7 @@ module In = struct let+ x = t.read () in unlock t; x + ;; let empty () = create_unchecked (fun () -> return None) @@ -40,46 +43,51 @@ module In = struct let rec go () = match !remains with | [] -> return None - | x :: xs -> ( + | x :: xs -> let* v = read x in - match v with - | Some v -> return (Some v) - | None -> - remains := xs; - go ()) + (match v with + | Some v -> return (Some v) + | None -> + remains := xs; + go ()) in create go + ;; let append x y = concat [ x; y ] let of_list xs = let xs = ref xs in create_unchecked (fun () -> - match !xs with - | [] -> return None - | x :: xs' -> - xs := xs'; - return (Some x)) + match !xs with + | [] -> return None + | x :: xs' -> + xs := xs'; + return (Some x)) + ;; let cons a t = concat [ of_list [ a ]; t ] let filter_map t ~f = let rec read () = - t.read () >>= function + t.read () + >>= function | None -> unlock t; return None - | Some x -> ( - match f x with - | None -> read () - | Some y -> return (Some y)) + | Some x -> + (match f x with + | None -> read () + | Some y -> return (Some y)) in lock t; create_unchecked read + ;; let sequential_iter t ~f = let rec loop t ~f = - t.read () >>= function + t.read () + >>= function | None -> unlock t; return () @@ -89,12 +97,14 @@ module In = struct in lock t; loop t ~f + ;; let parallel_iter t ~f k = let n = ref 1 in let k () = decr n; - if !n = 0 then ( + if !n = 0 + then ( unlock t; k ()) else end_of_fiber @@ -108,6 +118,7 @@ module In = struct in lock t; loop t + ;; end module Out = struct @@ -119,27 +130,30 @@ module Out = struct let lock t = if t.writing then Code_error.raise "Fiber.Stream.Out: already writing" []; t.writing <- true + ;; let unlock t = t.writing <- false let create write = let t = { write; writing = false } in let write x = - if Option.is_none x then - t.write <- - (function - | None -> return () - | Some _ -> - Code_error.raise "Fiber.Stream.Out: stream output closed" []); + if Option.is_none x + then + t.write + <- (function + | None -> return () + | Some _ -> Code_error.raise "Fiber.Stream.Out: stream output closed" []); write x in t.write <- write; t + ;; let write t x = lock t; let+ () = t.write x in unlock t + ;; let null () = create (fun _ -> return ()) end @@ -158,6 +172,7 @@ let connect i o = | Some _ -> go () in go () +;; let supply i o = In.lock i; @@ -174,9 +189,11 @@ let supply i o = go () in go () +;; let pipe () = let mvar = Mvar.create () in let i = In.create (fun () -> Mvar.read mvar) in let o = Out.create (fun x -> Mvar.write mvar x) in - (i, o) + i, o +;; diff --git a/fiber/src/svar.ml b/fiber/src/svar.ml index 95d9866..7b828b0 100644 --- a/fiber/src/svar.ml +++ b/fiber/src/svar.ml @@ -10,16 +10,16 @@ type 'a t = let read t = t.current let wait = - let suspend t ~until = - suspend (fun k -> t.waiters <- (k, until) :: t.waiters) - in + let suspend t ~until = suspend (fun k -> t.waiters <- (k, until) :: t.waiters) in let rec wait t ~until = - if until t.current then return () + if until t.current + then return () else let* () = suspend t ~until in wait t ~until in fun t ~until -> wait t ~until +;; let create current = { current; waiters = [] } @@ -32,10 +32,11 @@ let write = t.current <- a; let sleep, awake = List.rev_partition_map t.waiters ~f:(fun (k, f) -> - if f t.current then Right k else Left (k, f)) + if f t.current then Right k else Left (k, f)) in match awake with | [] -> k () | awake -> t.waiters <- List.rev sleep; run_awakers k awake +;; diff --git a/fiber/src/throttle.ml b/fiber/src/throttle.ml index 0e27a20..cb7767c 100644 --- a/fiber/src/throttle.ml +++ b/fiber/src/throttle.ml @@ -9,24 +9,25 @@ type t = } let create size = { size; running = 0; waiting = Queue.create () } - let size t = t.size - let running t = t.running let rec restart t = - if t.running >= t.size then return () - else + if t.running >= t.size + then return () + else ( match Queue.pop t.waiting with | None -> return () | Some ivar -> t.running <- t.running + 1; let* () = Ivar.fill ivar () in - restart t + restart t) +;; let resize t n = t.size <- n; restart t +;; let run t ~f = finalize @@ -34,11 +35,13 @@ let run t ~f = t.running <- t.running - 1; restart t) (fun () -> - if t.running < t.size then ( + if t.running < t.size + then ( t.running <- t.running + 1; f ()) - else + else ( let waiting = Ivar.create () in Queue.push t.waiting waiting; let* () = Ivar.read waiting in - f ()) + f ())) +;; diff --git a/fiber/test/fiber_scheduler.ml b/fiber/test/fiber_scheduler.ml index a4eaa82..3cd24ec 100644 --- a/fiber/test/fiber_scheduler.ml +++ b/fiber/test/fiber_scheduler.ml @@ -9,12 +9,13 @@ let%expect_test "test fiber scheduler" = print_endline "ivar filled" in (match Scheduler.start (Fiber.of_thunk f) with - | Done _ -> assert false - | Stalled s -> ( - let step = Scheduler.advance s [ Fiber.Fill (ivar, ()) ] in - match step with - | Done () -> () - | Stalled _ -> assert false)); + | Done _ -> assert false + | Stalled s -> + let step = Scheduler.advance s [ Fiber.Fill (ivar, ()) ] in + (match step with + | Done () -> () + | Stalled _ -> assert false)); [%expect {| waiting for ivar ivar filled |}] +;; diff --git a/fiber/test/fiber_tests.ml b/fiber/test/fiber_tests.ml index 2a96b77..83a191b 100644 --- a/fiber/test/fiber_tests.ml +++ b/fiber/test/fiber_tests.ml @@ -3,70 +3,68 @@ open Fiber.O open Dyn let printf = Printf.printf - let print_dyn dyn = Dyn.to_string dyn |> print_endline - let () = Printexc.record_backtrace false module Scheduler = struct let t = Test_scheduler.create () - let yield () = Test_scheduler.yield t - let run f = Test_scheduler.run t f end let failing_fiber () : unit Fiber.t = let* () = Scheduler.yield () in raise Exit +;; let long_running_fiber () = let rec loop n = - if n = 0 then Fiber.return () + if n = 0 + then Fiber.return () else let* () = Scheduler.yield () in loop (n - 1) in loop 10 +;; let never_fiber () = Fiber.never - -let backtrace_result dyn_of_ok = - Result.to_dyn dyn_of_ok (list Exn_with_backtrace.to_dyn) - +let backtrace_result dyn_of_ok = Result.to_dyn dyn_of_ok (list Exn_with_backtrace.to_dyn) let unit_result dyn_of_ok = Result.to_dyn dyn_of_ok unit let test ?(expect_never = false) to_dyn f = let never_raised = ref false in - (try Scheduler.run f |> to_dyn |> print_dyn - with Test_scheduler.Never -> never_raised := true); - match (!never_raised, expect_never) with + (try Scheduler.run f |> to_dyn |> print_dyn with + | Test_scheduler.Never -> never_raised := true); + match !never_raised, expect_never with | false, false -> (* We don't raise in this case b/c we assume something else is being tested *) () | true, true -> print_endline "[PASS] Never raised as expected" - | false, true -> - print_endline "[FAIL] expected Never to be raised but it wasn't" + | false, true -> print_endline "[FAIL] expected Never to be raised but it wasn't" | true, false -> print_endline "[FAIL] unexpected Never raised" +;; let%expect_test "basics" = test unit (Fiber.return ()); [%expect {| () |}]; - - test unit + test + unit (let* () = Fiber.return () in Fiber.return ()); [%expect {| () |}]; - - test unit + test + unit (let* () = Scheduler.yield () in Fiber.return ()); [%expect {| () |}] +;; let%expect_test "collect_errors" = test (backtrace_result unit) (Fiber.collect_errors (fun () -> raise Exit)); [%expect {| Error [ { exn = "Stdlib.Exit"; backtrace = "" } ] |}] +;; let%expect_test "reraise_all" = let exns = @@ -82,11 +80,12 @@ let%expect_test "reraise_all" = ; { exn = "Stdlib.Exit"; backtrace = "" } ; { exn = "Stdlib.Exit"; backtrace = "" } ] |}]; - test (backtrace_result unit) + test + (backtrace_result unit) (Fiber.collect_errors (fun () -> - Fiber.finalize fail ~finally:(fun () -> - print_endline "finally"; - Fiber.return ()))); + Fiber.finalize fail ~finally:(fun () -> + print_endline "finally"; + Fiber.return ()))); [%expect {| finally @@ -95,12 +94,14 @@ let%expect_test "reraise_all" = ; { exn = "Stdlib.Exit"; backtrace = "" } ; { exn = "Stdlib.Exit"; backtrace = "" } ] |}]; - - test unit ~expect_never:true + test + unit + ~expect_never:true (let+ _ = Fiber.reraise_all [] in print_endline "finish"); [%expect {| [PASS] Never raised as expected |}] +;; let%expect_test "execution context of ivars" = (* The point of this test it show that the execution context is restored when @@ -111,17 +112,18 @@ let%expect_test "execution context of ivars" = let run_when_filled () = let var = Fiber.Var.create () in Fiber.Var.set var 42 (fun () -> - let* peek = Fiber.Ivar.peek ivar in - assert (peek = None); - let* () = Fiber.Ivar.read ivar in - let+ value = Fiber.Var.get_exn var in - Printf.printf "var value %d\n" value) + let* peek = Fiber.Ivar.peek ivar in + assert (peek = None); + let* () = Fiber.Ivar.read ivar in + let+ value = Fiber.Var.get_exn var in + Printf.printf "var value %d\n" value) in let run = Fiber.fork_and_join_unit run_when_filled (Fiber.Ivar.fill ivar) in test unit run; [%expect {| var value 42 () |}] +;; let%expect_test "fiber vars are preserved across yields" = let var = Fiber.Var.create () in @@ -129,16 +131,17 @@ let%expect_test "fiber vars are preserved across yields" = let* v = Fiber.Var.get var in assert (v = None); Fiber.Var.set var th (fun () -> - let* v = Fiber.Var.get var in - assert (v = Some th); - let* () = Scheduler.yield () in - let+ v = Fiber.Var.get var in - assert (v = Some th)) + let* v = Fiber.Var.get var in + assert (v = Some th); + let* () = Scheduler.yield () in + let+ v = Fiber.Var.get var in + assert (v = Some th)) in let run = Fiber.fork_and_join_unit (fiber 1) (fiber 2) in test unit run; [%expect {| () |}] +;; let%expect_test "fill returns a fiber that executes before waiters are awoken" = let ivar = Fiber.Ivar.create () in @@ -156,42 +159,46 @@ let%expect_test "fill returns a fiber that executes before waiters are awoken" = Printf.printf "ivar filled\n" in test unit (Fiber.fork_and_join_unit waiters run); - [%expect - {| + [%expect {| ivar filled waiter 1 resumed waiter 2 resumed () |}] +;; let%expect_test "collect_errors catches one error" = test (backtrace_result unit) (Fiber.collect_errors failing_fiber); [%expect {| Error [ { exn = "Stdlib.Exit"; backtrace = "" } ] |}] +;; let%expect_test "collect_errors doesn't terminate on [never]" = test ~expect_never:true opaque (Fiber.collect_errors never_fiber); [%expect {| [PASS] Never raised as expected |}] +;; let%expect_test "failing_fiber doesn't terminate" = - test (backtrace_result unit) + test + (backtrace_result unit) (Fiber.collect_errors (fun () -> - let* () = failing_fiber () in - failing_fiber ())); + let* () = failing_fiber () in + failing_fiber ())); [%expect {| Error [ { exn = "Stdlib.Exit"; backtrace = "" } ] |}] +;; let%expect_test "collect_errors fail one concurrent child fibers raises" = test (backtrace_result (pair unit unit)) - (Fiber.collect_errors (fun () -> - Fiber.fork_and_join failing_fiber long_running_fiber)); + (Fiber.collect_errors (fun () -> Fiber.fork_and_join failing_fiber long_running_fiber)); [%expect {| Error [ { exn = "Stdlib.Exit"; backtrace = "" } ] |}] +;; let%expect_test "collect_errors can run concurrently" = test @@ -202,9 +209,11 @@ let%expect_test "collect_errors can run concurrently" = [%expect {| (Error [ { exn = "Stdlib.Exit"; backtrace = "" } ], ()) |}] +;; let map_reduce_errors_unit ~on_error t = Fiber.map_reduce_errors (module Monoid.Unit) ~on_error t +;; let%expect_test "collect errors inside with_error_handler" = test @@ -227,46 +236,51 @@ let%expect_test "collect errors inside with_error_handler" = got 1 errors out of collect_errors captured the error Error () |}] +;; let%expect_test "collect_errors restores the execution context properly" = let var = Fiber.Var.create () in - test unit + test + unit (Fiber.Var.set var "a" (fun () -> - let* _res = - Fiber.Var.set var "b" (fun () -> - Fiber.collect_errors (fun () -> - Fiber.Var.set var "c" (fun () -> raise Exit))) - in - let* v = Fiber.Var.get_exn var in - print_endline v; - Fiber.return ())); + let* _res = + Fiber.Var.set var "b" (fun () -> + Fiber.collect_errors (fun () -> Fiber.Var.set var "c" (fun () -> raise Exit))) + in + let* v = Fiber.Var.get_exn var in + print_endline v; + Fiber.return ())); [%expect {| a () |}] +;; let%expect_test "handlers bubble up errors to parent handlers" = - test ~expect_never:false (unit_result unit) + test + ~expect_never:false + (unit_result unit) (Fiber.fork_and_join_unit long_running_fiber (fun () -> - let log_error by (e : Exn_with_backtrace.t) = - Printf.printf "%s: raised %s\n" by (Printexc.to_string e.exn) - in - map_reduce_errors_unit - ~on_error:(fun err -> - log_error "outer" err; - Fiber.return ()) - (fun () -> - Fiber.fork_and_join_unit failing_fiber (fun () -> - Fiber.with_error_handler - ~on_error:(fun exn -> - log_error "inner" exn; - raise Exit) - failing_fiber)))); + let log_error by (e : Exn_with_backtrace.t) = + Printf.printf "%s: raised %s\n" by (Printexc.to_string e.exn) + in + map_reduce_errors_unit + ~on_error:(fun err -> + log_error "outer" err; + Fiber.return ()) + (fun () -> + Fiber.fork_and_join_unit failing_fiber (fun () -> + Fiber.with_error_handler + ~on_error:(fun exn -> + log_error "inner" exn; + raise Exit) + failing_fiber)))); [%expect {| outer: raised Stdlib.Exit inner: raised Stdlib.Exit outer: raised Stdlib.Exit Error () |}] +;; let%expect_test "nested with_error_handler" = let fiber = @@ -281,11 +295,13 @@ let%expect_test "nested with_error_handler" = Exn_with_backtrace.reraise exn) (fun () -> raise Exit)) in - (try test unit fiber with Exit -> print_endline "[PASS] got Exit"); + (try test unit fiber with + | Exit -> print_endline "[PASS] got Exit"); [%expect {| inner handler outer handler [PASS] got Exit |}] +;; let must_set_flag f = let flag = ref false in @@ -296,9 +312,11 @@ let must_set_flag f = try f setter; check_set () - with e -> + with + | e -> check_set (); raise e +;; let%expect_test "finalize" = let fiber = @@ -311,16 +329,17 @@ let%expect_test "finalize" = finally () |}]; - let fiber = Fiber.finalize ~finally:(fun () -> Fiber.return (print_endline "finally")) (fun () -> raise Exit) in - (try test unit fiber with Exit -> print_endline "[PASS] got Exit"); + (try test unit fiber with + | Exit -> print_endline "[PASS] got Exit"); [%expect {| finally [PASS] got Exit |}] +;; let%expect_test "nested finalize" = let fiber = @@ -331,11 +350,13 @@ let%expect_test "nested finalize" = ~finally:(fun () -> Fiber.return (print_endline "inner finally")) (fun () -> raise Exit)) in - (try test unit fiber with Exit -> print_endline "[PASS] got Exit"); + (try test unit fiber with + | Exit -> print_endline "[PASS] got Exit"); [%expect {| inner finally outer finally [PASS] got Exit |}] +;; let%expect_test "context switch and raise inside finalize" = let fiber = @@ -354,13 +375,15 @@ let%expect_test "context switch and raise inside finalize" = printf "raising in second fiber\n"; raise Exit)) in - (try test unit fiber with Exit -> print_endline "[PASS] got Exit"); + (try test unit fiber with + | Exit -> print_endline "[PASS] got Exit"); [%expect {| Hello from first fiber! raising in second fiber finally [PASS] got Exit |}] +;; let%expect_test "sequential_iter error handling" = let fiber = @@ -370,8 +393,7 @@ let%expect_test "sequential_iter error handling" = map_reduce_errors_unit (fun () -> Fiber.sequential_iter [ 1; 2; 3 ] ~f:(fun x -> - if x = 2 then raise Exit - else Fiber.return (Printf.printf "count: %d\n" x))) + if x = 2 then raise Exit else Fiber.return (Printf.printf "count: %d\n" x))) ~on_error:(fun exn_with_bt -> printf "exn: %s\n%!" (Printexc.to_string exn_with_bt.exn); Fiber.return ())) @@ -382,6 +404,7 @@ let%expect_test "sequential_iter error handling" = exn: Stdlib.Exit finally Error () |}] +;; let%expect_test "sequential_iter" = let fiber = @@ -389,7 +412,7 @@ let%expect_test "sequential_iter" = ~finally:(fun () -> Fiber.return (print_endline "finally")) (fun () -> Fiber.sequential_iter [ 1; 2; 3 ] ~f:(fun x -> - Fiber.return (Printf.printf "count: %d\n" x))) + Fiber.return (Printf.printf "count: %d\n" x))) in test unit fiber; [%expect {| @@ -398,34 +421,36 @@ let%expect_test "sequential_iter" = count: 3 finally () |}] +;; let%expect_test _ = must_set_flag (fun setter -> - test ~expect_never:true unit - @@ Fiber.fork_and_join_unit never_fiber (fun () -> - let* res = Fiber.collect_errors failing_fiber in - print_dyn (backtrace_result unit res); - let* () = long_running_fiber () in - Fiber.return (setter ()))); + test ~expect_never:true unit + @@ Fiber.fork_and_join_unit never_fiber (fun () -> + let* res = Fiber.collect_errors failing_fiber in + print_dyn (backtrace_result unit res); + let* () = long_running_fiber () in + Fiber.return (setter ()))); [%expect {| Error [ { exn = "Stdlib.Exit"; backtrace = "" } ] [PASS] Never raised as expected [PASS] flag set |}] +;; let%expect_test _ = let forking_fiber () = Fiber.parallel_map [ 1; 2; 3; 4; 5 ] ~f:(fun x -> - let* () = Scheduler.yield () in - if x mod 2 = 1 then Fiber.return () else Printf.ksprintf failwith "%d" x) + let* () = Scheduler.yield () in + if x mod 2 = 1 then Fiber.return () else Printf.ksprintf failwith "%d" x) in must_set_flag (fun setter -> - test ~expect_never:true unit - @@ Fiber.fork_and_join_unit never_fiber (fun () -> - let* res = Fiber.collect_errors forking_fiber in - print_dyn (backtrace_result (list unit) res); - let* () = long_running_fiber () in - Fiber.return (setter ()))); + test ~expect_never:true unit + @@ Fiber.fork_and_join_unit never_fiber (fun () -> + let* res = Fiber.collect_errors forking_fiber in + print_dyn (backtrace_result (list unit) res); + let* () = long_running_fiber () in + Fiber.return (setter ()))); [%expect {| Error @@ -434,20 +459,25 @@ let%expect_test _ = ] [PASS] Never raised as expected [PASS] flag set |}] +;; (* Mvar tests *) module Mvar = Fiber.Mvar let%expect_test "created mvar is empty" = - test ~expect_never:true opaque + test + ~expect_never:true + opaque (let mvar : int Mvar.t = Mvar.create () in Mvar.read mvar); [%expect {| [PASS] Never raised as expected |}] +;; let%expect_test "reading from written mvar consumes value" = - test unit + test + unit (let mvar = Mvar.create () in let value = "foo" in let* () = Mvar.write mvar value in @@ -457,9 +487,11 @@ let%expect_test "reading from written mvar consumes value" = [%expect {| [PASS] mvar contains expected value () |}] +;; let%expect_test "reading from empty mvar blocks" = - test unit + test + unit (let mvar = Mvar.create () in let value = "foo" in Fiber.fork_and_join_unit @@ -480,9 +512,11 @@ let%expect_test "reading from empty mvar blocks" = written mvar [PASS] mvar contains expected value () |}] +;; let%expect_test "writing multiple values" = - test unit + test + unit (let mvar = Mvar.create () in let write (n : int) : unit Fiber.t = Printf.printf "writing %d\n" n; @@ -516,9 +550,11 @@ let%expect_test "writing multiple values" = read 1 read 0 () |}] +;; let%expect_test "writing multiple values" = - test unit + test + unit (let m = Mvar.create () in Fiber.fork_and_join_unit (fun () -> @@ -543,26 +579,28 @@ let%expect_test "writing multiple values" = reader1: writing reader2: got 1 () |}] +;; let stream a b = let n = ref a in Fiber.Stream.In.create (fun () -> - if !n > b then Fiber.return None - else - let x = !n in - n := x + 1; - Fiber.return (Some x)) + if !n > b + then Fiber.return None + else ( + let x = !n in + n := x + 1; + Fiber.return (Some x))) +;; let%expect_test "Stream.parallel_iter is indeed parallel" = let test ~iter_function = Scheduler.run (iter_function (stream 1 3) ~f:(fun n -> - Printf.printf "%d: enter\n" n; - let* () = long_running_fiber () in - Printf.printf "%d: leave\n" n; - Fiber.return ())) + Printf.printf "%d: enter\n" n; + let* () = long_running_fiber () in + Printf.printf "%d: leave\n" n; + Fiber.return ())) in - (* The [enter] amd [leave] messages must be interleaved to indicate that the calls to [f] are executed in parallel: *) test ~iter_function:Fiber.Stream.In.parallel_iter; @@ -574,7 +612,6 @@ let%expect_test "Stream.parallel_iter is indeed parallel" = 1: leave 2: leave 3: leave |}]; - (* With [sequential_iter] however, The [enter] amd [leave] messages must be paired in sequence: *) test ~iter_function:Fiber.Stream.In.sequential_iter; @@ -586,6 +623,7 @@ let%expect_test "Stream.parallel_iter is indeed parallel" = 2: leave 3: enter 3: leave |}] +;; let%expect_test "Stream.*_iter can be finalized" = let test ~iter_function = @@ -598,12 +636,13 @@ let%expect_test "Stream.*_iter can be finalized" = in test ~iter_function:Fiber.Stream.In.sequential_iter; [%expect {| finalized |}]; - test ~iter_function:Fiber.Stream.In.parallel_iter; [%expect {| finalized |}] +;; let rec naive_stream_parallel_iter (t : _ Fiber.Stream.In.t) ~f = - Fiber.Stream.In.read t >>= function + Fiber.Stream.In.read t + >>= function | None -> Fiber.return () | Some x -> Fiber.fork_and_join_unit @@ -612,6 +651,7 @@ let rec naive_stream_parallel_iter (t : _ Fiber.Stream.In.t) ~f = optimization in [fork_and_join_unit]*) Scheduler.yield () >>= fun () -> f x) (fun () -> naive_stream_parallel_iter t ~f) +;; let%expect_test "Stream.parallel_iter doesn't leak" = (* Check that a naive [parallel_iter] functions on streams is leaking memory, @@ -629,12 +669,13 @@ let%expect_test "Stream.parallel_iter doesn't leak" = let stream = let count = ref n in Fiber.Stream.In.create (fun () -> - if !count > 0 then ( - decr count; - Fiber.return (Some ())) - else - let* () = Fiber.Ivar.read finish_stream in - Fiber.return None) + if !count > 0 + then ( + decr count; + Fiber.return (Some ())) + else + let* () = Fiber.Ivar.read finish_stream in + Fiber.return None) in let awaiting = ref n in let iter_await = Fiber.Ivar.create () in @@ -669,8 +710,8 @@ let%expect_test "Stream.parallel_iter doesn't leak" = in let test ~pred ~iter_function = let results = List.map data_points ~f:(test ~iter_function) in - if pair_wise_check results ~f:pred then - print_endline "[PASS] memory usage is as expected" + if pair_wise_check results ~f:pred + then print_endline "[PASS] memory usage is as expected" else ( print_endline "[FAIL] memory usage is not as expected"; Dyn.(list int) results |> print_dyn) @@ -681,35 +722,37 @@ let%expect_test "Stream.parallel_iter doesn't leak" = [%expect {| [PASS] memory usage is as expected |}]; test ~pred:( = ) ~iter_function:Fiber.Stream.In.parallel_iter; [%expect {| [PASS] memory usage is as expected |}] +;; let sorted_failures v = - Result.map_error v + Result.map_error + v ~f: - (List.sort - ~compare:(fun (x : Exn_with_backtrace.t) (y : Exn_with_backtrace.t) -> - match (x.exn, y.exn) with - | Failure x, Failure y -> String.compare x y - | _, _ -> assert false)) + (List.sort ~compare:(fun (x : Exn_with_backtrace.t) (y : Exn_with_backtrace.t) -> + match x.exn, y.exn with + | Failure x, Failure y -> String.compare x y + | _, _ -> assert false)) +;; let%expect_test "fork - exceptions always thrown" = test (fun x -> sorted_failures x |> backtrace_result unit) (Fiber.collect_errors (fun () -> - Fiber.fork_and_join_unit - (fun () -> failwith "left") - (fun () -> failwith "right"))); + Fiber.fork_and_join_unit (fun () -> failwith "left") (fun () -> failwith "right"))); [%expect {| Error [ { exn = "Failure(\"left\")"; backtrace = "" } ; { exn = "Failure(\"right\")"; backtrace = "" } ] |}] +;; let test iter = test (fun x -> sorted_failures x |> backtrace_result unit) (Fiber.collect_errors (fun () -> - iter [ 1; 2; 3 ] ~f:(fun x -> failwith (Int.to_string x)))) + iter [ 1; 2; 3 ] ~f:(fun x -> failwith (Int.to_string x)))) +;; let%expect_test "parallel_iter - all exceptions raised" = test Fiber.parallel_iter; @@ -720,11 +763,13 @@ let%expect_test "parallel_iter - all exceptions raised" = ; { exn = "Failure(\"2\")"; backtrace = "" } ; { exn = "Failure(\"3\")"; backtrace = "" } ] |}] +;; let%expect_test "sequential_iter - stop after first exception" = test Fiber.sequential_iter; [%expect {| Error [ { exn = "Failure(\"1\")"; backtrace = "" } ] |}] +;; let%expect_test "Stream: multiple readers is an error" = (* [stream] is so that the first element takes longer to be produced. An @@ -733,15 +778,16 @@ let%expect_test "Stream: multiple readers is an error" = let stream = let n = ref 0 in Fiber.Stream.In.create (fun () -> - let x = !n in - n := x + 1; - let+ () = - if x = 0 then - let* () = long_running_fiber () in - long_running_fiber () - else Fiber.return () - in - Some ()) + let x = !n in + n := x + 1; + let+ () = + if x = 0 + then + let* () = long_running_fiber () in + long_running_fiber () + else Fiber.return () + in + Some ()) in Scheduler.run (Fiber.fork_and_join_unit @@ -754,13 +800,14 @@ let%expect_test "Stream: multiple readers is an error" = printf "Reader 2 reading\n"; let+ _x = Fiber.Stream.In.read stream in printf "Reader 2 done\n")) - [@@expect.uncaught_exn - {| +[@@expect.uncaught_exn + {| ("(\"Fiber.Stream.In: already reading\", {})") Trailing output --------------- Reader 1 reading Reader 2 reading |}] +;; let%expect_test "Stream: multiple writers is an error" = (* [stream] is so that the first element takes longer to be consumed. An @@ -784,29 +831,31 @@ let%expect_test "Stream: multiple writers is an error" = printf "Writer 2 writing\n"; let+ _x = Fiber.Stream.Out.write stream (Some 2) in printf "Writer 2 done\n")) - [@@expect.uncaught_exn - {| +[@@expect.uncaught_exn + {| ("(\"Fiber.Stream.Out: already writing\", {})") Trailing output --------------- Writer 1 writing Writer 2 writing |}] +;; let%expect_test "Stream: writing on a closed stream is an error" = Scheduler.run (let out = Fiber.Stream.Out.create (fun x -> - print_dyn ((option unit) x); - Fiber.return ()) + print_dyn ((option unit) x); + Fiber.return ()) in let* () = Fiber.Stream.Out.write out None in Fiber.Stream.Out.write out (Some ())) - [@@expect.uncaught_exn - {| +[@@expect.uncaught_exn + {| ("(\"Fiber.Stream.Out: stream output closed\", {})") Trailing output --------------- None |}] +;; module Pool = Fiber.Pool @@ -815,6 +864,7 @@ let%expect_test "start & stop pool" = (let pool = Pool.create () in Pool.close pool); [%expect {| |}] +;; let%expect_test "run 2 tasks" = Scheduler.run @@ -834,6 +884,7 @@ let%expect_test "run 2 tasks" = [%expect {| task 1 task 2 |}] +;; let%expect_test "raise exception" = Scheduler.run @@ -850,6 +901,7 @@ let%expect_test "raise exception" = | _ -> assert false) (fun () -> Pool.close pool)); [%expect {| Caught Exit |}] +;; let%expect_test "double run a pool" = (* Calling [Pool.run] twice on the same pool shouldn't be allowed @@ -858,105 +910,110 @@ let%expect_test "double run a pool" = sure only a single [run] will get the exceptions from all tasks in the pool *) (Scheduler.run - @@ - let pool = Pool.create () in - Fiber.fork_and_join_unit (fun () -> Pool.run pool) (fun () -> Pool.run pool)); + @@ + let pool = Pool.create () in + Fiber.fork_and_join_unit (fun () -> Pool.run pool) (fun () -> Pool.run pool)); [%expect.unreachable] - [@@expect.uncaught_exn - {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +[@@expect.uncaught_exn + {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +;; let%expect_test "run -> stop -> run a pool" = (* We shouldn't be able to call [Pool.run] again after we already called [Pool.run] and [Pool.close]. In other words, we can't reuse pools *) (Scheduler.run - @@ - let pool = Pool.create () in - let* () = - Fiber.fork_and_join_unit - (fun () -> Pool.run pool) - (fun () -> Fiber.Pool.task pool ~f:(fun () -> Pool.close pool)) - in - Pool.run pool); + @@ + let pool = Pool.create () in + let* () = + Fiber.fork_and_join_unit + (fun () -> Pool.run pool) + (fun () -> Fiber.Pool.task pool ~f:(fun () -> Pool.close pool)) + in + Pool.run pool); [%expect.unreachable] - [@@expect.uncaught_exn - {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +[@@expect.uncaught_exn + {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +;; let%expect_test "stop a pool and then run it" = (Scheduler.run - @@ - let pool = Pool.create () in - let* () = Pool.close pool in - Pool.run pool); + @@ + let pool = Pool.create () in + let* () = Pool.close pool in + Pool.run pool); [%expect {||}] +;; let%expect_test "pool - weird deadlock" = (* this doesn't dead lock *) (Scheduler.run - @@ - let pool = Pool.create () in - let* () = Pool.task pool ~f:Fiber.return in - Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool) - ); + @@ + let pool = Pool.create () in + let* () = Pool.task pool ~f:Fiber.return in + Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool)); [%expect {||}]; (* but this does *) (Scheduler.run - @@ - let pool = Pool.create () in - let* () = Pool.task pool ~f:Fiber.return in - let* () = Pool.task pool ~f:Fiber.return in - Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool) - ); + @@ + let pool = Pool.create () in + let* () = Pool.task pool ~f:Fiber.return in + let* () = Pool.task pool ~f:Fiber.return in + Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool)); [%expect {||}] +;; let%expect_test "nested run in task" = (Scheduler.run - @@ - let pool = Pool.create () in - let* () = Pool.task pool ~f:(fun () -> Pool.run pool) in - Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool) - ); + @@ + let pool = Pool.create () in + let* () = Pool.task pool ~f:(fun () -> Pool.run pool) in + Fiber.fork_and_join_unit (fun () -> Pool.close pool) (fun () -> Pool.run pool)); [%expect.unreachable] - [@@expect.uncaught_exn - {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +[@@expect.uncaught_exn + {| ("(\"Fiber.Pool.run: concurent calls to run aren't allowed\", {})") |}] +;; let%expect_test "nested tasks" = (Scheduler.run - @@ - let pool = Pool.create () in - let* () = - Pool.task pool ~f:(fun () -> - print_endline "outer"; - let* () = - Pool.task pool ~f:(fun () -> - let+ () = Fiber.return () in - print_endline "inner") - in - Pool.close pool) - in - Pool.run pool); + @@ + let pool = Pool.create () in + let* () = + Pool.task pool ~f:(fun () -> + print_endline "outer"; + let* () = + Pool.task pool ~f:(fun () -> + let+ () = Fiber.return () in + print_endline "inner") + in + Pool.close pool) + in + Pool.run pool); [%expect {| outer inner |}] +;; let%expect_test "stopping inside a task" = (Scheduler.run - @@ - let pool = Pool.create () in - let* () = Pool.task pool ~f:(fun () -> Pool.close pool) in - Pool.run pool); + @@ + let pool = Pool.create () in + let* () = Pool.task pool ~f:(fun () -> Pool.close pool) in + Pool.run pool); [%expect {||}] +;; let%expect_test "stack usage with consecutive Ivar.fill" = let stack_size () = (Gc.stat ()).stack_size in let rec loop acc prev n = - if n = 0 then (acc, prev) - else + if n = 0 + then acc, prev + else ( let next = Fiber.Ivar.create () in let fiber = let* () = Fiber.Ivar.read prev in Fiber.Ivar.fill next () in - loop (fiber :: acc) next (n - 1) + loop (fiber :: acc) next (n - 1)) in let stack_usage n = let first = Fiber.Ivar.create () in @@ -971,32 +1028,31 @@ let%expect_test "stack usage with consecutive Ivar.fill" = in let n0 = Scheduler.run (stack_usage 0) in let n1000 = Scheduler.run (stack_usage 1000) in - if n0 = n1000 then printf "[PASS]" + if n0 = n1000 + then printf "[PASS]" else printf - "[FAIL]\n\ - Stack usage for n = 0: %d words\n\ - Stack usage for n = 1000: %d words\n" - n0 n1000; + "[FAIL]\nStack usage for n = 0: %d words\nStack usage for n = 1000: %d words\n" + n0 + n1000; [%expect {| [PASS] |}] +;; let%expect_test "all_concurrently_unit" = Scheduler.run (let+ () = Fiber.all_concurrently_unit [] in printf "empty list"); [%expect {| empty list |}]; - Scheduler.run (let+ () = Fiber.all_concurrently_unit [ Fiber.return () ] in printf "singleton list"); [%expect {| singleton list |}]; - Scheduler.run (let print i = Fiber.of_thunk (fun () -> - printfn "print: %i" i; - Fiber.return ()) + printfn "print: %i" i; + Fiber.return ()) in let+ () = Fiber.all_concurrently_unit [ print 1; print 2 ] in printf "multi element list"); @@ -1004,18 +1060,16 @@ let%expect_test "all_concurrently_unit" = print: 1 print: 2 multi element list |}]; - Scheduler.run (let print i = Fiber.of_thunk (fun () -> - printfn "print: %i" i; - Fiber.return ()) + printfn "print: %i" i; + Fiber.return ()) in let fail = Fiber.of_thunk (fun () -> raise Exit) in let+ () = let+ res = - Fiber.collect_errors (fun () -> - Fiber.all_concurrently_unit [ print 1; fail ]) + Fiber.collect_errors (fun () -> Fiber.all_concurrently_unit [ print 1; fail ]) in match res with | Error [ { exn = Exit; _ } ] -> printfn "successfully caught error" @@ -1023,11 +1077,11 @@ let%expect_test "all_concurrently_unit" = | Error _ -> assert false in printf "multi element list"); - [%expect - {| + [%expect {| print: 1 successfully caught error multi element list |}] +;; let%expect_test "cancel_test1" = let cancel = Fiber.Cancel.create () in @@ -1039,6 +1093,7 @@ let%expect_test "cancel_test1" = [%expect {| false true |}] +;; let%expect_test "cancel_test2" = let cancel = Fiber.Cancel.create () in @@ -1046,7 +1101,8 @@ let%expect_test "cancel_test2" = let ivar2 = Fiber.Ivar.create () in let (), what = Scheduler.run - (Fiber.Cancel.with_handler cancel + (Fiber.Cancel.with_handler + cancel (fun () -> let* () = Fiber.Ivar.fill ivar1 () in let* () = Fiber.Cancel.fire cancel in @@ -1055,40 +1111,45 @@ let%expect_test "cancel_test2" = in print_endline (match what with - | Cancelled () -> "PASS" - | Not_cancelled -> "FAIL"); + | Cancelled () -> "PASS" + | Not_cancelled -> "FAIL"); [%expect {| PASS |}] +;; let%expect_test "cancel_test3" = let cancel = Fiber.Cancel.create () in let (), what = Scheduler.run - (Fiber.Cancel.with_handler cancel + (Fiber.Cancel.with_handler + cancel (fun () -> Fiber.return ()) ~on_cancel:(fun () -> assert false)) in print_endline (match what with - | Cancelled () -> "FAIL" - | Not_cancelled -> "PASS"); + | Cancelled () -> "FAIL" + | Not_cancelled -> "PASS"); [%expect {| PASS |}] +;; let%expect_test "cancel_test4" = let cancel = Fiber.Cancel.create () in let (), what = Scheduler.run (let* () = Fiber.Cancel.fire cancel in - Fiber.Cancel.with_handler cancel + Fiber.Cancel.with_handler + cancel (fun () -> Fiber.return ()) ~on_cancel:(fun () -> Fiber.return ())) in print_endline (match what with - | Cancelled () -> "PASS" - | Not_cancelled -> "FAIL"); + | Cancelled () -> "PASS" + | Not_cancelled -> "FAIL"); [%expect {| PASS |}] +;; let%expect_test "svar" = let module Svar = Fiber.Svar in @@ -1115,3 +1176,4 @@ let%expect_test "svar" = waiter: waiting for value > 15 setter: modifying value to 17 wait: 17 |}] +;; diff --git a/fiber/test/test_scheduler.ml b/fiber/test/test_scheduler.ml index d0948ea..83c7b7e 100644 --- a/fiber/test/test_scheduler.ml +++ b/fiber/test/test_scheduler.ml @@ -1,30 +1,34 @@ open Stdune type job = Job : (unit -> 'a) * 'a Fiber.Ivar.t -> job - type t = job Queue.t let create () : t = Queue.create () let yield t = - Fiber.of_thunk @@ fun () -> + Fiber.of_thunk + @@ fun () -> let ivar = Fiber.Ivar.create () in Queue.push t (Job ((fun () -> ()), ivar)); Fiber.Ivar.read ivar +;; let yield_gen (t : t) ~do_in_scheduler = - Fiber.of_thunk @@ fun () -> + Fiber.of_thunk + @@ fun () -> let ivar = Fiber.Ivar.create () in Queue.push t (Job (do_in_scheduler, ivar)); Fiber.Ivar.read ivar +;; exception Never let run (t : t) fiber = Queue.clear t; Fiber.run fiber ~iter:(fun () -> - match Queue.pop t with - | None -> raise Never - | Some (Job (job, ivar)) -> - let v = job () in - [ Fiber.Fill (ivar, v) ]) + match Queue.pop t with + | None -> raise Never + | Some (Job (job, ivar)) -> + let v = job () in + [ Fiber.Fill (ivar, v) ]) +;; diff --git a/fiber/test/test_scheduler.mli b/fiber/test/test_scheduler.mli index 405484b..24f5f7e 100644 --- a/fiber/test/test_scheduler.mli +++ b/fiber/test/test_scheduler.mli @@ -5,9 +5,6 @@ type t exception Never val create : unit -> t - val yield : t -> unit Fiber.t - val yield_gen : t -> do_in_scheduler:(unit -> 'a) -> 'a Fiber.t - val run : t -> 'a Fiber.t -> 'a diff --git a/flake.nix b/flake.nix index b5281c2..dd23ebf 100644 --- a/flake.nix +++ b/flake.nix @@ -32,7 +32,7 @@ }; devShells.default = pkgs.mkShell { inputsFrom = pkgs.lib.attrValues packages; - buildInputs = with pkgs.ocamlPackages; [ ocaml-lsp pkgs.ocamlformat_0_24_1 ]; + buildInputs = with pkgs.ocamlPackages; [ ocaml-lsp pkgs.ocamlformat_0_26_1 ]; }; }); }