34ffdb7cae37a44bbd594b7f3a107df2293bb510
[ghc-hetmet.git] / compiler / vectorise / VectCore.hs
1 module VectCore (
2   Vect, VVar, VExpr, VBind,
3
4   vectorised, lifted,
5   mapVect,
6
7   vNonRec, vRec,
8
9   vVar, vType, vNote, vLet,
10   vLams, vLamsWithoutLC, vVarApps,
11   vCaseDEFAULT, vCaseProd
12 ) where
13
14 -- XXX This define is a bit of a hack, and should be done more nicely
15 #define FAST_STRING_NOT_NEEDED 1
16 #include "HsVersions.h"
17
18 import CoreSyn
19 import CoreUtils      ( exprType )
20 import DataCon        ( DataCon )
21 import Type           ( Type )
22 import Id             ( mkWildId )
23 import Var
24
25 type Vect a = (a,a)
26 type VVar   = Vect Var
27 type VExpr  = Vect CoreExpr
28 type VBind  = Vect CoreBind
29
30 vectorised :: Vect a -> a
31 vectorised = fst
32
33 lifted :: Vect a -> a
34 lifted = snd
35
36 mapVect :: (a -> b) -> Vect a -> Vect b
37 mapVect f (x,y) = (f x, f y)
38
39 zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
40 zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
41
42 vVar :: VVar -> VExpr
43 vVar = mapVect Var
44
45 vType :: Type -> VExpr
46 vType ty = (Type ty, Type ty)
47
48 vNote :: Note -> VExpr -> VExpr
49 vNote = mapVect . Note
50
51 vNonRec :: VVar -> VExpr -> VBind
52 vNonRec = zipWithVect NonRec
53
54 vRec :: [VVar] -> [VExpr] -> VBind
55 vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
56   where
57     (vvs, lvs) = unzip vs
58     (ves, les) = unzip es
59
60 vLet :: VBind -> VExpr -> VExpr
61 vLet = zipWithVect Let
62
63 vLams :: Var -> [VVar] -> VExpr -> VExpr
64 vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
65   where
66     (vvs,lvs) = unzip vs
67
68 vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
69 vLamsWithoutLC vvs (ve,le) = (mkLams vs ve, mkLams ls le)
70   where
71     (vs,ls) = unzip vvs
72
73 vVarApps :: Var -> VExpr -> [VVar] -> VExpr
74 vVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
75   where
76     (vs,ls) = unzip vvs 
77
78 vCaseDEFAULT :: VExpr -> VVar -> Type -> Type -> VExpr -> VExpr
79 vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
80   = (Case vscrut vbndr vty (mkDEFAULT vbody),
81      Case lscrut lbndr lty (mkDEFAULT lbody))
82   where
83     mkDEFAULT e = [(DEFAULT, [], e)]
84
85 vCaseProd :: VExpr -> Type -> Type
86           -> DataCon -> DataCon -> [Var] -> [VVar] -> VExpr -> VExpr
87 vCaseProd (vscrut, lscrut) vty lty vdc ldc sh_bndrs bndrs
88           (vbody,lbody)
89   = (Case vscrut (mkWildId $ exprType vscrut) vty
90           [(DataAlt vdc, vbndrs, vbody)],
91      Case lscrut (mkWildId $ exprType lscrut) lty
92           [(DataAlt ldc, sh_bndrs ++ lbndrs, lbody)])
93   where
94     (vbndrs, lbndrs) = unzip bndrs