de832793f654e3243486ded9dab5f79ff7ac247e
[ghc-hetmet.git] / compiler / vectorise / VectCore.hs
1 module VectCore (
2   Vect, VVar, VExpr, VBind,
3
4   vectorised, lifted,
5   mapVect,
6
7   vNonRec, vRec,
8
9   vVar, vType, vNote, vLet,
10   vLams, vLamsWithoutLC, vVarApps,
11   vCaseDEFAULT, vCaseProd
12 ) where
13
14 #include "HsVersions.h"
15
16 import CoreSyn
17 import CoreUtils      ( exprType )
18 import DataCon        ( DataCon )
19 import Type           ( Type )
20 import Id             ( mkWildId )
21 import Var
22
23 type Vect a = (a,a)
24 type VVar   = Vect Var
25 type VExpr  = Vect CoreExpr
26 type VBind  = Vect CoreBind
27
28 vectorised :: Vect a -> a
29 vectorised = fst
30
31 lifted :: Vect a -> a
32 lifted = snd
33
34 mapVect :: (a -> b) -> Vect a -> Vect b
35 mapVect f (x,y) = (f x, f y)
36
37 zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
38 zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
39
40 vVar :: VVar -> VExpr
41 vVar = mapVect Var
42
43 vType :: Type -> VExpr
44 vType ty = (Type ty, Type ty)
45
46 vNote :: Note -> VExpr -> VExpr
47 vNote = mapVect . Note
48
49 vNonRec :: VVar -> VExpr -> VBind
50 vNonRec = zipWithVect NonRec
51
52 vRec :: [VVar] -> [VExpr] -> VBind
53 vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
54   where
55     (vvs, lvs) = unzip vs
56     (ves, les) = unzip es
57
58 vLet :: VBind -> VExpr -> VExpr
59 vLet = zipWithVect Let
60
61 vLams :: Var -> [VVar] -> VExpr -> VExpr
62 vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
63   where
64     (vvs,lvs) = unzip vs
65
66 vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
67 vLamsWithoutLC vvs (ve,le) = (mkLams vs ve, mkLams ls le)
68   where
69     (vs,ls) = unzip vvs
70
71 vVarApps :: Var -> VExpr -> [VVar] -> VExpr
72 vVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
73   where
74     (vs,ls) = unzip vvs 
75
76 vCaseDEFAULT :: VExpr -> VVar -> Type -> Type -> VExpr -> VExpr
77 vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
78   = (Case vscrut vbndr vty (mkDEFAULT vbody),
79      Case lscrut lbndr lty (mkDEFAULT lbody))
80   where
81     mkDEFAULT e = [(DEFAULT, [], e)]
82
83 vCaseProd :: VExpr -> Type -> Type
84           -> DataCon -> DataCon -> [Var] -> [VVar] -> VExpr -> VExpr
85 vCaseProd (vscrut, lscrut) vty lty vdc ldc sh_bndrs bndrs
86           (vbody,lbody)
87   = (Case vscrut (mkWildId $ exprType vscrut) vty
88           [(DataAlt vdc, vbndrs, vbody)],
89      Case lscrut (mkWildId $ exprType lscrut) lty
90           [(DataAlt ldc, sh_bndrs ++ lbndrs, lbody)])
91   where
92     (vbndrs, lbndrs) = unzip bndrs