Generate lots of __inline_me during vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectUtils.hs
index 6a8f893..0d5585f 100644 (file)
@@ -12,6 +12,7 @@ module VectUtils (
   prDFunOfTyCon,
   paDictArgType, paDictOfType, paDFunType,
   paMethod, mkPR, lengthPA, replicatePA, emptyPA, packPA, combinePA, liftPA,
+  zipScalars, scalarClosure,
   polyAbstract, polyApply, polyVApply,
   hoistBinding, hoistExpr, hoistPolyVExpr, takeHoisted,
   buildClosure, buildClosures,
@@ -56,8 +57,8 @@ collectAnnTypeBinders expr = go [] expr
 collectAnnValBinders :: AnnExpr Var ann -> ([Var], AnnExpr Var ann)
 collectAnnValBinders expr = go [] expr
   where
-    go bs (_, AnnLam b e) | isIdVar b = go (b:bs) e
-    go bs e                           = (reverse bs, e)
+    go bs (_, AnnLam b e) | isId b = go (b:bs) e
+    go bs e                        = (reverse bs, e)
 
 isAnnTypeArg :: AnnExpr b ann -> Bool
 isAnnTypeArg (_, AnnType _) = True
@@ -270,6 +271,24 @@ liftPA x
       lc <- builtin liftingContext
       replicatePA (Var lc) x
 
+zipScalars :: [Type] -> Type -> VM CoreExpr
+zipScalars arg_tys res_ty
+  = do
+      scalar <- builtin scalarClass
+      (dfuns, _) <- mapAndUnzipM (\ty -> lookupInst scalar [ty]) ty_args
+      zipf <- builtin (scalarZip $ length arg_tys)
+      return $ Var zipf `mkTyApps` ty_args `mkApps` map Var dfuns
+    where
+      ty_args = arg_tys ++ [res_ty]
+
+scalarClosure :: [Type] -> Type -> CoreExpr -> CoreExpr -> VM CoreExpr
+scalarClosure arg_tys res_ty scalar_fun array_fun
+  = do
+      ctr <- builtin (closureCtrFun $ length arg_tys)
+      pas <- mapM paDictOfType (init arg_tys)
+      return $ Var ctr `mkTyApps` (arg_tys ++ [res_ty])
+                       `mkApps`   (pas ++ [scalar_fun, array_fun])
+
 newLocalVVar :: FastString -> Type -> VM VVar
 newLocalVVar fs vty
   = do
@@ -375,12 +394,13 @@ buildClosures :: [TyVar] -> [VVar] -> [Type] -> Type -> VM VExpr -> VM VExpr
 buildClosures _   _    [] _ mk_body
   = mk_body
 buildClosures tvs vars [arg_ty] res_ty mk_body
-  = buildClosure tvs vars arg_ty res_ty mk_body
+  = liftM vInlineMe (buildClosure tvs vars arg_ty res_ty mk_body)
 buildClosures tvs vars (arg_ty : arg_tys) res_ty mk_body
   = do
       res_ty' <- mkClosureTypes arg_tys res_ty
       arg <- newLocalVVar (fsLit "x") arg_ty
-      buildClosure tvs vars arg_ty res_ty'
+      liftM vInlineMe
+        . buildClosure tvs vars arg_ty res_ty'
         . hoistPolyVExpr tvs
         $ do
             lc <- builtin liftingContext
@@ -405,7 +425,7 @@ buildClosure tvs vars arg_ty res_ty mk_body
               body  <- mk_body
               body' <- bind (vVar env_bndr)
                             (vVarApps lc body (vars ++ [arg_bndr]))
-              return (vLamsWithoutLC [env_bndr, arg_bndr] body')
+              return . vInlineMe $ vLamsWithoutLC [env_bndr, arg_bndr] body'
 
       mkClosure arg_ty res_ty env_ty fn env