Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter out unreachable functions in JuvixTree #2597

Merged
merged 1 commit into from
Jan 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Asm/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ toReg' = validate >=> filterUnreachable >=> computeStackUsage >=> computePreallo
-- | Perform transformations on JuvixAsm necessary before the translation to
-- Nockma
toNockma' :: (Members '[Error AsmError, Reader Options] r) => InfoTable -> Sem r InfoTable
toNockma' = validate >=> filterUnreachable
toNockma' = validate

toReg :: (Members '[Error JuvixError, Reader EntryPoint] r) => InfoTable -> Sem r InfoTable
toReg = mapReader fromEntryPoint . mapError (JuvixError @AsmError) . toReg'
Expand Down
35 changes: 35 additions & 0 deletions src/Juvix/Compiler/Tree/Data/CallGraph.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
module Juvix.Compiler.Tree.Data.CallGraph where

import Data.HashSet qualified as HashSet
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Compiler.Tree.Extra.Recursors

-- | Call graph type
type CallGraph = DependencyInfo Symbol

-- | Compute the call graph
createCallGraph :: InfoTable -> CallGraph
createCallGraph tab = createDependencyInfo (createCallGraphMap tab) startVertices
where
startVertices :: HashSet Symbol
startVertices = HashSet.fromList syms

syms :: [Symbol]
syms = maybe [] singleton (tab ^. infoMainFunction)

createCallGraphMap :: InfoTable -> HashMap Symbol (HashSet Symbol)
createCallGraphMap tab = fmap (getFunSymbols . (^. functionCode)) (tab ^. infoFunctions)

getFunSymbols :: Node -> HashSet Symbol
getFunSymbols = gather go mempty
where
go :: HashSet Symbol -> Node -> HashSet Symbol
go syms = \case
AllocClosure NodeAllocClosure {..} -> HashSet.insert _nodeAllocClosureFunSymbol syms
Call NodeCall {..} -> goCallType syms _nodeCallType
_ -> syms

goCallType :: HashSet Symbol -> CallType -> HashSet Symbol
goCallType syms = \case
CallFun sym -> HashSet.insert sym syms
CallClosure {} -> syms
4 changes: 3 additions & 1 deletion src/Juvix/Compiler/Tree/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ data TransformationId
| IdentityD
| Apply
| TempHeight
| FilterUnreachable
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -20,7 +21,7 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toNockmaTransformations :: [TransformationId]
toNockmaTransformations = [Apply, TempHeight]
toNockmaTransformations = [Apply, FilterUnreachable, TempHeight]

toAsmTransformations :: [TransformationId]
toAsmTransformations = []
Expand All @@ -33,6 +34,7 @@ instance TransformationId' TransformationId where
IdentityD -> strIdentityD
Apply -> strApply
TempHeight -> strTempHeight
FilterUnreachable -> strFilterUnreachable

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@ strApply = "apply"

strTempHeight :: Text
strTempHeight = "temp-height"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"
2 changes: 2 additions & 0 deletions src/Juvix/Compiler/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ where
import Juvix.Compiler.Tree.Data.TransformationId
import Juvix.Compiler.Tree.Transformation.Apply
import Juvix.Compiler.Tree.Transformation.Base
import Juvix.Compiler.Tree.Transformation.FilterUnreachable
import Juvix.Compiler.Tree.Transformation.Identity
import Juvix.Compiler.Tree.Transformation.TempHeight

Expand All @@ -21,3 +22,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
IdentityD -> return . identityD
Apply -> return . computeApply
TempHeight -> return . computeTempHeight
FilterUnreachable -> return . filterUnreachable
12 changes: 12 additions & 0 deletions src/Juvix/Compiler/Tree/Transformation/FilterUnreachable.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module Juvix.Compiler.Tree.Transformation.FilterUnreachable where

import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Tree.Data.CallGraph
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Prelude

filterUnreachable :: InfoTable -> InfoTable
filterUnreachable tab =
over infoFunctions (HashMap.filterWithKey (const . isReachable graph)) tab
where
graph = createCallGraph tab
4 changes: 3 additions & 1 deletion test/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ module Tree.Transformation where
import Base
import Tree.Transformation.Apply qualified as Apply
import Tree.Transformation.Identity qualified as Identity
import Tree.Transformation.Reachability qualified as Reachability

allTests :: TestTree
allTests =
testGroup
"JuvixTree transformations"
[ Identity.allTests,
Apply.allTests
Apply.allTests,
Reachability.allTests
]
48 changes: 48 additions & 0 deletions test/Tree/Transformation/Reachability.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module Tree.Transformation.Reachability (allTests) where

import Base
import Data.HashMap.Strict qualified as HashMap
import Juvix.Compiler.Tree.Transformation as Tree
import Tree.Eval.Positive qualified as Eval
import Tree.Transformation.Base

data ReachabilityTest = ReachabilityTest
{ _reachabilityTestReachable :: [Text],
_reachabilityTestEval :: Eval.PosTest
}

allTests :: TestTree
allTests =
testGroup "Reachability" $
map liftTest rtests

rtests :: [ReachabilityTest]
rtests =
[ ReachabilityTest
{ _reachabilityTestReachable = ["f", "f'", "g'", "h", "h'", "main"],
_reachabilityTestEval =
Eval.PosTest
"Test001: Reachability"
$(mkRelDir "reachability")
$(mkRelFile "test001.jvt")
$(mkRelFile "out/test001.out")
},
ReachabilityTest
{ _reachabilityTestReachable = ["f", "g", "id", "sum", "main"],
_reachabilityTestEval =
Eval.PosTest
"Test002: Reachability with loops & closures"
$(mkRelDir "reachability")
$(mkRelFile "test002.jvt")
$(mkRelFile "out/test002.out")
}
]

liftTest :: ReachabilityTest -> TestTree
liftTest ReachabilityTest {..} =
fromTest
Test
{ _testTransformations = [Tree.FilterUnreachable],
_testAssertion = \tab -> unless (nubSort (map (^. functionName) (HashMap.elems (tab ^. infoFunctions))) == nubSort _reachabilityTestReachable) (error "check reachable"),
_testEval = _reachabilityTestEval
}
1 change: 1 addition & 0 deletions tests/Tree/positive/reachability/out/test001.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
9
1 change: 1 addition & 0 deletions tests/Tree/positive/reachability/out/test002.out
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
5051
36 changes: 36 additions & 0 deletions tests/Tree/positive/reachability/test001.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

function h(integer) : integer;
function h'(integer) : integer;
function f(integer) : integer;
function f'(integer) : integer;
function g(integer) : integer;
function g'(integer) : integer;
function main() : integer;

function h(integer) : integer {
arg[0]
}

function h'(integer) : integer {
arg[0]
}

function f(integer) : integer {
add(call[h](arg[0]), 1)
}

function f'(integer) : integer {
add(call[h'](arg[0]), 1)
}

function g(integer) : integer {
add(call[f](arg[0]), 2)
}

function g'(integer) : integer {
call[f'](arg[0])
}

function main() : integer {
call[g'](call[f](7))
}
39 changes: 39 additions & 0 deletions tests/Tree/positive/reachability/test002.jvt
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@

function f(*, integer) : integer;
function id(integer) : integer;
function g(integer) : integer;
function sum(integer) : integer;
function g'(integer) : integer;
function g''(integer) : integer;
function main() : integer;

function f(*, integer) : integer {
call(arg[0], arg[1])
}

function id(integer) : integer {
arg[0]
}

function g(integer) : integer {
add(call[f](calloc[id](), arg[0]), 1)
}

function sum(integer) : integer {
br(eq(0, arg[0])) {
true: call[g](0)
false: add(arg[0], call[sum](sub(arg[0], 1)))
}
}

function g'(integer) : integer {
add(call[id](arg[0]), 2)
}

function g''(integer) : integer {
call[sum](arg[0])
}

function main() : integer {
call[sum](100)
}
Loading