2 module Vectorise.Type.PRepr
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.Repr
15 import MkCore ( mkWildCase )
29 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
30 mk_fam_inst fam_tc arg_tc
31 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
34 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
35 buildPReprTyCon orig_tc vect_tc repr
37 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
38 -- rhs_ty <- buildPReprType vect_tc
39 rhs_ty <- sumReprType repr
40 prepr_tc <- builtin preprTyCon
41 liftDs $ buildSynTyCon name
46 (Just $ mk_fam_inst prepr_tc vect_tc)
48 tyvars = tyConTyVars vect_tc
51 buildToPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
52 buildToPRepr vect_tc repr_tc _ repr
54 let arg_ty = mkTyConApp vect_tc ty_args
55 res_ty <- mkPReprType arg_ty
56 arg <- newLocalVar (fsLit "x") arg_ty
57 result <- to_sum (Var arg) arg_ty res_ty repr
58 return $ Lam arg result
60 ty_args = mkTyVarTys (tyConTyVars vect_tc)
62 wrap_repr_inst = wrapFamInstBody repr_tc ty_args
66 void <- builtin voidVar
67 return $ wrap_repr_inst $ Var void
69 to_sum arg arg_ty res_ty (UnarySum r)
71 (pat, vars, body) <- con_alt r
72 return $ mkWildCase arg arg_ty res_ty
73 [(pat, vars, wrap_repr_inst body)]
75 to_sum arg arg_ty res_ty (Sum { repr_sum_tc = sum_tc
79 alts <- mapM con_alt cons
80 let alts' = [(pat, vars, wrap_repr_inst
81 $ mkConApp sum_con (map Type tys ++ [body]))
82 | ((pat, vars, body), sum_con)
83 <- zip alts (tyConDataCons sum_tc)]
84 return $ mkWildCase arg arg_ty res_ty alts'
86 con_alt (ConRepr con r)
88 (vars, body) <- to_prod r
89 return (DataAlt con, vars, body)
93 void <- builtin voidVar
96 to_prod (UnaryProd comp)
98 var <- newLocalVar (fsLit "x") (compOrigType comp)
99 body <- to_comp (Var var) comp
102 to_prod(Prod { repr_tup_tc = tup_tc
103 , repr_comp_tys = tys
104 , repr_comps = comps })
106 vars <- newLocalVars (fsLit "x") (map compOrigType comps)
107 exprs <- zipWithM to_comp (map Var vars) comps
108 return (vars, mkConApp tup_con (map Type tys ++ exprs))
110 [tup_con] = tyConDataCons tup_tc
112 to_comp expr (Keep _ _) = return expr
113 to_comp expr (Wrap ty) = do
114 wrap_tc <- builtin wrapTyCon
115 return $ wrapNewTypeBody wrap_tc [ty] expr
118 buildFromPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
119 buildFromPRepr vect_tc repr_tc _ repr
121 arg_ty <- mkPReprType res_ty
122 arg <- newLocalVar (fsLit "x") arg_ty
124 result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
126 return $ Lam arg result
128 ty_args = mkTyVarTys (tyConTyVars vect_tc)
129 res_ty = mkTyConApp vect_tc ty_args
133 dummy <- builtin fromVoidVar
134 return $ Var dummy `App` Type res_ty
136 from_sum expr (UnarySum r) = from_con expr r
137 from_sum expr (Sum { repr_sum_tc = sum_tc
139 , repr_cons = cons })
141 vars <- newLocalVars (fsLit "x") tys
142 es <- zipWithM from_con (map Var vars) cons
143 return $ mkWildCase expr (exprType expr) res_ty
144 [(DataAlt con, [var], e)
145 | (con, var, e) <- zip3 (tyConDataCons sum_tc) vars es]
147 from_con expr (ConRepr con r)
148 = from_prod expr (mkConApp con $ map Type ty_args) r
150 from_prod _ con EmptyProd = return con
151 from_prod expr con (UnaryProd r)
153 e <- from_comp expr r
156 from_prod expr con (Prod { repr_tup_tc = tup_tc
157 , repr_comp_tys = tys
161 vars <- newLocalVars (fsLit "y") tys
162 es <- zipWithM from_comp (map Var vars) comps
163 return $ mkWildCase expr (exprType expr) res_ty
164 [(DataAlt tup_con, vars, con `mkApps` es)]
166 [tup_con] = tyConDataCons tup_tc
168 from_comp expr (Keep _ _) = return expr
169 from_comp expr (Wrap ty)
171 wrap <- builtin wrapTyCon
172 return $ unwrapNewTypeBody wrap [ty] expr
175 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
176 buildToArrPRepr vect_tc prepr_tc pdata_tc r
178 arg_ty <- mkPDataType el_ty
179 res_ty <- mkPDataType =<< mkPReprType el_ty
180 arg <- newLocalVar (fsLit "xs") arg_ty
182 pdata_co <- mkBuiltinCo pdataTyCon
183 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
184 co = mkAppCo pdata_co
186 $ mkAxInstCo repr_co ty_args
188 scrut = unwrapFamInstScrut pdata_tc ty_args (Var arg)
190 (vars, result) <- to_sum r
193 $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
194 [(DataAlt pdata_dc, vars, mkCoerce co result)]
196 ty_args = mkTyVarTys $ tyConTyVars vect_tc
197 el_ty = mkTyConApp vect_tc ty_args
199 [pdata_dc] = tyConDataCons pdata_tc
203 pvoid <- builtin pvoidVar
204 return ([], Var pvoid)
205 to_sum (UnarySum r) = to_con r
206 to_sum (Sum { repr_psum_tc = psum_tc
207 , repr_sel_ty = sel_ty
212 (vars, exprs) <- mapAndUnzipM to_con cons
213 sel <- newLocalVar (fsLit "sel") sel_ty
214 return (sel : concat vars, mk_result (Var sel) exprs)
216 [psum_con] = tyConDataCons psum_tc
217 mk_result sel exprs = wrapFamInstBody psum_tc tys
219 $ map Type tys ++ (sel : exprs)
221 to_con (ConRepr _ r) = to_prod r
223 to_prod EmptyProd = do
224 pvoid <- builtin pvoidVar
225 return ([], Var pvoid)
226 to_prod (UnaryProd r)
228 pty <- mkPDataType (compOrigType r)
229 var <- newLocalVar (fsLit "x") pty
230 expr <- to_comp (Var var) r
233 to_prod (Prod { repr_ptup_tc = ptup_tc
234 , repr_comp_tys = tys
235 , repr_comps = comps })
237 ptys <- mapM (mkPDataType . compOrigType) comps
238 vars <- newLocalVars (fsLit "x") ptys
239 es <- zipWithM to_comp (map Var vars) comps
240 return (vars, mk_result es)
242 [ptup_con] = tyConDataCons ptup_tc
243 mk_result exprs = wrapFamInstBody ptup_tc tys
245 $ map Type tys ++ exprs
247 to_comp expr (Keep _ _) = return expr
249 -- FIXME: this is bound to be wrong!
250 to_comp expr (Wrap ty)
252 wrap_tc <- builtin wrapTyCon
253 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
254 return $ wrapNewTypeBody pwrap_tc [ty] expr
257 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
258 buildFromArrPRepr vect_tc prepr_tc pdata_tc r
260 arg_ty <- mkPDataType =<< mkPReprType el_ty
261 res_ty <- mkPDataType el_ty
262 arg <- newLocalVar (fsLit "xs") arg_ty
264 pdata_co <- mkBuiltinCo pdataTyCon
265 let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
266 co = mkAppCo pdata_co
267 $ mkAxInstCo repr_co var_tys
269 scrut = mkCoerce co (Var arg)
271 mk_result args = wrapFamInstBody pdata_tc var_tys
273 $ map Type var_tys ++ args
275 (expr, _) <- fixV $ \ ~(_, args) ->
276 from_sum res_ty (mk_result args) scrut r
278 return $ Lam arg expr
280 -- (args, mk) <- from_sum res_ty scrut r
282 -- let result = wrapFamInstBody pdata_tc var_tys
283 -- . mkConApp pdata_dc
284 -- $ map Type var_tys ++ args
286 -- return $ Lam arg (mk result)
288 var_tys = mkTyVarTys $ tyConTyVars vect_tc
289 el_ty = mkTyConApp vect_tc var_tys
291 [pdata_con] = tyConDataCons pdata_tc
293 from_sum _ res _ EmptySum = return (res, [])
294 from_sum res_ty res expr (UnarySum r) = from_con res_ty res expr r
295 from_sum res_ty res expr (Sum { repr_psum_tc = psum_tc
296 , repr_sel_ty = sel_ty
298 , repr_cons = cons })
300 sel <- newLocalVar (fsLit "sel") sel_ty
301 ptys <- mapM mkPDataType tys
302 vars <- newLocalVars (fsLit "xs") ptys
303 (res', args) <- fold from_con res_ty res (map Var vars) cons
304 let scrut = unwrapFamInstScrut psum_tc tys expr
305 body = mkWildCase scrut (exprType scrut) res_ty
306 [(DataAlt psum_con, sel : vars, res')]
307 return (body, Var sel : args)
309 [psum_con] = tyConDataCons psum_tc
312 from_con res_ty res expr (ConRepr _ r) = from_prod res_ty res expr r
314 from_prod _ res _ EmptyProd = return (res, [])
315 from_prod res_ty res expr (UnaryProd r)
316 = from_comp res_ty res expr r
317 from_prod res_ty res expr (Prod { repr_ptup_tc = ptup_tc
318 , repr_comp_tys = tys
319 , repr_comps = comps })
321 ptys <- mapM mkPDataType tys
322 vars <- newLocalVars (fsLit "ys") ptys
323 (res', args) <- fold from_comp res_ty res (map Var vars) comps
324 let scrut = unwrapFamInstScrut ptup_tc tys expr
325 body = mkWildCase scrut (exprType scrut) res_ty
326 [(DataAlt ptup_con, vars, res')]
329 [ptup_con] = tyConDataCons ptup_tc
331 from_comp _ res expr (Keep _ _) = return (res, [expr])
332 from_comp _ res expr (Wrap ty)
334 wrap_tc <- builtin wrapTyCon
335 (pwrap_tc, _) <- pdataReprTyCon (mkTyConApp wrap_tc [ty])
336 return (res, [unwrapNewTypeBody pwrap_tc [ty]
337 $ unwrapFamInstScrut pwrap_tc [ty] expr])
339 fold f res_ty res exprs rs = foldrM f' (res, []) (zip exprs rs)
341 f' (expr, r) (res, args) = do
342 (res', args') <- f res_ty res expr r
343 return (res', args' ++ args)