Skip to content

Commit

Permalink
Merge pull request #180 from MPLLang/auto-par-management
Browse files Browse the repository at this point in the history
(WIP) merge: automatic parallelism management
  • Loading branch information
shwestrick authored Feb 19, 2024
2 parents b59d9e6 + 32b51ff commit e5950f3
Show file tree
Hide file tree
Showing 142 changed files with 7,003 additions and 791 deletions.
3 changes: 2 additions & 1 deletion basis-library/fork-join.mlb
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
schedulers/shh.mlb
(* schedulers/shh.mlb *)
schedulers/pcall.mlb
1 change: 1 addition & 0 deletions basis-library/libs/basis-extra/top-level/basis.sig
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,7 @@ signature BASIS_EXTRA =
where type MLton.Pointer.t = MLton.Pointer.t
where type 'a MLton.Thread.t = 'a MLton.Thread.t
where type MLton.Thread.Runnable.t = MLton.Thread.Runnable.t
where type MLton.Thread.Basic.t = MLton.Thread.Basic.t

(* Types that must be exposed because constants denote them. *)
where type FixedInt.int = FixedInt.int
Expand Down
1 change: 1 addition & 0 deletions basis-library/mlton/signal.sig
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ signature MLTON_SIGNAL =
val isDefault: t -> bool
val isIgnore: t -> bool
val simple: (unit -> unit) -> t
val inspectInterrupted: (MLtonThread.Basic.t -> unit) -> t
end

structure Mask:
Expand Down
115 changes: 85 additions & 30 deletions basis-library/mlton/signal.sml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
structure MLtonSignal: MLTON_SIGNAL_EXTRA =
struct

fun die (s: string): 'a =
(PrimitiveFFI.Stdio.print s
; PrimitiveFFI.Posix.Process.exit 1
; let exception DieFailed
in raise DieFailed
end)

open Posix.Signal
structure Prim = PrimitiveFFI.Posix.Signal
structure GCState = Primitive.MLton.GCState
Expand Down Expand Up @@ -88,7 +95,9 @@ structure Handler =
struct
datatype t =
Default
| Handler of MLtonThread.Runnable.t -> MLtonThread.Runnable.t
(* | Handler of MLtonThread.Runnable.t -> MLtonThread.Runnable.t *)
(* | Handler of unit -> unit *)
| Handler of MLtonThread.Basic.t -> unit
| Ignore
| InvalidSignal
end
Expand Down Expand Up @@ -143,7 +152,7 @@ structure Handler =
val isDefault = fn Default => true | _ => false
val isIgnore = fn Ignore => true | _ => false

val handler =
val handler_ =
(* This let is used so that Thread.setHandler is only used if
* Handler.handler is used. This prevents threads from being part
* of every program.
Expand All @@ -166,37 +175,82 @@ structure Handler =
fn () => Mask.setBlocked m
end)

val () =
MLtonThread.setSignalHandler
(fn t =>
let
val mask = Mask.getBlocked ()
val () = Mask.block (handled ())
val fs =
case !gcHandler of
Handler f => if Prim.isPendingGC (GCState.gcState ()) <> C_Int.zero
then [f]
else []
| _ => []
val fs =
Array.foldri
(fn (s, h, fs) =>
case h of
Handler f =>
if Prim.isPending (GCState.gcState (), repFromInt s) <> C_Int.zero
then f::fs
else fs
| _ => fs) fs handlers
val () = Prim.resetPending (GCState.gcState ())
val () = Mask.setBlocked mask
in
List.foldl (fn (f, t) => f t) t fs
end)

