Sort all the PADict/PData/PRDict/PRepr stuff into their own modules
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PRepr.hs
1
2 module Vectorise.Type.PRepr
3         ( buildPReprTyCon
4         , buildToPRepr
5         , buildFromPRepr
6         , buildToArrPRepr
7         , buildFromArrPRepr)
8 where
9 import VectUtils
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
13 import CoreSyn
14 import CoreUtils
15 import MkCore            ( mkWildCase )
16 import TyCon
17 import Type
18 import BuildTyCl
19 import OccName
20 import Coercion
21 import MkId
22
23 import FastString
24 import MonadUtils
25 import Control.Monad
26
27
28 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
29 mk_fam_inst fam_tc arg_tc
30   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
31
32
33 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
34 buildPReprTyCon orig_tc vect_tc repr
35   = do
36       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
37       -- rhs_ty   <- buildPReprType vect_tc
38       rhs_ty   <- sumReprType repr
39       prepr_tc <- builtin preprTyCon
40       liftDs $ buildSynTyCon name
41                              tyvars
42                              (SynonymTyCon rhs_ty)
43                              (typeKind rhs_ty)
44                              (Just $ mk_fam_inst prepr_tc vect_tc)
45   where
46     tyvars = tyConTyVars vect_tc
47
48
49 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
50 buildToPRepr vect_tc repr_tc _ repr
51   = do
52       let arg_ty = mkTyConApp vect_tc ty_args
53       res_ty <- mkPReprType arg_ty
54       arg    <- newLocalVar (fsLit "x") arg_ty
55       result <- to_sum (Var arg) arg_ty res_ty repr
56       return $ Lam arg result
57   where
58     ty_args = mkTyVarTys (tyConTyVars vect_tc)
59
60     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
61
62     to_sum _ _ _ EmptySum
63       = do
64           void <- builtin voidVar
65           return $ wrap_repr_inst $ Var void
66
67     to_sum arg arg_ty res_ty (UnarySum r)
68       = do
69           (pat, vars, body) <- con_alt r
70           return $ mkWildCase arg arg_ty res_ty
71                    [(pat, vars, wrap_repr_inst body)]
72
73     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
74                                   , repr_con_tys = tys
75                                   , repr_cons    =  cons })
76       = do
77           alts <- mapM con_alt cons
78           let alts' = [(pat, vars, wrap_repr_inst
79                                    $ mkConApp sum_con (map Type tys ++ [body]))
80                         | ((pat, vars, body), sum_con)
81                             <- zip alts (tyConDataCons sum_tc)]
82           return $ mkWildCase arg arg_ty res_ty alts'
83
84     con_alt (ConRepr con r)
85       = do
86           (vars, body) <- to_prod r
87           return (DataAlt con, vars, body)
88
89     to_prod EmptyProd
90       = do
91           void <- builtin voidVar
92           return ([], Var void)
93
94     to_prod (UnaryProd comp)
95       = do
96           var  <- newLocalVar (fsLit "x") (compOrigType comp)
97           body <- to_comp (Var var) comp
98           return ([var], body)
99
100     to_prod(Prod { repr_tup_tc   = tup_tc
101                  , repr_comp_tys = tys
102                  , repr_comps    = comps })
103       = do
104           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
105           exprs <- zipWithM to_comp (map Var vars) comps
106           return (vars, mkConApp tup_con (map Type tys ++ exprs))
107       where
108         [tup_con] = tyConDataCons tup_tc
109
110     to_comp expr (Keep _ _) = return expr
111     to_comp expr (Wrap ty)  = do
112                                 wrap_tc <- builtin wrapTyCon
113                                 return $ wrapNewTypeBody wrap_tc [ty] expr
114
115
116 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
117 buildFromPRepr vect_tc repr_tc _ repr
118   = do
119       arg_ty <- mkPReprType res_ty
120       arg <- newLocalVar (fsLit "x") arg_ty
121
122       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
123                          repr
124       return $ Lam arg result
125   where
126     ty_args = mkTyVarTys (tyConTyVars vect_tc)
127     res_ty  = mkTyConApp vect_tc ty_args
128
129     from_sum _ EmptySum
130       = do
131           dummy <- builtin fromVoidVar
132           return $ Var dummy `App` Type res_ty
133
134     from_sum expr (UnarySum r) = from_con expr r
135     from_sum expr (Sum { repr_sum_tc  = sum_tc
136                        , repr_con_tys = tys
137                        , repr_cons    = cons })
138       = do
139           vars  <- newLocalVars (fsLit "x") tys
140           es    <- zipWithM from_con (map Var vars) cons
141           return $ mkWildCase expr (exprType expr) res_ty
142                    [(DataAlt con, [var], e)
143                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
144
145     from_con expr (ConRepr con r)
146       = from_prod expr (mkConApp con $ map Type ty_args) r
147
148     from_prod _ con EmptyProd = return con
149     from_prod expr con (UnaryProd r)
150       = do
151           e <- from_comp expr r
152           return $ con `App` e
153      
154     from_prod expr con (Prod { repr_tup_tc   = tup_tc
155                              , repr_comp_tys = tys
156                              , repr_comps    = comps
157                              })
158       = do
159           vars <- newLocalVars (fsLit "y") tys
160           es   <- zipWithM from_comp (map Var vars) comps
161           return $ mkWildCase expr (exprType expr) res_ty
162                    [(DataAlt tup_con, vars, con `mkApps` es)]
163       where
164         [tup_con] = tyConDataCons tup_tc  
165
166     from_comp expr (Keep _ _) = return expr
167     from_comp expr (Wrap ty)
168       = do
169           wrap <- builtin wrapTyCon
170           return $ unwrapNewTypeBody wrap [ty] expr
171
172
173 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
174 buildToArrPRepr vect_tc prepr_tc pdata_tc r
175   = do
176       arg_ty <- mkPDataType el_ty
177       res_ty <- mkPDataType =<< mkPReprType el_ty
178       arg    <- newLocalVar (fsLit "xs") arg_ty
179
180       pdata_co <- mkBuiltinCo pdataTyCon
181       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
182           co           = mkAppCoercion pdata_co
183                        . mkSymCoercion
184                        $ mkTyConApp repr_co ty_args
185
186           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
187
188       (vars, result) <- to_sum r
189
190       return . Lam arg
191              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
192                [(DataAlt pdata_dc, vars, mkCoerce co result)]
193   where
194     ty_args = mkTyVarTys $ tyConTyVars vect_tc
195     el_ty   = mkTyConApp vect_tc ty_args
196
197     [pdata_dc] = tyConDataCons pdata_tc
198
199
200     to_sum EmptySum = do
201                         pvoid <- builtin pvoidVar
202                         return ([], Var pvoid)
203     to_sum (UnarySum r) = to_con r
204     to_sum (Sum { repr_psum_tc = psum_tc
205                 , repr_sel_ty  = sel_ty
206                 , repr_con_tys = tys
207                 , repr_cons    = cons
208                 })
209       = do
210           (vars, exprs) <- mapAndUnzipM to_con cons
211           sel <- newLocalVar (fsLit "sel") sel_ty
212           return (sel : concat vars, mk_result (Var sel) exprs)
213       where
214         [psum_con] = tyConDataCons psum_tc
215         mk_result sel exprs = wrapFamInstBody psum_tc tys
216                             $ mkConApp psum_con
217                             $ map Type tys ++ (sel : exprs)
218
219     to_con (ConRepr _ r) = to_prod r
220
221     to_prod EmptyProd = do
222                           pvoid <- builtin pvoidVar
223                           return ([], Var pvoid)
224     to_prod (UnaryProd r)
225       = do
226           pty  <- mkPDataType (compOrigType r)
227           var  <- newLocalVar (fsLit "x") pty
228           expr <- to_comp (Var var) r
229           return ([var], expr)
230
231     to_prod (Prod { repr_ptup_tc  = ptup_tc
232                   , repr_comp_tys = tys
233                   , repr_comps    = comps })
234       = do
235           ptys <- mapM (mkPDataType . compOrigType) comps
236           vars <- newLocalVars (fsLit "x") ptys
237           es   <- zipWithM to_comp (map Var vars) comps
238           return (vars, mk_result es)
239       where
240         [ptup_con] = tyConDataCons ptup_tc
241         mk_result exprs = wrapFamInstBody ptup_tc tys
242                         $ mkConApp ptup_con
243                         $ map Type tys ++ exprs
244
245     to_comp expr (Keep _ _) = return expr
246
247     -- FIXME: this is bound to be wrong!
248     to_comp expr (Wrap ty)
249       = do
250           wrap_tc  <- builtin wrapTyCon
251           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
252           return $ wrapNewTypeBody pwrap_tc [ty] expr
253
254
255 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
256 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
257   = do
258       arg_ty <- mkPDataType =<< mkPReprType el_ty
259       res_ty <- mkPDataType el_ty
260       arg    <- newLocalVar (fsLit "xs") arg_ty
261
262       pdata_co <- mkBuiltinCo pdataTyCon
263       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
264           co           = mkAppCoercion pdata_co
265                        $ mkTyConApp repr_co var_tys
266
267           scrut  = mkCoerce co (Var arg)
268
269           mk_result args = wrapFamInstBody pdata_tc var_tys
270                          $ mkConApp pdata_con
271                          $ map Type var_tys ++ args
272
273       (expr, _) <- fixV $ \ ~(_, args) ->
274                      from_sum res_ty (mk_result args) scrut r
275
276       return $ Lam arg expr
277     
278       -- (args, mk) <- from_sum res_ty scrut r
279       
280       -- let result = wrapFamInstBody pdata_tc var_tys
281       --           . mkConApp pdata_dc
282       --           $ map Type var_tys ++ args
283
284       -- return $ Lam arg (mk result)
285   where
286     var_tys = mkTyVarTys $ tyConTyVars vect_tc
287     el_ty   = mkTyConApp vect_tc var_tys
288
289     [pdata_con] = tyConDataCons pdata_tc
290
291     from_sum _ res _ EmptySum = return (res, [])
292     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
293     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
294                                   , repr_sel_ty  = sel_ty
295                                   , repr_con_tys = tys
296                                   , repr_cons    = cons })
297       = do
298           sel  <- newLocalVar (fsLit "sel") sel_ty
299           ptys <- mapM mkPDataType tys
300           vars <- newLocalVars (fsLit "xs") ptys
301           (res', args) <- fold from_con res_ty res (map Var vars) cons
302           let scrut = unwrapFamInstScrut psum_tc tys expr
303               body  = mkWildCase scrut (exprType scrut) res_ty
304                       [(DataAlt psum_con, sel : vars, res')]
305           return (body, Var sel : args)
306       where
307         [psum_con] = tyConDataCons psum_tc
308
309
310     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
311
312     from_prod _ res _ EmptyProd = return (res, [])
313     from_prod res_ty res expr (UnaryProd r)
314       = from_comp res_ty res expr r
315     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
316                                     , repr_comp_tys = tys
317                                     , repr_comps    = comps })
318       = do
319           ptys <- mapM mkPDataType tys
320           vars <- newLocalVars (fsLit "ys") ptys
321           (res', args) <- fold from_comp res_ty res (map Var vars) comps
322           let scrut = unwrapFamInstScrut ptup_tc tys expr
323               body  = mkWildCase scrut (exprType scrut) res_ty
324                       [(DataAlt ptup_con, vars, res')]
325           return (body, args)
326       where
327         [ptup_con] = tyConDataCons ptup_tc
328
329     from_comp _ res expr (Keep _ _) = return (res, [expr])
330     from_comp _ res expr (Wrap ty)
331       = do
332           wrap_tc  <- builtin wrapTyCon
333           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
334           return (res, [unwrapNewTypeBody pwrap_tc [ty]
335                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
336
337     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
338       where
339         f' (expr, r) (res, args) = do
340                                      (res', args') <- f res_ty res expr r
341                                      return (res', args' ++ args)