Skip to content

Commit

Permalink
Transform JuvixReg into SSA form (#2646)
Browse files Browse the repository at this point in the history
* Closes #2560 
* Adds a transformation of JuvixReg into SSA form.
* Adds an "output variable" field to branching instructions (`Case`,
`Branch`) which indicates the output variable to which the result is
assigned in both branches. The output variable corresponds to top of
stack in JuvixAsm after executing the branches. In the SSA
transformation, differently renamed output variables are unified by
inserting assignment instructions at the end of branches.
* Adds tests for the SSA transformation.
* Depends on #2641.
  • Loading branch information
lukaszcz authored Feb 20, 2024
1 parent 9a48f1f commit cb808c1
Show file tree
Hide file tree
Showing 21 changed files with 522 additions and 74 deletions.
54 changes: 27 additions & 27 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import Juvix.Compiler.Asm.Pretty
data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseBranch :: Bool -> m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: Bool -> m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseSave :: m -> CmdSave -> [a] -> Sem r a
}

Expand Down Expand Up @@ -252,7 +252,7 @@ recurse' sig = go True
let mem0 = popValueStack 1 mem
(mem1, as1) <- go isTail mem0 _cmdBranchTrue
(mem2, as2) <- go isTail mem0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) mem cmd as1 as2
a' <- (sig ^. recurseBranch) isTail mem cmd as1 as2
mem' <- unifyMemory' loc (sig ^. recursorInfoTable) mem1 mem2
checkBranchInvariant 1 loc mem0 mem'
return (mem', a')
Expand All @@ -268,7 +268,7 @@ recurse' sig = go True
rd <- maybe (return Nothing) (fmap Just . go isTail mem) _cmdCaseDefault
let md = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) mem cmd ass ad
a' <- (sig ^. recurseCase) isTail mem cmd ass ad
case mems of
[] -> return (fromMaybe mem md, a')
mem0 : mems' -> do
Expand Down Expand Up @@ -333,25 +333,25 @@ recurseS :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a
recurseS sig code = snd <$> recurseS' sig initialStackInfo code

recurseS' :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a])
recurseS' sig = go
recurseS' sig = go True
where
go :: StackInfo -> Code -> Sem r (StackInfo, [a])
go si = \case
go :: Bool -> StackInfo -> Code -> Sem r (StackInfo, [a])
go isTail si = \case
[] -> return (si, [])
h : t -> case h of
Instr x -> do
goNextCmd (goInstr si x) t
goNextCmd isTail (goInstr si x) t
Branch x ->
goNextCmd (goBranch si x) t
goNextCmd isTail (goBranch (isTail && null t) si x) t
Case x ->
goNextCmd (goCase si x) t
goNextCmd isTail (goCase (isTail && null t) si x) t
Save x ->
goNextCmd (goSave si x) t
goNextCmd isTail (goSave si x) t

goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
goNextCmd :: Bool -> Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd isTail mp t = do
(si', r) <- mp
(si'', rs) <- go si' t
(si'', rs) <- go isTail si' t
return (si'', r : rs)

goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a)
Expand Down Expand Up @@ -433,26 +433,26 @@ recurseS' sig = go
fixStackCallClosures si InstrCallClosures {..} = do
return $ stackInfoPopValueStack _callClosuresArgsNum si

goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch si cmd@CmdBranch {..} = do
goBranch :: Bool -> StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch isTail si cmd@CmdBranch {..} = do
let si0 = stackInfoPopValueStack 1 si
(si1, as1) <- go si0 _cmdBranchTrue
(si2, as2) <- go si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) si cmd as1 as2
(si1, as1) <- go isTail si0 _cmdBranchTrue
(si2, as2) <- go isTail si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) isTail si cmd as1 as2
checkStackInfo loc si1 si2
return (si1, a')
where
loc = cmd ^. cmdBranchInfo . commandInfoLocation

goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase si cmd@CmdCase {..} = do
rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches
goCase :: Bool -> StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase isTail si cmd@CmdCase {..} = do
rs <- mapM (go isTail si . (^. caseBranchCode)) _cmdCaseBranches
let sis = map fst rs
ass = map snd rs
rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault
rd <- maybe (return Nothing) (fmap Just . go isTail si) _cmdCaseDefault
let sd = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) si cmd ass ad
a' <- (sig ^. recurseCase) isTail si cmd ass ad
case sis of
[] -> return (fromMaybe si sd, a')
si0 : sis' -> do
Expand All @@ -465,7 +465,7 @@ recurseS' sig = go
goSave :: StackInfo -> CmdSave -> Sem r (StackInfo, a)
goSave si cmd@CmdSave {..} = do
let si1 = stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
(si2, c) <- go si1 _cmdSaveCode
(si2, c) <- go _cmdSaveIsTail si1 _cmdSaveCode
c' <- (sig ^. recurseSave) si cmd c
let si' = if _cmdSaveIsTail then si2 else stackInfoPopTempStack 1 si2
return (si', c')
Expand Down Expand Up @@ -528,15 +528,15 @@ foldS' sig si code acc = do
RecursorSig
{ _recursorInfoTable = sig ^. foldInfoTable,
_recurseInstr = \s cmd -> return ((sig ^. foldInstr) s cmd),
_recurseBranch = \s cmd br1 br2 ->
_recurseBranch = \_ s cmd br1 br2 ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
a1 <- compose' br1 a'
a2 <- compose' br2 a'
(sig ^. foldBranch) s cmd a1 a2 a
),
_recurseCase = \s cmd brs md ->
_recurseCase = \_ s cmd brs md ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Transformation/StackUsage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ computeFunctionStackUsage tab fi = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight),
_recurseBranch = \si _ l r ->
_recurseBranch = \_ si _ l r ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))),
max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r)))
),
_recurseCase = \si _ cs md ->
_recurseCase = \_ si _ cs md ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Transformation/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ validateCode tab fi code = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \_ _ -> return (),
_recurseBranch = \_ _ _ _ -> return (),
_recurseCase = \_ _ _ _ -> return (),
_recurseBranch = \_ _ _ _ _ -> return (),
_recurseCase = \_ _ _ _ _ -> return (),
_recurseSave = \_ _ _ -> return ()
}

