diff --git a/app/Commands/Dev/Tree/Read.hs b/app/Commands/Dev/Tree/Read.hs index eddbe6b02d..1d47a13cba 100644 --- a/app/Commands/Dev/Tree/Read.hs +++ b/app/Commands/Dev/Tree/Read.hs @@ -15,10 +15,13 @@ runCommand opts = do case Tree.runParser (toFilePath afile) s of Left err -> exitJuvixError (JuvixError err) Right tab -> do - tab' <- Tree.applyTransformations (project opts ^. treeReadTransformations) tab - unless (project opts ^. treeReadNoPrint) $ - renderStdOut (Tree.ppOutDefault tab' tab') - doEval tab' + r <- runError @JuvixError (Tree.applyTransformations (project opts ^. treeReadTransformations) tab) + case r of + Left err -> exitJuvixError (JuvixError err) + Right tab' -> do + unless (project opts ^. treeReadNoPrint) $ + renderStdOut (Tree.ppOutDefault tab' tab') + doEval tab' where file :: AppPath File file = opts ^. treeReadInputFile diff --git a/src/Juvix/Compiler/Asm/Extra/Memory.hs b/src/Juvix/Compiler/Asm/Extra/Memory.hs index 6e08f8319d..96aab847dd 100644 --- a/src/Juvix/Compiler/Asm/Extra/Memory.hs +++ b/src/Juvix/Compiler/Asm/Extra/Memory.hs @@ -168,11 +168,11 @@ unifyMemory' loc tab mem1 mem2 = do unless (length (mem1 ^. memoryValueStack) == length (mem2 ^. memoryValueStack)) $ throw $ AsmError loc "value stack height mismatch" - vs <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack)) + vs <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryValueStack)) (toList (mem2 ^. memoryValueStack)) unless (length (mem1 ^. memoryTempStack) == length (mem2 ^. memoryTempStack)) $ throw $ AsmError loc "temporary stack height mismatch" - ts <- zipWithM (unifyTypes' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack)) + ts <- zipWithM (unifyTypes'' loc tab) (toList (mem1 ^. memoryTempStack)) (toList (mem2 ^. memoryTempStack)) unless ( length (mem1 ^. memoryArgumentArea) == length (mem2 ^. memoryArgumentArea) && mem1 ^. memoryArgsNum == mem2 ^. memoryArgsNum @@ -183,7 +183,7 @@ unifyMemory' loc tab mem1 mem2 = do args <- mapM ( \off -> - unifyTypes' + unifyTypes'' loc tab (fromJust $ HashMap.lookup off (mem1 ^. memoryArgumentArea)) diff --git a/src/Juvix/Compiler/Asm/Extra/Recursors.hs b/src/Juvix/Compiler/Asm/Extra/Recursors.hs index 3bf2550030..6f2ce88ddc 100644 --- a/src/Juvix/Compiler/Asm/Extra/Recursors.hs +++ b/src/Juvix/Compiler/Asm/Extra/Recursors.hs @@ -130,7 +130,7 @@ recurse' sig = go True checkValueStack' loc (sig ^. recursorInfoTable) tyargs mem tys <- zipWithM - (\ty idx -> unifyTypes' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem)) + (\ty idx -> unifyTypes'' loc (sig ^. recursorInfoTable) ty (topValueStack' idx mem)) tyargs [0 ..] return $ @@ -226,7 +226,7 @@ recurse' sig = go True checkValueStack' loc (sig ^. recursorInfoTable) (take argsNum (typeArgs ty)) mem' let tyargs = topValuesFromValueStack' argsNum mem' -- `typeArgs ty` may be shorter than `tyargs` only if `ty` is dynamic - zipWithM_ (unifyTypes' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty) + zipWithM_ (unifyTypes'' loc (sig ^. recursorInfoTable)) tyargs (typeArgs ty) return $ pushValueStack (mkTypeFun (drop argsNum (typeArgs ty)) (typeTarget ty)) $ popValueStack argsNum mem' diff --git a/src/Juvix/Compiler/Asm/Extra/Type.hs b/src/Juvix/Compiler/Asm/Extra/Type.hs index c8e040296c..26b94359c5 100644 --- a/src/Juvix/Compiler/Asm/Extra/Type.hs +++ b/src/Juvix/Compiler/Asm/Extra/Type.hs @@ -4,84 +4,18 @@ module Juvix.Compiler.Asm.Extra.Type ) where -import Data.List.NonEmpty qualified as NonEmpty import Juvix.Compiler.Asm.Data.InfoTable import Juvix.Compiler.Asm.Error import Juvix.Compiler.Asm.Language -import Juvix.Compiler.Asm.Pretty +import Juvix.Compiler.Tree.Error import Juvix.Compiler.Tree.Extra.Type -unifyTypes :: forall r. (Members '[Error AsmError, Reader (Maybe Location), Reader InfoTable] r) => Type -> Type -> Sem r Type -unifyTypes ty1 ty2 = case (ty1, ty2) of - (TyDynamic, x) -> return x - (x, TyDynamic) -> return x - (TyInductive TypeInductive {..}, TyConstr TypeConstr {..}) - | _typeInductiveSymbol == _typeConstrInductive -> - return ty1 - (TyConstr {}, TyInductive {}) -> unifyTypes ty2 ty1 - (TyConstr c1, TyConstr c2) - | c1 ^. typeConstrInductive == c2 ^. typeConstrInductive - && c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do - flds <- zipWithM unifyTypes (c1 ^. typeConstrFields) (c2 ^. typeConstrFields) - return $ TyConstr (set typeConstrFields flds c1) - (TyConstr c1, TyConstr c2) - | c1 ^. typeConstrInductive == c2 ^. typeConstrInductive -> - return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive)) - (TyFun t1, TyFun t2) - | length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do - let args1 = toList (t1 ^. typeFunArgs) - args2 = toList (t2 ^. typeFunArgs) - tgt1 = t1 ^. typeFunTarget - tgt2 = t2 ^. typeFunTarget - args <- zipWithM unifyTypes args1 args2 - tgt <- unifyTypes tgt1 tgt2 - return $ TyFun (TypeFun (NonEmpty.fromList args) tgt) - (TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) -> - return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2)) - where - unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer - unifyBounds _ Nothing _ = Nothing - unifyBounds _ _ Nothing = Nothing - unifyBounds f (Just x) (Just y) = Just (f x y) - (TyBool {}, TyBool {}) - | ty1 == ty2 -> return ty1 - (TyString, TyString) -> return TyString - (TyUnit, TyUnit) -> return TyUnit - (TyVoid, TyVoid) -> return TyVoid - (TyInductive {}, TyInductive {}) - | ty1 == ty2 -> return ty1 - (TyUnit, _) -> err - (_, TyUnit) -> err - (TyVoid, _) -> err - (_, TyVoid) -> err - (TyInteger {}, _) -> err - (_, TyInteger {}) -> err - (TyString, _) -> err - (_, TyString) -> err - (TyBool {}, _) -> err - (_, TyBool {}) -> err - (TyFun {}, _) -> err - (_, TyFun {}) -> err - (TyInductive {}, _) -> err - (_, TyConstr {}) -> err +unifyTypes'' :: forall t e r. (Member (Error AsmError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type +unifyTypes'' loc tab ty1 ty2 = mapError toAsmError $ unifyTypes' loc tab ty1 ty2 where - err :: Sem r a - err = do - loc <- ask - tab <- ask - throw $ AsmError loc ("not unifiable: " <> ppTrace tab ty1 <> ", " <> ppTrace tab ty2) - -unifyTypes' :: (Member (Error AsmError) r) => Maybe Location -> InfoTable -> Type -> Type -> Sem r Type -unifyTypes' loc tab ty1 ty2 = - runReader loc $ - runReader tab $ - -- The `if` is to ensure correct behaviour with dynamic type targets. E.g. - -- `(A, B) -> *` should unify with `A -> B -> C -> D`. - if - | tgt1 == TyDynamic || tgt2 == TyDynamic -> - unifyTypes (curryType ty1) (curryType ty2) - | otherwise -> - unifyTypes ty1 ty2 - where - tgt1 = typeTarget (uncurryType ty1) - tgt2 = typeTarget (uncurryType ty2) + toAsmError :: TreeError -> AsmError + toAsmError TreeError {..} = + AsmError + { _asmErrorLoc = _treeErrorLoc, + _asmErrorMsg = _treeErrorMsg + } diff --git a/src/Juvix/Compiler/Pipeline.hs b/src/Juvix/Compiler/Pipeline.hs index 28e2ec550f..bdba8ab49b 100644 --- a/src/Juvix/Compiler/Pipeline.hs +++ b/src/Juvix/Compiler/Pipeline.hs @@ -176,7 +176,7 @@ coreToVampIR' = Core.toStored' >=> storedCoreToVampIR' -- Other workflows -------------------------------------------------------------------------------- -treeToAsm :: Tree.InfoTable -> Sem r Asm.InfoTable +treeToAsm :: (Member (Error JuvixError) r) => Tree.InfoTable -> Sem r Asm.InfoTable treeToAsm = Tree.toAsm >=> return . Asm.fromTree treeToNockma :: (Members '[Error JuvixError, Reader EntryPoint] r) => Tree.InfoTable -> Sem r (Nockma.Cell Natural) diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId.hs b/src/Juvix/Compiler/Tree/Data/TransformationId.hs index e001e0535a..6cd45861d5 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId.hs @@ -11,6 +11,7 @@ data TransformationId | Apply | TempHeight | FilterUnreachable + | Validate deriving stock (Data, Bounded, Enum, Show) data PipelineId @@ -21,10 +22,10 @@ data PipelineId type TransformationLikeId = TransformationLikeId' TransformationId PipelineId toNockmaTransformations :: [TransformationId] -toNockmaTransformations = [Apply, FilterUnreachable, TempHeight] +toNockmaTransformations = [Validate, Apply, FilterUnreachable, TempHeight] toAsmTransformations :: [TransformationId] -toAsmTransformations = [] +toAsmTransformations = [Validate] instance TransformationId' TransformationId where transformationText :: TransformationId -> Text @@ -35,6 +36,7 @@ instance TransformationId' TransformationId where Apply -> strApply TempHeight -> strTempHeight FilterUnreachable -> strFilterUnreachable + Validate -> strValidate instance PipelineId' TransformationId PipelineId where pipelineText :: PipelineId -> Text diff --git a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs index 7e454ccbdc..41ed6a3974 100644 --- a/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs +++ b/src/Juvix/Compiler/Tree/Data/TransformationId/Strings.hs @@ -25,3 +25,6 @@ strTempHeight = "temp-height" strFilterUnreachable :: Text strFilterUnreachable = "filter-unreachable" + +strValidate :: Text +strValidate = "validate" diff --git a/src/Juvix/Compiler/Tree/Extra/Type.hs b/src/Juvix/Compiler/Tree/Extra/Type.hs index 9e25042770..f3382f12d8 100644 --- a/src/Juvix/Compiler/Tree/Extra/Type.hs +++ b/src/Juvix/Compiler/Tree/Extra/Type.hs @@ -1,7 +1,15 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# OPTIONS_GHC -Wno-unrecognised-pragmas #-} + +{-# HLINT ignore "Avoid restricted extensions" #-} +{-# HLINT ignore "Avoid restricted flags" #-} + module Juvix.Compiler.Tree.Extra.Type where +import Juvix.Compiler.Tree.Data.InfoTable.Base +import Juvix.Compiler.Tree.Error import Juvix.Compiler.Tree.Language.Base -import Juvix.Compiler.Tree.Language.Type +import Juvix.Compiler.Tree.Pretty mkTypeInteger :: Type mkTypeInteger = TyInteger (TypeInteger Nothing Nothing) @@ -98,3 +106,78 @@ isSubtype' ty1 ty2 tgt2 = typeTarget (uncurryType ty2) isSubtype' ty1 ty2 = isSubtype ty1 ty2 + +unifyTypes :: forall t e r. (Members '[Error TreeError, Reader (Maybe Location), Reader (InfoTable' t e)] r) => Type -> Type -> Sem r Type +unifyTypes ty1 ty2 = case (ty1, ty2) of + (TyDynamic, x) -> return x + (x, TyDynamic) -> return x + (TyInductive TypeInductive {..}, TyConstr TypeConstr {..}) + | _typeInductiveSymbol == _typeConstrInductive -> + return ty1 + (TyConstr {}, TyInductive {}) -> unifyTypes @t @e ty2 ty1 + (TyConstr c1, TyConstr c2) + | c1 ^. typeConstrInductive == c2 ^. typeConstrInductive + && c1 ^. typeConstrTag == c2 ^. typeConstrTag -> do + flds <- zipWithM (unifyTypes @t @e) (c1 ^. typeConstrFields) (c2 ^. typeConstrFields) + return $ TyConstr (set typeConstrFields flds c1) + (TyConstr c1, TyConstr c2) + | c1 ^. typeConstrInductive == c2 ^. typeConstrInductive -> + return $ TyInductive (TypeInductive (c1 ^. typeConstrInductive)) + (TyFun t1, TyFun t2) + | length (t1 ^. typeFunArgs) == length (t2 ^. typeFunArgs) -> do + let args1 = toList (t1 ^. typeFunArgs) + args2 = toList (t2 ^. typeFunArgs) + tgt1 = t1 ^. typeFunTarget + tgt2 = t2 ^. typeFunTarget + args <- zipWithM (unifyTypes @t @e) args1 args2 + tgt <- unifyTypes @t @e tgt1 tgt2 + return $ TyFun (TypeFun (nonEmpty' args) tgt) + (TyInteger (TypeInteger l1 u1), TyInteger (TypeInteger l2 u2)) -> + return $ TyInteger (TypeInteger (unifyBounds min l1 l2) (unifyBounds max u1 u2)) + where + unifyBounds :: (Integer -> Integer -> Integer) -> Maybe Integer -> Maybe Integer -> Maybe Integer + unifyBounds _ Nothing _ = Nothing + unifyBounds _ _ Nothing = Nothing + unifyBounds f (Just x) (Just y) = Just (f x y) + (TyBool {}, TyBool {}) + | ty1 == ty2 -> return ty1 + (TyString, TyString) -> return TyString + (TyUnit, TyUnit) -> return TyUnit + (TyVoid, TyVoid) -> return TyVoid + (TyInductive {}, TyInductive {}) + | ty1 == ty2 -> return ty1 + (TyUnit, _) -> err + (_, TyUnit) -> err + (TyVoid, _) -> err + (_, TyVoid) -> err + (TyInteger {}, _) -> err + (_, TyInteger {}) -> err + (TyString, _) -> err + (_, TyString) -> err + (TyBool {}, _) -> err + (_, TyBool {}) -> err + (TyFun {}, _) -> err + (_, TyFun {}) -> err + (TyInductive {}, _) -> err + (_, TyConstr {}) -> err + where + err :: Sem r a + err = do + loc <- ask + tab <- ask @(InfoTable' t e) + throw $ TreeError loc ("not unifiable: " <> ppTrace' (defaultOptions tab) ty1 <> ", " <> ppTrace' (defaultOptions tab) ty2) + +unifyTypes' :: forall t e r. (Member (Error TreeError) r) => Maybe Location -> InfoTable' t e -> Type -> Type -> Sem r Type +unifyTypes' loc tab ty1 ty2 = + runReader loc $ + runReader tab $ + -- The `if` is to ensure correct behaviour with dynamic type targets. E.g. + -- `(A, B) -> *` should unify with `A -> B -> C -> D`. + if + | tgt1 == TyDynamic || tgt2 == TyDynamic -> + unifyTypes @t @e (curryType ty1) (curryType ty2) + | otherwise -> + unifyTypes @t @e ty1 ty2 + where + tgt1 = typeTarget (uncurryType ty1) + tgt2 = typeTarget (uncurryType ty2) diff --git a/src/Juvix/Compiler/Tree/Pipeline.hs b/src/Juvix/Compiler/Tree/Pipeline.hs index 001b92bc50..96ce8becac 100644 --- a/src/Juvix/Compiler/Tree/Pipeline.hs +++ b/src/Juvix/Compiler/Tree/Pipeline.hs @@ -7,8 +7,8 @@ where import Juvix.Compiler.Tree.Data.InfoTable import Juvix.Compiler.Tree.Transformation -toNockma :: InfoTable -> Sem r InfoTable +toNockma :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable toNockma = applyTransformations toNockmaTransformations -toAsm :: InfoTable -> Sem r InfoTable +toAsm :: (Member (Error JuvixError) r) => InfoTable -> Sem r InfoTable toAsm = applyTransformations toAsmTransformations diff --git a/src/Juvix/Compiler/Tree/Transformation.hs b/src/Juvix/Compiler/Tree/Transformation.hs index c16456c838..3f2e078e3d 100644 --- a/src/Juvix/Compiler/Tree/Transformation.hs +++ b/src/Juvix/Compiler/Tree/Transformation.hs @@ -6,13 +6,15 @@ module Juvix.Compiler.Tree.Transformation where import Juvix.Compiler.Tree.Data.TransformationId +import Juvix.Compiler.Tree.Error import Juvix.Compiler.Tree.Transformation.Apply import Juvix.Compiler.Tree.Transformation.Base import Juvix.Compiler.Tree.Transformation.FilterUnreachable import Juvix.Compiler.Tree.Transformation.Identity import Juvix.Compiler.Tree.Transformation.TempHeight +import Juvix.Compiler.Tree.Transformation.Validate -applyTransformations :: forall r. [TransformationId] -> InfoTable -> Sem r InfoTable +applyTransformations :: forall r. (Member (Error JuvixError) r) => [TransformationId] -> InfoTable -> Sem r InfoTable applyTransformations ts tbl = foldM (flip appTrans) tbl ts where appTrans :: TransformationId -> InfoTable -> Sem r InfoTable @@ -23,3 +25,4 @@ applyTransformations ts tbl = foldM (flip appTrans) tbl ts Apply -> return . computeApply TempHeight -> return . computeTempHeight FilterUnreachable -> return . filterUnreachable + Validate -> mapError (JuvixError @TreeError) . validate diff --git a/src/Juvix/Compiler/Tree/Transformation/Validate.hs b/src/Juvix/Compiler/Tree/Transformation/Validate.hs new file mode 100644 index 0000000000..0fc486768d --- /dev/null +++ b/src/Juvix/Compiler/Tree/Transformation/Validate.hs @@ -0,0 +1,266 @@ +module Juvix.Compiler.Tree.Transformation.Validate where + +import Juvix.Compiler.Core.Data.BinderList qualified as BL +import Juvix.Compiler.Tree.Error +import Juvix.Compiler.Tree.Extra.Base (getNodeLocation) +import Juvix.Compiler.Tree.Extra.Recursors +import Juvix.Compiler.Tree.Extra.Type +import Juvix.Compiler.Tree.Transformation.Base + +inferType :: forall r. (Member (Error TreeError) r) => InfoTable -> FunctionInfo -> Node -> Sem r Type +inferType tab funInfo = goInfer mempty + where + goInfer :: BinderList Type -> Node -> Sem r Type + goInfer bl = \case + Binop x -> goBinop bl x + Unop x -> goUnop bl x + Const x -> goConst bl x + MemRef x -> goMemRef bl x + AllocConstr x -> goAllocConstr bl x + AllocClosure x -> goAllocClosure bl x + ExtendClosure x -> goExtendClosure bl x + Call x -> goCall bl x + CallClosures x -> goCallClosures bl x + Branch x -> goBranch bl x + Case x -> goCase bl x + Save x -> goSave bl x + + goBinop :: BinderList Type -> NodeBinop -> Sem r Type + goBinop bl NodeBinop {..} = case _nodeBinopOpcode of + IntAdd -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger + IntSub -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger + IntMul -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger + IntDiv -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger + IntMod -> checkBinop mkTypeInteger mkTypeInteger mkTypeInteger + IntLt -> checkBinop mkTypeInteger mkTypeInteger mkTypeBool + IntLe -> checkBinop mkTypeInteger mkTypeInteger mkTypeBool + ValEq -> checkBinop TyDynamic TyDynamic mkTypeBool + StrConcat -> checkBinop TyString TyString TyString + OpSeq -> do + checkType bl _nodeBinopArg1 TyDynamic + goInfer bl _nodeBinopArg2 + where + loc = _nodeBinopInfo ^. nodeInfoLocation + + checkBinop :: Type -> Type -> Type -> Sem r Type + checkBinop ty1' ty2' rty = do + ty1 <- goInfer bl _nodeBinopArg1 + ty2 <- goInfer bl _nodeBinopArg2 + void $ unifyTypes' loc tab ty1 ty1' + void $ unifyTypes' loc tab ty2 ty2' + return rty + + goUnop :: BinderList Type -> NodeUnop -> Sem r Type + goUnop bl NodeUnop {..} = case _nodeUnopOpcode of + OpShow -> checkUnop TyDynamic TyString + OpStrToInt -> checkUnop TyString mkTypeInteger + OpTrace -> goInfer bl _nodeUnopArg + OpFail -> checkUnop TyDynamic TyDynamic + OpArgsNum -> checkUnop TyDynamic mkTypeInteger + where + loc = _nodeUnopInfo ^. nodeInfoLocation + + checkUnop :: Type -> Type -> Sem r Type + checkUnop ty rty = do + ty' <- goInfer bl _nodeUnopArg + void $ unifyTypes' loc tab ty ty' + return rty + + goConst :: BinderList Type -> NodeConstant -> Sem r Type + goConst _ NodeConstant {..} = case _nodeConstant of + ConstInt {} -> return mkTypeInteger + ConstBool {} -> return mkTypeBool + ConstString {} -> return TyString + ConstUnit {} -> return TyUnit + ConstVoid {} -> return TyVoid + + goMemRef :: BinderList Type -> NodeMemRef -> Sem r Type + goMemRef bl NodeMemRef {..} = case _nodeMemRef of + DRef d -> goDirectRef (_nodeMemRefInfo ^. nodeInfoLocation) bl d + ConstrRef x -> goField bl x + + goDirectRef :: Maybe Location -> BinderList Type -> DirectRef -> Sem r Type + goDirectRef loc bl = \case + ArgRef x -> goArgRef loc bl x + TempRef RefTemp {..} -> goTempRef bl _refTempOffsetRef + + goArgRef :: Maybe Location -> BinderList Type -> OffsetRef -> Sem r Type + goArgRef loc _ OffsetRef {..} + | _offsetRefOffset < length tys = return $ tys !! _offsetRefOffset + | typeTarget (funInfo ^. functionType) == TyDynamic = return TyDynamic + | otherwise = + throw $ + TreeError + { _treeErrorLoc = loc, + _treeErrorMsg = "Wrong target type" + } + where + tys = typeArgs (funInfo ^. functionType) + + goTempRef :: BinderList Type -> OffsetRef -> Sem r Type + goTempRef bl OffsetRef {..} = return $ BL.lookupLevel _offsetRefOffset bl + + goField :: BinderList Type -> Field -> Sem r Type + goField _ Field {..} + | _fieldOffset < length tys = return $ tys !! _fieldOffset + | otherwise = return TyDynamic + where + ci = lookupConstrInfo tab _fieldTag + tys = typeArgs (ci ^. constructorType) + + goAllocConstr :: BinderList Type -> NodeAllocConstr -> Sem r Type + goAllocConstr bl NodeAllocConstr {..} + | length _nodeAllocConstrArgs == length tys = do + forM_ (zipExact _nodeAllocConstrArgs tys) (uncurry (checkType bl)) + return $ typeTarget (ci ^. constructorType) + | otherwise = + throw $ + TreeError + { _treeErrorLoc = _nodeAllocConstrInfo ^. nodeInfoLocation, + _treeErrorMsg = "" + } + where + ci = lookupConstrInfo tab _nodeAllocConstrTag + tys = typeArgs (ci ^. constructorType) + + goAllocClosure :: BinderList Type -> NodeAllocClosure -> Sem r Type + goAllocClosure bl NodeAllocClosure {..} + | n <= fi ^. functionArgsNum = do + forM_ (zipExact _nodeAllocClosureArgs (take n tys)) (uncurry (checkType bl)) + return $ mkTypeFun (drop n tys) (typeTarget (fi ^. functionType)) + | otherwise = + throw $ + TreeError + { _treeErrorLoc = _nodeAllocClosureInfo ^. nodeInfoLocation, + _treeErrorMsg = "Wrong number of arguments" + } + where + n = length _nodeAllocClosureArgs + fi = lookupFunInfo tab _nodeAllocClosureFunSymbol + tys = typeArgs (fi ^. functionType) + + goExtendClosure :: BinderList Type -> NodeExtendClosure -> Sem r Type + goExtendClosure bl NodeExtendClosure {..} = do + ty <- goInfer bl _nodeExtendClosureFun + let tys = typeArgs ty + m = length tys + n = length _nodeExtendClosureArgs + if + | n < m -> do + forM_ (zipExact (toList _nodeExtendClosureArgs) (take n tys)) (uncurry (checkType bl)) + return $ mkTypeFun (drop n tys) (typeTarget ty) + | typeTarget ty == TyDynamic -> do + let tys' = tys ++ replicate (n - m) TyDynamic + forM_ (zipExact (toList _nodeExtendClosureArgs) tys') (uncurry (checkType bl)) + return $ typeTarget ty + | otherwise -> + throw $ + TreeError + { _treeErrorLoc = _nodeExtendClosureInfo ^. nodeInfoLocation, + _treeErrorMsg = "Too many arguments" + } + + goCall :: BinderList Type -> NodeCall -> Sem r Type + goCall bl NodeCall {..} = case _nodeCallType of + CallFun sym + | n == fi ^. functionArgsNum -> do + unless (n == 0) $ + forM_ (zipExact _nodeCallArgs tys) (uncurry (checkType bl)) + return $ mkTypeFun (drop n tys) (typeTarget (fi ^. functionType)) + | otherwise -> + throw $ + TreeError + { _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation, + _treeErrorMsg = "Wrong number of arguments" + } + where + n = length _nodeCallArgs + fi = lookupFunInfo tab sym + tys = typeArgs (fi ^. functionType) + CallClosure cl -> do + ty <- goInfer bl cl + let tys = typeArgs ty + n = length _nodeCallArgs + when (length tys > n) $ + throw $ + TreeError + { _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation, + _treeErrorMsg = "Too few arguments" + } + when (length tys < n && typeTarget ty /= TyDynamic) $ + throw $ + TreeError + { _treeErrorLoc = _nodeCallInfo ^. nodeInfoLocation, + _treeErrorMsg = "Too many arguments" + } + let tys' = tys ++ replicate (n - length tys) TyDynamic + forM_ (zipExact _nodeCallArgs tys') (uncurry (checkType bl)) + return $ typeTarget ty + + goCallClosures :: BinderList Type -> NodeCallClosures -> Sem r Type + goCallClosures bl NodeCallClosures {..} = do + ty <- goInfer bl _nodeCallClosuresFun + go ty (toList _nodeCallClosuresArgs) + where + go :: Type -> [Node] -> Sem r Type + go ty args + | m == 0 = + return ty + | m <= n = do + forM_ (zipExact (take m args) tys) (uncurry (checkType bl)) + go (typeTarget ty) (drop m args) + | otherwise = do + forM_ (zipExact args (take n tys)) (uncurry (checkType bl)) + return $ mkTypeFun (drop n tys) (typeTarget ty) + where + tys = typeArgs ty + m = length tys + n = length args + + goBranch :: BinderList Type -> NodeBranch -> Sem r Type + goBranch bl NodeBranch {..} = do + checkType bl _nodeBranchArg mkTypeBool + ty1 <- goInfer bl _nodeBranchTrue + ty2 <- goInfer bl _nodeBranchFalse + unifyTypes' (_nodeBranchInfo ^. nodeInfoLocation) tab ty1 ty2 + + goCase :: BinderList Type -> NodeCase -> Sem r Type + goCase bl NodeCase {..} = do + ity <- goInfer bl _nodeCaseArg + unless (ity == mkTypeInductive _nodeCaseInductive || ity == TyDynamic) $ + throw $ + TreeError + { _treeErrorLoc = _nodeCaseInfo ^. nodeInfoLocation, + _treeErrorMsg = "Inductive type mismatch" + } + ty <- maybe (return TyDynamic) (goInfer bl) _nodeCaseDefault + go ity ty _nodeCaseBranches + where + go :: Type -> Type -> [CaseBranch] -> Sem r Type + go ity ty = \case + [] -> return ty + CaseBranch {..} : brs -> do + let bl' = if _caseBranchSave then BL.cons ity bl else bl + ty' <- goInfer bl' _caseBranchBody + ty'' <- unifyTypes' (_nodeCaseInfo ^. nodeInfoLocation) tab ty ty' + go ity ty'' brs + + goSave :: BinderList Type -> NodeSave -> Sem r Type + goSave bl NodeSave {..} = do + ty <- goInfer bl _nodeSaveArg + goInfer (BL.cons ty bl) _nodeSaveBody + + checkType :: BinderList Type -> Node -> Type -> Sem r () + checkType bl node ty = do + ty' <- goInfer bl node + void $ unifyTypes' (getNodeLocation node) tab ty ty' + +validateFunction :: (Member (Error TreeError) r) => InfoTable -> FunctionInfo -> Sem r FunctionInfo +validateFunction tab funInfo = do + ty <- inferType tab funInfo (funInfo ^. functionCode) + let ty' = if funInfo ^. functionArgsNum == 0 then funInfo ^. functionType else typeTarget (funInfo ^. functionType) + void $ unifyTypes' (funInfo ^. functionLocation) tab ty ty' + return funInfo + +validate :: (Member (Error TreeError) r) => InfoTable -> Sem r InfoTable +validate tab = mapFunctionsM (validateFunction tab) tab diff --git a/test/Tree/Eval/Base.hs b/test/Tree/Eval/Base.hs index 4c2d93cfa7..6f98965f7d 100644 --- a/test/Tree/Eval/Base.hs +++ b/test/Tree/Eval/Base.hs @@ -35,25 +35,31 @@ treeEvalAssertionParam evalParam mainFile expectedFile trans testTrans step = do case runParser (toFilePath mainFile) s of Left err -> assertFailure (show (pretty err)) Right tab0 -> do - unless (null trans) $ - step "Transform" - let tab = run $ applyTransformations trans tab0 - testTrans tab - case tab ^. infoMainFunction of - Just sym -> do - withTempDir' - ( \dirPath -> do - let outputFile = dirPath $(mkRelFile "out.out") - hout <- openFile (toFilePath outputFile) WriteMode - step "Evaluate" - evalParam hout sym tab - hClose hout - actualOutput <- readFile (toFilePath outputFile) - step "Compare expected and actual program output" - expected <- readFile (toFilePath expectedFile) - assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected - ) - Nothing -> assertFailure "no 'main' function" + step "Validate" + case run $ runError @JuvixError $ applyTransformations [Validate] tab0 of + Left err -> assertFailure (show (pretty (fromJuvixError @GenericError err))) + Right tab1 -> do + unless (null trans) $ + step "Transform" + case run $ runError @JuvixError $ applyTransformations trans tab1 of + Left err -> assertFailure (show (pretty (fromJuvixError @GenericError err))) + Right tab -> do + testTrans tab + case tab ^. infoMainFunction of + Just sym -> do + withTempDir' + ( \dirPath -> do + let outputFile = dirPath $(mkRelFile "out.out") + hout <- openFile (toFilePath outputFile) WriteMode + step "Evaluate" + evalParam hout sym tab + hClose hout + actualOutput <- readFile (toFilePath outputFile) + step "Compare expected and actual program output" + expected <- readFile (toFilePath expectedFile) + assertEqDiffText ("Check: RUN output = " <> toFilePath expectedFile) actualOutput expected + ) + Nothing -> assertFailure "no 'main' function" evalAssertion :: Handle -> Symbol -> InfoTable -> IO () evalAssertion hout sym tab = do diff --git a/tests/Asm/positive/test032.jva b/tests/Asm/positive/test032.jva index d0d08b3368..4449224805 100644 --- a/tests/Asm/positive/test032.jva +++ b/tests/Asm/positive/test032.jva @@ -99,7 +99,7 @@ function uncurry(*, * -> *, *) { tccall 2; } -function pred_step(Pair) : (* -> *, *) -> * { +function pred_step(Pair) : Pair { push arg[0].pair[1]; call isZero; br { diff --git a/tests/Tree/positive/test032.jvt b/tests/Tree/positive/test032.jvt index 41e99ffb64..4a4c50b650 100644 --- a/tests/Tree/positive/test032.jvt +++ b/tests/Tree/positive/test032.jvt @@ -101,7 +101,7 @@ function uncurry(*, * → *, *) : * { ccall(arg[0], arg[1], arg[2]) } -function pred_step(Pair) : (* → *, *) → * { +function pred_step(Pair) : Pair { br(call[isZero](arg[0].pair[1])) { true: alloc[pair](arg[0].pair[0], calloc[uncurry](calloc[succ](arg[0].pair[1]))) false: alloc[pair](calloc[uncurry](calloc[succ](arg[0].pair[0])), calloc[uncurry](calloc[succ](arg[0].pair[1])))