Utility functions for vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
index 27dd330..2757cbc 100644 (file)
@@ -1,12 +1,15 @@
 module VectUtils (
   collectAnnTypeBinders, collectAnnTypeArgs, isAnnTypeArg,
   collectAnnValBinders,
+  mkDataConTag,
   splitClosureTy,
+  mkPlusType, mkPlusTypes, mkCrossType, mkCrossTypes, mkEmbedType,
+  mkPlusAlts, mkCrosses, mkEmbed,
   mkPADictType, mkPArrayType,
+  parrayReprTyCon, parrayReprDataCon, mkVScrut,
   paDictArgType, paDictOfType, paDFunType,
   paMethod, lengthPA, replicatePA, emptyPA, liftPA,
   polyAbstract, polyApply, polyVApply,
-  lookupPArrayFamInst,
   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
   buildClosure, buildClosures,
   mkClosureApp
@@ -23,7 +26,7 @@ import CoreUtils
 import Type
 import TypeRep
 import TyCon
-import DataCon            ( dataConWrapId )
+import DataCon            ( DataCon, dataConWrapId, dataConTag )
 import Var
 import Id                 ( mkWildId )
 import MkId               ( unwrapFamInstScrut )
@@ -58,6 +61,9 @@ isAnnTypeArg :: AnnExpr b ann -> Bool
 isAnnTypeArg (_, AnnType t) = True
 isAnnTypeArg _              = False
 
+mkDataConTag :: DataCon -> CoreExpr
+mkDataConTag dc = mkConApp intDataCon [mkIntLitInt $ dataConTag dc]
+
 isClosureTyCon :: TyCon -> Bool
 isClosureTyCon tc = tyConName tc == closureTyConName
 
@@ -80,31 +86,117 @@ splitPArrayTy ty
 
   | otherwise = pprPanic "splitPArrayTy" (ppr ty)
 
-mkClosureType :: Type -> Type -> VM Type
-mkClosureType arg_ty res_ty
+mkBuiltinTyConApp :: (Builtins -> TyCon) -> [Type] -> VM Type
+mkBuiltinTyConApp get_tc tys
   = do
-      tc <- builtin closureTyCon
-      return $ mkTyConApp tc [arg_ty, res_ty]
+      tc <- builtin get_tc
+      return $ mkTyConApp tc tys
 
-mkClosureTypes :: [Type] -> Type -> VM Type
-mkClosureTypes arg_tys res_ty
+mkBuiltinTyConApps :: (Builtins -> TyCon) -> [Type] -> Type -> VM Type
+mkBuiltinTyConApps get_tc tys ty
   = do
-      tc <- builtin closureTyCon
-      return $ foldr (mk tc) res_ty arg_tys
+      tc <- builtin get_tc
+      return $ foldr (mk tc) ty tys
   where
-    mk tc arg_ty res_ty = mkTyConApp tc [arg_ty, res_ty]
+    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
 
-mkPADictType :: Type -> VM Type
-mkPADictType ty
+mkBuiltinTyConApps1 :: (Builtins -> TyCon) -> Type -> [Type] -> VM Type
+mkBuiltinTyConApps1 get_tc dft [] = return dft
+mkBuiltinTyConApps1 get_tc dft tys
+  = do
+      tc <- builtin get_tc
+      case tys of
+        [] -> pprPanic "mkBuiltinTyConApps1" (ppr tc)
+        _  -> return $ foldr1 (mk tc) tys
+  where
+    mk tc ty1 ty2 = mkTyConApp tc [ty1,ty2]
+
+mkBuiltinDataConApp :: (Builtins -> DataCon) -> [CoreExpr] -> VM CoreExpr
+mkBuiltinDataConApp get_dc args
   = do
-      tc <- builtin paTyCon
-      return $ TyConApp tc [ty]
+      dc <- builtin get_dc
+      return $ mkConApp dc args
+
+mkPlusType :: Type -> Type -> VM Type
+mkPlusType ty1 ty2 = mkBuiltinTyConApp plusTyCon [ty1, ty2]
+
+mkPlusTypes :: Type -> [Type] -> VM Type
+mkPlusTypes = mkBuiltinTyConApps1 plusTyCon
+
+mkPlusAlts :: [CoreExpr] -> VM [CoreExpr]
+mkPlusAlts [] = return []
+mkPlusAlts exprs
+  = do
+      plus_tc  <- builtin plusTyCon
+      left_dc  <- builtin leftDataCon
+      right_dc <- builtin rightDataCon
+
+      let go [expr] = ([expr], exprType expr)
+          go (expr : exprs)
+            | (alts, right_ty) <- go exprs
+            = (mkConApp left_dc [Type left_ty, Type right_ty, expr]
+               : [mkConApp right_dc [Type left_ty, Type right_ty, alt]
+                    | alt <- alts],
+               mkTyConApp plus_tc [left_ty, right_ty])
+            where
+              left_ty = exprType expr
+
+      return . fst $ go exprs
+
+mkCrossType :: Type -> Type -> VM Type
+mkCrossType ty1 ty2 = mkBuiltinTyConApp crossTyCon [ty1, ty2]
+
+mkCrossTypes :: Type -> [Type] -> VM Type
+mkCrossTypes = mkBuiltinTyConApps1 crossTyCon
+
+mkCrosses :: [CoreExpr] -> VM CoreExpr
+mkCrosses [] = return (Var unitDataConId)
+mkCrosses exprs
+  = do
+      cross_tc <- builtin crossTyCon
+      cross_dc <- builtin crossDataCon
+
+      let mk (left, left_ty) (right, right_ty)
+            = (mkConApp   cross_dc [Type left_ty, Type right_ty, left, right],
+               mkTyConApp cross_tc [left_ty, right_ty])
+
+      return . fst
+             $ foldr1 mk [(expr, exprType expr) | expr <- exprs]
+
+mkEmbedType :: Type -> VM Type
+mkEmbedType ty = mkBuiltinTyConApp embedTyCon [ty]
+
+mkEmbed :: CoreExpr -> VM CoreExpr
+mkEmbed expr = mkBuiltinDataConApp embedDataCon
+                                   [Type $ exprType expr, expr]
+
+mkClosureType :: Type -> Type -> VM Type
+mkClosureType arg_ty res_ty = mkBuiltinTyConApp closureTyCon [arg_ty, res_ty]
+
+mkClosureTypes :: [Type] -> Type -> VM Type
+mkClosureTypes = mkBuiltinTyConApps closureTyCon
+
+mkPADictType :: Type -> VM Type
+mkPADictType ty = mkBuiltinTyConApp paTyCon [ty]
 
 mkPArrayType :: Type -> VM Type
-mkPArrayType ty
+mkPArrayType ty = mkBuiltinTyConApp parrayTyCon [ty]
+
+parrayReprTyCon :: Type -> VM (TyCon, [Type])
+parrayReprTyCon ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
+
+parrayReprDataCon :: Type -> VM (DataCon, [Type])
+parrayReprDataCon ty
   = do
-      tc <- builtin parrayTyCon
-      return $ TyConApp tc [ty]
+      (tc, arg_tys) <- parrayReprTyCon ty
+      let [dc] = tyConDataCons tc
+      return (dc, arg_tys)
+
+mkVScrut :: VExpr -> VM (VExpr, TyCon, [Type])
+mkVScrut (ve, le)
+  = do
+      (tc, arg_tys) <- parrayReprTyCon (exprType ve)
+      return ((ve, unwrapFamInstScrut tc arg_tys le), tc, arg_tys)
 
 paDictArgType :: TyVar -> VM (Maybe Type)
 paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
@@ -140,7 +232,7 @@ paDictOfTyApp (TyVarTy tv) ty_args
       paDFunApply dfun ty_args
 paDictOfTyApp (TyConApp tc _) ty_args
   = do
-      dfun <- maybeV (lookupTyConPA tc)
+      dfun <- traceMaybeV "paDictOfTyApp" (ppr tc) (lookupTyConPA tc)
       paDFunApply (Var dfun) ty_args
 paDictOfTyApp ty ty_args = pprPanic "paDictOfTyApp" (ppr ty)
 
@@ -222,9 +314,6 @@ polyVApply expr tys
       dicts <- mapM paDictOfType tys
       return $ mapVect (\e -> e `mkTyApps` tys `mkApps` dicts) expr
 
-lookupPArrayFamInst :: Type -> VM (TyCon, [Type])
-lookupPArrayFamInst ty = builtin parrayTyCon >>= (`lookupFamInst` [ty])
-
 hoistBinding :: Var -> CoreExpr -> VM ()
 hoistBinding v e = updGEnv $ \env ->
   env { global_bindings = (v,e) : global_bindings env }
@@ -279,6 +368,8 @@ mkClosureApp (vclo, lclo) (varg, larg)
     (arg_ty, res_ty) = splitClosureTy (exprType vclo)
 
 buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
+buildClosures tvs vars [] res_ty mk_body
+  = mk_body
 buildClosures tvs vars [arg_ty] res_ty mk_body
   = buildClosure tvs vars arg_ty res_ty mk_body
 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
@@ -350,7 +441,7 @@ mkLiftEnv lc [ty] [v]
 -- NOTE: this transparently deals with empty environments
 mkLiftEnv lc tys vs
   = do
-      (env_tc, env_tyargs) <- lookupPArrayFamInst vty
+      (env_tc, env_tyargs) <- parrayReprTyCon vty
       let [env_con] = tyConDataCons env_tc
           
           env = Var (dataConWrapId env_con)