1 {-# OPTIONS -fno-warn-missing-signatures #-}
3 module Vectorise.Type.Env (
10 import Vectorise.Monad
11 import Vectorise.Builtins
12 import Vectorise.Type.TyConDecl
13 import Vectorise.Type.Classify
14 import Vectorise.Type.PADict
15 import Vectorise.Type.PData
16 import Vectorise.Type.PRepr
17 import Vectorise.Type.Repr
18 import Vectorise.Utils.Closure
19 import Vectorise.Utils.Hoisting
45 dtrace s x = if debug then pprTrace "VectType" s x else x
48 -- | Vectorise a type environment.
49 -- The type environment contains all the type things defined in a module.
52 -> VM ( TypeEnv -- Vectorised type environment.
53 , [FamInst] -- New type family instances.
54 , [(Var, CoreExpr)]) -- New top level bindings.
59 cs <- readGEnv $ mk_map . global_tycons
61 -- Split the list of TyCons into the ones we have to vectorise vs the
62 -- ones we can pass through unchanged. We also pass through algebraic
63 -- types that use non Haskell98 features, as we don't handle those.
64 let (conv_tcs, keep_tcs) = classifyTyCons cs groups
65 keep_dcs = concatMap tyConDataCons keep_tcs
67 zipWithM_ defTyCon keep_tcs keep_tcs
68 zipWithM_ defDataCon keep_dcs keep_dcs
70 new_tcs <- vectTyConDecls conv_tcs
72 let orig_tcs = keep_tcs ++ conv_tcs
74 -- We don't need to make new representation types for dictionary
75 -- constructors. The constructors are always fully applied, and we don't
76 -- need to lift them to arrays as a dictionary of a particular type
77 -- always has the same value.
78 let vect_tcs = filter (not . isClassTyCon)
81 (_, binds, inst_tcs) <- fixV $ \ ~(dfuns', _, _) ->
83 defTyConPAs (zipLazy vect_tcs dfuns')
84 reprs <- mapM tyConRepr vect_tcs
85 repr_tcs <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
86 pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
89 $ zipWith5 buildTyConBindings
97 return (dfuns, binds, repr_tcs ++ pdata_tcs)
99 let all_new_tcs = new_tcs ++ inst_tcs
101 let new_env = extendTypeEnvList env
102 (map ATyCon all_new_tcs
103 ++ [ADataCon dc | tc <- all_new_tcs
104 , dc <- tyConDataCons tc])
106 return (new_env, map mkLocalFamInst inst_tcs, binds)
108 tycons = typeEnvTyCons env
109 groups = tyConGroups tycons
111 mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
115 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
116 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc repr
117 = do vectDataConWorkers orig_tc vect_tc pdata_tc
118 buildPADict vect_tc prepr_tc pdata_tc repr
121 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
122 vectDataConWorkers orig_tc vect_tc arr_tc
124 . zipWith3 def_worker (tyConDataCons orig_tc) rep_tys
125 $ zipWith4 mk_data_con (tyConDataCons vect_tc)
128 (tail $ tails rep_tys)
129 mapM_ (uncurry hoistBinding) bs
131 tyvars = tyConTyVars vect_tc
132 var_tys = mkTyVarTys tyvars
133 ty_args = map Type var_tys
134 res_ty = mkTyConApp vect_tc var_tys
136 cons = tyConDataCons vect_tc
138 [arr_dc] = tyConDataCons arr_tc
140 rep_tys = map dataConRepArgTys $ tyConDataCons vect_tc
143 mk_data_con con tys pre post
144 = liftM2 (,) (vect_data_con con)
145 (lift_data_con tys pre post (mkDataConTag con))
147 sel_replicate len tag
149 rep <- builtin (selReplicate arity)
150 return [rep `mkApps` [len, tag]]
152 | otherwise = return []
154 vect_data_con con = return $ mkConApp con ty_args
155 lift_data_con tys pre_tys post_tys tag
157 len <- builtin liftingContext
158 args <- mapM (newLocalVar (fsLit "xs"))
159 =<< mapM mkPDataType tys
161 sel <- sel_replicate (Var len) tag
163 pre <- mapM emptyPD (concat pre_tys)
164 post <- mapM emptyPD (concat post_tys)
166 return . mkLams (len : args)
167 . wrapFamInstBody arr_tc var_tys
169 $ ty_args ++ sel ++ pre ++ map Var args ++ post
171 def_worker data_con arg_tys mk_body
173 arity <- polyArity tyvars
176 . polyAbstract tyvars $ \args ->
177 liftM (mkLams (tyvars ++ args) . vectorised)
178 $ buildClosures tyvars [] arg_tys res_ty mk_body
180 raw_worker <- cloneId mkVectOcc orig_worker (exprType body)
181 let vect_worker = raw_worker `setIdUnfolding`
182 mkInlineRule body (Just arity)
183 defGlobalVar orig_worker vect_worker
184 return (vect_worker, body)
186 orig_worker = dataConWorkId data_con