Break out Repr and PADict stuff for vectorisation of ADTs to their own modules
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
1
2 module Vectorise.Type.PADict
3         ( buildToPRepr
4         , buildFromPRepr
5         , buildToArrPRepr
6         , buildFromArrPRepr)
7 where
8 import VectUtils
9 import Vectorise.Monad
10 import Vectorise.Builtins
11 import Vectorise.Type.Repr
12 import CoreSyn
13 import CoreUtils
14 import MkCore            ( mkWildCase )
15 import TyCon
16 import Type
17 import Coercion
18 import MkId
19
20 import FastString
21 import MonadUtils
22 import Control.Monad
23
24
25 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
26 buildToPRepr vect_tc repr_tc _ repr
27   = do
28       let arg_ty = mkTyConApp vect_tc ty_args
29       res_ty <- mkPReprType arg_ty
30       arg    <- newLocalVar (fsLit "x") arg_ty
31       result <- to_sum (Var arg) arg_ty res_ty repr
32       return $ Lam arg result
33   where
34     ty_args = mkTyVarTys (tyConTyVars vect_tc)
35
36     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
37
38     to_sum _ _ _ EmptySum
39       = do
40           void <- builtin voidVar
41           return $ wrap_repr_inst $ Var void
42
43     to_sum arg arg_ty res_ty (UnarySum r)
44       = do
45           (pat, vars, body) <- con_alt r
46           return $ mkWildCase arg arg_ty res_ty
47                    [(pat, vars, wrap_repr_inst body)]
48
49     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
50                                   , repr_con_tys = tys
51                                   , repr_cons    =  cons })
52       = do
53           alts <- mapM con_alt cons
54           let alts' = [(pat, vars, wrap_repr_inst
55                                    $ mkConApp sum_con (map Type tys ++ [body]))
56                         | ((pat, vars, body), sum_con)
57                             <- zip alts (tyConDataCons sum_tc)]
58           return $ mkWildCase arg arg_ty res_ty alts'
59
60     con_alt (ConRepr con r)
61       = do
62           (vars, body) <- to_prod r
63           return (DataAlt con, vars, body)
64
65     to_prod EmptyProd
66       = do
67           void <- builtin voidVar
68           return ([], Var void)
69
70     to_prod (UnaryProd comp)
71       = do
72           var  <- newLocalVar (fsLit "x") (compOrigType comp)
73           body <- to_comp (Var var) comp
74           return ([var], body)
75
76     to_prod(Prod { repr_tup_tc   = tup_tc
77                  , repr_comp_tys = tys
78                  , repr_comps    = comps })
79       = do
80           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
81           exprs <- zipWithM to_comp (map Var vars) comps
82           return (vars, mkConApp tup_con (map Type tys ++ exprs))
83       where
84         [tup_con] = tyConDataCons tup_tc
85
86     to_comp expr (Keep _ _) = return expr
87     to_comp expr (Wrap ty)  = do
88                                 wrap_tc <- builtin wrapTyCon
89                                 return $ wrapNewTypeBody wrap_tc [ty] expr
90
91
92 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
93 buildFromPRepr vect_tc repr_tc _ repr
94   = do
95       arg_ty <- mkPReprType res_ty
96       arg <- newLocalVar (fsLit "x") arg_ty
97
98       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
99                          repr
100       return $ Lam arg result
101   where
102     ty_args = mkTyVarTys (tyConTyVars vect_tc)
103     res_ty  = mkTyConApp vect_tc ty_args
104
105     from_sum _ EmptySum
106       = do
107           dummy <- builtin fromVoidVar
108           return $ Var dummy `App` Type res_ty
109
110     from_sum expr (UnarySum r) = from_con expr r
111     from_sum expr (Sum { repr_sum_tc  = sum_tc
112                        , repr_con_tys = tys
113                        , repr_cons    = cons })
114       = do
115           vars  <- newLocalVars (fsLit "x") tys
116           es    <- zipWithM from_con (map Var vars) cons
117           return $ mkWildCase expr (exprType expr) res_ty
118                    [(DataAlt con, [var], e)
119                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
120
121     from_con expr (ConRepr con r)
122       = from_prod expr (mkConApp con $ map Type ty_args) r
123
124     from_prod _ con EmptyProd = return con
125     from_prod expr con (UnaryProd r)
126       = do
127           e <- from_comp expr r
128           return $ con `App` e
129      
130     from_prod expr con (Prod { repr_tup_tc   = tup_tc
131                              , repr_comp_tys = tys
132                              , repr_comps    = comps
133                              })
134       = do
135           vars <- newLocalVars (fsLit "y") tys
136           es   <- zipWithM from_comp (map Var vars) comps
137           return $ mkWildCase expr (exprType expr) res_ty
138                    [(DataAlt tup_con, vars, con `mkApps` es)]
139       where
140         [tup_con] = tyConDataCons tup_tc  
141
142     from_comp expr (Keep _ _) = return expr
143     from_comp expr (Wrap ty)
144       = do
145           wrap <- builtin wrapTyCon
146           return $ unwrapNewTypeBody wrap [ty] expr
147
148
149 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
150 buildToArrPRepr vect_tc prepr_tc pdata_tc r
151   = do
152       arg_ty <- mkPDataType el_ty
153       res_ty <- mkPDataType =<< mkPReprType el_ty
154       arg    <- newLocalVar (fsLit "xs") arg_ty
155
156       pdata_co <- mkBuiltinCo pdataTyCon
157       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
158           co           = mkAppCoercion pdata_co
159                        . mkSymCoercion
160                        $ mkTyConApp repr_co ty_args
161
162           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
163
164       (vars, result) <- to_sum r
165
166       return . Lam arg
167              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
168                [(DataAlt pdata_dc, vars, mkCoerce co result)]
169   where
170     ty_args = mkTyVarTys $ tyConTyVars vect_tc
171     el_ty   = mkTyConApp vect_tc ty_args
172
173     [pdata_dc] = tyConDataCons pdata_tc
174
175
176     to_sum EmptySum = do
177                         pvoid <- builtin pvoidVar
178                         return ([], Var pvoid)
179     to_sum (UnarySum r) = to_con r
180     to_sum (Sum { repr_psum_tc = psum_tc
181                 , repr_sel_ty  = sel_ty
182                 , repr_con_tys = tys
183                 , repr_cons    = cons
184                 })
185       = do
186           (vars, exprs) <- mapAndUnzipM to_con cons
187           sel <- newLocalVar (fsLit "sel") sel_ty
188           return (sel : concat vars, mk_result (Var sel) exprs)
189       where
190         [psum_con] = tyConDataCons psum_tc
191         mk_result sel exprs = wrapFamInstBody psum_tc tys
192                             $ mkConApp psum_con
193                             $ map Type tys ++ (sel : exprs)
194
195     to_con (ConRepr _ r) = to_prod r
196
197     to_prod EmptyProd = do
198                           pvoid <- builtin pvoidVar
199                           return ([], Var pvoid)
200     to_prod (UnaryProd r)
201       = do
202           pty  <- mkPDataType (compOrigType r)
203           var  <- newLocalVar (fsLit "x") pty
204           expr <- to_comp (Var var) r
205           return ([var], expr)
206
207     to_prod (Prod { repr_ptup_tc  = ptup_tc
208                   , repr_comp_tys = tys
209                   , repr_comps    = comps })
210       = do
211           ptys <- mapM (mkPDataType . compOrigType) comps
212           vars <- newLocalVars (fsLit "x") ptys
213           es   <- zipWithM to_comp (map Var vars) comps
214           return (vars, mk_result es)
215       where
216         [ptup_con] = tyConDataCons ptup_tc
217         mk_result exprs = wrapFamInstBody ptup_tc tys
218                         $ mkConApp ptup_con
219                         $ map Type tys ++ exprs
220
221     to_comp expr (Keep _ _) = return expr
222
223     -- FIXME: this is bound to be wrong!
224     to_comp expr (Wrap ty)
225       = do
226           wrap_tc  <- builtin wrapTyCon
227           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
228           return $ wrapNewTypeBody pwrap_tc [ty] expr
229
230
231 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
232 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
233   = do
234       arg_ty <- mkPDataType =<< mkPReprType el_ty
235       res_ty <- mkPDataType el_ty
236       arg    <- newLocalVar (fsLit "xs") arg_ty
237
238       pdata_co <- mkBuiltinCo pdataTyCon
239       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
240           co           = mkAppCoercion pdata_co
241                        $ mkTyConApp repr_co var_tys
242
243           scrut  = mkCoerce co (Var arg)
244
245           mk_result args = wrapFamInstBody pdata_tc var_tys
246                          $ mkConApp pdata_con
247                          $ map Type var_tys ++ args
248
249       (expr, _) <- fixV $ \ ~(_, args) ->
250                      from_sum res_ty (mk_result args) scrut r
251
252       return $ Lam arg expr
253     
254       -- (args, mk) <- from_sum res_ty scrut r
255       
256       -- let result = wrapFamInstBody pdata_tc var_tys
257       --           . mkConApp pdata_dc
258       --           $ map Type var_tys ++ args
259
260       -- return $ Lam arg (mk result)
261   where
262     var_tys = mkTyVarTys $ tyConTyVars vect_tc
263     el_ty   = mkTyConApp vect_tc var_tys
264
265     [pdata_con] = tyConDataCons pdata_tc
266
267     from_sum _ res _ EmptySum = return (res, [])
268     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
269     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
270                                   , repr_sel_ty  = sel_ty
271                                   , repr_con_tys = tys
272                                   , repr_cons    = cons })
273       = do
274           sel  <- newLocalVar (fsLit "sel") sel_ty
275           ptys <- mapM mkPDataType tys
276           vars <- newLocalVars (fsLit "xs") ptys
277           (res', args) <- fold from_con res_ty res (map Var vars) cons
278           let scrut = unwrapFamInstScrut psum_tc tys expr
279               body  = mkWildCase scrut (exprType scrut) res_ty
280                       [(DataAlt psum_con, sel : vars, res')]
281           return (body, Var sel : args)
282       where
283         [psum_con] = tyConDataCons psum_tc
284
285
286     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
287
288     from_prod _ res _ EmptyProd = return (res, [])
289     from_prod res_ty res expr (UnaryProd r)
290       = from_comp res_ty res expr r
291     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
292                                     , repr_comp_tys = tys
293                                     , repr_comps    = comps })
294       = do
295           ptys <- mapM mkPDataType tys
296           vars <- newLocalVars (fsLit "ys") ptys
297           (res', args) <- fold from_comp res_ty res (map Var vars) comps
298           let scrut = unwrapFamInstScrut ptup_tc tys expr
299               body  = mkWildCase scrut (exprType scrut) res_ty
300                       [(DataAlt ptup_con, vars, res')]
301           return (body, args)
302       where
303         [ptup_con] = tyConDataCons ptup_tc
304
305     from_comp _ res expr (Keep _ _) = return (res, [expr])
306     from_comp _ res expr (Wrap ty)
307       = do
308           wrap_tc  <- builtin wrapTyCon
309           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
310           return (res, [unwrapNewTypeBody pwrap_tc [ty]
311                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
312
313     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
314       where
315         f' (expr, r) (res, args) = do
316                                      (res', args') <- f res_ty res expr r
317                                      return (res', args' ++ args)