Separate length from data in DPH arrays
[ghc-hetmet.git] / compiler / vectorise / VectCore.hs
1 module VectCore (
2   Vect, VVar, VExpr, VBind,
3
4   vectorised, lifted,
5   mapVect,
6
7   vVarType,
8
9   vNonRec, vRec,
10
11   vVar, vType, vNote, vLet,
12   vLams, vLamsWithoutLC, vVarApps,
13   vCaseDEFAULT, vInlineMe
14 ) where
15
16 #include "HsVersions.h"
17
18 import CoreSyn
19 import CoreUtils      ( mkInlineMe )
20 import MkCore         ( mkWildCase )
21 import CoreUtils      ( exprType )
22 import DataCon        ( DataCon )
23 import Type           ( Type )
24 import Var
25
26 type Vect a = (a,a)
27 type VVar   = Vect Var
28 type VExpr  = Vect CoreExpr
29 type VBind  = Vect CoreBind
30
31 vectorised :: Vect a -> a
32 vectorised = fst
33
34 lifted :: Vect a -> a
35 lifted = snd
36
37 mapVect :: (a -> b) -> Vect a -> Vect b
38 mapVect f (x,y) = (f x, f y)
39
40 zipWithVect :: (a -> b -> c) -> Vect a -> Vect b -> Vect c
41 zipWithVect f (x1,y1) (x2,y2) = (f x1 x2, f y1 y2)
42
43 vVarType :: VVar -> Type
44 vVarType = varType . vectorised
45
46 vVar :: VVar -> VExpr
47 vVar = mapVect Var
48
49 vType :: Type -> VExpr
50 vType ty = (Type ty, Type ty)
51
52 vNote :: Note -> VExpr -> VExpr
53 vNote = mapVect . Note
54
55 vNonRec :: VVar -> VExpr -> VBind
56 vNonRec = zipWithVect NonRec
57
58 vRec :: [VVar] -> [VExpr] -> VBind
59 vRec vs es = (Rec (zip vvs ves), Rec (zip lvs les))
60   where
61     (vvs, lvs) = unzip vs
62     (ves, les) = unzip es
63
64 vLet :: VBind -> VExpr -> VExpr
65 vLet = zipWithVect Let
66
67 vLams :: Var -> [VVar] -> VExpr -> VExpr
68 vLams lc vs (ve, le) = (mkLams vvs ve, mkLams (lc:lvs) le)
69   where
70     (vvs,lvs) = unzip vs
71
72 vLamsWithoutLC :: [VVar] -> VExpr -> VExpr
73 vLamsWithoutLC vvs (ve,le) = (mkLams vs ve, mkLams ls le)
74   where
75     (vs,ls) = unzip vvs
76
77 vVarApps :: Var -> VExpr -> [VVar] -> VExpr
78 vVarApps lc (ve, le) vvs = (ve `mkVarApps` vs, le `mkVarApps` (lc : ls))
79   where
80     (vs,ls) = unzip vvs 
81
82 vCaseDEFAULT :: VExpr -> VVar -> Type -> Type -> VExpr -> VExpr
83 vCaseDEFAULT (vscrut, lscrut) (vbndr, lbndr) vty lty (vbody, lbody)
84   = (Case vscrut vbndr vty (mkDEFAULT vbody),
85      Case lscrut lbndr lty (mkDEFAULT lbody))
86   where
87     mkDEFAULT e = [(DEFAULT, [], e)]
88
89 vInlineMe :: VExpr -> VExpr
90 vInlineMe (vexpr, lexpr) = (mkInlineMe vexpr, mkInlineMe lexpr)
91