Expand Down
58 changes: 58 additions & 0 deletions src/Juvix/Compiler/Reg/Data/IndexMap.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module Juvix.Compiler.Reg.Data.IndexMap where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Reg.Language.Base hiding (lookup)

data IndexMap k = IndexMap
{ _indexMapFirstFree :: Int,
_indexMapTable :: HashMap k Index
}

makeLenses ''IndexMap

instance (Hashable k) => Semigroup (IndexMap k) where
m1 <> m2 =
IndexMap
{ _indexMapTable = m1 ^. indexMapTable <> m2 ^. indexMapTable,
_indexMapFirstFree = max (m1 ^. indexMapFirstFree) (m2 ^. indexMapFirstFree)
}

instance (Hashable k) => Monoid (IndexMap k) where
mempty =
IndexMap
{ _indexMapFirstFree = 0,
_indexMapTable = mempty
}

assign :: (Hashable k) => IndexMap k -> k -> (Index, IndexMap k)
assign IndexMap {..} k =
( _indexMapFirstFree,
IndexMap
{ _indexMapFirstFree = _indexMapFirstFree + 1,
_indexMapTable = HashMap.insert k _indexMapFirstFree _indexMapTable
}
)

lookup' :: (Hashable k) => IndexMap k -> k -> Maybe Index
lookup' IndexMap {..} k = HashMap.lookup k _indexMapTable

lookup :: (Hashable k) => IndexMap k -> k -> Index
lookup mp = fromJust . lookup' mp

combine :: forall k. (Hashable k) => IndexMap k -> IndexMap k -> IndexMap k
combine mp1 mp2 =
IndexMap
{ _indexMapFirstFree = max (mp1 ^. indexMapFirstFree) (mp2 ^. indexMapFirstFree),
_indexMapTable = mp
}
where
mp =
foldr
(\k -> HashMap.update (checkVal k) k)
(HashMap.intersection (mp1 ^. indexMapTable) (mp2 ^. indexMapTable))
(HashMap.keys (mp2 ^. indexMapTable))

checkVal :: k -> Index -> Maybe Index
checkVal k idx
| lookup mp2 k == idx = Just idx
| otherwise = Nothing
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Juvix.Prelude

data TransformationId
= Identity
| SSA
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -19,12 +20,13 @@ toCTransformations :: [TransformationId]
toCTransformations = []

toCairoTransformations :: [TransformationId]
toCairoTransformations = []
toCairoTransformations = [SSA]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
Identity -> strIdentity
SSA -> strSSA

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ strCairoPipeline = "pipeline-cairo"

strIdentity :: Text
strIdentity = "identity"

strSSA :: Text
strSSA = "ssa"
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Reg/Extra.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module Juvix.Compiler.Reg.Extra
( module Juvix.Compiler.Reg.Extra.Base,
module Juvix.Compiler.Reg.Extra.Recursors,
module Juvix.Compiler.Reg.Extra.Info,
)
where

import Juvix.Compiler.Reg.Extra.Base
import Juvix.Compiler.Reg.Extra.Info
import Juvix.Compiler.Reg.Extra.Recursors
Loading

0 comments on commit cb808c1

Please sign in to comment.