1 {-# OPTIONS -fno-warn-missing-signatures #-}
14 import Vectorise.Convert
16 import Vectorise.Monad
17 import Vectorise.Builtins
18 import Vectorise.Type.Type
19 import Vectorise.Type.TyConDecl
20 import Vectorise.Type.Classify
21 import Vectorise.Type.Repr
22 import Vectorise.Type.PADict
23 import Vectorise.Utils.Closure
24 import Vectorise.Utils.Hoisting
26 import HscTypes ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
36 import FamInstEnv ( FamInst, mkLocalFamInst )
41 import Name ( Name, getOccName )
56 dtrace s x = if debug then pprTrace "VectType" s x else x
59 -- | Vectorise a type environment.
60 -- The type environment contains all the type things defined in a module.
63 -> VM ( TypeEnv -- Vectorised type environment.
64 , [FamInst] -- New type family instances.
65 , [(Var, CoreExpr)]) -- New top level bindings.
70 cs <- readGEnv $ mk_map . global_tycons
72 -- Split the list of TyCons into the ones we have to vectorise vs the
73 -- ones we can pass through unchanged. We also pass through algebraic
74 -- types that use non Haskell98 features, as we don't handle those.
75 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
76 keep_dcs = concatMap tyConDataCons keep_tcs
78 zipWithM_ defTyCon keep_tcs keep_tcs
79 zipWithM_ defDataCon keep_dcs keep_dcs
81 new_tcs <- vectTyConDecls conv_tcs
83 let orig_tcs = keep_tcs ++ conv_tcs
85 -- We don't need to make new representation types for dictionary
86 -- constructors. The constructors are always fully applied, and we don't
87 -- need to lift them to arrays as a dictionary of a particular type
88 -- always has the same value.
89 let vect_tcs = filter (not . isClassTyCon)
92 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
94 defTyConPAs (zipLazy vect_tcs dfuns')
95 reprs <- mapM tyConRepr vect_tcs
96 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
97 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
100 $ zipWith5 buildTyConBindings
108 return (dfuns, binds, repr_tcs ++ pdata_tcs)
110 let all_new_tcs = new_tcs ++ inst_tcs
112 let new_env = extendTypeEnvList env
113 (map ATyCon all_new_tcs
114 ++ [ADataCon dc | tc <- all_new_tcs
115 , dc <- tyConDataCons tc])
117 return (new_env, map mkLocalFamInst inst_tcs, binds)
119 tycons = typeEnvTyCons env
120 groups = tyConGroups tycons
122 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
125 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
126 mk_fam_inst fam_tc arg_tc
127 = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
130 buildPReprTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
131 buildPReprTyCon orig_tc vect_tc repr
133 name <- cloneName mkPReprTyConOcc (tyConName orig_tc)
134 -- rhs_ty <- buildPReprType vect_tc
135 rhs_ty <- sumReprType repr
136 prepr_tc <- builtin preprTyCon
137 liftDs $ buildSynTyCon name
139 (SynonymTyCon rhs_ty)
141 (Just $ mk_fam_inst prepr_tc vect_tc)
143 tyvars = tyConTyVars vect_tc
147 buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
148 buildPRDict vect_tc prepr_tc _ r
151 pr_co <- mkBuiltinCo prTyCon
152 let co = mkAppCoercion pr_co
154 $ mkTyConApp arg_co ty_args
155 return (mkCoerce co dict)
157 ty_args = mkTyVarTys (tyConTyVars vect_tc)
158 Just arg_co = tyConFamilyCoercion_maybe prepr_tc
160 sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
161 sum_dict (UnarySum r) = con_dict r
162 sum_dict (Sum { repr_sum_tc = sum_tc
167 dicts <- mapM con_dict cons
168 dfun <- prDFunOfTyCon sum_tc
169 return $ dfun `mkTyApps` tys `mkApps` dicts
171 con_dict (ConRepr _ r) = prod_dict r
173 prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
174 prod_dict (UnaryProd r) = comp_dict r
175 prod_dict (Prod { repr_tup_tc = tup_tc
176 , repr_comp_tys = tys
177 , repr_comps = comps })
179 dicts <- mapM comp_dict comps
180 dfun <- prDFunOfTyCon tup_tc
181 return $ dfun `mkTyApps` tys `mkApps` dicts
183 comp_dict (Keep _ pr) = return pr
184 comp_dict (Wrap ty) = wrapPR ty
187 buildPDataTyCon :: TyCon -> TyCon -> SumRepr -> VM TyCon
188 buildPDataTyCon orig_tc vect_tc repr = fixV $ \repr_tc ->
190 name' <- cloneName mkPDataTyConOcc orig_name
191 rhs <- buildPDataTyConRhs orig_name vect_tc repr_tc repr
192 pdata <- builtin pdataTyCon
194 liftDs $ buildAlgTyCon name'
196 [] -- no stupid theta
198 rec_flag -- FIXME: is this ok?
199 False -- FIXME: no generics
200 False -- not GADT syntax
201 (Just $ mk_fam_inst pdata vect_tc)
203 orig_name = tyConName orig_tc
204 tyvars = tyConTyVars vect_tc
205 rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
208 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> SumRepr -> VM AlgTyConRhs
209 buildPDataTyConRhs orig_name vect_tc repr_tc repr
211 data_con <- buildPDataDataCon orig_name vect_tc repr_tc repr
212 return $ DataTyCon { data_cons = [data_con], is_enum = False }
214 buildPDataDataCon :: Name -> TyCon -> TyCon -> SumRepr -> VM DataCon
215 buildPDataDataCon orig_name vect_tc repr_tc repr
217 dc_name <- cloneName mkPDataDataConOcc orig_name
218 comp_tys <- sum_tys repr
220 liftDs $ buildDataCon dc_name
222 (map (const HsNoBang) comp_tys)
223 [] -- no field labels
225 [] -- no existentials
229 (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
232 tvs = tyConTyVars vect_tc
234 sum_tys EmptySum = return []
235 sum_tys (UnarySum r) = con_tys r
236 sum_tys (Sum { repr_sel_ty = sel_ty
237 , repr_cons = cons })
238 = liftM (sel_ty :) (concatMapM con_tys cons)
240 con_tys (ConRepr _ r) = prod_tys r
242 prod_tys EmptyProd = return []
243 prod_tys (UnaryProd r) = liftM singleton (comp_ty r)
244 prod_tys (Prod { repr_comps = comps }) = mapM comp_ty comps
246 comp_ty r = mkPDataType (compOrigType r)
249 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr
251 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
253 vectDataConWorkers orig_tc vect_tc pdata_tc
254 buildPADict vect_tc prepr_tc pdata_tc repr
256 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
257 vectDataConWorkers orig_tc vect_tc arr_tc
260 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
261 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
264 (tail $ tails rep_tys)
265 mapM_ (uncurry hoistBinding) bs
267 tyvars = tyConTyVars vect_tc
268 var_tys = mkTyVarTys tyvars
269 ty_args = map Type var_tys
270 res_ty = mkTyConApp vect_tc var_tys
272 cons = tyConDataCons vect_tc
274 [arr_dc] = tyConDataCons arr_tc
276 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
279 mk_data_con con tys pre post
280 = liftM2 (,) (vect_data_con con)
281 (lift_data_con tys pre post (mkDataConTag con))
283 sel_replicate len tag
285 rep <- builtin (selReplicate arity)
286 return [rep `mkApps` [len, tag]]
288 | otherwise = return []
290 vect_data_con con = return $ mkConApp con ty_args
291 lift_data_con tys pre_tys post_tys tag
293 len <- builtin liftingContext
294 args <- mapM (newLocalVar (fsLit "xs"))
295 =<< mapM mkPDataType tys
297 sel <- sel_replicate (Var len) tag
299 pre <- mapM emptyPD (concat pre_tys)
300 post <- mapM emptyPD (concat post_tys)
302 return . mkLams (len : args)
303 . wrapFamInstBody arr_tc var_tys
305 $ ty_args ++ sel ++ pre ++ map Var args ++ post
307 def_worker data_con arg_tys mk_body
309 arity <- polyArity tyvars
312 . polyAbstract tyvars $ \args ->
313 liftM (mkLams (tyvars ++ args) . vectorised)
314 $ buildClosures tyvars [] arg_tys res_ty mk_body
316 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
317 let vect_worker = raw_worker `setIdUnfolding`
318 mkInlineRule body (Just arity)
319 defGlobalVar orig_worker vect_worker
320 return (vect_worker, body)
322 orig_worker = dataConWorkId data_con
324 buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
325 buildPADict vect_tc prepr_tc arr_tc repr
326 = polyAbstract tvs $ \args ->
328 method_ids <- mapM (method args) paMethods
330 pa_tc <- builtin paTyCon
331 pa_dc <- builtin paDataCon
332 let dict = mkLams (tvs ++ args)
334 $ Type inst_ty : map (method_call args) method_ids
336 dfun_ty = mkForAllTys tvs
337 $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
339 raw_dfun <- newExportedVar dfun_name dfun_ty
340 let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
341 `setInlinePragma` dfunInlinePragma
343 hoistBinding dfun dict
346 tvs = tyConTyVars vect_tc
347 arg_tys = mkTyVarTys tvs
348 inst_ty = mkTyConApp vect_tc arg_tys
350 dfun_name = mkPADFunOcc (getOccName vect_tc)
352 method args (name, build)
355 expr <- build vect_tc prepr_tc arr_tc repr
356 let body = mkLams (tvs ++ args) expr
357 raw_var <- newExportedVar (method_name name) (exprType body)
359 `setIdUnfolding` mkInlineRule body (Just (length args))
360 `setInlinePragma` alwaysInlinePragma
361 hoistBinding var body
364 method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
366 method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
369 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
370 paMethods = [("dictPRepr", buildPRDict),
371 ("toPRepr", buildToPRepr),
372 ("fromPRepr", buildFromPRepr),
373 ("toArrPRepr", buildToArrPRepr),
374 ("fromArrPRepr", buildFromArrPRepr)]