Skip to content

Commit

Permalink
reachability
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszcz committed Jan 25, 2024
1 parent f3848aa commit 4fe9538
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Asm/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ toReg' = validate >=> filterUnreachable >=> computeStackUsage >=> computePreallo
-- | Perform transformations on JuvixAsm necessary before the translation to
-- Nockma
toNockma' :: (Members '[Error AsmError, Reader Options] r) => InfoTable -> Sem r InfoTable
toNockma' = validate >=> filterUnreachable
toNockma' = validate

toReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => InfoTable -> Sem r InfoTable
toReg = mapReader fromEntryPoint . mapError (JuvixError @AsmError) . toReg'
Expand Down
35 changes: 35 additions & 0 deletions src/Juvix/Compiler/Tree/Data/CallGraph.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module Juvix.Compiler.Tree.Data.CallGraph where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Compiler.Tree.Extra.Recursors

-- | Call graph type
type CallGraph = DependencyInfo Symbol

-- | Compute the call graph
createCallGraph :: InfoTable -> CallGraph
createCallGraph tab = createDependencyInfo (createCallGraphMap tab) startVertices
where
startVertices :: HashSet Symbol
startVertices = HashSet.fromList syms

syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMainFunction)

createCallGraphMap :: InfoTable -> HashMap Symbol (HashSet Symbol)
createCallGraphMap tab = fmap (getFunSymbols . (^. functionCode)) (tab ^. infoFunctions)

getFunSymbols :: Node -> HashSet Symbol
getFunSymbols = gather go mempty
where
go :: HashSet Symbol -> Node -> HashSet Symbol
go syms = \case
AllocClosure NodeAllocClosure {..} -> HashSet.insert _nodeAllocClosureFunSymbol syms
Call NodeCall {..} -> goCallType syms _nodeCallType
_ -> syms

goCallType :: HashSet Symbol -> CallType -> HashSet Symbol
goCallType syms = \case
CallFun sym -> HashSet.insert sym syms
CallClosure {} -> syms
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Tree/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ data TransformationId
| IdentityD
| Apply
| TempHeight
| FilterUnreachable
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -20,7 +21,7 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toNockmaTransformations :: [TransformationId]
toNockmaTransformations = [Apply, TempHeight]
toNockmaTransformations = [Apply, FilterUnreachable, TempHeight]

toAsmTransformations :: [TransformationId]
toAsmTransformations = []
Expand All @@ -33,6 +34,7 @@ instance TransformationId' TransformationId where
IdentityD -> strIdentityD
Apply -> strApply
TempHeight -> strTempHeight
FilterUnreachable -> strFilterUnreachable

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ strApply = "apply"

strTempHeight :: Text
strTempHeight = "temp-height"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where
import Juvix.Compiler.Tree.Data.TransformationId
import Juvix.Compiler.Tree.Transformation.Apply
import Juvix.Compiler.Tree.Transformation.Base
import Juvix.Compiler.Tree.Transformation.FilterUnreachable
import Juvix.Compiler.Tree.Transformation.Identity
import Juvix.Compiler.Tree.Transformation.TempHeight

Expand All @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
IdentityD -> return . identityD
Apply -> return . computeApply
TempHeight -> return . computeTempHeight
FilterUnreachable -> return . filterUnreachable
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Tree/Transformation/FilterUnreachable.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Juvix.Compiler.Tree.Transformation.FilterUnreachable where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Tree.Data.CallGraph
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Prelude

filterUnreachable :: InfoTable -> InfoTable
filterUnreachable tab =
over infoFunctions (HashMap.filterWithKey (const . isReachable graph)) tab
where
graph = createCallGraph tab
4 changes: 3 additions & 1 deletion test/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ module Tree.Transformation where
import Base
import Tree.Transformation.Apply qualified as Apply
import Tree.Transformation.Identity qualified as Identity
import Tree.Transformation.Reachability qualified as Reachability

allTests :: TestTree
allTests =
testGroup
"JuvixTree transformations"
[ Identity.allTests,
Apply.allTests
Apply.allTests,
Reachability.allTests
]
48 changes: 48 additions & 0 deletions test/Tree/Transformation/Reachability.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module Tree.Transformation.Reachability (allTests) where

import Base
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Tree.Transformation as Tree
import Tree.Eval.Positive qualified as Eval
import Tree.Transformation.Base

data ReachabilityTest = ReachabilityTest
{ _reachabilityTestReachable :: [Text],
_reachabilityTestEval :: Eval.PosTest
}

allTests :: TestTree
allTests =
testGroup "Reachability" $
map liftTest rtests

rtests :: [ReachabilityTest]
rtests =
[ ReachabilityTest
{ _reachabilityTestReachable = ["f", "f'", "g'", "h", "h'", "main"],
_reachabilityTestEval =
Eval.PosTest
"Test001: Reachability"
$(mkRelDir "reachability")
$(mkRelFile "test001.jvt")
$(mkRelFile "out/test001.out")
},
ReachabilityTest
{ _reachabilityTestReachable = ["f", "g", "id", "sum", "main"],
_reachabilityTestEval =
Eval.PosTest
"Test002: Reachability with loops & closures"
$(mkRelDir "reachability")
$(mkRelFile "test002.jvt")
$(mkRelFile "out/test002.out")
}
]

liftTest :: ReachabilityTest -> TestTree
liftTest ReachabilityTest {..} =
fromTest
Test
{ _testTransformations = [Tree.FilterUnreachable],
_testAssertion = \tab -> unless (nubSort (map (^. functionName) (HashMap.elems (tab ^. infoFunctions))) == nubSort _reachabilityTestReachable) (error "check reachable"),
_testEval = _reachabilityTestEval
}
1 change: 1 addition & 0 deletions tests/Tree/positive/reachability/out/test001.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9
1 change: 1 addition & 0 deletions tests/Tree/positive/reachability/out/test002.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5051
36 changes: 36 additions & 0 deletions tests/Tree/positive/reachability/test001.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

function h(integer) : integer;
function h'(integer) : integer;
function f(integer) : integer;
function f'(integer) : integer;
function g(integer) : integer;
function g'(integer) : integer;
function main() : integer;

function h(integer) : integer {
arg[0]
}

function h'(integer) : integer {
arg[0]
}

function f(integer) : integer {
add(call[h](arg[0]), 1)
}

function f'(integer) : integer {
add(call[h'](arg[0]), 1)
}

function g(integer) : integer {
add(call[f](arg[0]), 2)
}

function g'(integer) : integer {
call[f'](arg[0])
}

function main() : integer {
call[g'](call[f](7))
}
39 changes: 39 additions & 0 deletions tests/Tree/positive/reachability/test002.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

function f(*, integer) : integer;
function id(integer) : integer;
function g(integer) : integer;
function sum(integer) : integer;
function g'(integer) : integer;
function g''(integer) : integer;
function main() : integer;

function f(*, integer) : integer {
call(arg[0], arg[1])
}

function id(integer) : integer {
arg[0]
}

function g(integer) : integer {
add(call[f](calloc[id](), arg[0]), 1)
}

function sum(integer) : integer {
br(eq(0, arg[0])) {
true: call[g](0)
false: add(arg[0], call[sum](sub(arg[0], 1)))
}
}

function g'(integer) : integer {
add(call[id](arg[0]), 2)
}

function g''(integer) : integer {
call[sum](arg[0])
}

function main() : integer {
call[sum](100)
}

0 comments on commit 4fe9538

Please sign in to comment.