module Transformations.CaseCompletion (completeCase) where
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative ((<$>), (<*>))
#endif
import Control.Monad (replicateM)
import qualified Control.Monad.State as S (State, evalState, gets, modify)
import Data.List (find)
import Data.Maybe (fromMaybe, listToMaybe)
import Curry.Base.Ident
import Curry.Base.Position (SrcRef)
import qualified Curry.Syntax as CS
import Base.Expr
import Base.Messages (internalError)
import Env.Interface (InterfaceEnv, lookupInterface)
import IL
completeCase :: InterfaceEnv -> Module -> Module
completeCase iEnv mdl@(Module mid is ds) = Module mid is ds'
where ds'= S.evalState (mapM ccDecl ds) (CCState mdl iEnv 0)
data CCState = CCState
{ modul :: Module
, interfaceEnv :: InterfaceEnv
, nextId :: Int
}
type CCM a = S.State CCState a
getModule :: CCM Module
getModule = S.gets modul
getInterfaceEnv :: CCM InterfaceEnv
getInterfaceEnv = S.gets interfaceEnv
freshIdent :: CCM Ident
freshIdent = do
nid <- S.gets nextId
S.modify $ \s -> s { nextId = succ nid }
return $ mkIdent $ "_#comp" ++ show nid
ccDecl :: Decl -> CCM Decl
ccDecl dd@(DataDecl _ _ _) = return dd
ccDecl nt@(NewtypeDecl _ _ _) = return nt
ccDecl (FunctionDecl qid vs ty e) = FunctionDecl qid vs ty <$> ccExpr e
ccDecl ed@(ExternalDecl _ _ _ _) = return ed
ccExpr :: Expression -> CCM Expression
ccExpr l@(Literal _) = return l
ccExpr v@(Variable _) = return v
ccExpr f@(Function _ _) = return f
ccExpr c@(Constructor _ _) = return c
ccExpr (Apply e1 e2) = Apply <$> ccExpr e1 <*> ccExpr e2
ccExpr (Case r ea e bs) = do
e' <- ccExpr e
bs' <- mapM ccAlt bs
ccCase r ea e' bs'
ccExpr (Or e1 e2) = Or <$> ccExpr e1 <*> ccExpr e2
ccExpr (Exist v e) = Exist v <$> ccExpr e
ccExpr (Let b e) = Let <$> ccBinding b <*> ccExpr e
ccExpr (Letrec bs e) = Letrec <$> mapM ccBinding bs <*> ccExpr e
ccExpr (Typed e ty) = flip Typed ty <$> ccExpr e
ccAlt :: Alt -> CCM Alt
ccAlt (Alt p e) = Alt p <$> ccExpr e
ccBinding :: Binding -> CCM Binding
ccBinding (Binding v e) = Binding v <$> ccExpr e
ccCase :: SrcRef -> Eval -> Expression -> [Alt] -> CCM Expression
ccCase r Flex e alts = return $ Case r Flex e alts
ccCase _ Rigid _ [] = internalError $ "CaseCompletion.ccCase: "
++ "empty alternative list"
ccCase r Rigid e as@(Alt p _:_) = case p of
ConstructorPattern _ _ -> completeConsAlts r Rigid e as
LiteralPattern _ -> completeLitAlts r Rigid e as
VariablePattern _ -> completeVarAlts e as
completeConsAlts :: SrcRef -> Eval -> Expression -> [Alt] -> CCM Expression
completeConsAlts r ea ce alts = do
mdl <- getModule
menv <- getInterfaceEnv
complPats <- mapM genPat $ getComplConstrs mdl menv
[ c | (Alt (ConstructorPattern c _) _) <- consAlts ]
v <- freshIdent
w <- freshIdent
return $ case (complPats, defaultAlt v) of
(_:_, Just e') -> bindDefVar v ce w e' complPats
_ -> Case r ea ce consAlts
where
consAlts = [ a | a@(Alt (ConstructorPattern _ _) _) <- alts ]
genPat (qid, arity) = ConstructorPattern qid <$> replicateM arity freshIdent
defaultAlt v = listToMaybe [ replaceVar x (Variable v) e
| Alt (VariablePattern x) e <- alts ]
bindDefVar v e w e' ps
| v `elem` fv e' = mkBinding v e $ mkCase (Variable v) w e' ps
| otherwise = mkCase e w e' ps
mkCase e w e' ps = case ps of
[p] -> Case r ea e (consAlts ++ [Alt p e'])
_ -> mkBinding w e'
$ Case r ea e (consAlts ++ [Alt p (Variable w) | p <- ps])
completeLitAlts :: SrcRef -> Eval -> Expression -> [Alt] -> CCM Expression
completeLitAlts r ea ce alts = do
x <- freshIdent
return $ mkBinding x ce $ nestedCases x alts
where
nestedCases _ [] = failedExpr
nestedCases x (Alt p ae : as) = case p of
LiteralPattern l -> Case r ea (Variable x `eqExpr` Literal l)
[ Alt truePatt ae
, Alt falsePatt (nestedCases x as)
]
VariablePattern v -> replaceVar v (Variable x) ae
_ -> internalError "CaseCompletion.completeLitAlts: illegal alternative"
completeVarAlts :: Expression -> [Alt] -> CCM Expression
completeVarAlts _ [] = return failedExpr
completeVarAlts ce (Alt p ae : _) = case p of
VariablePattern x -> return $ mkBinding x ce ae
_ -> internalError $
"CaseCompletion.completeVarAlts: variable pattern expected"
mkBinding :: Ident -> Expression -> Expression -> Expression
mkBinding v e e' = case e of
Variable _ -> replaceVar v e e'
_ -> Let (Binding v e) e'
replaceVar :: Ident -> Expression -> Expression -> Expression
replaceVar v e x@(Variable w)
| v == w = e
| otherwise = x
replaceVar v e (Apply e1 e2)
= Apply (replaceVar v e e1) (replaceVar v e e2)
replaceVar v e (Case r ev e' bs)
= Case r ev (replaceVar v e e') (map (replaceVarInAlt v e) bs)
replaceVar v e (Or e1 e2)
= Or (replaceVar v e e1) (replaceVar v e e2)
replaceVar v e (Exist w e')
| v == w = Exist w e'
| otherwise = Exist w (replaceVar v e e')
replaceVar v e (Let b e')
| v `occursInBinding` b = Let b e'
| otherwise = Let (replaceVarInBinding v e b)
(replaceVar v e e')
replaceVar v e (Letrec bs e')
| any (occursInBinding v) bs = Letrec bs e'
| otherwise = Letrec (map (replaceVarInBinding v e) bs)
(replaceVar v e e')
replaceVar _ _ e' = e'
replaceVarInAlt :: Ident -> Expression -> Alt -> Alt
replaceVarInAlt v e (Alt p e')
| v `occursInPattern` p = Alt p e'
| otherwise = Alt p (replaceVar v e e')
replaceVarInBinding :: Ident -> Expression -> Binding -> Binding
replaceVarInBinding v e (Binding w e')
| v == w = Binding w e'
| otherwise = Binding w (replaceVar v e e')
occursInPattern :: Ident -> ConstrTerm -> Bool
occursInPattern v (VariablePattern w) = v == w
occursInPattern v (ConstructorPattern _ vs) = v `elem` vs
occursInPattern _ _ = False
occursInBinding :: Ident -> Binding -> Bool
occursInBinding v (Binding w _) = v == w
failedExpr :: Expression
failedExpr = Function (qualifyWith preludeMIdent (mkIdent "failed")) 0
eqExpr :: Expression -> Expression -> Expression
eqExpr e1 e2 = Apply (Apply eq e1) e2
where eq = Function (qualifyWith preludeMIdent (mkIdent "==")) 2
truePatt :: ConstrTerm
truePatt = ConstructorPattern qTrueId []
falsePatt :: ConstrTerm
falsePatt = ConstructorPattern qFalseId []
getComplConstrs :: Module -> InterfaceEnv -> [QualIdent] -> [(QualIdent, Int)]
getComplConstrs _ _ []
= internalError "CaseCompletion.getComplConstrs: empty constructor list"
getComplConstrs (Module mid _ ds) menv cs@(c:_)
| c `elem` [qNilId, qConsId] = complementary cs [(qNilId, 0), (qConsId, 2)]
| mid' == mid = getCCFromDecls cs ds
| otherwise = maybe [] (getCCFromIDecls mid' cs)
(lookupInterface mid' menv)
where mid' = fromMaybe mid (qidModule c)
getCCFromDecls :: [QualIdent] -> [Decl] -> [(QualIdent, Int)]
getCCFromDecls cs ds = complementary cs cinfos
where
cinfos = map constrInfo
$ maybe [] extractConstrDecls (find (`declares` head cs) ds)
decl `declares` qid = case decl of
DataDecl _ _ cs' -> any (`declaresConstr` qid) cs'
NewtypeDecl _ _ nc -> nc `declaresConstr` qid
_ -> False
declaresConstr (ConstrDecl cid _) qid = cid == qid
extractConstrDecls (DataDecl _ _ cs') = cs'
extractConstrDecls _ = []
constrInfo (ConstrDecl cid tys) = (cid, length tys)
getCCFromIDecls :: ModuleIdent -> [QualIdent] -> CS.Interface -> [(QualIdent, Int)]
getCCFromIDecls mid cs (CS.Interface _ _ ds) = complementary cs cinfos
where
cinfos = map constrInfo
$ maybe [] extractConstrDecls (find (`declares` head cs) ds)
decl `declares` qid = case decl of
CS.IDataDecl _ _ _ cs' _ -> any (`declaresConstr` qid) cs'
CS.INewtypeDecl _ _ _ nc _ -> isNewConstrDecl qid nc
_ -> False
declaresConstr (CS.ConstrDecl _ _ cid _) qid = unqualify qid == cid
declaresConstr (CS.ConOpDecl _ _ _ oid _) qid = unqualify qid == oid
declaresConstr (CS.RecordDecl _ _ cid _) qid = unqualify qid == cid
isNewConstrDecl qid (CS.NewConstrDecl _ _ cid _) = unqualify qid == cid
isNewConstrDecl qid (CS.NewRecordDecl _ _ cid _) = unqualify qid == cid
extractConstrDecls (CS.IDataDecl _ _ _ cs' _) = cs'
extractConstrDecls _ = []
constrInfo (CS.ConstrDecl _ _ cid tys) = (qualifyWith mid cid, length tys)
constrInfo (CS.ConOpDecl _ _ _ oid _) = (qualifyWith mid oid, 2)
constrInfo (CS.RecordDecl _ _ cid fs) = (qualifyWith mid cid, length labels)
where labels = [l | CS.FieldDecl _ ls _ <- fs, l <- ls]
complementary :: [QualIdent] -> [(QualIdent, Int)] -> [(QualIdent, Int)]
complementary known others = filter ((`notElem` known) . fst) others