summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSimon Peyton Jones <simonpj@microsoft.com>2017-02-16 09:42:32 (GMT)
committerSimon Peyton Jones <simonpj@microsoft.com>2017-02-16 14:24:57 (GMT)
commit6bab649bde653f13c15eba30d5007bef4a9a9d3a (patch)
tree9732155c1110fa3e2b3d5e68f249eee4c47a35ed
parentfc9d152b058f21ab03986ea722d0c94688b9969f (diff)
downloadghc-6bab649bde653f13c15eba30d5007bef4a9a9d3a.zip
ghc-6bab649bde653f13c15eba30d5007bef4a9a9d3a.tar.gz
ghc-6bab649bde653f13c15eba30d5007bef4a9a9d3a.tar.bz2
Improve checking of joins in Core Lint
This patch addresses the rather expensive treatment of join points, identified in Trac #13220 comment:17 Before we were tracking the "bad joins". Now we track the good ones. That is easier to think about, and much more efficient; see CoreLint Note [Join points]. On the way I did some other modest refactoring, among other things removing a duplicated call of lintIdBndr for let-bindings. On teh
-rw-r--r--compiler/coreSyn/CoreLint.hs253
1 files changed, 130 insertions, 123 deletions
diff --git a/compiler/coreSyn/CoreLint.hs b/compiler/coreSyn/CoreLint.hs
index f87989d..053ac21 100644
--- a/compiler/coreSyn/CoreLint.hs
+++ b/compiler/coreSyn/CoreLint.hs
@@ -151,7 +151,6 @@ find an occurrence of an Id, we fetch it from the in-scope set.
Note [Bad unsafe coercion]
~~~~~~~~~~~~~~~~~~~~~~~~~~
-
For discussion see https://ghc.haskell.org/trac/ghc/wiki/BadUnsafeCoercions
Linter introduces additional rules that checks improper coercion between
different types, called bad coercions. Following coercions are forbidden:
@@ -170,12 +169,10 @@ different types, called bad coercions. Following coercions are forbidden:
Note [Join points]
~~~~~~~~~~~~~~~~~~
-
We check the rules listed in Note [Invariants on join points] in CoreSyn. The
only one that causes any difficulty is the first: All occurrences must be tail
-calls. To this end, along with the in-scope set, we remember in le_bad_joins the
-subset of join ids that are no longer allowed because they were declared "too
-far away." For example:
+calls. To this end, along with the in-scope set, we remember in le_joins the
+subset of in-scope Ids that are valid join ids. For example:
join j x = ... in
case e of
@@ -184,11 +181,11 @@ far away." For example:
C -> join h = jump j w in ... -- good
D -> let x = jump j v in ... -- BAD
-A join point remains valid in case branches, so when checking the A branch, j
-is still valid. When we check the scrutinee of the inner case, however, we add j
-to le_bad_joins and catch the error. Similarly, join points can occur free in
-RHSes of other join points but not the RHSes of value bindings (thunks and
-functions).
+A join point remains valid in case branches, so when checking the A
+branch, j is still valid. When we check the scrutinee of the inner
+case, however, we set le_joins to empty, and catch the
+error. Similarly, join points can occur free in RHSes of other join
+points but not the RHSes of value bindings (thunks and functions).
************************************************************************
* *
@@ -387,10 +384,9 @@ lintCoreBindings :: DynFlags -> CoreToDo -> [Var] -> CoreProgram -> (Bag MsgDoc,
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lintCoreBindings dflags pass local_in_scope binds
- = initL dflags flags $
- addLoc TopLevelBindings $
- addInScopeVars local_in_scope $
- addInScopeVars binders $
+ = initL dflags flags in_scope_set $
+ addLoc TopLevelBindings $
+ lintIdBndrs TopLevel binders $
-- Put all the top-level binders in scope at the start
-- This is because transformation rules can bring something
-- into use 'unexpectedly'
@@ -398,6 +394,8 @@ lintCoreBindings dflags pass local_in_scope binds
; checkL (null ext_dups) (dupExtVars ext_dups)
; mapM lint_bind binds }
where
+ in_scope_set = mkInScopeSet (mkVarSet local_in_scope)
+
flags = LF { lf_check_global_ids = check_globals
, lf_check_inline_loop_breakers = check_lbs
, lf_check_static_ptrs = check_static_ptrs }
@@ -463,9 +461,9 @@ lintUnfolding dflags locn vars expr
| isEmptyBag errs = Nothing
| otherwise = Just (pprMessageBag errs)
where
- (_warns, errs) = initL dflags defaultLintFlags linter
+ in_scope = mkInScopeSet vars
+ (_warns, errs) = initL dflags defaultLintFlags in_scope linter
linter = addLoc (ImportedUnfolding locn) $
- addInScopeVarSet vars $
lintCoreExpr expr
lintExpr :: DynFlags
@@ -477,9 +475,9 @@ lintExpr dflags vars expr
| isEmptyBag errs = Nothing
| otherwise = Just (pprMessageBag errs)
where
- (_warns, errs) = initL dflags defaultLintFlags linter
+ in_scope = mkInScopeSet (mkVarSet vars)
+ (_warns, errs) = initL dflags defaultLintFlags in_scope linter
linter = addLoc TopLevelBindings $
- addInScopeVars vars $
lintCoreExpr expr
{-
@@ -499,7 +497,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
= addLoc (RhsOf binder) $
-- Check the rhs
do { ty <- lintRhs binder rhs
- ; lint_bndr binder -- Check match to RHS type
; binder_ty <- applySubstTy (idType binder)
; ensureEqTys binder_ty ty (mkRhsMsg binder (text "RHS") ty)
@@ -571,11 +568,6 @@ lintSingleBinding top_lvl_flag rec_flag (binder,rhs)
-- We should check the unfolding, if any, but this is tricky because
-- the unfolding is a SimplifiableCoreExpr. Give up for now.
- where
- -- If you edit this function, you may need to update the GHC formalism
- -- See Note [GHC Formalism]
- lint_bndr var | isId var = lintIdBndr top_lvl_flag var $ \_ -> return ()
- | otherwise = return ()
-- | Checks the RHS of bindings. It only differs from 'lintCoreExpr'
-- in that it doesn't reject occurrences of the function 'makeStatic' when they
@@ -680,7 +672,7 @@ lintCoreExpr :: CoreExpr -> LintM OutType
-- If you edit this function, you may need to update the GHC formalism
-- See Note [GHC Formalism]
lintCoreExpr (Var var)
- = lintCoreVar var 0
+ = lintVarOcc var 0
lintCoreExpr (Lit lit)
= return (literalType lit)
@@ -726,13 +718,16 @@ lintCoreExpr (Let (NonRec bndr rhs) body)
| isId bndr
= do { lintSingleBinding NotTopLevel NonRecursive (bndr,rhs)
; addLoc (BodyOfLetRec [bndr])
- (lintIdBndr NotTopLevel bndr $ \_ -> lintCoreExpr body) }
+ (lintIdBndr NotTopLevel bndr $ \_ ->
+ addGoodJoins [bndr] $
+ lintCoreExpr body) }
| otherwise
= failWithL (mkLetErr bndr rhs) -- Not quite accurate
lintCoreExpr (Let (Rec pairs) body)
- = lintIdBndrs bndrs $ \_ ->
+ = lintIdBndrs NotTopLevel bndrs $
+ addGoodJoins bndrs $
do { checkL (null dups) (dupVars dups)
; checkL (all isJoinId bndrs || all (not . isJoinId) bndrs) $
mkInconsistentRecMsg bndrs
@@ -812,51 +807,38 @@ lintCoreExpr (Coercion co)
= do { (k1, k2, ty1, ty2, role) <- lintInCo co
; return (mkHeteroCoercionType role k1 k2 ty1 ty2) }
-lintCoreVar :: Var -> Int -- Number of arguments (type or value) being passed
+----------------------
+lintVarOcc :: Var -> Int -- Number of arguments (type or value) being passed
-> LintM Type -- returns type of the *variable*
-lintCoreVar var nargs
+lintVarOcc var nargs
= do { checkL (isNonCoVarId var)
(text "Non term variable" <+> ppr var)
- ; lf <- getLintFlags
+ -- Cneck that the type of the occurrence is the same
+ -- as the type of the binding site
+ ; ty <- applySubstTy (idType var)
+ ; var' <- lookupIdInScope var
+ ; let ty' = idType var'
+ ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
+
-- Check for a nested occurrence of the StaticPtr constructor.
-- See Note [Checking StaticPtrs].
+ ; lf <- getLintFlags
; when (nargs /= 0 && lf_check_static_ptrs lf /= AllowAnywhere) $
checkL (idName var /= makeStaticName) $
text "Found makeStatic nested in an expression"
; checkDeadIdOcc var
- ; ty <- applySubstTy (idType var)
- ; var' <- lookupIdInScope var
- ; let ty' = idType var'
- ; ensureEqTys ty ty' $ mkBndrOccTypeMismatchMsg var' var ty' ty
- ; mb_join_arity
- <- case isJoinId_maybe var' of
- Just join_arity ->
- do { checkL (isJoinId_maybe var == Just join_arity) $
- mkJoinBndrOccMismatchMsg var' var
- ; return $ Just join_arity }
- Nothing ->
- case tailCallInfo (idOccInfo var') of
- AlwaysTailCalled join_arity -> return $ Just join_arity
- -- This function will be turned into a join point by the
- -- simplifier; typecheck it as if it already were one
- NoTailCallInfo -> return $ Nothing
- ; case mb_join_arity of
- Just join_arity ->
- do { bad <- isBadJoin var'
- ; checkL (not bad) $ mkJoinOutOfScopeMsg var'
- ; checkL (nargs == join_arity) $
- mkBadJumpMsg var' join_arity nargs }
- Nothing ->
- do { checkL (not (isJoinId var)) $
- mkJoinBndrOccMismatchMsg var' var }
+ ; checkJoinOcc var nargs
+
; return (idType var') }
-lintCoreFun :: CoreExpr -> Int -- Number of arguments (type or val) being passed
- -> LintM Type -- returns type of the *function*
+lintCoreFun :: CoreExpr
+ -> Int -- Number of arguments (type or val) being passed
+ -> LintM Type -- Returns type of the *function*
lintCoreFun (Var var) nargs
- = lintCoreVar var nargs
+ = lintVarOcc var nargs
+
lintCoreFun (Lam var body) nargs
-- Act like lintCoreExpr of Lam, but *don't* call markAllJoinsBad; see
-- Note [Beta redexes]
@@ -865,10 +847,47 @@ lintCoreFun (Lam var body) nargs
lintBinder var $ \ var' ->
do { body_ty <- lintCoreFun body (nargs - 1)
; return $ mkLamType var' body_ty }
+
lintCoreFun expr nargs
= markAllJoinsBadIf (nargs /= 0) $
lintCoreExpr expr
+------------------
+checkDeadIdOcc :: Id -> LintM ()
+-- Occurrences of an Id should never be dead....
+-- except when we are checking a case pattern
+checkDeadIdOcc id
+ | isDeadOcc (idOccInfo id)
+ = do { in_case <- inCasePat
+ ; checkL in_case
+ (text "Occurrence of a dead Id" <+> ppr id) }
+ | otherwise
+ = return ()
+
+------------------
+checkJoinOcc :: Id -> JoinArity -> LintM ()
+-- Check that if the occurrence is a JoinId, then so is the
+-- binding site, and it's a valid join Id
+checkJoinOcc var n_args
+ | Just join_arity_occ <- isJoinId_maybe var
+ = do { mb_join_arity_bndr <- lookupJoinId var
+ ; case mb_join_arity_bndr of {
+ Nothing -> -- Binder is not a join point
+ addErrL (invalidJoinOcc var) ;
+
+ Just join_arity_bndr ->
+
+ do { checkL (join_arity_bndr == join_arity_occ) $
+ -- Arity differs at binding site and occurrence
+ mkJoinBndrOccMismatchMsg var join_arity_bndr join_arity_occ
+
+ ; checkL (n_args == join_arity_occ) $
+ -- Arity doesn't match #args
+ mkBadJumpMsg var join_arity_occ n_args } } }
+
+ | otherwise
+ = return ()
+
{-
Note [No alternatives lint check]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -1010,17 +1029,6 @@ lintTyKind tyvar arg_ty
where
tyvar_kind = tyVarKind tyvar
-checkDeadIdOcc :: Id -> LintM ()
--- Occurrences of an Id should never be dead....
--- except when we are checking a case pattern
-checkDeadIdOcc id
- | isDeadOcc (idOccInfo id)
- = do { in_case <- inCasePat
- ; checkL in_case
- (text "Occurrence of a dead Id" <+> ppr id) }
- | otherwise
- = return ()
-
{-
************************************************************************
* *
@@ -1152,21 +1160,22 @@ lintCoBndr cv thing_inside
(text "CoVar with non-coercion type:" <+> pprTyVar cv)
; updateTCvSubst subst' (thing_inside cv') }
-lintIdBndrs :: [Var] -> ([Var] -> LintM a) -> LintM a
-lintIdBndrs ids linterF
+lintIdBndrs :: TopLevelFlag -> [Var] -> LintM a -> LintM a
+lintIdBndrs top_lvl ids linterF
= go ids
where
- go [] = linterF []
- go (id:ids) = lintIdBndr NotTopLevel id $ \id ->
- lintIdBndrs ids $ \ids ->
- linterF (id:ids)
+ go [] = linterF
+ go (id:ids) = lintIdBndr top_lvl id $ \_ ->
+ lintIdBndrs top_lvl ids $
+ linterF
lintIdBndr :: TopLevelFlag -> InVar -> (OutVar -> LintM a) -> LintM a
-- Do substitution on the type of a binder and add the var with this
-- new type to the in-scope set of the second argument
-- ToDo: lint its rules
lintIdBndr top_lvl id linterF
- = do { flags <- getLintFlags
+ = ASSERT2( isId id, ppr id )
+ do { flags <- getLintFlags
; checkL (not (lf_check_global_ids flags) || isLocalId id)
(text "Non-local Id binder" <+> ppr id)
-- See Note [Checking for global Ids]
@@ -1784,7 +1793,8 @@ data LintEnv
, le_subst :: TCvSubst -- Current type substitution; we also use this
-- to keep track of all the variables in scope,
-- both Ids and TyVars
- , le_bad_joins :: IdSet -- Join points that are no longer valid
+ , le_joins :: IdSet -- Join points in scope that are valid
+ -- A subset of teh InScopeSet in le_subst
-- See Note [Join points]
, le_dynflags :: DynFlags -- DynamicFlags
}
@@ -1891,13 +1901,17 @@ data LintLocInfo
| InType Type -- Inside a type
| InCo Coercion -- Inside a coercion
-initL :: DynFlags -> LintFlags -> LintM a -> WarnsAndErrs -- Errors and warnings
-initL dflags flags m
+initL :: DynFlags -> LintFlags -> InScopeSet
+ -> LintM a -> WarnsAndErrs -- Errors and warnings
+initL dflags flags in_scope m
= case unLintM m env (emptyBag, emptyBag) of
(_, errs) -> errs
where
- env = LE { le_flags = flags, le_subst = emptyTCvSubst, le_loc = []
- , le_dynflags = dflags, le_bad_joins = emptyVarSet }
+ env = LE { le_flags = flags
+ , le_subst = mkEmptyTCvSubst in_scope
+ , le_joins = emptyVarSet
+ , le_loc = []
+ , le_dynflags = dflags }
getLintFlags :: LintM LintFlags
getLintFlags = LintM $ \ env errs -> (Just (le_flags env), errs)
@@ -1952,29 +1966,12 @@ inCasePat = LintM $ \ env errs -> (Just (is_case_pat env), errs)
is_case_pat (LE { le_loc = CasePat {} : _ }) = True
is_case_pat _other = False
-addInScopeVars :: [Var] -> LintM a -> LintM a
-addInScopeVars vars m
- = LintM $ \ env errs ->
- unLintM m (env { le_subst = extendTCvInScopeList (le_subst env) vars
- , le_bad_joins = bad_joins' env })
- errs
- where
- bad_joins' env = delVarSetList (le_bad_joins env) (filter isJoinId vars)
-
-addInScopeVarSet :: VarSet -> LintM a -> LintM a
-addInScopeVarSet vars m
- = LintM $ \ env errs ->
- unLintM m (env { le_subst = extendTCvInScopeSet (le_subst env) vars })
- errs
-
addInScopeVar :: Var -> LintM a -> LintM a
addInScopeVar var m
= LintM $ \ env errs ->
- unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
- , le_bad_joins = bad_joins' env }) errs
- where
- bad_joins' env | isJoinId var = delVarSet (le_bad_joins env) var
- | otherwise = le_bad_joins env
+ unLintM m (env { le_subst = extendTCvInScope (le_subst env) var
+ , le_joins = delVarSet (le_joins env) var
+ }) errs
extendSubstL :: TyVar -> Type -> LintM a -> LintM a
extendSubstL tv ty m
@@ -1987,16 +1984,25 @@ updateTCvSubst subst' m
markAllJoinsBad :: LintM a -> LintM a
markAllJoinsBad m
- = LintM $ \ env errs -> unLintM m (marked env) errs
- where
- marked env = env { le_bad_joins = filterVarSet isJoinId in_set }
- where
- in_set = getInScopeVars (getTCvInScope (le_subst env))
+ = LintM $ \ env errs -> unLintM m (env { le_joins = emptyVarSet }) errs
markAllJoinsBadIf :: Bool -> LintM a -> LintM a
markAllJoinsBadIf True m = markAllJoinsBad m
markAllJoinsBadIf False m = m
+addGoodJoins :: [Var] -> LintM a -> LintM a
+addGoodJoins vars thing_inside
+ | null join_ids
+ = thing_inside
+ | otherwise
+ = LintM $ \ env errs -> unLintM thing_inside (add_joins env) errs
+ where
+ add_joins env = env { le_joins = le_joins env `extendVarSetList` join_ids }
+ join_ids = filter isJoinId vars
+
+getValidJoins :: LintM IdSet
+getValidJoins = LintM (\ env errs -> (Just (le_joins env), errs))
+
getTCvSubst :: LintM TCvSubst
getTCvSubst = LintM (\ env errs -> (Just (le_subst env), errs))
@@ -2022,9 +2028,14 @@ lookupIdInScope id
where
out_of_scope = pprBndr LetBind id <+> text "is out of scope"
-isBadJoin :: Id -> LintM Bool
-isBadJoin id = LintM $ \env errs -> (Just (id `elemVarSet` le_bad_joins env),
- errs)
+lookupJoinId :: Id -> LintM (Maybe JoinArity)
+-- Look up an Id which should be a join point, valid here
+-- If so, return its arity, if not return Nothing
+lookupJoinId id
+ = do { join_set <- getValidJoins
+ ; case lookupVarSet join_set id of
+ Just id' -> return (isJoinId_maybe id')
+ Nothing -> return Nothing }
lintTyCoVarInScope :: Var -> LintM ()
lintTyCoVarInScope v = lintInScope (text "is out of scope") v
@@ -2294,9 +2305,10 @@ mkBadJoinArityMsg var ar nlams
text "Join arity:" <+> ppr ar,
text "Number of lambdas:" <+> ppr nlams ]
-mkJoinOutOfScopeMsg :: Var -> SDoc
-mkJoinOutOfScopeMsg var
- = text "Join variable no longer in scope:" <+> ppr var
+invalidJoinOcc :: Var -> SDoc
+invalidJoinOcc var
+ = vcat [ text "Invalid occurrence of a join variable:" <+> ppr var
+ , text "The binder is either not a join point, or not valid here" ]
mkBadJumpMsg :: Var -> Int -> Int -> SDoc
mkBadJumpMsg var ar nargs
@@ -2312,17 +2324,12 @@ mkInconsistentRecMsg bndrs
where
ppr_with_details bndr = ppr bndr <> ppr (idDetails bndr)
-mkJoinBndrOccMismatchMsg :: Var -> Var -> SDoc
-mkJoinBndrOccMismatchMsg bndr var
- = vcat [ text "Mismatch in join point status between binder and occurrence",
- text "Var:" <+> ppr bndr,
- text "Binder:" <+> ppr_join_status bndr,
- text "Occ:" <+> ppr_join_status var ]
- where
- ppr_join_status v = case details of JoinId _ -> ppr details
- _ -> text "not a join id"
- where
- details = idDetails v
+mkJoinBndrOccMismatchMsg :: Var -> JoinArity -> JoinArity -> SDoc
+mkJoinBndrOccMismatchMsg bndr join_arity_bndr join_arity_occ
+ = vcat [ text "Mismatch in join point arity between binder and occurrence"
+ , text "Var:" <+> ppr bndr
+ , text "Arity at binding site:" <+> ppr join_arity_bndr
+ , text "Arity at occurrence: " <+> ppr join_arity_occ ]
mkBndrOccTypeMismatchMsg :: Var -> Var -> OutType -> OutType -> SDoc
mkBndrOccTypeMismatchMsg bndr var bndr_ty var_ty