1 {-# OPTIONS -fno-warn-missing-signatures #-}
3 module Vectorise.Type.Env (
10 import Vectorise.Builtins
11 import Vectorise.Type.TyConDecl
12 import Vectorise.Type.Classify
13 import Vectorise.Type.PADict
14 import Vectorise.Type.PData
15 import Vectorise.Type.PRepr
16 import Vectorise.Type.Repr
17 import Vectorise.Utils
43 dtrace s x = if debug then pprTrace "VectType" s x else x
46 -- | Vectorise a type environment.
47 -- The type environment contains all the type things defined in a module.
50 -> VM ( TypeEnv -- Vectorised type environment.
51 , [FamInst] -- New type family instances.
52 , [(Var, CoreExpr)]) -- New top level bindings.
57 cs <- readGEnv $ mk_map . global_tycons
59 -- Split the list of TyCons into the ones we have to vectorise vs the
60 -- ones we can pass through unchanged. We also pass through algebraic
61 -- types that use non Haskell98 features, as we don't handle those.
62 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
63 keep_dcs = concatMap tyConDataCons keep_tcs
65 zipWithM_ defTyCon keep_tcs keep_tcs
66 zipWithM_ defDataCon keep_dcs keep_dcs
68 new_tcs <- vectTyConDecls conv_tcs
70 let orig_tcs = keep_tcs ++ conv_tcs
72 -- We don't need to make new representation types for dictionary
73 -- constructors. The constructors are always fully applied, and we don't
74 -- need to lift them to arrays as a dictionary of a particular type
75 -- always has the same value.
76 let vect_tcs = filter (not . isClassTyCon)
79 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
81 defTyConPAs (zipLazy vect_tcs dfuns')
82 reprs <- mapM tyConRepr vect_tcs
83 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
84 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
87 $ zipWith5 buildTyConBindings
95 return (dfuns, binds, repr_tcs ++ pdata_tcs)
97 let all_new_tcs = new_tcs ++ inst_tcs
99 let new_env = extendTypeEnvList env
100 (map ATyCon all_new_tcs
101 ++ [ADataCon dc | tc <- all_new_tcs
102 , dc <- tyConDataCons tc])
104 return (new_env, map mkLocalFamInst inst_tcs, binds)
106 tycons = typeEnvTyCons env
107 groups = tyConGroups tycons
109 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
113 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
114 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
115 = do vectDataConWorkers orig_tc vect_tc pdata_tc
116 buildPADict vect_tc prepr_tc pdata_tc repr
119 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
120 vectDataConWorkers orig_tc vect_tc arr_tc
122 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
123 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
126 (tail $ tails rep_tys)
127 mapM_ (uncurry hoistBinding) bs
129 tyvars = tyConTyVars vect_tc
130 var_tys = mkTyVarTys tyvars
131 ty_args = map Type var_tys
132 res_ty = mkTyConApp vect_tc var_tys
134 cons = tyConDataCons vect_tc
136 [arr_dc] = tyConDataCons arr_tc
138 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
141 mk_data_con con tys pre post
142 = liftM2 (,) (vect_data_con con)
143 (lift_data_con tys pre post (mkDataConTag con))
145 sel_replicate len tag
147 rep <- builtin (selReplicate arity)
148 return [rep `mkApps` [len, tag]]
150 | otherwise = return []
152 vect_data_con con = return $ mkConApp con ty_args
153 lift_data_con tys pre_tys post_tys tag
155 len <- builtin liftingContext
156 args <- mapM (newLocalVar (fsLit "xs"))
157 =<< mapM mkPDataType tys
159 sel <- sel_replicate (Var len) tag
161 pre <- mapM emptyPD (concat pre_tys)
162 post <- mapM emptyPD (concat post_tys)
164 return . mkLams (len : args)
165 . wrapFamInstBody arr_tc var_tys
167 $ ty_args ++ sel ++ pre ++ map Var args ++ post
169 def_worker data_con arg_tys mk_body
171 arity <- polyArity tyvars
174 . polyAbstract tyvars $ \args ->
175 liftM (mkLams (tyvars ++ args) . vectorised)
176 $ buildClosures tyvars [] arg_tys res_ty mk_body
178 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
179 let vect_worker = raw_worker `setIdUnfolding`
180 mkInlineRule body (Just arity)
181 defGlobalVar orig_worker vect_worker
182 return (vect_worker, body)
184 orig_worker = dataConWorkId data_con