Super-monster patch implementing the new typechecker -- at last
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PRepr.hs
1
2 module Vectorise.Type.PRepr
3         ( buildPReprTyCon
4         , buildToPRepr
5         , buildFromPRepr
6         , buildToArrPRepr
7         , buildFromArrPRepr)
8 where
9 import Vectorise.Utils
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
13 import CoreSyn
14 import CoreUtils
15 import MkCore            ( mkWildCase )
16 import TyCon
17 import Type
18 import BuildTyCl
19 import OccName
20 import Coercion
21 import MkId
22
23 import FastString
24 import MonadUtils
25 import Control.Monad
26
27
28 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
29 mk_fam_inst fam_tc arg_tc
30   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
31
32
33 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
34 buildPReprTyCon orig_tc vect_tc repr
35   = do
36       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
37       -- rhs_ty   <- buildPReprType vect_tc
38       rhs_ty   <- sumReprType repr
39       prepr_tc <- builtin preprTyCon
40       liftDs $ buildSynTyCon name
41                              tyvars
42                              (SynonymTyCon rhs_ty)
43                              (typeKind rhs_ty)
44                              NoParentTyCon
45                              (Just $ mk_fam_inst prepr_tc vect_tc)
46   where
47     tyvars = tyConTyVars vect_tc
48
49
50 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
51 buildToPRepr vect_tc repr_tc _ repr
52   = do
53       let arg_ty = mkTyConApp vect_tc ty_args
54       res_ty <- mkPReprType arg_ty
55       arg    <- newLocalVar (fsLit "x") arg_ty
56       result <- to_sum (Var arg) arg_ty res_ty repr
57       return $ Lam arg result
58   where
59     ty_args = mkTyVarTys (tyConTyVars vect_tc)
60
61     wrap_repr_inst = wrapFamInstBody repr_tc ty_args
62
63     to_sum _ _ _ EmptySum
64       = do
65           void <- builtin voidVar
66           return $ wrap_repr_inst $ Var void
67
68     to_sum arg arg_ty res_ty (UnarySum r)
69       = do
70           (pat, vars, body) <- con_alt r
71           return $ mkWildCase arg arg_ty res_ty
72                    [(pat, vars, wrap_repr_inst body)]
73
74     to_sum arg arg_ty res_ty (Sum { repr_sum_tc  = sum_tc
75                                   , repr_con_tys = tys
76                                   , repr_cons    =  cons })
77       = do
78           alts <- mapM con_alt cons
79           let alts' = [(pat, vars, wrap_repr_inst
80                                    $ mkConApp sum_con (map Type tys ++ [body]))
81                         | ((pat, vars, body), sum_con)
82                             <- zip alts (tyConDataCons sum_tc)]
83           return $ mkWildCase arg arg_ty res_ty alts'
84
85     con_alt (ConRepr con r)
86       = do
87           (vars, body) <- to_prod r
88           return (DataAlt con, vars, body)
89
90     to_prod EmptyProd
91       = do
92           void <- builtin voidVar
93           return ([], Var void)
94
95     to_prod (UnaryProd comp)
96       = do
97           var  <- newLocalVar (fsLit "x") (compOrigType comp)
98           body <- to_comp (Var var) comp
99           return ([var], body)
100
101     to_prod(Prod { repr_tup_tc   = tup_tc
102                  , repr_comp_tys = tys
103                  , repr_comps    = comps })
104       = do
105           vars  <- newLocalVars (fsLit "x") (map compOrigType comps)
106           exprs <- zipWithM to_comp (map Var vars) comps
107           return (vars, mkConApp tup_con (map Type tys ++ exprs))
108       where
109         [tup_con] = tyConDataCons tup_tc
110
111     to_comp expr (Keep _ _) = return expr
112     to_comp expr (Wrap ty)  = do
113                                 wrap_tc <- builtin wrapTyCon
114                                 return $ wrapNewTypeBody wrap_tc [ty] expr
115
116
117 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
118 buildFromPRepr vect_tc repr_tc _ repr
119   = do
120       arg_ty <- mkPReprType res_ty
121       arg <- newLocalVar (fsLit "x") arg_ty
122
123       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
124                          repr
125       return $ Lam arg result
126   where
127     ty_args = mkTyVarTys (tyConTyVars vect_tc)
128     res_ty  = mkTyConApp vect_tc ty_args
129
130     from_sum _ EmptySum
131       = do
132           dummy <- builtin fromVoidVar
133           return $ Var dummy `App` Type res_ty
134
135     from_sum expr (UnarySum r) = from_con expr r
136     from_sum expr (Sum { repr_sum_tc  = sum_tc
137                        , repr_con_tys = tys
138                        , repr_cons    = cons })
139       = do
140           vars  <- newLocalVars (fsLit "x") tys
141           es    <- zipWithM from_con (map Var vars) cons
142           return $ mkWildCase expr (exprType expr) res_ty
143                    [(DataAlt con, [var], e)
144                       | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
145
146     from_con expr (ConRepr con r)
147       = from_prod expr (mkConApp con $ map Type ty_args) r
148
149     from_prod _ con EmptyProd = return con
150     from_prod expr con (UnaryProd r)
151       = do
152           e <- from_comp expr r
153           return $ con `App` e
154      
155     from_prod expr con (Prod { repr_tup_tc   = tup_tc
156                              , repr_comp_tys = tys
157                              , repr_comps    = comps
158                              })
159       = do
160           vars <- newLocalVars (fsLit "y") tys
161           es   <- zipWithM from_comp (map Var vars) comps
162           return $ mkWildCase expr (exprType expr) res_ty
163                    [(DataAlt tup_con, vars, con `mkApps` es)]
164       where
165         [tup_con] = tyConDataCons tup_tc  
166
167     from_comp expr (Keep _ _) = return expr
168     from_comp expr (Wrap ty)
169       = do
170           wrap <- builtin wrapTyCon
171           return $ unwrapNewTypeBody wrap [ty] expr
172
173
174 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
175 buildToArrPRepr vect_tc prepr_tc pdata_tc r
176   = do
177       arg_ty <- mkPDataType el_ty
178       res_ty <- mkPDataType =<< mkPReprType el_ty
179       arg    <- newLocalVar (fsLit "xs") arg_ty
180
181       pdata_co <- mkBuiltinCo pdataTyCon
182       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
183           co           = mkAppCoercion pdata_co
184                        . mkSymCoercion
185                        $ mkTyConApp repr_co ty_args
186
187           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
188
189       (vars, result) <- to_sum r
190
191       return . Lam arg
192              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
193                [(DataAlt pdata_dc, vars, mkCoerce co result)]
194   where
195     ty_args = mkTyVarTys $ tyConTyVars vect_tc
196     el_ty   = mkTyConApp vect_tc ty_args
197
198     [pdata_dc] = tyConDataCons pdata_tc
199
200
201     to_sum EmptySum = do
202                         pvoid <- builtin pvoidVar
203                         return ([], Var pvoid)
204     to_sum (UnarySum r) = to_con r
205     to_sum (Sum { repr_psum_tc = psum_tc
206                 , repr_sel_ty  = sel_ty
207                 , repr_con_tys = tys
208                 , repr_cons    = cons
209                 })
210       = do
211           (vars, exprs) <- mapAndUnzipM to_con cons
212           sel <- newLocalVar (fsLit "sel") sel_ty
213           return (sel : concat vars, mk_result (Var sel) exprs)
214       where
215         [psum_con] = tyConDataCons psum_tc
216         mk_result sel exprs = wrapFamInstBody psum_tc tys
217                             $ mkConApp psum_con
218                             $ map Type tys ++ (sel : exprs)
219
220     to_con (ConRepr _ r) = to_prod r
221
222     to_prod EmptyProd = do
223                           pvoid <- builtin pvoidVar
224                           return ([], Var pvoid)
225     to_prod (UnaryProd r)
226       = do
227           pty  <- mkPDataType (compOrigType r)
228           var  <- newLocalVar (fsLit "x") pty
229           expr <- to_comp (Var var) r
230           return ([var], expr)
231
232     to_prod (Prod { repr_ptup_tc  = ptup_tc
233                   , repr_comp_tys = tys
234                   , repr_comps    = comps })
235       = do
236           ptys <- mapM (mkPDataType . compOrigType) comps
237           vars <- newLocalVars (fsLit "x") ptys
238           es   <- zipWithM to_comp (map Var vars) comps
239           return (vars, mk_result es)
240       where
241         [ptup_con] = tyConDataCons ptup_tc
242         mk_result exprs = wrapFamInstBody ptup_tc tys
243                         $ mkConApp ptup_con
244                         $ map Type tys ++ exprs
245
246     to_comp expr (Keep _ _) = return expr
247
248     -- FIXME: this is bound to be wrong!
249     to_comp expr (Wrap ty)
250       = do
251           wrap_tc  <- builtin wrapTyCon
252           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
253           return $ wrapNewTypeBody pwrap_tc [ty] expr
254
255
256 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
257 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
258   = do
259       arg_ty <- mkPDataType =<< mkPReprType el_ty
260       res_ty <- mkPDataType el_ty
261       arg    <- newLocalVar (fsLit "xs") arg_ty
262
263       pdata_co <- mkBuiltinCo pdataTyCon
264       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
265           co           = mkAppCoercion pdata_co
266                        $ mkTyConApp repr_co var_tys
267
268           scrut  = mkCoerce co (Var arg)
269
270           mk_result args = wrapFamInstBody pdata_tc var_tys
271                          $ mkConApp pdata_con
272                          $ map Type var_tys ++ args
273
274       (expr, _) <- fixV $ \ ~(_, args) ->
275                      from_sum res_ty (mk_result args) scrut r
276
277       return $ Lam arg expr
278     
279       -- (args, mk) <- from_sum res_ty scrut r
280       
281       -- let result = wrapFamInstBody pdata_tc var_tys
282       --           . mkConApp pdata_dc
283       --           $ map Type var_tys ++ args
284
285       -- return $ Lam arg (mk result)
286   where
287     var_tys = mkTyVarTys $ tyConTyVars vect_tc
288     el_ty   = mkTyConApp vect_tc var_tys
289
290     [pdata_con] = tyConDataCons pdata_tc
291
292     from_sum _ res _ EmptySum = return (res, [])
293     from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
294     from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
295                                   , repr_sel_ty  = sel_ty
296                                   , repr_con_tys = tys
297                                   , repr_cons    = cons })
298       = do
299           sel  <- newLocalVar (fsLit "sel") sel_ty
300           ptys <- mapM mkPDataType tys
301           vars <- newLocalVars (fsLit "xs") ptys
302           (res', args) <- fold from_con res_ty res (map Var vars) cons
303           let scrut = unwrapFamInstScrut psum_tc tys expr
304               body  = mkWildCase scrut (exprType scrut) res_ty
305                       [(DataAlt psum_con, sel : vars, res')]
306           return (body, Var sel : args)
307       where
308         [psum_con] = tyConDataCons psum_tc
309
310
311     from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
312
313     from_prod _ res _ EmptyProd = return (res, [])
314     from_prod res_ty res expr (UnaryProd r)
315       = from_comp res_ty res expr r
316     from_prod res_ty res expr (Prod { repr_ptup_tc  = ptup_tc
317                                     , repr_comp_tys = tys
318                                     , repr_comps    = comps })
319       = do
320           ptys <- mapM mkPDataType tys
321           vars <- newLocalVars (fsLit "ys") ptys
322           (res', args) <- fold from_comp res_ty res (map Var vars) comps
323           let scrut = unwrapFamInstScrut ptup_tc tys expr
324               body  = mkWildCase scrut (exprType scrut) res_ty
325                       [(DataAlt ptup_con, vars, res')]
326           return (body, args)
327       where
328         [ptup_con] = tyConDataCons ptup_tc
329
330     from_comp _ res expr (Keep _ _) = return (res, [expr])
331     from_comp _ res expr (Wrap ty)
332       = do
333           wrap_tc  <- builtin wrapTyCon
334           (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
335           return (res, [unwrapNewTypeBody pwrap_tc [ty]
336                         $ unwrapFamInstScrut pwrap_tc [ty] expr])
337
338     fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
339       where
340         f' (expr, r) (res, args) = do
341                                      (res', args') <- f res_ty res expr r
342                                      return (res', args' ++ args)