2 module Vectorise.Type.PADict
10 import Vectorise.Builtins
11 import Vectorise.Type.Repr
14 import MkCore ( mkWildCase )
25 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
26 buildToPRepr vect_tc repr_tc _ repr
28 let arg_ty = mkTyConApp vect_tc ty_args
29 res_ty <- mkPReprType arg_ty
30 arg <- newLocalVar (fsLit "x") arg_ty
31 result <- to_sum (Var arg) arg_ty res_ty repr
32 return $ Lam arg result
34 ty_args = mkTyVarTys (tyConTyVars vect_tc)
36 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
40 void <- builtin voidVar
41 return $ wrap_repr_inst $ Var void
43 to_sum arg arg_ty res_ty (UnarySum r)
45 (pat, vars, body) <- con_alt r
46 return $ mkWildCase arg arg_ty res_ty
47 [(pat, vars, wrap_repr_inst body)]
49 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
53 alts <- mapM con_alt cons
54 let alts' = [(pat, vars, wrap_repr_inst
55 $ mkConApp sum_con (map Type tys ++ [body]))
56 | ((pat, vars, body), sum_con)
57 <- zip alts (tyConDataCons sum_tc)]
58 return $ mkWildCase arg arg_ty res_ty alts'
60 con_alt (ConRepr con r)
62 (vars, body) <- to_prod r
63 return (DataAlt con, vars, body)
67 void <- builtin voidVar
70 to_prod (UnaryProd comp)
72 var <- newLocalVar (fsLit "x") (compOrigType comp)
73 body <- to_comp (Var var) comp
76 to_prod(Prod { repr_tup_tc = tup_tc
78 , repr_comps = comps })
80 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
81 exprs <- zipWithM to_comp (map Var vars) comps
82 return (vars, mkConApp tup_con (map Type tys ++ exprs))
84 [tup_con] = tyConDataCons tup_tc
86 to_comp expr (Keep _ _) = return expr
87 to_comp expr (Wrap ty) = do
88 wrap_tc <- builtin wrapTyCon
89 return $ wrapNewTypeBody wrap_tc [ty] expr
92 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
93 buildFromPRepr vect_tc repr_tc _ repr
95 arg_ty <- mkPReprType res_ty
96 arg <- newLocalVar (fsLit "x") arg_ty
98 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
100 return $ Lam arg result
102 ty_args = mkTyVarTys (tyConTyVars vect_tc)
103 res_ty = mkTyConApp vect_tc ty_args
107 dummy <- builtin fromVoidVar
108 return $ Var dummy `App` Type res_ty
110 from_sum expr (UnarySum r) = from_con expr r
111 from_sum expr (Sum { repr_sum_tc = sum_tc
113 , repr_cons = cons })
115 vars <- newLocalVars (fsLit "x") tys
116 es <- zipWithM from_con (map Var vars) cons
117 return $ mkWildCase expr (exprType expr) res_ty
118 [(DataAlt con, [var], e)
119 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
121 from_con expr (ConRepr con r)
122 = from_prod expr (mkConApp con $ map Type ty_args) r
124 from_prod _ con EmptyProd = return con
125 from_prod expr con (UnaryProd r)
127 e <- from_comp expr r
130 from_prod expr con (Prod { repr_tup_tc = tup_tc
131 , repr_comp_tys = tys
135 vars <- newLocalVars (fsLit "y") tys
136 es <- zipWithM from_comp (map Var vars) comps
137 return $ mkWildCase expr (exprType expr) res_ty
138 [(DataAlt tup_con, vars, con `mkApps` es)]
140 [tup_con] = tyConDataCons tup_tc
142 from_comp expr (Keep _ _) = return expr
143 from_comp expr (Wrap ty)
145 wrap <- builtin wrapTyCon
146 return $ unwrapNewTypeBody wrap [ty] expr
149 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
150 buildToArrPRepr vect_tc prepr_tc pdata_tc r
152 arg_ty <- mkPDataType el_ty
153 res_ty <- mkPDataType =<< mkPReprType el_ty
154 arg <- newLocalVar (fsLit "xs") arg_ty
156 pdata_co <- mkBuiltinCo pdataTyCon
157 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
158 co = mkAppCoercion pdata_co
160 $ mkTyConApp repr_co ty_args
162 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
164 (vars, result) <- to_sum r
167 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
168 [(DataAlt pdata_dc, vars, mkCoerce co result)]
170 ty_args = mkTyVarTys $ tyConTyVars vect_tc
171 el_ty = mkTyConApp vect_tc ty_args
173 [pdata_dc] = tyConDataCons pdata_tc
177 pvoid <- builtin pvoidVar
178 return ([], Var pvoid)
179 to_sum (UnarySum r) = to_con r
180 to_sum (Sum { repr_psum_tc = psum_tc
181 , repr_sel_ty = sel_ty
186 (vars, exprs) <- mapAndUnzipM to_con cons
187 sel <- newLocalVar (fsLit "sel") sel_ty
188 return (sel : concat vars, mk_result (Var sel) exprs)
190 [psum_con] = tyConDataCons psum_tc
191 mk_result sel exprs = wrapFamInstBody psum_tc tys
193 $ map Type tys ++ (sel : exprs)
195 to_con (ConRepr _ r) = to_prod r
197 to_prod EmptyProd = do
198 pvoid <- builtin pvoidVar
199 return ([], Var pvoid)
200 to_prod (UnaryProd r)
202 pty <- mkPDataType (compOrigType r)
203 var <- newLocalVar (fsLit "x") pty
204 expr <- to_comp (Var var) r
207 to_prod (Prod { repr_ptup_tc = ptup_tc
208 , repr_comp_tys = tys
209 , repr_comps = comps })
211 ptys <- mapM (mkPDataType . compOrigType) comps
212 vars <- newLocalVars (fsLit "x") ptys
213 es <- zipWithM to_comp (map Var vars) comps
214 return (vars, mk_result es)
216 [ptup_con] = tyConDataCons ptup_tc
217 mk_result exprs = wrapFamInstBody ptup_tc tys
219 $ map Type tys ++ exprs
221 to_comp expr (Keep _ _) = return expr
223 -- FIXME: this is bound to be wrong!
224 to_comp expr (Wrap ty)
226 wrap_tc <- builtin wrapTyCon
227 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
228 return $ wrapNewTypeBody pwrap_tc [ty] expr
231 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
232 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
234 arg_ty <- mkPDataType =<< mkPReprType el_ty
235 res_ty <- mkPDataType el_ty
236 arg <- newLocalVar (fsLit "xs") arg_ty
238 pdata_co <- mkBuiltinCo pdataTyCon
239 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
240 co = mkAppCoercion pdata_co
241 $ mkTyConApp repr_co var_tys
243 scrut = mkCoerce co (Var arg)
245 mk_result args = wrapFamInstBody pdata_tc var_tys
247 $ map Type var_tys ++ args
249 (expr, _) <- fixV $ \ ~(_, args) ->
250 from_sum res_ty (mk_result args) scrut r
252 return $ Lam arg expr
254 -- (args, mk) <- from_sum res_ty scrut r
256 -- let result = wrapFamInstBody pdata_tc var_tys
257 -- . mkConApp pdata_dc
258 -- $ map Type var_tys ++ args
260 -- return $ Lam arg (mk result)
262 var_tys = mkTyVarTys $ tyConTyVars vect_tc
263 el_ty = mkTyConApp vect_tc var_tys
265 [pdata_con] = tyConDataCons pdata_tc
267 from_sum _ res _ EmptySum = return (res, [])
268 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
269 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
270 , repr_sel_ty = sel_ty
272 , repr_cons = cons })
274 sel <- newLocalVar (fsLit "sel") sel_ty
275 ptys <- mapM mkPDataType tys
276 vars <- newLocalVars (fsLit "xs") ptys
277 (res', args) <- fold from_con res_ty res (map Var vars) cons
278 let scrut = unwrapFamInstScrut psum_tc tys expr
279 body = mkWildCase scrut (exprType scrut) res_ty
280 [(DataAlt psum_con, sel : vars, res')]
281 return (body, Var sel : args)
283 [psum_con] = tyConDataCons psum_tc
286 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
288 from_prod _ res _ EmptyProd = return (res, [])
289 from_prod res_ty res expr (UnaryProd r)
290 = from_comp res_ty res expr r
291 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
292 , repr_comp_tys = tys
293 , repr_comps = comps })
295 ptys <- mapM mkPDataType tys
296 vars <- newLocalVars (fsLit "ys") ptys
297 (res', args) <- fold from_comp res_ty res (map Var vars) comps
298 let scrut = unwrapFamInstScrut ptup_tc tys expr
299 body = mkWildCase scrut (exprType scrut) res_ty
300 [(DataAlt ptup_con, vars, res')]
303 [ptup_con] = tyConDataCons ptup_tc
305 from_comp _ res expr (Keep _ _) = return (res, [expr])
306 from_comp _ res expr (Wrap ty)
308 wrap_tc <- builtin wrapTyCon
309 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
310 return (res, [unwrapNewTypeBody pwrap_tc [ty]
311 $ unwrapFamInstScrut pwrap_tc [ty] expr])
313 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
315 f' (expr, r) (res, args) = do
316 (res', args') <- f res_ty res expr r
317 return (res', args' ++ args)