fun normalHandler t =
let
val mask = Mask.getBlocked ()
val () = Mask.block (handled ())
val fs =
case !gcHandler of
Handler f => if Prim.isPendingGC (GCState.gcState ()) <> C_Int.zero
then [f]
else []
| _ => []
val fs =
Array.foldri
(fn (s, h, fs) =>
case h of
Handler f =>
if Prim.isPending (GCState.gcState (), repFromInt s) <> C_Int.zero
then f::fs
else fs
| _ => fs) fs handlers
val () = Prim.resetPending (GCState.gcState ())
val () = Mask.setBlocked mask
in
(* List.foldl (fn (f, t) => f t) t fs *)
List.app (fn f => f t) fs
end


(* SAM_NOTE: TODO: Check with Matthew. This is probably subtly
* incorrect.
*
* The goal is to do a "fast path" for the signal handling code.
* The common case is that just a single signal handler needs to
* run, and we should be able to find this handler without any
* intermediate allocation if possible.
*)
fun maybeFastHandler t =
let
val n = Array.length handlers
fun loop f count i =
if i >= n then (f, count)
else
case Array.sub (handlers, i) of
Handler g =>
if Prim.isPending (GCState.gcState (), repFromInt i) <> C_Int.zero
then loop g (count+1) (i+1)
else loop f count (i+1)

| _ => loop f count (i+1)
in
case loop (fn t => ()) 0 0 of
(f, 1) =>
(* fast path succeeds if we find just a single handler that
* needs to be run.
*)
let
val _ = Prim.resetPending (GCState.gcState ())
in
f t
end

| _ => normalHandler t
end


val () = MLtonThread.setSimpleSignalHandler maybeFastHandler
(* val () = MLtonThread.setSimpleSignalHandler normalHandler *)
in
Handler
end

fun simple (f: unit -> unit) = handler (fn t => (f (); t))
fun handler _ = raise Fail "Signal.Handler.handler not supported"

fun simple (f: unit -> unit) = handler_ (fn t => f ())

fun inspectInterrupted f = handler_ f
end

val setHandler = fn (s, h) =>
Expand All @@ -222,7 +276,8 @@ fun suspend m =
; MLtonThread.switchToSignalHandler ())

fun handleGC f =
(Prim.handleGC (GCState.gcState ())
( ignore (die ("MLton.Signal.handleGC unsupported\n"))
; Prim.handleGC (GCState.gcState ())
; gcHandler := Handler.simple f)

end
27 changes: 25 additions & 2 deletions basis-library/mlton/thread.sig
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,31 @@ signature MLTON_THREAD =
val updateBytesPinnedEntangledWatermark: unit -> unit

(* "put a new thread in the hierarchy *)
val moveNewThreadToDepth : thread * int -> unit
val moveNewThreadToDepth : thread * Word64.word * int -> unit

val checkFinishedCCReadyToJoin: unit -> bool

val canForkThread: thread -> bool

(* second arg is joinpoint *)
(* val forkThread: thread * 'a ref -> Basic.p *)

val joinIntoParentBeforeFastClone:
{ thread: thread
, newDepth: int
, tidLeft: Word64.word
, tidRight: Word64.word
}
-> unit

val joinIntoParent:
{ thread: thread
, rightSideThread: thread
, newDepth: int
, tidLeft: Word64.word
, tidRight: Word64.word
}
-> unit
end

(* disentanglement checking *)
Expand All @@ -107,7 +129,7 @@ signature MLTON_THREAD =
val decheckGetTid : thread -> Word64.word

(* arguments are (victim thread, steal depth) *)
val copySyncDepthsFromThread : thread * int -> unit
val copySyncDepthsFromThread : thread * thread * int -> unit
end

type 'a t
Expand Down Expand Up @@ -151,6 +173,7 @@ signature MLTON_THREAD_EXTRA =
val amInSignalHandler: unit -> bool
val register: int * (MLtonPointer.t -> unit) -> unit
val setSignalHandler: (Runnable.t -> Runnable.t) -> unit
val setSimpleSignalHandler: (Basic.t -> unit) -> unit
val switchToSignalHandler: unit -> unit

val initPrimitive: unit t -> Runnable.t
Expand Down
Loading

0 comments on commit e5950f3

Please sign in to comment.