module Curry.ExtendedFlat.TypeInference
( dispType, adjustTypeInfo, labelVarsWithTypes, uniqueTypeIndices
, genEquations
) where
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.IntMap as IntMap
import Text.PrettyPrint.HughesPJ
import Curry.ExtendedFlat.Type
import Curry.ExtendedFlat.Goodies
trace' :: String -> b -> b
trace' _ x = x
adjustTypeInfo :: Prog -> Prog
adjustTypeInfo = genEquations . uniqueTypeIndices . labelVarsWithTypes
dispType :: TypeExpr -> String
dispType = render . prettyType
prettyType :: TypeExpr -> Doc
prettyType (TVar i) = text ('t':show i)
prettyType (FuncType f x) = parens (prettyType f) <+> text "->" <+> prettyType x
prettyType (TCons qn ts) = let n = let (m,l) = qnOf qn in m ++ '.' : l
in text n <+> hsep (map (parens . prettyType) ts)
prettyAllEqns :: ((String, String), TypeExpr, [(TVarIndex, TypeExpr)]) -> String
prettyAllEqns = render . prettyEqns
where
prettyEqn ::(TVarIndex, TypeExpr) -> Doc
prettyEqn (l, r) = char 't' <> int l <+> text "->" <+> prettyType r
prettyEqns ((m,l), t, eqns)
= text m <> char '.' <> text l <+> text "::" <+> prettyType t <> char ':'
$$ nest 5 (vcat (map prettyEqn eqns))
postOrderExpr :: Monad m => (Expr -> m Expr) -> Expr -> m Expr
postOrderExpr f = po
where
po e@(Var _) = f e
po e@(Lit _) = f e
po (Comb t n es) = do es' <- mapM po es
f (Comb t n es')
po (Free vs e) = do e' <- po e
f (Free vs e')
po (Let bs e) = do bs' <- mapM poBind bs
e' <- po e
f (Let bs' e')
po (Or l r) = liftM2 Or (po l) (po r) >>= f
po (Case r t e bs) = do e' <- po e
bs' <- mapM poBranch bs
f (Case r t e' bs')
po (Typed e ty) = do e' <- po e
f (Typed e' ty)
poBind (v, rhs) = do rhs' <- po rhs
return (v, rhs')
poBranch (Branch p rhs) = do rhs' <- po rhs
return (Branch p rhs')
postOrderType :: Monad m => (TypeExpr -> m TypeExpr) -> TypeExpr -> m TypeExpr
postOrderType f = po
where po e@(TVar _) = f e
po (FuncType t1 t2) = do t1' <- po t1
t2' <- po t2
f (FuncType t1' t2')
po (TCons qn ts) = do ts' <- mapM po ts
f (TCons qn ts')
visitTVars :: Monad m => (TVarIndex -> m TypeExpr) -> TypeExpr -> m TypeExpr
visitTVars f = postOrderType f'
where f' (TVar i) = f i
f' t = return t
type TDictM = ReaderT TypeMap (State Int)
labelVarsWithTypes :: Prog -> Prog
labelVarsWithTypes = updProgFuncs updateFunc
where
updateFunc = map (\func -> let maxtvi = maxFuncTV func + 1
in trFunc (foo maxtvi) func)
foo _ qn arity visty te r@(External _) = Func qn arity visty te r
foo maxtv qn arity visty te (Rule vs expr)
= let expr' = evalState (runReaderT (withVS vs (po expr)) typeMap) maxtv
typeMap = trace' (show argTypes') $ IntMap.fromList argTypes'
argTypes' = [ (vi, t) | VarIndex (Just t) vi <- vs ]
in Func qn arity visty te (Rule vs expr')
po :: Expr -> TDictM Expr
po e@(Var vi)
= do vt <- asks (IntMap.lookup $ idxOf vi)
trace' ("labelVarsWithTypes " ++ show e ++" :: "++ show vt)(
case vt of
Just t -> return (Var vi { typeofVar = Just t })
Nothing -> case typeofVar vi of
Nothing -> error $ "no type for var " ++ show e
_ -> liftM Var (poVarIndex vi))
po e@(Lit _)
= return e
po (Comb t n es)
= do es' <- mapM po es
n' <- poQName n
return (Comb t n' es')
po (Free vs e)
= do vs' <- mapM poVarIndex vs
e' <- po e
return (Free vs' e')
po (Let bs e)
= do let (vs, es) = unzip bs
vs' <- mapM poVarIndex vs
withVS vs' (do es' <- mapM po es
e' <- po e
return (Let (zip vs' es') e'))
po (Or l r)
= liftM2 Or (po l) (po r)
po (Case r t e bs)
= do e' <- po e
bs' <- mapM poBranch bs
return (Case r t e' bs')
po (Typed e ty) = do e' <- po e
return (Typed e' ty)
poBranch :: BranchExpr -> TDictM BranchExpr
poBranch (Branch (Pattern qn vs) rhs)
= do qn' <- poQName qn
vs' <- mapM poVarIndex vs
withVS vs' (do rhs' <- po rhs
return (Branch (Pattern qn' vs') rhs'))
poBranch (Branch (LPattern l) e)
= do rhs' <- po e
return (Branch (LPattern l) rhs')
poVarIndex :: VarIndex -> TDictM VarIndex
poVarIndex vi
= do t <- maybe (lift$freshTVar) return . typeofVar $ vi
return vi{typeofVar = Just t }
poQName :: QName -> TDictM QName
poQName qn
= do t <- maybe (lift$freshTVar)
return . typeofQName $ qn
return qn{typeofQName = Just t }
withVS :: MonadReader TypeMap m => [VarIndex] -> m a -> m a
withVS vs = local (\ m -> foldr add m vs)
where
add v m = case typeofVar v of
Nothing -> error "Curry.ExtendedFlat.TypeInference.withVS"
Just ty -> IntMap.insert (idxOf v) ty m
uniqueTypeIndices :: Prog -> Prog
uniqueTypeIndices = updProgFuncs (map updateFunc)
where
updateFunc func = let firstfree = maxFuncTV func + 1
in updFuncRule (trRule (ruleFoo firstfree) External) func
ruleFoo firstfree args expr
= let expr' = evalState (postOrderExpr relabelTypes expr) firstfree
in Rule args expr'
relabelTypes :: Expr -> State TVarIndex Expr
relabelTypes (Comb ct qname args)
= do t' <- case typeofQName qname of
Just lt -> relabelType lt
Nothing -> freshTVar
return (Comb ct qname {typeofQName = Just t'} args)
relabelTypes (Var v)
| typeofVar v == Nothing
= do t <- freshTVar
return (Var v{typeofVar = Just t})
relabelTypes (Case r t e bs)
= do bs' <- mapM relabelPatType bs
return (Case r t e bs')
where relabelPatType (Branch (Pattern qn vis) e')
= do t' <- case typeofQName qn of
Just lt -> relabelType lt
Nothing -> freshTVar
return (Branch (Pattern qn {typeofQName = Just t'} vis) e')
relabelPatType be = return be
relabelTypes t = return t
relabelType :: TypeExpr -> State TVarIndex TypeExpr
relabelType t = evalStateT (visitTVars typeFoo t) IntMap.empty
where typeFoo i = do m <- get
case IntMap.lookup i m of
Just v -> return v
Nothing -> do v <- lift freshTVar
modify (IntMap.insert i v)
return v
type TypeMap = IntMap.IntMap TypeExpr
type EqnMonad = StateT TypeMap (State TVarIndex)
genEquations :: Prog -> Prog
genEquations = updProgFuncs updateFunc
where
updateFunc = map (\func -> let maxtvi = maxFuncTV func + 1
in trFunc (foo maxtvi) func)
foo _ qn arity visty te r@(External _) = Func qn arity visty te r
foo maxtv qn arity visty te (Rule vs expr)
= let h = evalState (execStateT (do argTypes' <- mapM varIndexType vs
etype <- equations expr
qnt <- qnType qn
_ <- qnt =:= foldr FuncType etype argTypes'
return()
) IntMap.empty) maxtv
in trace' (prettyAllEqns (qnOf qn,te,IntMap.toList h)) Func qn arity visty (specialiseType h te) (specInRule h (Rule vs expr))
equations :: Expr -> EqnMonad TypeExpr
equations = trExpr varIndexType (return . typeofLiteral) combEqn letEqn frEqn orEqn casEqn branchEqn typedEqn
where
combEqn :: (CombType -> QName -> [EqnMonad TypeExpr] -> EqnMonad TypeExpr)
combEqn _ qn args = do
resultType' <- lift$freshTVar
argTypes' <- sequence args
tqn <- qnType qn
_ <- tqn =:= foldr FuncType resultType' argTypes'
return resultType'
letEqn :: ([(VarIndex, EqnMonad TypeExpr)] -> EqnMonad TypeExpr -> EqnMonad TypeExpr)
letEqn bs = (mapM_ bindEqn bs >>)
frEqn _ e = e
orEqn l r = do
l' <- l
r' <- r
l' =:= r'
casEqn :: SrcRef -> CaseType -> EqnMonad TypeExpr -> [EqnMonad (Pattern, TypeExpr)] -> EqnMonad TypeExpr
casEqn _ _ scr [] = scr >> (lift$freshTVar)
casEqn _ _ scr ps = do
scrt <- scr
branches <- sequence ps
let pats = map fst branches
let (p:ps') = map snd branches
mapM_ (unifLhs scrt) pats
foldM (=:=) p ps'
unifLhs scrt (LPattern l) = typeofLiteral l =:= scrt
unifLhs scrt (Pattern qn vs) = do
qnt <- qnType qn
argTypes' <- mapM varIndexType vs
qnt =:= foldr FuncType scrt argTypes'
branchEqn :: Pattern -> EqnMonad TypeExpr -> EqnMonad (Pattern, TypeExpr)
branchEqn p e = do trhs <- e
return (p, trhs)
typedEqn :: EqnMonad TypeExpr -> TypeExpr -> EqnMonad TypeExpr
typedEqn e ty = do
ety <- e
ety =:= ty
bindEqn :: (VarIndex, EqnMonad TypeExpr) -> EqnMonad TypeExpr
bindEqn (vi, rhs) = do vit <- varIndexType vi
rvi <- rhs
vit =:= rvi
unify :: TypeExpr -> TypeExpr -> TypeMap -> TypeMap
unify (TVar i) t tm
| Just s <- IntMap.lookup i tm
= unify s t tm
unify s (TVar j) tm
| Just t <- IntMap.lookup j tm
= unify s t tm
unify s@(TVar i) t@(TVar j) tm
| i == j = tm
| i < j = IntMap.insert j s tm
| i > j = IntMap.insert i t tm
unify (TVar i) t tm
= IntMap.insert i t tm
unify s (TVar j) tm
= IntMap.insert j s tm
unify (FuncType f x) (FuncType g y) tm
= unify x y (unify f g tm)
unify (TCons m as) (TCons n bs) tm
| m == n = foldr ($) tm (zipWith unify as bs)
unify s t _
= error . render $
text "Types differ: " <+> prettyType s <+> text "/=" <+> prettyType t
(=:=) :: TypeExpr -> TypeExpr -> EqnMonad TypeExpr
a =:= b = modify (unify a b) >> return a
varIndexType :: VarIndex -> EqnMonad TypeExpr
varIndexType = maybe (lift$freshTVar) return . typeofVar
qnType :: QName -> EqnMonad TypeExpr
qnType = maybe (lift$freshTVar) return . typeofQName
freshTVar :: MonadState Int m => m TypeExpr
freshTVar = do nextIdx <- get
modify succ
return (TVar nextIdx)
maxFuncTV :: FuncDecl -> TVarIndex
maxFuncTV = trFunc (\qn _ _ te r -> max (maxQNameTV qn) (max (maxTypeTV te) (maxRuleTV r)))
where
maxRuleTV = trRule (\vis e -> maximum (maxExprTV e : map maxVarIndexTV vis)) (const (1))
maxExprTV :: Expr -> Int
maxExprTV = trExpr var lit comb lt fr max cas branch typed
where var = maxVarIndexTV
lit = const (1)
comb _ qn ms = maximum (maxQNameTV qn : ms)
lt bs e = maximum (e : map maxBindTV bs)
fr vs e = maximum (e : map maxVarIndexTV vs)
cas _ _ e ps = maximum (e : ps)
branch p e = max e (maxPatternTV p)
typed e _ = e
maxQNameTV = maybe (1) maxTypeTV . typeofQName
maxVarIndexTV = maybe (1) maxTypeTV . typeofVar
maxBindTV (vi, e) = max e (maxVarIndexTV vi)
maxPatternTV (Pattern qn vis) = maximum (maxQNameTV qn : map maxVarIndexTV vis)
maxPatternTV (LPattern _) = 1
maxTypeTV = trTypeExpr id tapp max
where tapp _ args = maximum (1:args)
specialiseType :: TypeMap -> TypeExpr -> TypeExpr
specialiseType m t = trTypeExpr (foo m) TCons FuncType t
where foo m' i = maybe (TVar i) (specialiseType m') (IntMap.lookup i m')
specInRule :: TypeMap -> Rule -> Rule
specInRule = modifyType . specialiseType
modifyType :: (TypeExpr -> TypeExpr) -> Rule -> Rule
modifyType f = updRule (map specInVarIndex) specInExpr id
where
specInExpr = trExpr var Lit comb letexp free Or Case alt typed
var = Var . specInVarIndex
comb ct = Comb ct . specInQName
letexp = Let . map specInBind
free = Free . map specInVarIndex
alt = Branch . specInPattern
typed e = Typed e . f
specInBind (vi, e) = (specInVarIndex vi, e)
specInPattern (Pattern qn vis) = Pattern (specInQName qn) (map specInVarIndex vis)
specInPattern p = p
specInVarIndex vi = vi { typeofVar = fmap f (typeofVar vi)}
specInQName qn = qn { typeofQName = fmap f (typeofQName qn)}