Generate lots of __inline_me during vectorisation
[ghc-hetmet.git] / compiler / vectorise / VectCore.hs
1 module VectCore (
2   Vect, VVar, VExpr, VBind,
3
4   vectorised, lifted,
5   mapVect,
6
7   vNonRec, vRec,
8
9   vVar, vType, vNote, vLet,
10   vLams, vLamsWithoutLC, vVarApps,
11   vCaseDEFAULT, vCaseProd, vInlineMe
12 ) where
13
14 #include "HsVersions.h"
15
16 import CoreSyn
17 import CoreUtils      ( mkInlineMe )
18 import MkCore         ( mkWildCase )
19 import CoreUtils      ( exprType )
20 import DataCon        ( DataCon )
21 import Type           ( Type )
22 import Var
23
24 type Vect a = (a,a)
25 type VVar   = Vect Var
26 type VExpr  = Vect CoreExpr
27 type VBind  = Vect CoreBind
28
29 vectorised :: Vect a -> a
30 vectorised = fst
31
32 lifted :: Vect a -> a
33 lifted = snd
34
35 mapVect :: (a -> b) -> Vect a -> Vect b
36 mapVect f (x,y) = (f x, f y)
37
38 zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
39 zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
40
41 vVar :: VVar -> VExpr
42 vVar = mapVect Var
43
44 vType :: Type -> VExpr
45 vType ty = (Type ty, Type ty)
46
47 vNote :: Note -> VExpr -> VExpr
48 vNote = mapVect . Note
49
50 vNonRec :: VVar -> VExpr -> VBind
51 vNonRec = zipWithVect NonRec
52
53 vRec :: [VVar] -> [VExpr] -> VBind
54 vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
55   where
56     (vvs, lvs) = unzip vs
57     (ves, les) = unzip es
58
59 vLet :: VBind -> VExpr -> VExpr
60 vLet = zipWithVect Let
61
62 vLams :: Var -> [VVar] -> VExpr -> VExpr
63 vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
64   where
65     (vvs,lvs) = unzip vs
66
67 vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
68 vLamsWithoutLC vvs (ve,le) = (mkLams vs ve, mkLams ls le)
69   where
70     (vs,ls) = unzip vvs
71
72 vVarApps :: Var -> VExpr -> [VVar] -> VExpr
73 vVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
74   where
75     (vs,ls) = unzip vvs 
76
77 vCaseDEFAULT :: VExpr -> VVar -> Type -> Type -> VExpr -> VExpr
78 vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
79   = (Case vscrut vbndr vty (mkDEFAULT vbody),
80      Case lscrut lbndr lty (mkDEFAULT lbody))
81   where
82     mkDEFAULT e = [(DEFAULT, [], e)]
83
84 vCaseProd :: VExpr -> Type -> Type
85           -> DataCon -> DataCon -> [Var] -> [VVar] -> VExpr -> VExpr
86 vCaseProd (vscrut, lscrut) vty lty vdc ldc sh_bndrs bndrs
87           (vbody,lbody)
88   = (mkWildCase vscrut (exprType vscrut) vty
89           [(DataAlt vdc, vbndrs, vbody)],
90      mkWildCase lscrut (exprType lscrut) lty
91           [(DataAlt ldc, sh_bndrs ++ lbndrs, lbody)])
92   where
93     (vbndrs, lbndrs) = unzip bndrs
94
95 vInlineMe :: VExpr -> VExpr
96 vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr)
97