-vectTopRhs :: Var -> CoreExpr -> VM CoreExpr
-vectTopRhs var expr
- = do
- closedV . liftM vectorised
- . inBind var
- $ vectPolyExpr (freeVars expr)
-
--- ----------------------------------------------------------------------------
--- Bindings
-
-vectBndr :: Var -> VM VVar
-vectBndr v
- = do
- vty <- vectType (idType v)
- lty <- mkPArrayType vty
- let vv = v `Id.setIdType` vty
- lv = v `Id.setIdType` lty
- updLEnv (mapTo vv lv)
- return (vv, lv)
- where
- mapTo vv lv env = env { local_vars = extendVarEnv (local_vars env) v (vv, lv) }
-
-vectBndrIn :: Var -> VM a -> VM (VVar, a)
-vectBndrIn v p
- = localV
- $ do
- vv <- vectBndr v
- x <- p
- return (vv, x)
-
-vectBndrIn' :: Var -> (VVar -> VM a) -> VM (VVar, a)
-vectBndrIn' v p
- = localV
- $ do
- vv <- vectBndr v
- x <- p vv
- return (vv, x)
-
-vectBndrsIn :: [Var] -> VM a -> VM ([VVar], a)
-vectBndrsIn vs p
- = localV
- $ do
- vvs <- mapM vectBndr vs
- x <- p
- return (vvs, x)
-
--- ----------------------------------------------------------------------------
--- Expressions
-
-vectVar :: Var -> VM VExpr
-vectVar v
- = do
- r <- lookupVar v
- case r of
- Local (vv,lv) -> return (Var vv, Var lv)
- Global vv -> do
- let vexpr = Var vv
- lexpr <- liftPA vexpr
- return (vexpr, lexpr)
-
-vectPolyVar :: Var -> [Type] -> VM VExpr
-vectPolyVar v tys
- = do
- vtys <- mapM vectType tys
- r <- lookupVar v
- case r of
- Local (vv, lv) -> liftM2 (,) (polyApply (Var vv) vtys)
- (polyApply (Var lv) vtys)
- Global poly -> do
- vexpr <- polyApply (Var poly) vtys
- lexpr <- liftPA vexpr
- return (vexpr, lexpr)
-
-vectLiteral :: Literal -> VM VExpr
-vectLiteral lit
- = do
- lexpr <- liftPA (Lit lit)
- return (Lit lit, lexpr)
-
-vectPolyExpr :: CoreExprWithFVs -> VM VExpr
-vectPolyExpr expr
- = polyAbstract tvs $ \abstract ->
- do
- mono' <- vectExpr mono
- return $ mapVect abstract mono'
- where
- (tvs, mono) = collectAnnTypeBinders expr
-
-vectExpr :: CoreExprWithFVs -> VM VExpr
-vectExpr (_, AnnType ty)
- = liftM vType (vectType ty)
-
-vectExpr (_, AnnVar v) = vectVar v
-
-vectExpr (_, AnnLit lit) = vectLiteral lit
-
-vectExpr (_, AnnNote note expr)
- = liftM (vNote note) (vectExpr expr)
-
-vectExpr e@(_, AnnApp _ arg)
- | isAnnTypeArg arg
- = vectTyAppExpr fn tys
- where
- (fn, tys) = collectAnnTypeArgs e
-
-vectExpr (_, AnnApp fn arg)
- = do
- arg_ty' <- vectType arg_ty
- res_ty' <- vectType res_ty
- fn' <- vectExpr fn
- arg' <- vectExpr arg
- mkClosureApp arg_ty' res_ty' fn' arg'
- where
- (arg_ty, res_ty) = splitFunTy . exprType $ deAnnotate fn
-
-vectExpr (_, AnnCase scrut bndr ty alts)
- | isAlgType scrut_ty
- = vectAlgCase scrut bndr ty alts
- where
- scrut_ty = exprType (deAnnotate scrut)
-
-vectExpr (_, AnnCase expr bndr ty alts)
- = panic "vectExpr: case"
-
-vectExpr (_, AnnLet (AnnNonRec bndr rhs) body)
- = do
- vrhs <- localV . inBind bndr $ vectPolyExpr rhs
- (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
- return $ vLet (vNonRec vbndr vrhs) vbody
-
-vectExpr (_, AnnLet (AnnRec bs) body)
- = do
- (vbndrs, (vrhss, vbody)) <- vectBndrsIn bndrs
- $ liftM2 (,)
- (zipWithM vect_rhs bndrs rhss)
- (vectPolyExpr body)
- return $ vLet (vRec vbndrs vrhss) vbody
- where
- (bndrs, rhss) = unzip bs
-
- vect_rhs bndr rhs = localV
- . inBind bndr
- $ vectExpr rhs
-
-vectExpr e@(fvs, AnnLam bndr _)
- | not (isId bndr) = pprPanic "vectExpr" (ppr $ deAnnotate e)
- | otherwise = vectLam fvs bs body
+-- | Make the vectorised version of this top level binder, and add the mapping
+-- between it and the original to the state. For some binder @foo@ the vectorised
+-- version is @$v_foo@
+--
+-- NOTE: vectTopBinder *MUST* be lazy in inline and expr because of how it is
+-- used inside of fixV in vectTopBind
+--
+vectTopBinder :: Var -- ^ Name of the binding.
+ -> Inline -- ^ Whether it should be inlined, used to annotate it.
+ -> CoreExpr -- ^ RHS of binding, used to set the 'Unfolding' of the returned 'Var'.
+ -> VM Var -- ^ Name of the vectorised binding.
+vectTopBinder var inline expr
+ = do { -- Vectorise the type attached to the var.
+ ; vty <- vectType (idType var)
+
+ -- If there is a vectorisation declartion for this binding, make sure that its type
+ -- matches
+ ; vectDecl <- lookupVectDecl var
+ ; case vectDecl of
+ Nothing -> return ()
+ Just (vdty, _)
+ | coreEqType vty vdty -> return ()
+ | otherwise ->
+ cantVectorise ("Type mismatch in vectorisation pragma for " ++ show var) $
+ (text "Expected type" <+> ppr vty)
+ $$
+ (text "Inferred type" <+> ppr vdty)
+
+ -- Make the vectorised version of binding's name, and set the unfolding used for inlining
+ ; var' <- liftM (`setIdUnfoldingLazily` unfolding)
+ $ cloneId mkVectOcc var vty
+
+ -- Add the mapping between the plain and vectorised name to the state.
+ ; defGlobalVar var var'
+
+ ; return var'
+ }