summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Pickering <matthewtpickering@gmail.com>2016-11-29 19:43:43 (GMT)
committerBen Gamari <ben@smart-cactus.org>2016-11-29 19:43:44 (GMT)
commitc2268ba0eeb36a48da77ba95c72525c398c8b306 (patch)
treeb0b550bc91132d81b10db5e904da3b76ee52fd9d
parent3ec856308cbfb89299daba56337eda866ac88d6e (diff)
downloadghc-c2268ba0eeb36a48da77ba95c72525c398c8b306.zip
ghc-c2268ba0eeb36a48da77ba95c72525c398c8b306.tar.gz
ghc-c2268ba0eeb36a48da77ba95c72525c398c8b306.tar.bz2
Refactor Pattern Match Checker to use ListT
Reviewers: bgamari, austin Reviewed By: bgamari Subscribers: thomie Differential Revision: https://phabricator.haskell.org/D2725
-rw-r--r--compiler/deSugar/Check.hs326
-rw-r--r--compiler/ghc.cabal.in1
-rw-r--r--compiler/utils/ListT.hs71
3 files changed, 287 insertions, 111 deletions
diff --git a/compiler/deSugar/Check.hs b/compiler/deSugar/Check.hs
index b5f6eac..04ba568 100644
--- a/compiler/deSugar/Check.hs
+++ b/compiler/deSugar/Check.hs
@@ -50,6 +50,8 @@ import Coercion
import TcEvidence
import IOEnv
+import ListT (ListT(..), fold)
+
{-
This module checks pattern matches for:
\begin{enumerate}
@@ -72,7 +74,25 @@ The algorithm is based on the paper:
%************************************************************************
-}
-type PmM a = DsM a
+-- We use the non-determinism monad to apply the algorithm to several
+-- possible sets of constructors. Users can specify complete sets of
+-- constructors by using COMPLETE pragmas.
+-- The algorithm only picks out constructor
+-- sets deep in the bowels which makes a simpler `mapM` more difficult to
+-- implement. The non-determinism is only used in one place, see the ConVar
+-- case in `pmCheckHd`.
+
+type PmM a = ListT DsM a
+
+liftD :: DsM a -> PmM a
+liftD m = ListT $ \sk fk -> m >>= \a -> sk a fk
+
+
+myRunListT :: PmM a -> DsM [a]
+myRunListT pm = fold pm go (return [])
+ where
+ go a mas =
+ mas >>= \as -> return (a:as)
data PatTy = PAT | VA -- Used only as a kind, to index PmPat
@@ -122,14 +142,64 @@ type Uncovered = ValSetAbs
-- C = True ==> Useful clause (no warning)
-- C = False, D = True ==> Clause with inaccessible RHS
-- C = False, D = False ==> Redundant clause
-type Triple = (Bool, Uncovered, Bool)
+
+data Covered = Covered | NotCovered
+ deriving Show
+
+instance Outputable Covered where
+ ppr (Covered) = text "Covered"
+ ppr (NotCovered) = text "NotCovered"
+
+-- Like the or monoid for booleans
+-- Covered = True, Uncovered = False
+instance Monoid Covered where
+ mempty = NotCovered
+ Covered `mappend` _ = Covered
+ _ `mappend` Covered = Covered
+ NotCovered `mappend` NotCovered = NotCovered
+
+data Diverged = Diverged | NotDiverged
+ deriving Show
+
+instance Outputable Diverged where
+ ppr Diverged = text "Diverged"
+ ppr NotDiverged = text "NotDiverged"
+
+instance Monoid Diverged where
+ mempty = NotDiverged
+ Diverged `mappend` _ = Diverged
+ _ `mappend` Diverged = Diverged
+ NotDiverged `mappend` NotDiverged = NotDiverged
+
+data PartialResult = PartialResult {
+ presultCovered :: Covered
+ , presultUncovered :: Uncovered
+ , presultDivergent :: Diverged }
+
+instance Outputable PartialResult where
+ ppr (PartialResult c vsa d) = text "PartialResult" <+> ppr c
+ <+> ppr d <+> ppr vsa
+
+instance Monoid PartialResult where
+ mempty = PartialResult mempty [] mempty
+ (PartialResult cs1 vsa1 ds1)
+ `mappend` (PartialResult cs2 vsa2 ds2)
+ = PartialResult (cs1 `mappend` cs2)
+ (vsa1 `mappend` vsa2)
+ (ds1 `mappend` ds2)
+
+-- newtype ChoiceOf a = ChoiceOf [a]
-- | Pattern check result
--
-- * Redundant clauses
-- * Not-covered clauses
-- * Clauses with inaccessible RHS
-type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]])
+data PmResult =
+ PmResult {
+ pmresultRedundant :: [Located [LPat Id]]
+ , pmresultUncovered :: Uncovered
+ , pmresultInaccessible :: [Located [LPat Id]] }
{-
%************************************************************************
@@ -142,63 +212,67 @@ type PmResult = ([Located [LPat Id]], Uncovered, [Located [LPat Id]])
-- | Check a single pattern binding (let)
checkSingle :: DynFlags -> DsMatchContext -> Id -> Pat Id -> DsM ()
checkSingle dflags ctxt@(DsMatchContext _ locn) var p = do
- tracePm "checkSingle" (vcat [ppr ctxt, ppr var, ppr p])
- mb_pm_res <- tryM (checkSingle' locn var p)
+ tracePmD "checkSingle" (vcat [ppr ctxt, ppr var, ppr p])
+ mb_pm_res <- tryM (head <$> myRunListT (checkSingle' locn var p))
case mb_pm_res of
Left _ -> warnPmIters dflags ctxt
Right res -> dsPmWarn dflags ctxt res
-- | Check a single pattern binding (let)
-checkSingle' :: SrcSpan -> Id -> Pat Id -> DsM PmResult
+checkSingle' :: SrcSpan -> Id -> Pat Id -> PmM PmResult
checkSingle' locn var p = do
- resetPmIterDs -- set the iter-no to zero
- fam_insts <- dsGetFamInstEnvs
- clause <- translatePat fam_insts p
+ liftD resetPmIterDs -- set the iter-no to zero
+ fam_insts <- liftD dsGetFamInstEnvs
+ clause <- liftD $ translatePat fam_insts p
missing <- mkInitialUncovered [var]
tracePm "checkSingle: missing" (vcat (map pprValVecDebug missing))
- (cs,us,ds) <- runMany (pmcheckI clause []) missing -- no guards
+ PartialResult cs us ds <- runMany (pmcheckI clause []) missing -- no guards
return $ case (cs,ds) of
- (True, _ ) -> ([], us, []) -- useful
- (False, False) -> ( m, us, []) -- redundant
- (False, True ) -> ([], us, m) -- inaccessible rhs
+ (Covered, _ ) -> PmResult [] us [] -- useful
+ (NotCovered, NotDiverged) -> PmResult m us [] -- redundant
+ (NotCovered, Diverged ) -> PmResult [] us m -- inaccessible rhs
where m = [L locn [L locn p]]
-- | Check a matchgroup (case, functions, etc.)
checkMatches :: DynFlags -> DsMatchContext
-> [Id] -> [LMatch Id (LHsExpr Id)] -> DsM ()
checkMatches dflags ctxt vars matches = do
- tracePm "checkMatches" (hang (vcat [ppr ctxt
+ tracePmD "checkMatches" (hang (vcat [ppr ctxt
, ppr vars
, text "Matches:"])
2
(vcat (map ppr matches)))
- mb_pm_res <- tryM (checkMatches' vars matches)
+ mb_pm_res <- tryM (head <$> myRunListT (checkMatches' vars matches))
case mb_pm_res of
Left _ -> warnPmIters dflags ctxt
Right res -> dsPmWarn dflags ctxt res
-- | Check a matchgroup (case, functions, etc.)
-checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> DsM PmResult
+checkMatches' :: [Id] -> [LMatch Id (LHsExpr Id)] -> PmM PmResult
checkMatches' vars matches
- | null matches = return ([], [], [])
+ | null matches = return $ PmResult [] [] []
| otherwise = do
- resetPmIterDs -- set the iter-no to zero
+ liftD resetPmIterDs -- set the iter-no to zero
missing <- mkInitialUncovered vars
tracePm "checkMatches: missing" (vcat (map pprValVecDebug missing))
(rs,us,ds) <- go matches missing
- return (map hsLMatchToLPats rs, us, map hsLMatchToLPats ds)
+ return $ PmResult (map hsLMatchToLPats rs) us (map hsLMatchToLPats ds)
where
+ go :: [LMatch Id (LHsExpr Id)] -> Uncovered
+ -> PmM ([LMatch Id (LHsExpr Id)] , Uncovered , [LMatch Id (LHsExpr Id)])
go [] missing = return ([], missing, [])
go (m:ms) missing = do
tracePm "checMatches': go" (ppr m $$ ppr missing)
- fam_insts <- dsGetFamInstEnvs
- (clause, guards) <- translateMatch fam_insts m
- (cs, missing', ds) <- runMany (pmcheckI clause guards) missing
+ fam_insts <- liftD dsGetFamInstEnvs
+ (clause, guards) <- liftD $ translateMatch fam_insts m
+ r@(PartialResult cs missing' ds)
+ <- runMany (pmcheckI clause guards) missing
+ tracePm "checMatches': go: res" (ppr r)
(rs, final_u, is) <- go ms missing'
return $ case (cs, ds) of
- (True, _ ) -> ( rs, final_u, is) -- useful
- (False, False) -> (m:rs, final_u, is) -- redundant
- (False, True ) -> ( rs, final_u, m:is) -- inaccessible
+ (Covered, _ ) -> ( rs, final_u, is) -- useful
+ (NotCovered, NotDiverged) -> (m:rs, final_u, is) -- redundant
+ (NotCovered, Diverged ) -> ( rs, final_u, m:is) -- inaccessible
hsLMatchToLPats :: LMatch id body -> Located [LPat id]
hsLMatchToLPats (L l (Match _ pats _ _)) = L l pats
@@ -239,7 +313,7 @@ isFakeGuard [PmCon { pm_con_con = c }] (PmExprOther EWildPat)
isFakeGuard _pats _e = False
-- | Generate a `canFail` pattern vector of a specific type
-mkCanFailPmPat :: Type -> PmM PatVec
+mkCanFailPmPat :: Type -> DsM PatVec
mkCanFailPmPat ty = do
var <- mkPmVar ty
return [var, fake_pat]
@@ -274,7 +348,7 @@ mkLitPattern lit = PmLit { pm_lit_lit = PmSLit lit }
-- -----------------------------------------------------------------------
-- * Transform (Pat Id) into of (PmPat Id)
-translatePat :: FamInstEnvs -> Pat Id -> PmM PatVec
+translatePat :: FamInstEnvs -> Pat Id -> DsM PatVec
translatePat fam_insts pat = case pat of
WildPat ty -> mkPmVars [ty]
VarPat id -> return [PmVar (unLoc id)]
@@ -389,7 +463,7 @@ translatePat fam_insts pat = case pat of
-- | Translate an overloaded literal (see `tidyNPat' in deSugar/MatchLit.hs)
translateNPat :: FamInstEnvs
- -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> PmM PatVec
+ -> HsOverLit Id -> Maybe (SyntaxExpr Id) -> Type -> DsM PatVec
translateNPat fam_insts (OverLit val False _ ty) mb_neg outer_ty
| not type_change, isStringTy ty, HsIsString src s <- val, Nothing <- mb_neg
= translatePat fam_insts (LitPat (HsString src s))
@@ -407,12 +481,12 @@ translateNPat _ ol mb_neg _
-- | Translate a list of patterns (Note: each pattern is translated
-- to a pattern vector but we do not concatenate the results).
-translatePatVec :: FamInstEnvs -> [Pat Id] -> PmM [PatVec]
+translatePatVec :: FamInstEnvs -> [Pat Id] -> DsM [PatVec]
translatePatVec fam_insts pats = mapM (translatePat fam_insts) pats
-- | Translate a constructor pattern
translateConPatVec :: FamInstEnvs -> [Type] -> [TyVar]
- -> DataCon -> HsConPatDetails Id -> PmM PatVec
+ -> DataCon -> HsConPatDetails Id -> DsM PatVec
translateConPatVec fam_insts _univ_tys _ex_tvs _ (PrefixCon ps)
= concat <$> translatePatVec fam_insts (map unLoc ps)
translateConPatVec fam_insts _univ_tys _ex_tvs _ (InfixCon p1 p2)
@@ -467,7 +541,7 @@ translateConPatVec fam_insts univ_tys ex_tvs c (RecCon (HsRecFields fs _))
| otherwise = subsetOf (x:xs) ys
-- Translate a single match
-translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> PmM (PatVec,[PatVec])
+translateMatch :: FamInstEnvs -> LMatch Id (LHsExpr Id) -> DsM (PatVec,[PatVec])
translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
pats' <- concat <$> translatePatVec fam_insts pats
guards' <- mapM (translateGuards fam_insts) guards
@@ -483,7 +557,7 @@ translateMatch fam_insts (L _ (Match _ lpats _ grhss)) = do
-- * Transform source guards (GuardStmt Id) to PmPats (Pattern)
-- | Translate a list of guard statements to a pattern vector
-translateGuards :: FamInstEnvs -> [GuardStmt Id] -> PmM PatVec
+translateGuards :: FamInstEnvs -> [GuardStmt Id] -> DsM PatVec
translateGuards fam_insts guards = do
all_guards <- concat <$> mapM (translateGuard fam_insts) guards
return (replace_unhandled all_guards)
@@ -523,7 +597,7 @@ cantFailPattern (PmGrd pv _e)
cantFailPattern _ = False
-- | Translate a guard statement to Pattern
-translateGuard :: FamInstEnvs -> GuardStmt Id -> PmM PatVec
+translateGuard :: FamInstEnvs -> GuardStmt Id -> DsM PatVec
translateGuard fam_insts guard = case guard of
BodyStmt e _ _ _ -> translateBoolGuard e
LetStmt binds -> translateLet (unLoc binds)
@@ -535,17 +609,17 @@ translateGuard fam_insts guard = case guard of
ApplicativeStmt {} -> panic "translateGuard ApplicativeLastStmt"
-- | Translate let-bindings
-translateLet :: HsLocalBinds Id -> PmM PatVec
+translateLet :: HsLocalBinds Id -> DsM PatVec
translateLet _binds = return []
-- | Translate a pattern guard
-translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> PmM PatVec
+translateBind :: FamInstEnvs -> LPat Id -> LHsExpr Id -> DsM PatVec
translateBind fam_insts (L _ p) e = do
ps <- translatePat fam_insts p
return [mkGuard ps (unLoc e)]
-- | Translate a boolean guard
-translateBoolGuard :: LHsExpr Id -> PmM PatVec
+translateBoolGuard :: LHsExpr Id -> DsM PatVec
translateBoolGuard e
| isJust (isTrueLHsExpr e) = return []
-- The formal thing to do would be to generate (True <- True)
@@ -675,7 +749,7 @@ pmPatType (PmGrd { pm_grd_pv = pv })
-- | Generate a value abstraction for a given constructor (generate
-- fresh variables of the appropriate type for arguments)
-mkOneConFull :: Id -> DataCon -> PmM (ValAbs, ComplexEq, Bag EvVar)
+mkOneConFull :: Id -> DataCon -> DsM (ValAbs, ComplexEq, Bag EvVar)
-- * x :: T tys, where T is an algebraic data type
-- NB: in the case of a data familiy, T is the *representation* TyCon
-- e.g. data instance T (a,b) = T1 a b
@@ -738,17 +812,17 @@ mkPosEq x l = (PmExprVar (idName x), PmExprLit l)
{-# INLINE mkPosEq #-}
-- | Generate a variable pattern of a given type
-mkPmVar :: Type -> PmM (PmPat p)
+mkPmVar :: Type -> DsM (PmPat p)
mkPmVar ty = PmVar <$> mkPmId ty
{-# INLINE mkPmVar #-}
-- | Generate many variable patterns, given a list of types
-mkPmVars :: [Type] -> PmM PatVec
+mkPmVars :: [Type] -> DsM PatVec
mkPmVars tys = mapM mkPmVar tys
{-# INLINE mkPmVars #-}
-- | Generate a fresh `Id` of a given type
-mkPmId :: Type -> PmM Id
+mkPmId :: Type -> DsM Id
mkPmId ty = getUniqueM >>= \unique ->
let occname = mkVarOccFS (fsLit (show unique))
name = mkInternalName unique occname noSrcSpan
@@ -757,7 +831,7 @@ mkPmId ty = getUniqueM >>= \unique ->
-- | Generate a fresh term variable of a given and return it in two forms:
-- * A variable pattern
-- * A variable expression
-mkPmId2Forms :: Type -> PmM (Pattern, LHsExpr Id)
+mkPmId2Forms :: Type -> DsM (Pattern, LHsExpr Id)
mkPmId2Forms ty = do
x <- mkPmId ty
return (PmVar x, noLoc (HsVar (noLoc x)))
@@ -802,7 +876,7 @@ allConstructors = tyConDataCons . dataConTyCon
newEvVar :: Name -> Type -> EvVar
newEvVar name ty = mkLocalId name (toTcType ty)
-nameType :: String -> Type -> PmM EvVar
+nameType :: String -> Type -> DsM EvVar
nameType name ty = do
unique <- getUniqueM
let occname = mkVarOccFS (fsLit (name++"_"++show unique))
@@ -820,7 +894,8 @@ nameType name ty = do
-- | Check whether a set of type constraints is satisfiable.
tyOracle :: Bag EvVar -> PmM Bool
tyOracle evs
- = do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs
+ = liftD $
+ do { ((_warns, errs), res) <- initTcDsForSolver $ tcCheckSatisfiability evs
; case res of
Just sat -> return sat
Nothing -> pprPanic "tyOracle" (vcat $ pprErrMsgBagWithLoc errs) }
@@ -861,7 +936,7 @@ Main functions are:
are checked, if they are inconsistent, the set is empty, otherwise, the
set contains only a vector of variables with the constraints in scope.
-* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+* pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
Checks redundancy, coverage and inaccessibility, using auxilary functions
`pmcheckGuards` and `pmcheckHd`. Mainly handles the guard case which is
@@ -869,12 +944,12 @@ Main functions are:
whole clause is checked, or `pmcheckHd` when the pattern vector does not
start with a guard.
-* pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple
+* pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult
Processes the guards.
* pmcheckHd :: Pattern -> PatVec -> [PatVec]
- -> ValAbs -> ValVec -> PmM Triple
+ -> ValAbs -> ValVec -> PmM PartialResult
Worker: This function implements functions `covered`, `uncovered` and
`divergent` from the paper at once. Slightly different from the paper because
@@ -886,17 +961,20 @@ Main functions are:
-- | Lift a pattern matching action from a single value vector abstration to a
-- value set abstraction, but calling it on every vector and the combining the
-- results.
-runMany :: (ValVec -> PmM Triple) -> (Uncovered -> PmM Triple)
-runMany pm us = mapAndUnzip3M pm us >>= \(css, uss, dss) ->
- return (or css, concat uss, or dss)
+runMany :: (ValVec -> PmM PartialResult) -> (Uncovered -> PmM PartialResult)
+runMany _ [] = return $ PartialResult mempty mempty mempty
+runMany pm (m:ms) = do
+ (PartialResult c v d) <- pm m
+ (PartialResult cs vs ds) <- runMany pm ms
+ return (PartialResult (c `mappend` cs) (v `mappend` vs) (d `mappend` ds))
{-# INLINE runMany #-}
-- | Generate the initial uncovered set. It initializes the
-- delta with all term and type constraints in scope.
mkInitialUncovered :: [Id] -> PmM Uncovered
mkInitialUncovered vars = do
- ty_cs <- getDictsDs
- tm_cs <- map toComplex . bagToList <$> getTmCsDs
+ ty_cs <- liftD getDictsDs
+ tm_cs <- map toComplex . bagToList <$> liftD getTmCsDs
sat_ty <- tyOracle ty_cs
return $ case (sat_ty, tmOracle initialTmState tm_cs) of
(True, Just tm_state) -> [ValVec patterns (MkDelta ty_cs tm_state)]
@@ -908,41 +986,45 @@ mkInitialUncovered vars = do
-- | Increase the counter for elapsed algorithm iterations, check that the
-- limit is not exceeded and call `pmcheck`
-pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+pmcheckI :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
pmcheckI ps guards vva = do
- n <- incrCheckPmIterDs
+ n <- liftD incrCheckPmIterDs
tracePm "pmCheck" (ppr n <> colon <+> pprPatVec ps
$$ hang (text "guards:") 2 (vcat (map pprPatVec guards))
$$ pprValVecDebug vva)
- pmcheck ps guards vva
+ res <- pmcheck ps guards vva
+ tracePm "pmCheckResult:" (ppr res)
+ return res
{-# INLINE pmcheckI #-}
-- | Increase the counter for elapsed algorithm iterations, check that the
-- limit is not exceeded and call `pmcheckGuards`
-pmcheckGuardsI :: [PatVec] -> ValVec -> PmM Triple
-pmcheckGuardsI gvs vva = incrCheckPmIterDs >> pmcheckGuards gvs vva
+pmcheckGuardsI :: [PatVec] -> ValVec -> PmM PartialResult
+pmcheckGuardsI gvs vva = liftD incrCheckPmIterDs >> pmcheckGuards gvs vva
{-# INLINE pmcheckGuardsI #-}
-- | Increase the counter for elapsed algorithm iterations, check that the
-- limit is not exceeded and call `pmcheckHd`
-pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple
+pmcheckHdI :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult
pmcheckHdI p ps guards va vva = do
- n <- incrCheckPmIterDs
+ n <- liftD incrCheckPmIterDs
tracePm "pmCheckHdI" (ppr n <> colon <+> pprPmPatDebug p
$$ pprPatVec ps
$$ hang (text "guards:") 2 (vcat (map pprPatVec guards))
$$ pprPmPatDebug va
$$ pprValVecDebug vva)
- pmcheckHd p ps guards va vva
+ res <- pmcheckHd p ps guards va vva
+ tracePm "pmCheckHdI: res" (ppr res)
+ return res
{-# INLINE pmcheckHdI #-}
-- | Matching function: Check simultaneously a clause (takes separately the
-- patterns and the list of guards) for exhaustiveness, redundancy and
-- inaccessibility.
-pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM Triple
+pmcheck :: PatVec -> [PatVec] -> ValVec -> PmM PartialResult
pmcheck [] guards vva@(ValVec [] _)
- | null guards = return (True, [], False)
+ | null guards = return $ mempty { presultCovered = Covered }
| otherwise = pmcheckGuardsI guards vva
-- Guard
@@ -953,7 +1035,7 @@ pmcheck (p@(PmGrd pv e) : ps) guards vva@(ValVec vas delta)
-- though. So just have these two cases but do not do all the boilerplate
| isFakeGuard pv e = forces . mkCons vva <$> pmcheckI ps guards vva
| otherwise = do
- y <- mkPmId (pmPatType p)
+ y <- liftD $ mkPmId (pmPatType p)
let tm_state = extendSubst y e (delta_tm_cs delta)
delta' = delta { delta_tm_cs = tm_state }
utail <$> pmcheckI (pv ++ ps) guards (ValVec (PmVar y : vas) delta')
@@ -965,41 +1047,44 @@ pmcheck (p:ps) guards (ValVec (va:vva) delta)
= pmcheckHdI p ps guards va (ValVec vva delta)
-- | Check the list of guards
-pmcheckGuards :: [PatVec] -> ValVec -> PmM Triple
-pmcheckGuards [] vva = return (False, [vva], False)
+pmcheckGuards :: [PatVec] -> ValVec -> PmM PartialResult
+pmcheckGuards [] vva = return (usimple [vva])
pmcheckGuards (gv:gvs) vva = do
- (cs, vsa, ds ) <- pmcheckI gv [] vva
- (css, vsas, dss) <- runMany (pmcheckGuardsI gvs) vsa
- return (cs || css, vsas, ds || dss)
+ (PartialResult cs vsa ds) <- pmcheckI gv [] vva
+ (PartialResult css vsas dss) <- runMany (pmcheckGuardsI gvs) vsa
+ return $ PartialResult (cs `mappend` css) vsas (ds `mappend` dss)
-- | Worker function: Implements all cases described in the paper for all three
-- functions (`covered`, `uncovered` and `divergent`) apart from the `Guard`
-- cases which are handled by `pmcheck`
-pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM Triple
+pmcheckHd :: Pattern -> PatVec -> [PatVec] -> ValAbs -> ValVec -> PmM PartialResult
-- Var
pmcheckHd (PmVar x) ps guards va (ValVec vva delta)
| Just tm_state <- solveOneEq (delta_tm_cs delta)
(PmExprVar (idName x), vaToPmExpr va)
= ucon va <$> pmcheckI ps guards (ValVec vva (delta {delta_tm_cs = tm_state}))
- | otherwise = return (False, [], False)
+ | otherwise = return mempty
-- ConCon
pmcheckHd ( p@(PmCon {pm_con_con = c1, pm_con_args = args1})) ps guards
(va@(PmCon {pm_con_con = c2, pm_con_args = args2})) (ValVec vva delta)
- | c1 /= c2 = return (False, [ValVec (va:vva) delta], False)
+ | c1 /= c2 =
+ return (usimple [ValVec (va:vva) delta])
| otherwise = kcon c1 (pm_con_arg_tys p) (pm_con_tvs p) (pm_con_dicts p)
<$> pmcheckI (args1 ++ ps) guards (ValVec (args2 ++ vva) delta)
-- LitLit
-pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva = case eqPmLit l1 l2 of
- True -> ucon va <$> pmcheckI ps guards vva
- False -> return $ ucon va (False, [vva], False)
+pmcheckHd (PmLit l1) ps guards (va@(PmLit l2)) vva =
+ case eqPmLit l1 l2 of
+ True -> ucon va <$> pmcheckI ps guards vva
+ False -> return $ ucon va (usimple [vva])
-- ConVar
pmcheckHd (p@(PmCon { pm_con_con = con })) ps guards
(PmVar x) (ValVec vva delta) = do
- cons_cs <- mapM (mkOneConFull x) (allConstructors con)
+ cons_cs <- mapM (liftD . mkOneConFull x) (allConstructors con)
+
inst_vsa <- flip concatMapM cons_cs $ \(va, tm_ct, ty_cs) -> do
let ty_state = ty_cs `unionBags` delta_ty_cs delta -- not actually a state
sat_ty <- if isEmptyBag ty_cs then return True
@@ -1018,13 +1103,13 @@ pmcheckHd (p@(PmLit l)) ps guards (PmVar x) (ValVec vva delta)
case solveOneEq (delta_tm_cs delta) (mkPosEq x l) of
Just tm_state -> pmcheckHdI p ps guards (PmLit l) $
ValVec vva (delta {delta_tm_cs = tm_state})
- Nothing -> return (False, [], False)
+ Nothing -> return mempty
where
us | Just tm_state <- solveOneEq (delta_tm_cs delta) (mkNegEq x l)
= [ValVec (PmNLit x [l] : vva) (delta { delta_tm_cs = tm_state })]
| otherwise = []
- non_matched = (False, us, False)
+ non_matched = usimple us
-- LitNLit
pmcheckHd (p@(PmLit l)) ps guards
@@ -1044,7 +1129,7 @@ pmcheckHd (p@(PmLit l)) ps guards
= [ValVec (PmNLit x (l:lits) : vva) (delta { delta_tm_cs = tm_state })]
| otherwise = []
- non_matched = (False, us, False)
+ non_matched = usimple us
-- ----------------------------------------------------------------------------
-- The following three can happen only in cases like #322 where constructors
@@ -1055,14 +1140,14 @@ pmcheckHd (p@(PmLit l)) ps guards
-- LitCon
pmcheckHd (PmLit l) ps guards (va@(PmCon {})) (ValVec vva delta)
- = do y <- mkPmId (pmPatType va)
+ = do y <- liftD $ mkPmId (pmPatType va)
let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta)
delta' = delta { delta_tm_cs = tm_state }
pmcheckHdI (PmVar y) ps guards va (ValVec vva delta')
-- ConLit
pmcheckHd (p@(PmCon {})) ps guards (PmLit l) (ValVec vva delta)
- = do y <- mkPmId (pmPatType p)
+ = do y <- liftD $ mkPmId (pmPatType p)
let tm_state = extendSubst y (PmExprLit l) (delta_tm_cs delta)
delta' = delta { delta_tm_cs = tm_state }
pmcheckHdI p ps guards (PmVar y) (ValVec vva delta')
@@ -1077,54 +1162,66 @@ pmcheckHd (PmGrd {}) _ _ _ _ = panic "pmcheckHd: Guard"
-- ----------------------------------------------------------------------------
-- * Utilities for main checking
+updateVsa :: (ValSetAbs -> ValSetAbs) -> (PartialResult -> PartialResult)
+updateVsa f p@(PartialResult { presultUncovered = old })
+ = p { presultUncovered = f old }
+
+
+-- | Initialise with default values for covering and divergent information.
+usimple :: ValSetAbs -> PartialResult
+usimple vsa = mempty { presultUncovered = vsa }
+
-- | Take the tail of all value vector abstractions in the uncovered set
-utail :: Triple -> Triple
-utail (cs, vsa, ds) = (cs, vsa', ds)
- where vsa' = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ]
+utail :: PartialResult -> PartialResult
+utail = updateVsa upd
+ where upd vsa = [ ValVec vva delta | ValVec (_:vva) delta <- vsa ]
-- | Prepend a value abstraction to all value vector abstractions in the
-- uncovered set
-ucon :: ValAbs -> Triple -> Triple
-ucon va (cs, vsa, ds) = (cs, vsa', ds)
- where vsa' = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ]
+ucon :: ValAbs -> PartialResult -> PartialResult
+ucon va = updateVsa upd
+ where
+ upd vsa = [ ValVec (va:vva) delta | ValVec vva delta <- vsa ]
-- | Given a data constructor of arity `a` and an uncovered set containing
-- value vector abstractions of length `(a+n)`, pass the first `n` value
-- abstractions to the constructor (Hence, the resulting value vector
-- abstractions will have length `n+1`)
-kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar] -> Triple -> Triple
-kcon con arg_tys ex_tvs dicts (cs, vsa, ds)
- = (cs, [ ValVec (va:vva) delta
- | ValVec vva' delta <- vsa
- , let (args, vva) = splitAt n vva'
- , let va = PmCon { pm_con_con = con
- , pm_con_arg_tys = arg_tys
- , pm_con_tvs = ex_tvs
- , pm_con_dicts = dicts
- , pm_con_args = args } ]
- , ds)
- where n = dataConSourceArity con
+kcon :: DataCon -> [Type] -> [TyVar] -> [EvVar]
+ -> PartialResult -> PartialResult
+kcon con arg_tys ex_tvs dicts
+ = let n = dataConSourceArity con
+ upd vsa =
+ [ ValVec (va:vva) delta
+ | ValVec vva' delta <- vsa
+ , let (args, vva) = splitAt n vva'
+ , let va = PmCon { pm_con_con = con
+ , pm_con_arg_tys = arg_tys
+ , pm_con_tvs = ex_tvs
+ , pm_con_dicts = dicts
+ , pm_con_args = args } ]
+ in updateVsa upd
-- | Get the union of two covered, uncovered and divergent value set
-- abstractions. Since the covered and divergent sets are represented by a
-- boolean, union means computing the logical or (at least one of the two is
-- non-empty).
-mkUnion :: Triple -> Triple -> Triple
-mkUnion (cs1, vsa1, ds1) (cs2, vsa2, ds2)
- = (cs1 || cs2, vsa1 ++ vsa2, ds1 || ds2)
+
+mkUnion :: PartialResult -> PartialResult -> PartialResult
+mkUnion = mappend
-- | Add a value vector abstraction to a value set abstraction (uncovered).
-mkCons :: ValVec -> Triple -> Triple
-mkCons vva (cs, vsa, ds) = (cs, vva:vsa, ds)
+mkCons :: ValVec -> PartialResult -> PartialResult
+mkCons vva = updateVsa (vva:)
-- | Set the divergent set to not empty
-forces :: Triple -> Triple
-forces (cs, us, _) = (cs, us, True)
+forces :: PartialResult -> PartialResult
+forces pres = pres { presultDivergent = Diverged }
-- | Set the divergent set to non-empty if the flag is `True`
-force_if :: Bool -> Triple -> Triple
-force_if True (cs,us,_) = (cs,us,True)
-force_if False triple = triple
+force_if :: Bool -> PartialResult -> PartialResult
+force_if True pres = forces pres
+force_if False pres = pres
-- ----------------------------------------------------------------------------
-- * Propagation of term constraints inwards when checking nested matches
@@ -1133,7 +1230,7 @@ force_if False triple = triple
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When checking a match it would be great to have all type and term information
available so we can get more precise results. For this reason we have functions
-`addDictsDs' and `addTmCsDs' in DsMonad that store in the environment type and
+`addDictsDs' and `addTmCsDs' in PmMonad that store in the environment type and
term constraints (respectively) as we go deeper.
The type constraints we propagate inwards are collected by `collectEvVarsPats'
@@ -1275,7 +1372,10 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result
when exists_u $
putSrcSpanDs loc (warnDs flag_u_reason (pprEqns uncovered))
where
- (redundant, uncovered, inaccessible) = pm_result
+ PmResult
+ { pmresultRedundant = redundant
+ , pmresultUncovered = uncovered
+ , pmresultInaccessible = inaccessible } = pm_result
flag_i = wopt Opt_WarnOverlappingPatterns dflags
flag_u = exhaustive dflags kind
@@ -1298,7 +1398,7 @@ dsPmWarn dflags ctx@(DsMatchContext kind loc) pm_result
-- | Issue a warning when the predefined number of iterations is exceeded
-- for the pattern match checker
-warnPmIters :: DynFlags -> DsMatchContext -> PmM ()
+warnPmIters :: DynFlags -> DsMatchContext -> DsM ()
warnPmIters dflags (DsMatchContext kind loc)
= when (flag_i || flag_u) $ do
iters <- maxPmCheckIterations <$> getDynFlags
@@ -1441,7 +1541,11 @@ involved.
-- Debugging Infrastructre
tracePm :: String -> SDoc -> PmM ()
-tracePm herald doc = do
+tracePm herald doc = liftD $ tracePmD herald doc
+
+
+tracePmD :: String -> SDoc -> DsM ()
+tracePmD herald doc = do
dflags <- getDynFlags
printer <- mkPrintUnqualifiedDs
liftIO $ dumpIfSet_dyn_printer printer dflags
diff --git a/compiler/ghc.cabal.in b/compiler/ghc.cabal.in
index 9538e2c..c5ca313 100644
--- a/compiler/ghc.cabal.in
+++ b/compiler/ghc.cabal.in
@@ -490,6 +490,7 @@ Library
GraphPpr
IOEnv
ListSetOps
+ ListT
Maybes
MonadUtils
OrdList
diff --git a/compiler/utils/ListT.hs b/compiler/utils/ListT.hs
new file mode 100644
index 0000000..2b81db1
--- /dev/null
+++ b/compiler/utils/ListT.hs
@@ -0,0 +1,71 @@
+{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE Rank2Types #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE MultiParamTypeClasses #-}
+
+-------------------------------------------------------------------------
+-- |
+-- Module : Control.Monad.Logic
+-- Copyright : (c) Dan Doel
+-- License : BSD3
+--
+-- Maintainer : dan.doel@gmail.com
+-- Stability : experimental
+-- Portability : non-portable (multi-parameter type classes)
+--
+-- A backtracking, logic programming monad.
+--
+-- Adapted from the paper
+-- /Backtracking, Interleaving, and Terminating
+-- Monad Transformers/, by
+-- Oleg Kiselyov, Chung-chieh Shan, Daniel P. Friedman, Amr Sabry
+-- (<http://www.cs.rutgers.edu/~ccshan/logicprog/ListT-icfp2005.pdf>).
+-------------------------------------------------------------------------
+
+module ListT (
+ ListT(..),
+ runListT,
+ select,
+ fold
+ ) where
+
+import Control.Applicative
+
+import Control.Monad
+
+-------------------------------------------------------------------------
+-- | A monad transformer for performing backtracking computations
+-- layered over another monad 'm'
+newtype ListT m a =
+ ListT { unListT :: forall r. (a -> m r -> m r) -> m r -> m r }
+
+select :: Monad m => [a] -> ListT m a
+select xs = foldr (<|>) mzero (map pure xs)
+
+fold :: ListT m a -> (a -> m r -> m r) -> m r -> m r
+fold = runListT
+
+-------------------------------------------------------------------------
+-- | Runs a ListT computation with the specified initial success and
+-- failure continuations.
+runListT :: ListT m a -> (a -> m r -> m r) -> m r -> m r
+runListT = unListT
+
+instance Functor (ListT f) where
+ fmap f lt = ListT $ \sk fk -> unListT lt (sk . f) fk
+
+instance Applicative (ListT f) where
+ pure a = ListT $ \sk fk -> sk a fk
+ f <*> a = ListT $ \sk fk -> unListT f (\g fk' -> unListT a (sk . g) fk') fk
+
+instance Alternative (ListT f) where
+ empty = ListT $ \_ fk -> fk
+ f1 <|> f2 = ListT $ \sk fk -> unListT f1 sk (unListT f2 sk fk)
+
+instance Monad (ListT m) where
+ m >>= f = ListT $ \sk fk -> unListT m (\a fk' -> unListT (f a) sk fk') fk
+ fail _ = ListT $ \_ fk -> fk
+
+instance MonadPlus (ListT m) where
+ mzero = ListT $ \_ fk -> fk
+ m1 `mplus` m2 = ListT $ \sk fk -> unListT m1 sk (unListT m2 sk fk)