2 module Vectorise.Type.PRepr
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
15 import MkCore ( mkWildCase )
28 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
29 mk_fam_inst fam_tc arg_tc
30 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
33 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
34 buildPReprTyCon orig_tc vect_tc repr
36 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
37 -- rhs_ty <- buildPReprType vect_tc
38 rhs_ty <- sumReprType repr
39 prepr_tc <- builtin preprTyCon
40 liftDs $ buildSynTyCon name
45 (Just $ mk_fam_inst prepr_tc vect_tc)
47 tyvars = tyConTyVars vect_tc
50 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
51 buildToPRepr vect_tc repr_tc _ repr
53 let arg_ty = mkTyConApp vect_tc ty_args
54 res_ty <- mkPReprType arg_ty
55 arg <- newLocalVar (fsLit "x") arg_ty
56 result <- to_sum (Var arg) arg_ty res_ty repr
57 return $ Lam arg result
59 ty_args = mkTyVarTys (tyConTyVars vect_tc)
61 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
65 void <- builtin voidVar
66 return $ wrap_repr_inst $ Var void
68 to_sum arg arg_ty res_ty (UnarySum r)
70 (pat, vars, body) <- con_alt r
71 return $ mkWildCase arg arg_ty res_ty
72 [(pat, vars, wrap_repr_inst body)]
74 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
78 alts <- mapM con_alt cons
79 let alts' = [(pat, vars, wrap_repr_inst
80 $ mkConApp sum_con (map Type tys ++ [body]))
81 | ((pat, vars, body), sum_con)
82 <- zip alts (tyConDataCons sum_tc)]
83 return $ mkWildCase arg arg_ty res_ty alts'
85 con_alt (ConRepr con r)
87 (vars, body) <- to_prod r
88 return (DataAlt con, vars, body)
92 void <- builtin voidVar
95 to_prod (UnaryProd comp)
97 var <- newLocalVar (fsLit "x") (compOrigType comp)
98 body <- to_comp (Var var) comp
101 to_prod(Prod { repr_tup_tc = tup_tc
102 , repr_comp_tys = tys
103 , repr_comps = comps })
105 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
106 exprs <- zipWithM to_comp (map Var vars) comps
107 return (vars, mkConApp tup_con (map Type tys ++ exprs))
109 [tup_con] = tyConDataCons tup_tc
111 to_comp expr (Keep _ _) = return expr
112 to_comp expr (Wrap ty) = do
113 wrap_tc <- builtin wrapTyCon
114 return $ wrapNewTypeBody wrap_tc [ty] expr
117 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
118 buildFromPRepr vect_tc repr_tc _ repr
120 arg_ty <- mkPReprType res_ty
121 arg <- newLocalVar (fsLit "x") arg_ty
123 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
125 return $ Lam arg result
127 ty_args = mkTyVarTys (tyConTyVars vect_tc)
128 res_ty = mkTyConApp vect_tc ty_args
132 dummy <- builtin fromVoidVar
133 return $ Var dummy `App` Type res_ty
135 from_sum expr (UnarySum r) = from_con expr r
136 from_sum expr (Sum { repr_sum_tc = sum_tc
138 , repr_cons = cons })
140 vars <- newLocalVars (fsLit "x") tys
141 es <- zipWithM from_con (map Var vars) cons
142 return $ mkWildCase expr (exprType expr) res_ty
143 [(DataAlt con, [var], e)
144 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
146 from_con expr (ConRepr con r)
147 = from_prod expr (mkConApp con $ map Type ty_args) r
149 from_prod _ con EmptyProd = return con
150 from_prod expr con (UnaryProd r)
152 e <- from_comp expr r
155 from_prod expr con (Prod { repr_tup_tc = tup_tc
156 , repr_comp_tys = tys
160 vars <- newLocalVars (fsLit "y") tys
161 es <- zipWithM from_comp (map Var vars) comps
162 return $ mkWildCase expr (exprType expr) res_ty
163 [(DataAlt tup_con, vars, con `mkApps` es)]
165 [tup_con] = tyConDataCons tup_tc
167 from_comp expr (Keep _ _) = return expr
168 from_comp expr (Wrap ty)
170 wrap <- builtin wrapTyCon
171 return $ unwrapNewTypeBody wrap [ty] expr
174 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
175 buildToArrPRepr vect_tc prepr_tc pdata_tc r
177 arg_ty <- mkPDataType el_ty
178 res_ty <- mkPDataType =<< mkPReprType el_ty
179 arg <- newLocalVar (fsLit "xs") arg_ty
181 pdata_co <- mkBuiltinCo pdataTyCon
182 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
183 co = mkAppCoercion pdata_co
185 $ mkTyConApp repr_co ty_args
187 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
189 (vars, result) <- to_sum r
192 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
193 [(DataAlt pdata_dc, vars, mkCoerce co result)]
195 ty_args = mkTyVarTys $ tyConTyVars vect_tc
196 el_ty = mkTyConApp vect_tc ty_args
198 [pdata_dc] = tyConDataCons pdata_tc
202 pvoid <- builtin pvoidVar
203 return ([], Var pvoid)
204 to_sum (UnarySum r) = to_con r
205 to_sum (Sum { repr_psum_tc = psum_tc
206 , repr_sel_ty = sel_ty
211 (vars, exprs) <- mapAndUnzipM to_con cons
212 sel <- newLocalVar (fsLit "sel") sel_ty
213 return (sel : concat vars, mk_result (Var sel) exprs)
215 [psum_con] = tyConDataCons psum_tc
216 mk_result sel exprs = wrapFamInstBody psum_tc tys
218 $ map Type tys ++ (sel : exprs)
220 to_con (ConRepr _ r) = to_prod r
222 to_prod EmptyProd = do
223 pvoid <- builtin pvoidVar
224 return ([], Var pvoid)
225 to_prod (UnaryProd r)
227 pty <- mkPDataType (compOrigType r)
228 var <- newLocalVar (fsLit "x") pty
229 expr <- to_comp (Var var) r
232 to_prod (Prod { repr_ptup_tc = ptup_tc
233 , repr_comp_tys = tys
234 , repr_comps = comps })
236 ptys <- mapM (mkPDataType . compOrigType) comps
237 vars <- newLocalVars (fsLit "x") ptys
238 es <- zipWithM to_comp (map Var vars) comps
239 return (vars, mk_result es)
241 [ptup_con] = tyConDataCons ptup_tc
242 mk_result exprs = wrapFamInstBody ptup_tc tys
244 $ map Type tys ++ exprs
246 to_comp expr (Keep _ _) = return expr
248 -- FIXME: this is bound to be wrong!
249 to_comp expr (Wrap ty)
251 wrap_tc <- builtin wrapTyCon
252 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
253 return $ wrapNewTypeBody pwrap_tc [ty] expr
256 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
257 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
259 arg_ty <- mkPDataType =<< mkPReprType el_ty
260 res_ty <- mkPDataType el_ty
261 arg <- newLocalVar (fsLit "xs") arg_ty
263 pdata_co <- mkBuiltinCo pdataTyCon
264 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
265 co = mkAppCoercion pdata_co
266 $ mkTyConApp repr_co var_tys
268 scrut = mkCoerce co (Var arg)
270 mk_result args = wrapFamInstBody pdata_tc var_tys
272 $ map Type var_tys ++ args
274 (expr, _) <- fixV $ \ ~(_, args) ->
275 from_sum res_ty (mk_result args) scrut r
277 return $ Lam arg expr
279 -- (args, mk) <- from_sum res_ty scrut r
281 -- let result = wrapFamInstBody pdata_tc var_tys
282 -- . mkConApp pdata_dc
283 -- $ map Type var_tys ++ args
285 -- return $ Lam arg (mk result)
287 var_tys = mkTyVarTys $ tyConTyVars vect_tc
288 el_ty = mkTyConApp vect_tc var_tys
290 [pdata_con] = tyConDataCons pdata_tc
292 from_sum _ res _ EmptySum = return (res, [])
293 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
294 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
295 , repr_sel_ty = sel_ty
297 , repr_cons = cons })
299 sel <- newLocalVar (fsLit "sel") sel_ty
300 ptys <- mapM mkPDataType tys
301 vars <- newLocalVars (fsLit "xs") ptys
302 (res', args) <- fold from_con res_ty res (map Var vars) cons
303 let scrut = unwrapFamInstScrut psum_tc tys expr
304 body = mkWildCase scrut (exprType scrut) res_ty
305 [(DataAlt psum_con, sel : vars, res')]
306 return (body, Var sel : args)
308 [psum_con] = tyConDataCons psum_tc
311 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
313 from_prod _ res _ EmptyProd = return (res, [])
314 from_prod res_ty res expr (UnaryProd r)
315 = from_comp res_ty res expr r
316 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
317 , repr_comp_tys = tys
318 , repr_comps = comps })
320 ptys <- mapM mkPDataType tys
321 vars <- newLocalVars (fsLit "ys") ptys
322 (res', args) <- fold from_comp res_ty res (map Var vars) comps
323 let scrut = unwrapFamInstScrut ptup_tc tys expr
324 body = mkWildCase scrut (exprType scrut) res_ty
325 [(DataAlt ptup_con, vars, res')]
328 [ptup_con] = tyConDataCons ptup_tc
330 from_comp _ res expr (Keep _ _) = return (res, [expr])
331 from_comp _ res expr (Wrap ty)
333 wrap_tc <- builtin wrapTyCon
334 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
335 return (res, [unwrapNewTypeBody pwrap_tc [ty]
336 $ unwrapFamInstScrut pwrap_tc [ty] expr])
338 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
340 f' (expr, r) (res, args) = do
341 (res', args') <- f res_ty res expr r
342 return (res', args' ++ args)