Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform JuvixReg into SSA form #2646

Merged
merged 7 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ import Juvix.Compiler.Asm.Pretty
data RecursorSig m r a = RecursorSig
{ _recursorInfoTable :: InfoTable,
_recurseInstr :: m -> CmdInstr -> Sem r a,
_recurseBranch :: m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseBranch :: Bool -> m -> CmdBranch -> [a] -> [a] -> Sem r a,
_recurseCase :: Bool -> m -> CmdCase -> [[a]] -> Maybe [a] -> Sem r a,
_recurseSave :: m -> CmdSave -> [a] -> Sem r a
}

Expand Down Expand Up @@ -252,7 +252,7 @@ recurse' sig = go True
let mem0 = popValueStack 1 mem
(mem1, as1) <- go isTail mem0 _cmdBranchTrue
(mem2, as2) <- go isTail mem0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) mem cmd as1 as2
a' <- (sig ^. recurseBranch) isTail mem cmd as1 as2
mem' <- unifyMemory' loc (sig ^. recursorInfoTable) mem1 mem2
checkBranchInvariant 1 loc mem0 mem'
return (mem', a')
Expand All @@ -268,7 +268,7 @@ recurse' sig = go True
rd <- maybe (return Nothing) (fmap Just . go isTail mem) _cmdCaseDefault
let md = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) mem cmd ass ad
a' <- (sig ^. recurseCase) isTail mem cmd ass ad
case mems of
[] -> return (fromMaybe mem md, a')
mem0 : mems' -> do
Expand Down Expand Up @@ -333,25 +333,25 @@ recurseS :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a
recurseS sig code = snd <$> recurseS' sig initialStackInfo code

recurseS' :: forall r a. (Member (Error AsmError) r) => RecursorSig StackInfo r a -> StackInfo -> Code -> Sem r (StackInfo, [a])
recurseS' sig = go
recurseS' sig = go True
where
go :: StackInfo -> Code -> Sem r (StackInfo, [a])
go si = \case
go :: Bool -> StackInfo -> Code -> Sem r (StackInfo, [a])
go isTail si = \case
[] -> return (si, [])
h : t -> case h of
Instr x -> do
goNextCmd (goInstr si x) t
goNextCmd isTail (goInstr si x) t
Branch x ->
goNextCmd (goBranch si x) t
goNextCmd isTail (goBranch (isTail && null t) si x) t
Case x ->
goNextCmd (goCase si x) t
goNextCmd isTail (goCase (isTail && null t) si x) t
Save x ->
goNextCmd (goSave si x) t
goNextCmd isTail (goSave si x) t

goNextCmd :: Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd mp t = do
goNextCmd :: Bool -> Sem r (StackInfo, a) -> Code -> Sem r (StackInfo, [a])
goNextCmd isTail mp t = do
(si', r) <- mp
(si'', rs) <- go si' t
(si'', rs) <- go isTail si' t
return (si'', r : rs)

goInstr :: StackInfo -> CmdInstr -> Sem r (StackInfo, a)
Expand Down Expand Up @@ -433,26 +433,26 @@ recurseS' sig = go
fixStackCallClosures si InstrCallClosures {..} = do
return $ stackInfoPopValueStack _callClosuresArgsNum si

goBranch :: StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch si cmd@CmdBranch {..} = do
goBranch :: Bool -> StackInfo -> CmdBranch -> Sem r (StackInfo, a)
goBranch isTail si cmd@CmdBranch {..} = do
let si0 = stackInfoPopValueStack 1 si
(si1, as1) <- go si0 _cmdBranchTrue
(si2, as2) <- go si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) si cmd as1 as2
(si1, as1) <- go isTail si0 _cmdBranchTrue
(si2, as2) <- go isTail si0 _cmdBranchFalse
a' <- (sig ^. recurseBranch) isTail si cmd as1 as2
checkStackInfo loc si1 si2
return (si1, a')
where
loc = cmd ^. cmdBranchInfo . commandInfoLocation

goCase :: StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase si cmd@CmdCase {..} = do
rs <- mapM (go si . (^. caseBranchCode)) _cmdCaseBranches
goCase :: Bool -> StackInfo -> CmdCase -> Sem r (StackInfo, a)
goCase isTail si cmd@CmdCase {..} = do
rs <- mapM (go isTail si . (^. caseBranchCode)) _cmdCaseBranches
let sis = map fst rs
ass = map snd rs
rd <- maybe (return Nothing) (fmap Just . go si) _cmdCaseDefault
rd <- maybe (return Nothing) (fmap Just . go isTail si) _cmdCaseDefault
let sd = fmap fst rd
ad = fmap snd rd
a' <- (sig ^. recurseCase) si cmd ass ad
a' <- (sig ^. recurseCase) isTail si cmd ass ad
case sis of
[] -> return (fromMaybe si sd, a')
si0 : sis' -> do
Expand All @@ -465,7 +465,7 @@ recurseS' sig = go
goSave :: StackInfo -> CmdSave -> Sem r (StackInfo, a)
goSave si cmd@CmdSave {..} = do
let si1 = stackInfoPushTempStack 1 (stackInfoPopValueStack 1 si)
(si2, c) <- go si1 _cmdSaveCode
(si2, c) <- go _cmdSaveIsTail si1 _cmdSaveCode
c' <- (sig ^. recurseSave) si cmd c
let si' = if _cmdSaveIsTail then si2 else stackInfoPopTempStack 1 si2
return (si', c')
Expand Down Expand Up @@ -528,15 +528,15 @@ foldS' sig si code acc = do
RecursorSig
{ _recursorInfoTable = sig ^. foldInfoTable,
_recurseInstr = \s cmd -> return ((sig ^. foldInstr) s cmd),
_recurseBranch = \s cmd br1 br2 ->
_recurseBranch = \_ s cmd br1 br2 ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
a1 <- compose' br1 a'
a2 <- compose' br2 a'
(sig ^. foldBranch) s cmd a1 a2 a
),
_recurseCase = \s cmd brs md ->
_recurseCase = \_ s cmd brs md ->
return
( \a -> do
let a' = (sig ^. foldAdjust) a
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Transformation/StackUsage.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@ computeFunctionStackUsage tab fi = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \si _ -> return (si ^. stackInfoValueStackHeight, si ^. stackInfoTempStackHeight),
_recurseBranch = \si _ l r ->
_recurseBranch = \_ si _ l r ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map fst l)) (maximum (map fst r))),
max (si ^. stackInfoTempStackHeight) (max (maximum (map snd l)) (maximum (map snd r)))
),
_recurseCase = \si _ cs md ->
_recurseCase = \_ si _ cs md ->
return
( max (si ^. stackInfoValueStackHeight) (max (maximum (map (maximum . map fst) cs)) (maybe 0 (maximum . map fst) md)),
max (si ^. stackInfoTempStackHeight) (max (maximum (map (maximum . map snd) cs)) (maybe 0 (maximum . map snd) md))
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Transformation/Validate.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ validateCode tab fi code = do
RecursorSig
{ _recursorInfoTable = tab,
_recurseInstr = \_ _ -> return (),
_recurseBranch = \_ _ _ _ -> return (),
_recurseCase = \_ _ _ _ -> return (),
_recurseBranch = \_ _ _ _ _ -> return (),
_recurseCase = \_ _ _ _ _ -> return (),
_recurseSave = \_ _ _ -> return ()
}

Expand Down
58 changes: 58 additions & 0 deletions src/Juvix/Compiler/Reg/Data/IndexMap.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module Juvix.Compiler.Reg.Data.IndexMap where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Reg.Language.Base hiding (lookup)

data IndexMap k = IndexMap
{ _indexMapFirstFree :: Int,
_indexMapTable :: HashMap k Index
}

makeLenses ''IndexMap

instance (Hashable k) => Semigroup (IndexMap k) where
m1 <> m2 =
IndexMap
{ _indexMapTable = m1 ^. indexMapTable <> m2 ^. indexMapTable,
_indexMapFirstFree = max (m1 ^. indexMapFirstFree) (m2 ^. indexMapFirstFree)
}

instance (Hashable k) => Monoid (IndexMap k) where
mempty =
IndexMap
{ _indexMapFirstFree = 0,
_indexMapTable = mempty
}

assign :: (Hashable k) => IndexMap k -> k -> (Index, IndexMap k)
assign IndexMap {..} k =
( _indexMapFirstFree,
IndexMap
{ _indexMapFirstFree = _indexMapFirstFree + 1,
_indexMapTable = HashMap.insert k _indexMapFirstFree _indexMapTable
}
)

lookup' :: (Hashable k) => IndexMap k -> k -> Maybe Index
lookup' IndexMap {..} k = HashMap.lookup k _indexMapTable

lookup :: (Hashable k) => IndexMap k -> k -> Index
lookup mp = fromJust . lookup' mp

combine :: forall k. (Hashable k) => IndexMap k -> IndexMap k -> IndexMap k
combine mp1 mp2 =
IndexMap
{ _indexMapFirstFree = max (mp1 ^. indexMapFirstFree) (mp2 ^. indexMapFirstFree),
_indexMapTable = mp
}
where
mp =
foldr
(\k -> HashMap.update (checkVal k) k)
(HashMap.intersection (mp1 ^. indexMapTable) (mp2 ^. indexMapTable))
(HashMap.keys (mp2 ^. indexMapTable))

checkVal :: k -> Index -> Maybe Index
checkVal k idx
| lookup mp2 k == idx = Just idx
| otherwise = Nothing
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Reg/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import Juvix.Prelude

data TransformationId
= Identity
| SSA
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -19,12 +20,13 @@ toCTransformations :: [TransformationId]
toCTransformations = []

toCairoTransformations :: [TransformationId]
toCairoTransformations = []
toCairoTransformations = [SSA]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
transformationText = \case
Identity -> strIdentity
SSA -> strSSA

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Reg/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ strCairoPipeline = "pipeline-cairo"

strIdentity :: Text
strIdentity = "identity"

strSSA :: Text
strSSA = "ssa"
10 changes: 10 additions & 0 deletions src/Juvix/Compiler/Reg/Extra.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module Juvix.Compiler.Reg.Extra
( module Juvix.Compiler.Reg.Extra.Base,
module Juvix.Compiler.Reg.Extra.Recursors,
module Juvix.Compiler.Reg.Extra.Info,
)
where

import Juvix.Compiler.Reg.Extra.Base
import Juvix.Compiler.Reg.Extra.Info
import Juvix.Compiler.Reg.Extra.Recursors
Loading
Loading