Skip to content

Commit

Permalink
JuvixTree validation (#2616)
Browse files Browse the repository at this point in the history
* Validation (type checking) of JuvixTree. Similar to JuvixAsm
validation, will help with debugging.
* Depends on #2608
  • Loading branch information
lukaszcz authored Feb 6, 2024
1 parent 10e2a23 commit 795212b
Show file tree
Hide file tree
Showing 14 changed files with 412 additions and 112 deletions.
11 changes: 7 additions & 4 deletions app/Commands/Dev/Tree/Read.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ runCommand opts = do
case Tree.runParser (toFilePath afile) s of
Left err -> exitJuvixError (JuvixError err)
Right tab -> do
tab' <- Tree.applyTransformations (project opts ^. treeReadTransformations) tab
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
r <- runError @JuvixError (Tree.applyTransformations (project opts ^. treeReadTransformations) tab)
case r of
Left err -> exitJuvixError (JuvixError err)
Right tab' -> do
unless (project opts ^. treeReadNoPrint) $
renderStdOut (Tree.ppOutDefault tab' tab')
doEval tab'
where
file :: AppPath File
file = opts ^. treeReadInputFile
Expand Down
6 changes: 3 additions & 3 deletions src/Juvix/Compiler/Asm/Extra/Memory.hs
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ unifyMemory' loc tab mem1 mem2 = do
unless (length (mem1 ^. memoryValueStack) == length (mem2 ^. memoryValueStack)) $
throw $
AsmError loc "value stack height mismatch"
vs <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
vs <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack))
unless (length (mem1 ^. memoryTempStack) == length (mem2 ^. memoryTempStack)) $
throw $
AsmError loc "temporary stack height mismatch"
ts <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
ts <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack))
unless
( length (mem1 ^. memoryArgumentArea) == length (mem2 ^. memoryArgumentArea)
&& mem1 ^. memoryArgsNum == mem2 ^. memoryArgsNum
Expand All @@ -183,7 +183,7 @@ unifyMemory' loc tab mem1 mem2 = do
args <-
mapM
( \off ->
unifyTypes'
unifyTypes''
loc
tab
(fromJust $ HashMap.lookup off (mem1 ^. memoryArgumentArea))
Expand Down
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Asm/Extra/Recursors.hs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) tyargs mem
tys <-
zipWithM
(\ty idx -> unifyTypes' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
(\ty idx -> unifyTypes'' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem))
tyargs
[0 ..]
return $
Expand Down Expand Up @@ -226,7 +226,7 @@ recurse' sig = go True
checkValueStack' loc (sig ^. recursorInfoTable) (take argsNum (typeArgs ty)) mem'
let tyargs = topValuesFromValueStack' argsNum mem'
-- `typeArgs ty` may be shorter than `tyargs` only if `ty` is dynamic
zipWithM_ (unifyTypes' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
zipWithM_ (unifyTypes'' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty)
return $
pushValueStack (mkTypeFun (drop argsNum (typeArgs ty)) (typeTarget ty)) $
popValueStack argsNum mem'
Expand Down
84 changes: 9 additions & 75 deletions src/Juvix/Compiler/Asm/Extra/Type.hs
Original file line number Diff line number Diff line change
Expand Up @@ -4,84 +4,18 @@ module Juvix.Compiler.Asm.Extra.Type
)
where

import Data.List.NonEmpty qualified as NonEmpty
import Juvix.Compiler.Asm.Data.InfoTable
import Juvix.Compiler.Asm.Error
import Juvix.Compiler.Asm.Language
import Juvix.Compiler.Asm.Pretty
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Extra.Type

unifyTypes :: forall r. (Members '[Error AsmError, Reader (Maybe Location), Reader InfoTable] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM unifyTypes (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM unifyTypes args1 args2
tgt <- unifyTypes tgt1 tgt2
return $ TyFun (TypeFun (NonEmpty.fromList args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
unifyTypes'' :: forall t e r. (Member (Error AsmError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes'' loc tab ty1 ty2 = mapError toAsmError $ unifyTypes' loc tab ty1 ty2
where
err :: Sem r a
err = do
loc <- ask
tab <- ask
throw $ AsmError loc ("not unifiable: " <> ppTrace tab ty1 <> ", " <> ppTrace tab ty2)

unifyTypes' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
toAsmError :: TreeError -> AsmError
toAsmError TreeError {..} =
AsmError
{ _asmErrorLoc = _treeErrorLoc,
_asmErrorMsg = _treeErrorMsg
}
2 changes: 1 addition & 1 deletion src/Juvix/Compiler/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ coreToVampIR' = Core.toStored' >=> storedCoreToVampIR'
-- Other workflows
--------------------------------------------------------------------------------

treeToAsm :: Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm :: (Member (Error JuvixError) r) => Tree.InfoTable -> Sem r Asm.InfoTable
treeToAsm = Tree.toAsm >=> return . Asm.fromTree

treeToNockma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Tree.InfoTable -> Sem r (Nockma.Cell Natural)
Expand Down
6 changes: 4 additions & 2 deletions src/Juvix/Compiler/Tree/Data/TransformationId.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data TransformationId
| Apply
| TempHeight
| FilterUnreachable
| Validate
deriving stock (Data, Bounded, Enum, Show)

data PipelineId
Expand All @@ -21,10 +22,10 @@ data PipelineId
type TransformationLikeId = TransformationLikeId' TransformationId PipelineId

toNockmaTransformations :: [TransformationId]
toNockmaTransformations = [Apply, FilterUnreachable, TempHeight]
toNockmaTransformations = [Validate, Apply, FilterUnreachable, TempHeight]

toAsmTransformations :: [TransformationId]
toAsmTransformations = []
toAsmTransformations = [Validate]

instance TransformationId' TransformationId where
transformationText :: TransformationId -> Text
Expand All @@ -35,6 +36,7 @@ instance TransformationId' TransformationId where
Apply -> strApply
TempHeight -> strTempHeight
FilterUnreachable -> strFilterUnreachable
Validate -> strValidate

instance PipelineId' TransformationId PipelineId where
pipelineText :: PipelineId -> Text
Expand Down
3 changes: 3 additions & 0 deletions src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,6 @@ strTempHeight = "temp-height"

strFilterUnreachable :: Text
strFilterUnreachable = "filter-unreachable"

strValidate :: Text
strValidate = "validate"
85 changes: 84 additions & 1 deletion src/Juvix/Compiler/Tree/Extra/Type.hs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-}

{-# HLINT ignore "Avoid restricted extensions" #-}
{-# HLINT ignore "Avoid restricted flags" #-}

module Juvix.Compiler.Tree.Extra.Type where

import Juvix.Compiler.Tree.Data.InfoTable.Base
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Language.Base
import Juvix.Compiler.Tree.Language.Type
import Juvix.Compiler.Tree.Pretty

mkTypeInteger :: Type
mkTypeInteger = TyInteger (TypeInteger Nothing Nothing)
Expand Down Expand Up @@ -98,3 +106,78 @@ isSubtype' ty1 ty2
tgt2 = typeTarget (uncurryType ty2)
isSubtype' ty1 ty2 =
isSubtype ty1 ty2

unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type
unifyTypes ty1 ty2 = case (ty1, ty2) of
(TyDynamic, x) -> return x
(x, TyDynamic) -> return x
(TyInductive TypeInductive {..}, TyConstr TypeConstr {..})
| _typeInductiveSymbol == _typeConstrInductive ->
return ty1
(TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive
&& c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do
flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields)
return $ TyConstr (set typeConstrFields flds c1)
(TyConstr c1, TyConstr c2)
| c1 ^. typeConstrInductive == c2 ^. typeConstrInductive ->
return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive))
(TyFun t1, TyFun t2)
| length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do
let args1 = toList (t1 ^. typeFunArgs)
args2 = toList (t2 ^. typeFunArgs)
tgt1 = t1 ^. typeFunTarget
tgt2 = t2 ^. typeFunTarget
args <- zipWithM (unifyTypes @t @e) args1 args2
tgt <- unifyTypes @t @e tgt1 tgt2
return $ TyFun (TypeFun (nonEmpty' args) tgt)
(TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) ->
return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2))
where
unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer
unifyBounds _ Nothing _ = Nothing
unifyBounds _ _ Nothing = Nothing
unifyBounds f (Just x) (Just y) = Just (f x y)
(TyBool {}, TyBool {})
| ty1 == ty2 -> return ty1
(TyString, TyString) -> return TyString
(TyUnit, TyUnit) -> return TyUnit
(TyVoid, TyVoid) -> return TyVoid
(TyInductive {}, TyInductive {})
| ty1 == ty2 -> return ty1
(TyUnit, _) -> err
(_, TyUnit) -> err
(TyVoid, _) -> err
(_, TyVoid) -> err
(TyInteger {}, _) -> err
(_, TyInteger {}) -> err
(TyString, _) -> err
(_, TyString) -> err
(TyBool {}, _) -> err
(_, TyBool {}) -> err
(TyFun {}, _) -> err
(_, TyFun {}) -> err
(TyInductive {}, _) -> err
(_, TyConstr {}) -> err
where
err :: Sem r a
err = do
loc <- ask
tab <- ask @(InfoTable' t e)
throw $ TreeError loc ("not unifiable: " <> ppTrace' (defaultOptions tab) ty1 <> ", " <> ppTrace' (defaultOptions tab) ty2)

unifyTypes' :: forall t e r. (Member (Error TreeError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type
unifyTypes' loc tab ty1 ty2 =
runReader loc $
runReader tab $
-- The `if` is to ensure correct behaviour with dynamic type targets. E.g.
-- `(A, B) -> *` should unify with `A -> B -> C -> D`.
if
| tgt1 == TyDynamic || tgt2 == TyDynamic ->
unifyTypes @t @e (curryType ty1) (curryType ty2)
| otherwise ->
unifyTypes @t @e ty1 ty2
where
tgt1 = typeTarget (uncurryType ty1)
tgt2 = typeTarget (uncurryType ty2)
4 changes: 2 additions & 2 deletions src/Juvix/Compiler/Tree/Pipeline.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ where
import Juvix.Compiler.Tree.Data.InfoTable
import Juvix.Compiler.Tree.Transformation

toNockma :: InfoTable -> Sem r InfoTable
toNockma :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toNockma = applyTransformations toNockmaTransformations

toAsm :: InfoTable -> Sem r InfoTable
toAsm :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable
toAsm = applyTransformations toAsmTransformations
5 changes: 4 additions & 1 deletion src/Juvix/Compiler/Tree/Transformation.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ module Juvix.Compiler.Tree.Transformation
where

import Juvix.Compiler.Tree.Data.TransformationId
import Juvix.Compiler.Tree.Error
import Juvix.Compiler.Tree.Transformation.Apply
import Juvix.Compiler.Tree.Transformation.Base
import Juvix.Compiler.Tree.Transformation.FilterUnreachable
import Juvix.Compiler.Tree.Transformation.Identity
import Juvix.Compiler.Tree.Transformation.TempHeight
import Juvix.Compiler.Tree.Transformation.Validate

applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations :: forall r. (Member (Error JuvixError) r) => [TransformationId] -> InfoTable -> Sem r InfoTable
applyTransformations ts tbl = foldM (flip appTrans) tbl ts
where
appTrans :: TransformationId -> InfoTable -> Sem r InfoTable
Expand All @@ -23,3 +25,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts
Apply -> return . computeApply
TempHeight -> return . computeTempHeight
FilterUnreachable -> return . filterUnreachable
Validate -> mapError (JuvixError @TreeError) . validate
Loading

0 comments on commit 795212b

Please sign in to comment.