2 module Vectorise.Type.PRepr
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
15 import MkCore ( mkWildCase )
28 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
29 mk_fam_inst fam_tc arg_tc
30 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
33 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
34 buildPReprTyCon orig_tc vect_tc repr
36 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
37 -- rhs_ty <- buildPReprType vect_tc
38 rhs_ty <- sumReprType repr
39 prepr_tc <- builtin preprTyCon
40 liftDs $ buildSynTyCon name
44 (Just $ mk_fam_inst prepr_tc vect_tc)
46 tyvars = tyConTyVars vect_tc
49 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
50 buildToPRepr vect_tc repr_tc _ repr
52 let arg_ty = mkTyConApp vect_tc ty_args
53 res_ty <- mkPReprType arg_ty
54 arg <- newLocalVar (fsLit "x") arg_ty
55 result <- to_sum (Var arg) arg_ty res_ty repr
56 return $ Lam arg result
58 ty_args = mkTyVarTys (tyConTyVars vect_tc)
60 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
64 void <- builtin voidVar
65 return $ wrap_repr_inst $ Var void
67 to_sum arg arg_ty res_ty (UnarySum r)
69 (pat, vars, body) <- con_alt r
70 return $ mkWildCase arg arg_ty res_ty
71 [(pat, vars, wrap_repr_inst body)]
73 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
77 alts <- mapM con_alt cons
78 let alts' = [(pat, vars, wrap_repr_inst
79 $ mkConApp sum_con (map Type tys ++ [body]))
80 | ((pat, vars, body), sum_con)
81 <- zip alts (tyConDataCons sum_tc)]
82 return $ mkWildCase arg arg_ty res_ty alts'
84 con_alt (ConRepr con r)
86 (vars, body) <- to_prod r
87 return (DataAlt con, vars, body)
91 void <- builtin voidVar
94 to_prod (UnaryProd comp)
96 var <- newLocalVar (fsLit "x") (compOrigType comp)
97 body <- to_comp (Var var) comp
100 to_prod(Prod { repr_tup_tc = tup_tc
101 , repr_comp_tys = tys
102 , repr_comps = comps })
104 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
105 exprs <- zipWithM to_comp (map Var vars) comps
106 return (vars, mkConApp tup_con (map Type tys ++ exprs))
108 [tup_con] = tyConDataCons tup_tc
110 to_comp expr (Keep _ _) = return expr
111 to_comp expr (Wrap ty) = do
112 wrap_tc <- builtin wrapTyCon
113 return $ wrapNewTypeBody wrap_tc [ty] expr
116 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
117 buildFromPRepr vect_tc repr_tc _ repr
119 arg_ty <- mkPReprType res_ty
120 arg <- newLocalVar (fsLit "x") arg_ty
122 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
124 return $ Lam arg result
126 ty_args = mkTyVarTys (tyConTyVars vect_tc)
127 res_ty = mkTyConApp vect_tc ty_args
131 dummy <- builtin fromVoidVar
132 return $ Var dummy `App` Type res_ty
134 from_sum expr (UnarySum r) = from_con expr r
135 from_sum expr (Sum { repr_sum_tc = sum_tc
137 , repr_cons = cons })
139 vars <- newLocalVars (fsLit "x") tys
140 es <- zipWithM from_con (map Var vars) cons
141 return $ mkWildCase expr (exprType expr) res_ty
142 [(DataAlt con, [var], e)
143 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
145 from_con expr (ConRepr con r)
146 = from_prod expr (mkConApp con $ map Type ty_args) r
148 from_prod _ con EmptyProd = return con
149 from_prod expr con (UnaryProd r)
151 e <- from_comp expr r
154 from_prod expr con (Prod { repr_tup_tc = tup_tc
155 , repr_comp_tys = tys
159 vars <- newLocalVars (fsLit "y") tys
160 es <- zipWithM from_comp (map Var vars) comps
161 return $ mkWildCase expr (exprType expr) res_ty
162 [(DataAlt tup_con, vars, con `mkApps` es)]
164 [tup_con] = tyConDataCons tup_tc
166 from_comp expr (Keep _ _) = return expr
167 from_comp expr (Wrap ty)
169 wrap <- builtin wrapTyCon
170 return $ unwrapNewTypeBody wrap [ty] expr
173 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
174 buildToArrPRepr vect_tc prepr_tc pdata_tc r
176 arg_ty <- mkPDataType el_ty
177 res_ty <- mkPDataType =<< mkPReprType el_ty
178 arg <- newLocalVar (fsLit "xs") arg_ty
180 pdata_co <- mkBuiltinCo pdataTyCon
181 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
182 co = mkAppCoercion pdata_co
184 $ mkTyConApp repr_co ty_args
186 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
188 (vars, result) <- to_sum r
191 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
192 [(DataAlt pdata_dc, vars, mkCoerce co result)]
194 ty_args = mkTyVarTys $ tyConTyVars vect_tc
195 el_ty = mkTyConApp vect_tc ty_args
197 [pdata_dc] = tyConDataCons pdata_tc
201 pvoid <- builtin pvoidVar
202 return ([], Var pvoid)
203 to_sum (UnarySum r) = to_con r
204 to_sum (Sum { repr_psum_tc = psum_tc
205 , repr_sel_ty = sel_ty
210 (vars, exprs) <- mapAndUnzipM to_con cons
211 sel <- newLocalVar (fsLit "sel") sel_ty
212 return (sel : concat vars, mk_result (Var sel) exprs)
214 [psum_con] = tyConDataCons psum_tc
215 mk_result sel exprs = wrapFamInstBody psum_tc tys
217 $ map Type tys ++ (sel : exprs)
219 to_con (ConRepr _ r) = to_prod r
221 to_prod EmptyProd = do
222 pvoid <- builtin pvoidVar
223 return ([], Var pvoid)
224 to_prod (UnaryProd r)
226 pty <- mkPDataType (compOrigType r)
227 var <- newLocalVar (fsLit "x") pty
228 expr <- to_comp (Var var) r
231 to_prod (Prod { repr_ptup_tc = ptup_tc
232 , repr_comp_tys = tys
233 , repr_comps = comps })
235 ptys <- mapM (mkPDataType . compOrigType) comps
236 vars <- newLocalVars (fsLit "x") ptys
237 es <- zipWithM to_comp (map Var vars) comps
238 return (vars, mk_result es)
240 [ptup_con] = tyConDataCons ptup_tc
241 mk_result exprs = wrapFamInstBody ptup_tc tys
243 $ map Type tys ++ exprs
245 to_comp expr (Keep _ _) = return expr
247 -- FIXME: this is bound to be wrong!
248 to_comp expr (Wrap ty)
250 wrap_tc <- builtin wrapTyCon
251 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
252 return $ wrapNewTypeBody pwrap_tc [ty] expr
255 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
256 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
258 arg_ty <- mkPDataType =<< mkPReprType el_ty
259 res_ty <- mkPDataType el_ty
260 arg <- newLocalVar (fsLit "xs") arg_ty
262 pdata_co <- mkBuiltinCo pdataTyCon
263 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
264 co = mkAppCoercion pdata_co
265 $ mkTyConApp repr_co var_tys
267 scrut = mkCoerce co (Var arg)
269 mk_result args = wrapFamInstBody pdata_tc var_tys
271 $ map Type var_tys ++ args
273 (expr, _) <- fixV $ \ ~(_, args) ->
274 from_sum res_ty (mk_result args) scrut r
276 return $ Lam arg expr
278 -- (args, mk) <- from_sum res_ty scrut r
280 -- let result = wrapFamInstBody pdata_tc var_tys
281 -- . mkConApp pdata_dc
282 -- $ map Type var_tys ++ args
284 -- return $ Lam arg (mk result)
286 var_tys = mkTyVarTys $ tyConTyVars vect_tc
287 el_ty = mkTyConApp vect_tc var_tys
289 [pdata_con] = tyConDataCons pdata_tc
291 from_sum _ res _ EmptySum = return (res, [])
292 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
293 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
294 , repr_sel_ty = sel_ty
296 , repr_cons = cons })
298 sel <- newLocalVar (fsLit "sel") sel_ty
299 ptys <- mapM mkPDataType tys
300 vars <- newLocalVars (fsLit "xs") ptys
301 (res', args) <- fold from_con res_ty res (map Var vars) cons
302 let scrut = unwrapFamInstScrut psum_tc tys expr
303 body = mkWildCase scrut (exprType scrut) res_ty
304 [(DataAlt psum_con, sel : vars, res')]
305 return (body, Var sel : args)
307 [psum_con] = tyConDataCons psum_tc
310 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
312 from_prod _ res _ EmptyProd = return (res, [])
313 from_prod res_ty res expr (UnaryProd r)
314 = from_comp res_ty res expr r
315 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
316 , repr_comp_tys = tys
317 , repr_comps = comps })
319 ptys <- mapM mkPDataType tys
320 vars <- newLocalVars (fsLit "ys") ptys
321 (res', args) <- fold from_comp res_ty res (map Var vars) comps
322 let scrut = unwrapFamInstScrut ptup_tc tys expr
323 body = mkWildCase scrut (exprType scrut) res_ty
324 [(DataAlt ptup_con, vars, res')]
327 [ptup_con] = tyConDataCons ptup_tc
329 from_comp _ res expr (Keep _ _) = return (res, [expr])
330 from_comp _ res expr (Wrap ty)
332 wrap_tc <- builtin wrapTyCon
333 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
334 return (res, [unwrapNewTypeBody pwrap_tc [ty]
335 $ unwrapFamInstScrut pwrap_tc [ty] expr])
337 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
339 f' (expr, r) (res, args) = do
340 (res', args') <- f res_ty res expr r
341 return (res', args' ++ args)