This BIG PATCH contains most of the work for the New Coercion Representation
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PRepr.hs
1
2 module Vectorise.Type.PRepr
3         ( buildPReprTyCon
4         , buildToPRepr
5         , buildFromPRepr
6         , buildToArrPRepr
7         , buildFromArrPRepr)
8 where
9 import Vectorise.Utils
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
13 import CoreSyn
14 import CoreUtils
15 import MkCore            ( mkWildCase )
16 import TyCon
17 import Type
18 import Kind
19 import BuildTyCl
20 import OccName
21 import Coercion
22 import MkId
23
24 import FastString
25 import MonadUtils
26 import Control.Monad
27
28
29 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
30 mk_fam_inst fam_tc arg_tc
31   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
32
33
34 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
35 buildPReprTyCon orig_tc vect_tc repr
36   = do
37       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
38       -- rhs_ty   <- buildPReprType vect_tc
39       rhs_ty   <- sumReprType repr
40       prepr_tc <- builtin preprTyCon
41       liftDs $ buildSynTyCon name
42                              tyvars
43                              (SynonymTyCon rhs_ty)
44                              (typeKind rhs_ty)
45                              NoParentTyCon
46                              (Just $ mk_fam_inst prepr_tc vect_tc)
47   where
48     tyvars = tyConTyVars vect_tc
49
50
51 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
52 buildToPRepr vect_tc repr_tc _ repr
53   = do
54       let arg_ty = mkTyConApp vect_tc ty_args
55       res_ty <- mkPReprType arg_ty
56       arg    <- newLocalVar (fsLit "x") arg_ty
57       result <- to_sum (Var arg) arg_ty res_ty repr
58       return $ Lam arg result
59   where
60     ty_args = mkTyVarTys (tyConTyVars vect_tc)
61
62     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
63
64     to_sum _ _ _ EmptySum
65       = do
66           void <- builtin voidVar
67           return $ wrap_repr_inst $ Var void
68
69     to_sum arg arg_ty res_ty (UnarySum r)
70       = do
71           (pat, vars, body) <- con_alt r
72           return $ mkWildCase arg arg_ty res_ty
73                    [(pat, vars, wrap_repr_inst body)]
74
75     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
76                                   , repr_con_tys = tys
77                                   , repr_cons    =  cons })
78       = do
79           alts <- mapM con_alt cons
80           let alts' = [(pat, vars, wrap_repr_inst
81                                    $ mkConApp sum_con (map Type tys ++ [body]))
82                         | ((pat, vars, body), sum_con)
83                             <- zip alts (tyConDataCons sum_tc)]
84           return $ mkWildCase arg arg_ty res_ty alts'
85
86     con_alt (ConRepr con r)
87       = do
88           (vars, body) <- to_prod r
89           return (DataAlt con, vars, body)
90
91     to_prod EmptyProd
92       = do
93           void <- builtin voidVar
94           return ([], Var void)
95
96     to_prod (UnaryProd comp)
97       = do
98           var  <- newLocalVar (fsLit "x") (compOrigType comp)
99           body <- to_comp (Var var) comp
100           return ([var], body)
101
102     to_prod(Prod { repr_tup_tc   = tup_tc
103                  , repr_comp_tys = tys
104                  , repr_comps    = comps })
105       = do
106           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
107           exprs <- zipWithM to_comp (map Var vars) comps
108           return (vars, mkConApp tup_con (map Type tys ++ exprs))
109       where
110         [tup_con] = tyConDataCons tup_tc
111
112     to_comp expr (Keep _ _) = return expr
113     to_comp expr (Wrap ty)  = do
114                                 wrap_tc <- builtin wrapTyCon
115                                 return $ wrapNewTypeBody wrap_tc [ty] expr
116
117
118 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
119 buildFromPRepr vect_tc repr_tc _ repr
120   = do
121       arg_ty <- mkPReprType res_ty
122       arg <- newLocalVar (fsLit "x") arg_ty
123
124       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
125                          repr
126       return $ Lam arg result
127   where
128     ty_args = mkTyVarTys (tyConTyVars vect_tc)
129     res_ty  = mkTyConApp vect_tc ty_args
130
131     from_sum _ EmptySum
132       = do
133           dummy <- builtin fromVoidVar
134           return $ Var dummy `App` Type res_ty
135
136     from_sum expr (UnarySum r) = from_con expr r
137     from_sum expr (Sum { repr_sum_tc  = sum_tc
138                        , repr_con_tys = tys
139                        , repr_cons    = cons })
140       = do
141           vars  <- newLocalVars (fsLit "x") tys
142           es    <- zipWithM from_con (map Var vars) cons
143           return $ mkWildCase expr (exprType expr) res_ty
144                    [(DataAlt con, [var], e)
145                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
146
147     from_con expr (ConRepr con r)
148       = from_prod expr (mkConApp con $ map Type ty_args) r
149
150     from_prod _ con EmptyProd = return con
151     from_prod expr con (UnaryProd r)
152       = do
153           e <- from_comp expr r
154           return $ con `App` e
155      
156     from_prod expr con (Prod { repr_tup_tc   = tup_tc
157                              , repr_comp_tys = tys
158                              , repr_comps    = comps
159                              })
160       = do
161           vars <- newLocalVars (fsLit "y") tys
162           es   <- zipWithM from_comp (map Var vars) comps
163           return $ mkWildCase expr (exprType expr) res_ty
164                    [(DataAlt tup_con, vars, con `mkApps` es)]
165       where
166         [tup_con] = tyConDataCons tup_tc  
167
168     from_comp expr (Keep _ _) = return expr
169     from_comp expr (Wrap ty)
170       = do
171           wrap <- builtin wrapTyCon
172           return $ unwrapNewTypeBody wrap [ty] expr
173
174
175 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
176 buildToArrPRepr vect_tc prepr_tc pdata_tc r
177   = do
178       arg_ty <- mkPDataType el_ty
179       res_ty <- mkPDataType =<< mkPReprType el_ty
180       arg    <- newLocalVar (fsLit "xs") arg_ty
181
182       pdata_co <- mkBuiltinCo pdataTyCon
183       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
184           co           = mkAppCo pdata_co
185                        . mkSymCo
186                        $ mkAxInstCo repr_co ty_args
187
188           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
189
190       (vars, result) <- to_sum r
191
192       return . Lam arg
193              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
194                [(DataAlt pdata_dc, vars, mkCoerce co result)]
195   where
196     ty_args = mkTyVarTys $ tyConTyVars vect_tc
197     el_ty   = mkTyConApp vect_tc ty_args
198
199     [pdata_dc] = tyConDataCons pdata_tc
200
201
202     to_sum EmptySum = do
203                         pvoid <- builtin pvoidVar
204                         return ([], Var pvoid)
205     to_sum (UnarySum r) = to_con r
206     to_sum (Sum { repr_psum_tc = psum_tc
207                 , repr_sel_ty  = sel_ty
208                 , repr_con_tys = tys
209                 , repr_cons    = cons
210                 })
211       = do
212           (vars, exprs) <- mapAndUnzipM to_con cons
213           sel <- newLocalVar (fsLit "sel") sel_ty
214           return (sel : concat vars, mk_result (Var sel) exprs)
215       where
216         [psum_con] = tyConDataCons psum_tc
217         mk_result sel exprs = wrapFamInstBody psum_tc tys
218                             $ mkConApp psum_con
219                             $ map Type tys ++ (sel : exprs)
220
221     to_con (ConRepr _ r) = to_prod r
222
223     to_prod EmptyProd = do
224                           pvoid <- builtin pvoidVar
225                           return ([], Var pvoid)
226     to_prod (UnaryProd r)
227       = do
228           pty  <- mkPDataType (compOrigType r)
229           var  <- newLocalVar (fsLit "x") pty
230           expr <- to_comp (Var var) r
231           return ([var], expr)
232
233     to_prod (Prod { repr_ptup_tc  = ptup_tc
234                   , repr_comp_tys = tys
235                   , repr_comps    = comps })
236       = do
237           ptys <- mapM (mkPDataType . compOrigType) comps
238           vars <- newLocalVars (fsLit "x") ptys
239           es   <- zipWithM to_comp (map Var vars) comps
240           return (vars, mk_result es)
241       where
242         [ptup_con] = tyConDataCons ptup_tc
243         mk_result exprs = wrapFamInstBody ptup_tc tys
244                         $ mkConApp ptup_con
245                         $ map Type tys ++ exprs
246
247     to_comp expr (Keep _ _) = return expr
248
249     -- FIXME: this is bound to be wrong!
250     to_comp expr (Wrap ty)
251       = do
252           wrap_tc  <- builtin wrapTyCon
253           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
254           return $ wrapNewTypeBody pwrap_tc [ty] expr
255
256
257 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
258 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
259   = do
260       arg_ty <- mkPDataType =<< mkPReprType el_ty
261       res_ty <- mkPDataType el_ty
262       arg    <- newLocalVar (fsLit "xs") arg_ty
263
264       pdata_co <- mkBuiltinCo pdataTyCon
265       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
266           co           = mkAppCo pdata_co
267                        $ mkAxInstCo repr_co var_tys
268
269           scrut  = mkCoerce co (Var arg)
270
271           mk_result args = wrapFamInstBody pdata_tc var_tys
272                          $ mkConApp pdata_con
273                          $ map Type var_tys ++ args
274
275       (expr, _) <- fixV $ \ ~(_, args) ->
276                      from_sum res_ty (mk_result args) scrut r
277
278       return $ Lam arg expr
279     
280       -- (args, mk) <- from_sum res_ty scrut r
281       
282       -- let result = wrapFamInstBody pdata_tc var_tys
283       --           . mkConApp pdata_dc
284       --           $ map Type var_tys ++ args
285
286       -- return $ Lam arg (mk result)
287   where
288     var_tys = mkTyVarTys $ tyConTyVars vect_tc
289     el_ty   = mkTyConApp vect_tc var_tys
290
291     [pdata_con] = tyConDataCons pdata_tc
292
293     from_sum _ res _ EmptySum = return (res, [])
294     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
295     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
296                                   , repr_sel_ty  = sel_ty
297                                   , repr_con_tys = tys
298                                   , repr_cons    = cons })
299       = do
300           sel  <- newLocalVar (fsLit "sel") sel_ty
301           ptys <- mapM mkPDataType tys
302           vars <- newLocalVars (fsLit "xs") ptys
303           (res', args) <- fold from_con res_ty res (map Var vars) cons
304           let scrut = unwrapFamInstScrut psum_tc tys expr
305               body  = mkWildCase scrut (exprType scrut) res_ty
306                       [(DataAlt psum_con, sel : vars, res')]
307           return (body, Var sel : args)
308       where
309         [psum_con] = tyConDataCons psum_tc
310
311
312     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
313
314     from_prod _ res _ EmptyProd = return (res, [])
315     from_prod res_ty res expr (UnaryProd r)
316       = from_comp res_ty res expr r
317     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
318                                     , repr_comp_tys = tys
319                                     , repr_comps    = comps })
320       = do
321           ptys <- mapM mkPDataType tys
322           vars <- newLocalVars (fsLit "ys") ptys
323           (res', args) <- fold from_comp res_ty res (map Var vars) comps
324           let scrut = unwrapFamInstScrut ptup_tc tys expr
325               body  = mkWildCase scrut (exprType scrut) res_ty
326                       [(DataAlt ptup_con, vars, res')]
327           return (body, args)
328       where
329         [ptup_con] = tyConDataCons ptup_tc
330
331     from_comp _ res expr (Keep _ _) = return (res, [expr])
332     from_comp _ res expr (Wrap ty)
333       = do
334           wrap_tc  <- builtin wrapTyCon
335           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
336           return (res, [unwrapNewTypeBody pwrap_tc [ty]
337                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
338
339     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
340       where
341         f' (expr, r) (res, args) = do
342                                      (res', args') <- f res_ty res expr r
343                                      return (res', args' ++ args)