Separate length from data in DPH arrays
[ghc-hetmet.git] / compiler / vectorise / VectType.hs
1 module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
2                   -- arrSumArity, pdataCompTys, pdataCompVars,
3                   buildPADict,
4                   fromVect )
5 where
6
7 import VectMonad
8 import VectUtils
9 import VectCore
10
11 import HscTypes          ( TypeEnv, extendTypeEnvList, typeEnvTyCons )
12 import CoreSyn
13 import CoreUtils
14 import MkCore            ( mkWildCase )
15 import BuildTyCl
16 import DataCon
17 import TyCon
18 import Type
19 import TypeRep
20 import Coercion
21 import FamInstEnv        ( FamInst, mkLocalFamInst )
22 import OccName
23 import MkId
24 import BasicTypes        ( StrictnessMark(..), boolToRecFlag )
25 import Var               ( Var, TyVar )
26 import Name              ( Name, getOccName )
27 import NameEnv
28 import TysWiredIn
29 import TysPrim           ( intPrimTy )
30
31 import Unique
32 import UniqFM
33 import UniqSet
34 import Util
35 import Digraph           ( SCC(..), stronglyConnCompFromEdgedVertices )
36
37 import Outputable
38 import FastString
39 import MonadUtils        ( mapAndUnzip3M )
40
41 import Control.Monad  ( liftM, liftM2, zipWithM, zipWithM_, mapAndUnzipM )
42 import Data.List      ( inits, tails, zipWith4, zipWith5 )
43
44 -- ----------------------------------------------------------------------------
45 -- Types
46
47 vectTyCon :: TyCon -> VM TyCon
48 vectTyCon tc
49   | isFunTyCon tc        = builtin closureTyCon
50   | isBoxedTupleTyCon tc = return tc
51   | isUnLiftedTyCon tc   = return tc
52   | otherwise            = maybeCantVectoriseM "Tycon not vectorised:" (ppr tc)
53                          $ lookupTyCon tc
54
55 vectAndLiftType :: Type -> VM (Type, Type)
56 vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
57 vectAndLiftType ty
58   = do
59       mdicts   <- mapM paDictArgType tyvars
60       let dicts = [dict | Just dict <- mdicts]
61       vmono_ty <- vectType mono_ty
62       lmono_ty <- mkPDataType vmono_ty
63       return (abstractType tyvars dicts vmono_ty,
64               abstractType tyvars dicts lmono_ty)
65   where
66     (tyvars, mono_ty) = splitForAllTys ty
67
68
69 vectType :: Type -> VM Type
70 vectType ty | Just ty' <- coreView ty = vectType ty'
71 vectType (TyVarTy tv) = return $ TyVarTy tv
72 vectType (AppTy ty1 ty2) = liftM2 AppTy (vectType ty1) (vectType ty2)
73 vectType (TyConApp tc tys) = liftM2 TyConApp (vectTyCon tc) (mapM vectType tys)
74 vectType (FunTy ty1 ty2)   = liftM2 TyConApp (builtin closureTyCon)
75                                              (mapM vectAndBoxType [ty1,ty2])
76 vectType ty@(ForAllTy _ _)
77   = do
78       mdicts   <- mapM paDictArgType tyvars
79       mono_ty' <- vectType mono_ty
80       return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty'
81   where
82     (tyvars, mono_ty) = splitForAllTys ty
83
84 vectType ty = cantVectorise "Can't vectorise type" (ppr ty)
85
86 vectAndBoxType :: Type -> VM Type
87 vectAndBoxType ty = vectType ty >>= boxType
88
89 abstractType :: [TyVar] -> [Type] -> Type -> Type
90 abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
91
92 -- ----------------------------------------------------------------------------
93 -- Boxing
94
95 boxType :: Type -> VM Type
96 boxType ty
97   | Just (tycon, []) <- splitTyConApp_maybe ty
98   , isUnLiftedTyCon tycon
99   = do
100       r <- lookupBoxedTyCon tycon
101       case r of
102         Just tycon' -> return $ mkTyConApp tycon' []
103         Nothing     -> return ty
104 boxType ty = return ty
105
106 -- ----------------------------------------------------------------------------
107 -- Type definitions
108
109 type TyConGroup = ([TyCon], UniqSet TyCon)
110
111 vectTypeEnv :: TypeEnv -> VM (TypeEnv, [FamInst], [(Var, CoreExpr)])
112 vectTypeEnv env
113   = do
114       cs <- readGEnv $ mk_map . global_tycons
115       let (conv_tcs, keep_tcs) = classifyTyCons cs groups
116           keep_dcs             = concatMap tyConDataCons keep_tcs
117       zipWithM_ defTyCon   keep_tcs keep_tcs
118       zipWithM_ defDataCon keep_dcs keep_dcs
119       new_tcs <- vectTyConDecls conv_tcs
120
121       let orig_tcs = keep_tcs ++ conv_tcs
122           vect_tcs = keep_tcs ++ new_tcs
123
124       repr_tcs  <- zipWithM buildPReprTyCon orig_tcs vect_tcs
125       pdata_tcs <- zipWithM buildPDataTyCon orig_tcs vect_tcs
126       dfuns     <- mapM mkPADFun vect_tcs
127       defTyConPAs (zip vect_tcs dfuns)
128       binds    <- sequence (zipWith5 buildTyConBindings orig_tcs
129                                                         vect_tcs
130                                                         repr_tcs
131                                                         pdata_tcs
132                                                         dfuns)
133
134       let all_new_tcs = new_tcs ++ repr_tcs ++ pdata_tcs
135
136       let new_env = extendTypeEnvList env
137                        (map ATyCon all_new_tcs
138                         ++ [ADataCon dc | tc <- all_new_tcs
139                                         , dc <- tyConDataCons tc])
140
141       return (new_env, map mkLocalFamInst (repr_tcs ++ pdata_tcs), concat binds)
142   where
143     tycons = typeEnvTyCons env
144     groups = tyConGroups tycons
145
146     mk_map env = listToUFM_Directly [(u, getUnique n /= u) | (u,n) <- nameEnvUniqueElts env]
147
148
149 vectTyConDecls :: [TyCon] -> VM [TyCon]
150 vectTyConDecls tcs = fixV $ \tcs' ->
151   do
152     mapM_ (uncurry defTyCon) (zipLazy tcs tcs')
153     mapM vectTyConDecl tcs
154
155 vectTyConDecl :: TyCon -> VM TyCon
156 vectTyConDecl tc
157   = do
158       name' <- cloneName mkVectTyConOcc name
159       rhs'  <- vectAlgTyConRhs tc (algTyConRhs tc)
160
161       liftDs $ buildAlgTyCon name'
162                              tyvars
163                              []           -- no stupid theta
164                              rhs'
165                              rec_flag     -- FIXME: is this ok?
166                              False        -- FIXME: no generics
167                              False        -- not GADT syntax
168                              Nothing      -- not a family instance
169   where
170     name   = tyConName tc
171     tyvars = tyConTyVars tc
172     rec_flag = boolToRecFlag (isRecursiveTyCon tc)
173
174 vectAlgTyConRhs :: TyCon -> AlgTyConRhs -> VM AlgTyConRhs
175 vectAlgTyConRhs _ (DataTyCon { data_cons = data_cons
176                              , is_enum   = is_enum
177                              })
178   = do
179       data_cons' <- mapM vectDataCon data_cons
180       zipWithM_ defDataCon data_cons data_cons'
181       return $ DataTyCon { data_cons = data_cons'
182                          , is_enum   = is_enum
183                          }
184 vectAlgTyConRhs tc _ = cantVectorise "Can't vectorise type definition:" (ppr tc)
185
186 vectDataCon :: DataCon -> VM DataCon
187 vectDataCon dc
188   | not . null $ dataConExTyVars dc
189         = cantVectorise "Can't vectorise constructor (existentials):" (ppr dc)
190   | not . null $ dataConEqSpec   dc
191         = cantVectorise "Can't vectorise constructor (eq spec):" (ppr dc)
192   | otherwise
193   = do
194       name'    <- cloneName mkVectDataConOcc name
195       tycon'   <- vectTyCon tycon
196       arg_tys  <- mapM vectType rep_arg_tys
197
198       liftDs $ buildDataCon name'
199                             False           -- not infix
200                             (map (const NotMarkedStrict) arg_tys)
201                             []              -- no labelled fields
202                             univ_tvs
203                             []              -- no existential tvs for now
204                             []              -- no eq spec for now
205                             []              -- no context
206                             arg_tys 
207                             (mkFamilyTyConApp tycon' (mkTyVarTys univ_tvs))
208                             tycon'
209   where
210     name        = dataConName dc
211     univ_tvs    = dataConUnivTyVars dc
212     rep_arg_tys = dataConRepArgTys dc
213     tycon       = dataConTyCon dc
214
215 mk_fam_inst :: TyCon -> TyCon -> (TyCon, [Type])
216 mk_fam_inst fam_tc arg_tc
217   = (fam_tc, [mkTyConApp arg_tc . mkTyVarTys $ tyConTyVars arg_tc])
218
219 buildPReprTyCon :: TyCon -> TyCon -> VM TyCon
220 buildPReprTyCon orig_tc vect_tc
221   = do
222       name     <- cloneName mkPReprTyConOcc (tyConName orig_tc)
223       rhs_ty   <- buildPReprType vect_tc
224       prepr_tc <- builtin preprTyCon
225       liftDs $ buildSynTyCon name
226                              tyvars
227                              (SynonymTyCon rhs_ty)
228                              (typeKind rhs_ty)
229                              (Just $ mk_fam_inst prepr_tc vect_tc)
230   where
231     tyvars = tyConTyVars vect_tc
232
233 buildPReprType :: TyCon -> VM Type
234 buildPReprType vect_tc = sum_type . map dataConRepArgTys $ tyConDataCons vect_tc
235   where
236     sum_type [] = voidType
237     sum_type [tys] = prod_type tys
238     sum_type tys  = do
239                       (sum_tc, _, _, args) <- reprSumTyCons vect_tc
240                       return $ mkTyConApp sum_tc args
241
242     prod_type []   = voidType
243     prod_type [ty] = return ty
244     prod_type tys  = do
245                        prod_tc <- builtin (prodTyCon (length tys))
246                        return $ mkTyConApp prod_tc tys
247
248 reprSumTyCons :: TyCon -> VM (TyCon, TyCon, Type, [Type])
249 reprSumTyCons vect_tc
250   = do
251       tc   <- builtin (sumTyCon arity)
252       args <- mapM (prod . dataConRepArgTys) cons
253       (pdata_tc, _) <- pdataReprTyCon (mkTyConApp tc args)
254       sel_ty <- builtin (selTy arity)
255       return (tc, pdata_tc, sel_ty, args)
256   where
257     cons = tyConDataCons vect_tc
258     arity = length cons
259
260     prod []   = voidType
261     prod [ty] = return ty
262     prod tys  = do
263                   prod_tc <- builtin (prodTyCon (length tys))
264                   return $ mkTyConApp prod_tc tys
265
266 buildToPRepr :: TyCon -> TyCon -> TyCon -> VM CoreExpr
267 buildToPRepr vect_tc repr_tc _
268   = do
269       let arg_ty = mkTyConApp vect_tc ty_args
270       res_ty <- mkPReprType arg_ty
271       arg    <- newLocalVar (fsLit "x") arg_ty
272       result <- to_sum (Var arg) arg_ty res_ty (tyConDataCons vect_tc)
273       return $ Lam arg result
274   where
275     ty_args = mkTyVarTys (tyConTyVars vect_tc)
276
277     wrap = wrapFamInstBody repr_tc ty_args
278
279     to_sum arg arg_ty res_ty []
280       = do
281           void <- builtin voidVar
282           return $ wrap (Var void)
283
284     to_sum arg arg_ty res_ty [con]
285       = do
286           (prod, vars) <- to_prod (dataConRepArgTys con)
287           return $ mkWildCase arg arg_ty res_ty
288                    [(DataAlt con, vars, wrap prod)]
289
290     to_sum arg arg_ty res_ty cons
291       = do
292           (prods, vars) <- mapAndUnzipM (to_prod . dataConRepArgTys) cons
293           (sum_tc, _, _, sum_ty_args) <- reprSumTyCons vect_tc
294           let sum_cons = [mkConApp con (map Type sum_ty_args)
295                             | con <- tyConDataCons sum_tc]
296           return . mkWildCase arg arg_ty res_ty
297                  $ zipWith4 mk_alt cons vars sum_cons prods
298       where
299         arity = length cons
300
301         mk_alt con vars sum_con expr
302           = (DataAlt con, vars, wrap $ sum_con `App` expr)
303
304     to_prod []
305       = do
306           void <- builtin voidVar
307           return (Var void, [])
308     to_prod [ty]
309       = do
310           var <- newLocalVar (fsLit "x") ty
311           return (Var var, [var])
312     to_prod tys
313       = do
314           prod_con <- builtin (prodDataCon (length tys))
315           vars <- newLocalVars (fsLit "x") tys
316           return (mkConApp prod_con (map Type tys ++ map Var vars), vars)
317       where
318         arity = length tys
319
320
321 buildFromPRepr :: TyCon -> TyCon -> TyCon -> VM CoreExpr
322 buildFromPRepr vect_tc repr_tc _
323   = do
324       arg_ty <- mkPReprType res_ty
325       arg <- newLocalVar (fsLit "x") arg_ty
326
327       result <- from_sum (unwrapFamInstScrut repr_tc ty_args (Var arg))
328                          (tyConDataCons vect_tc)
329       return $ Lam arg result
330   where
331     ty_args = mkTyVarTys (tyConTyVars vect_tc)
332     res_ty  = mkTyConApp vect_tc ty_args
333
334     from_sum expr []    = pprPanic "buildFromPRepr" (ppr vect_tc)
335     from_sum expr [con] = from_prod expr con
336     from_sum expr cons
337       = do
338           (sum_tc, _, _, sum_ty_args) <- reprSumTyCons vect_tc
339           let sum_cons = tyConDataCons sum_tc
340           vars <- newLocalVars (fsLit "x") sum_ty_args
341           prods <- zipWithM from_prod (map Var vars) cons
342           return . mkWildCase expr (exprType expr) res_ty
343                  $ zipWith3 mk_alt sum_cons vars prods
344       where
345         arity = length cons
346
347         mk_alt con var expr = (DataAlt con, [var], expr)
348
349     from_prod expr con
350       = case dataConRepArgTys con of
351           []   -> return $ apply_con []
352           [ty] -> return $ apply_con [expr]
353           tys  -> do
354                     prod_con <- builtin (prodDataCon (length tys))
355                     vars <- newLocalVars (fsLit "y") tys
356                     return $ mkWildCase expr (exprType expr) res_ty
357                              [(DataAlt prod_con, vars, apply_con (map Var vars))]
358       where
359         apply_con exprs = mkConApp con (map Type ty_args) `mkApps` exprs
360
361 buildToArrPRepr :: TyCon -> TyCon -> TyCon -> VM CoreExpr
362 buildToArrPRepr vect_tc prepr_tc pdata_tc
363   = do
364       arg_ty <- mkPDataType el_ty
365       res_ty <- mkPDataType =<< mkPReprType el_ty
366       arg    <- newLocalVar (fsLit "xs") arg_ty
367
368       pdata_co <- mkBuiltinCo pdataTyCon
369       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
370           co           = mkAppCoercion pdata_co
371                        . mkSymCoercion
372                        $ mkTyConApp repr_co ty_args
373
374           scrut   = unwrapFamInstScrut pdata_tc ty_args (Var arg)
375
376       (vars, result) <- to_sum (tyConDataCons vect_tc)
377
378       return . Lam arg
379              $ mkWildCase scrut (mkTyConApp pdata_tc ty_args) res_ty
380                [(DataAlt pdata_dc, vars, mkCoerce co result)]
381   where
382     ty_args = mkTyVarTys $ tyConTyVars vect_tc
383     el_ty   = mkTyConApp vect_tc ty_args
384
385     [pdata_dc] = tyConDataCons pdata_tc
386
387     to_sum []    = do
388                      pvoid <- builtin pvoidVar
389                      return ([], Var pvoid)
390     to_sum [con] = to_prod con
391     to_sum cons  = do
392                      (vars, exprs) <- mapAndUnzipM to_prod cons
393                      (_, pdata_tc, sel_ty, arg_tys) <- reprSumTyCons vect_tc
394                      sel <- newLocalVar (fsLit "sel") sel_ty
395                      let [pdata_con] = tyConDataCons pdata_tc
396                          result = wrapFamInstBody pdata_tc arg_tys
397                                 . mkConApp pdata_con
398                                 $ map Type arg_tys ++ (Var sel : exprs)
399                      return (sel : concat vars, result)
400
401     to_prod con
402       | [] <- tys = do
403                       pvoid <- builtin pvoidVar
404                       return ([], Var pvoid)
405       | [ty] <- tys = do
406                         var <- newLocalVar (fsLit "x") ty
407                         return ([var], Var var)
408       | otherwise
409         = do
410             vars <- newLocalVars (fsLit "x") tys
411             prod_tc <- builtin (prodTyCon (length tys))
412             (pdata_prod_tc, _) <- pdataReprTyCon (mkTyConApp prod_tc tys)
413             let [pdata_prod_con] = tyConDataCons pdata_prod_tc
414                 result = wrapFamInstBody pdata_prod_tc tys
415                        . mkConApp pdata_prod_con
416                        $ map Type tys ++ map Var vars
417             return (vars, result)
418       where
419         tys = dataConRepArgTys con
420
421 buildFromArrPRepr :: TyCon -> TyCon -> TyCon -> VM CoreExpr
422 buildFromArrPRepr vect_tc prepr_tc pdata_tc
423   = do
424       arg_ty <- mkPDataType =<< mkPReprType el_ty
425       res_ty <- mkPDataType el_ty
426       arg    <- newLocalVar (fsLit "xs") arg_ty
427
428       pdata_co <- mkBuiltinCo pdataTyCon
429       let Just repr_co = tyConFamilyCoercion_maybe prepr_tc
430           co           = mkAppCoercion pdata_co
431                        $ mkTyConApp repr_co var_tys
432
433           scrut  = mkCoerce co (Var arg)
434
435       (args, mk) <- from_sum res_ty scrut (tyConDataCons vect_tc)
436       
437       let result = wrapFamInstBody pdata_tc var_tys
438                  . mkConApp pdata_dc
439                  $ map Type var_tys ++ args
440
441       return $ Lam arg (mk result)
442   where
443     var_tys = mkTyVarTys $ tyConTyVars vect_tc
444     el_ty   = mkTyConApp vect_tc var_tys
445
446     [pdata_dc] = tyConDataCons pdata_tc
447
448     from_sum res_ty expr [] = return ([], mk)
449       where
450         mk body = mkWildCase expr (exprType expr) res_ty [(DEFAULT, [], body)]
451     from_sum res_ty expr [con] = from_prod res_ty expr con
452     from_sum res_ty expr cons
453       = do
454           (_, pdata_tc, sel_ty, arg_tys) <- reprSumTyCons vect_tc
455           prod_tys <- mapM mkPDataType arg_tys
456           sel  <- newLocalVar (fsLit "sel") sel_ty
457           vars <- newLocalVars (fsLit "xs") arg_tys
458           rs   <- zipWithM (from_prod res_ty) (map Var vars) cons
459           let (prods, mks) = unzip rs
460               [pdata_con]  = tyConDataCons pdata_tc
461               scrut        = unwrapFamInstScrut pdata_tc arg_tys expr
462
463               mk body = mkWildCase scrut (exprType scrut) res_ty
464                         [(DataAlt pdata_con, sel : vars, foldr ($) body mks)]
465           return (Var sel : concat prods, mk)
466
467
468     from_prod res_ty expr con
469       | []   <- tys = return ([], id)
470       | [ty] <- tys = return ([expr], id)
471       | otherwise
472         = do
473             prod_tc <- builtin (prodTyCon (length tys))
474             (pdata_tc, _) <- pdataReprTyCon (mkTyConApp prod_tc tys)
475             pdata_tys <- mapM mkPDataType tys
476             vars <- newLocalVars (fsLit "ys") pdata_tys
477             let [pdata_con] = tyConDataCons pdata_tc
478                 scrut       = unwrapFamInstScrut pdata_tc tys expr
479
480                 mk body = mkWildCase scrut (exprType scrut) res_ty
481                           [(DataAlt pdata_con, vars, body)]
482
483             return (map Var vars, mk)
484       where
485         tys = dataConRepArgTys con
486
487 buildPRDict :: TyCon -> TyCon -> TyCon -> VM CoreExpr
488 buildPRDict vect_tc prepr_tc _
489   = do
490       dict <- sum_dict (tyConDataCons vect_tc)
491       pr_co <- mkBuiltinCo prTyCon
492       let co = mkAppCoercion pr_co
493              . mkSymCoercion
494              $ mkTyConApp arg_co ty_args
495       return (mkCoerce co dict)
496   where
497     ty_args = mkTyVarTys (tyConTyVars vect_tc)
498     Just arg_co = tyConFamilyCoercion_maybe prepr_tc
499
500     sum_dict []    = prDFunOfTyCon =<< builtin voidTyCon
501     sum_dict [con] = prod_dict con
502     sum_dict cons  = do
503                        dicts <- mapM prod_dict cons
504                        (sum_tc, _, _, sum_ty_args) <- reprSumTyCons vect_tc
505                        dfun <- prDFunOfTyCon sum_tc
506                        return $ dfun `mkTyApps` sum_ty_args `mkApps` dicts
507
508     prod_dict con
509       | []   <- tys = prDFunOfTyCon =<< builtin voidTyCon
510       | [ty] <- tys = mkPR ty
511       | otherwise   = do
512                         dicts <- mapM mkPR tys
513                         prod_tc <- builtin (prodTyCon (length tys))
514                         dfun <- prDFunOfTyCon prod_tc
515                         return $ dfun `mkTyApps` tys `mkApps` dicts
516       where
517         tys = dataConRepArgTys con
518
519 buildPDataTyCon :: TyCon -> TyCon -> VM TyCon
520 buildPDataTyCon orig_tc vect_tc = fixV $ \repr_tc ->
521   do
522     name' <- cloneName mkPDataTyConOcc orig_name
523     rhs   <- buildPDataTyConRhs orig_name vect_tc repr_tc
524     pdata <- builtin pdataTyCon
525
526     liftDs $ buildAlgTyCon name'
527                            tyvars
528                            []          -- no stupid theta
529                            rhs
530                            rec_flag    -- FIXME: is this ok?
531                            False       -- FIXME: no generics
532                            False       -- not GADT syntax
533                            (Just $ mk_fam_inst pdata vect_tc)
534   where
535     orig_name = tyConName orig_tc
536     tyvars = tyConTyVars vect_tc
537     rec_flag = boolToRecFlag (isRecursiveTyCon vect_tc)
538
539
540 buildPDataTyConRhs :: Name -> TyCon -> TyCon -> VM AlgTyConRhs
541 buildPDataTyConRhs orig_name vect_tc repr_tc
542   = do
543       data_con <- buildPDataDataCon orig_name vect_tc repr_tc
544       return $ DataTyCon { data_cons = [data_con], is_enum = False }
545
546 buildPDataDataCon :: Name -> TyCon -> TyCon -> VM DataCon
547 buildPDataDataCon orig_name vect_tc repr_tc
548   = do
549       dc_name  <- cloneName mkPDataDataConOcc orig_name
550       comp_tys <- components
551
552       liftDs $ buildDataCon dc_name
553                             False                  -- not infix
554                             (map (const NotMarkedStrict) comp_tys)
555                             []                     -- no field labels
556                             tvs
557                             []                     -- no existentials
558                             []                     -- no eq spec
559                             []                     -- no context
560                             comp_tys
561                             (mkFamilyTyConApp repr_tc (mkTyVarTys tvs))
562                             repr_tc
563   where
564     tvs   = tyConTyVars vect_tc
565     cons  = tyConDataCons vect_tc
566     arity = length cons
567
568     components
569       | arity > 1 = liftM2 (:) (builtin (selTy arity)) data_components
570       | otherwise = data_components
571
572     data_components = mapM mkPDataType
573                     . concat
574                     $ map dataConRepArgTys cons
575
576 mkPADFun :: TyCon -> VM Var
577 mkPADFun vect_tc
578   = newExportedVar (mkPADFunOcc $ getOccName vect_tc) =<< paDFunType vect_tc
579
580 buildTyConBindings :: TyCon -> TyCon -> TyCon -> TyCon -> Var
581                    -> VM [(Var, CoreExpr)]
582 buildTyConBindings orig_tc vect_tc prepr_tc pdata_tc dfun
583   = do
584       vectDataConWorkers orig_tc vect_tc pdata_tc
585       dict <- buildPADict vect_tc prepr_tc pdata_tc dfun
586       binds <- takeHoisted
587       return $ (dfun, dict) : binds
588
589 vectDataConWorkers :: TyCon -> TyCon -> TyCon -> VM ()
590 vectDataConWorkers orig_tc vect_tc arr_tc
591   = do
592       bs <- sequence
593           . zipWith3 def_worker  (tyConDataCons orig_tc) rep_tys
594           $ zipWith4 mk_data_con (tyConDataCons vect_tc)
595                                  rep_tys
596                                  (inits rep_tys)
597                                  (tail $ tails rep_tys)
598       mapM_ (uncurry hoistBinding) bs
599   where
600     tyvars   = tyConTyVars vect_tc
601     var_tys  = mkTyVarTys tyvars
602     ty_args  = map Type var_tys
603     res_ty   = mkTyConApp vect_tc var_tys
604
605     cons     = tyConDataCons vect_tc
606     arity    = length cons
607     [arr_dc] = tyConDataCons arr_tc
608
609     rep_tys  = map dataConRepArgTys $ tyConDataCons vect_tc
610
611
612     mk_data_con con tys pre post
613       = liftM2 (,) (vect_data_con con)
614                    (lift_data_con tys pre post (mkDataConTag con))
615
616     sel_replicate len tag
617       | arity > 1 = do
618                       rep <- builtin (selReplicate arity)
619                       return [rep `mkApps` [len, tag]]
620
621       | otherwise = return []
622
623     vect_data_con con = return $ mkConApp con ty_args
624     lift_data_con tys pre_tys post_tys tag
625       = do
626           len  <- builtin liftingContext
627           args <- mapM (newLocalVar (fsLit "xs"))
628                   =<< mapM mkPDataType tys
629
630           sel  <- sel_replicate (Var len) tag
631
632           pre   <- mapM emptyPD (concat pre_tys)
633           post  <- mapM emptyPD (concat post_tys)
634
635           return . mkLams (len : args)
636                  . wrapFamInstBody arr_tc var_tys
637                  . mkConApp arr_dc
638                  $ ty_args ++ sel ++ pre ++ map Var args ++ post
639
640     def_worker data_con arg_tys mk_body
641       = do
642           body <- closedV
643                 . inBind orig_worker
644                 . polyAbstract tyvars $ \abstract ->
645                   liftM (abstract . vectorised)
646                 $ buildClosures tyvars [] arg_tys res_ty mk_body
647
648           vect_worker <- cloneId mkVectOcc orig_worker (exprType body)
649           defGlobalVar orig_worker vect_worker
650           return (vect_worker, body)
651       where
652         orig_worker = dataConWorkId data_con
653
654 buildPADict :: TyCon -> TyCon -> TyCon -> Var -> VM CoreExpr
655 buildPADict vect_tc prepr_tc arr_tc _
656   = polyAbstract tvs $ \abstract ->
657     do
658       meth_binds <- mapM mk_method paMethods
659       let meth_exprs = map (Var . fst) meth_binds
660
661       pa_dc <- builtin paDataCon
662       let dict = mkConApp pa_dc (Type (mkTyConApp vect_tc arg_tys) : meth_exprs)
663           body = Let (Rec meth_binds) dict
664       return . mkInlineMe $ abstract body
665   where
666     tvs = tyConTyVars arr_tc
667     arg_tys = mkTyVarTys tvs
668
669     mk_method (name, build)
670       = localV
671       $ do
672           body <- build vect_tc prepr_tc arr_tc
673           var  <- newLocalVar name (exprType body)
674           return (var, mkInlineMe body)
675
676 paMethods :: [(FastString, TyCon -> TyCon -> TyCon -> VM CoreExpr)]
677 paMethods = [(fsLit "toPRepr",      buildToPRepr),
678              (fsLit "fromPRepr",    buildFromPRepr),
679              (fsLit "toArrPRepr",   buildToArrPRepr),
680              (fsLit "fromArrPRepr", buildFromArrPRepr),
681              (fsLit "dictPRepr",    buildPRDict)]
682
683 -- | Split the given tycons into two sets depending on whether they have to be
684 -- converted (first list) or not (second list). The first argument contains
685 -- information about the conversion status of external tycons:
686 --
687 --   * tycons which have converted versions are mapped to True
688 --   * tycons which are not changed by vectorisation are mapped to False
689 --   * tycons which can't be converted are not elements of the map
690 --
691 classifyTyCons :: UniqFM Bool -> [TyConGroup] -> ([TyCon], [TyCon])
692 classifyTyCons = classify [] []
693   where
694     classify conv keep _  [] = (conv, keep)
695     classify conv keep cs ((tcs, ds) : rs)
696       | can_convert && must_convert
697         = classify (tcs ++ conv) keep (cs `addListToUFM` [(tc,True) | tc <- tcs]) rs
698       | can_convert
699         = classify conv (tcs ++ keep) (cs `addListToUFM` [(tc,False) | tc <- tcs]) rs
700       | otherwise
701         = classify conv keep cs rs
702       where
703         refs = ds `delListFromUniqSet` tcs
704
705         can_convert  = isNullUFM (refs `minusUFM` cs) && all convertable tcs
706         must_convert = foldUFM (||) False (intersectUFM_C const cs refs)
707
708         convertable tc = isDataTyCon tc && all isVanillaDataCon (tyConDataCons tc)
709
710 -- | Compute mutually recursive groups of tycons in topological order
711 --
712 tyConGroups :: [TyCon] -> [TyConGroup]
713 tyConGroups tcs = map mk_grp (stronglyConnCompFromEdgedVertices edges)
714   where
715     edges = [((tc, ds), tc, uniqSetToList ds) | tc <- tcs
716                                 , let ds = tyConsOfTyCon tc]
717
718     mk_grp (AcyclicSCC (tc, ds)) = ([tc], ds)
719     mk_grp (CyclicSCC els)       = (tcs, unionManyUniqSets dss)
720       where
721         (tcs, dss) = unzip els
722
723 tyConsOfTyCon :: TyCon -> UniqSet TyCon
724 tyConsOfTyCon
725   = tyConsOfTypes . concatMap dataConRepArgTys . tyConDataCons
726
727 tyConsOfType :: Type -> UniqSet TyCon
728 tyConsOfType ty
729   | Just ty' <- coreView ty    = tyConsOfType ty'
730 tyConsOfType (TyVarTy _)       = emptyUniqSet
731 tyConsOfType (TyConApp tc tys) = extend (tyConsOfTypes tys)
732   where
733     extend | isUnLiftedTyCon tc
734            || isTupleTyCon   tc = id
735
736            | otherwise          = (`addOneToUniqSet` tc)
737
738 tyConsOfType (AppTy a b)       = tyConsOfType a `unionUniqSets` tyConsOfType b
739 tyConsOfType (FunTy a b)       = (tyConsOfType a `unionUniqSets` tyConsOfType b)
740                                  `addOneToUniqSet` funTyCon
741 tyConsOfType (ForAllTy _ ty)   = tyConsOfType ty
742 tyConsOfType other             = pprPanic "ClosureConv.tyConsOfType" $ ppr other
743
744 tyConsOfTypes :: [Type] -> UniqSet TyCon
745 tyConsOfTypes = unionManyUniqSets . map tyConsOfType
746
747
748 -- ----------------------------------------------------------------------------
749 -- Conversions
750
751 fromVect :: Type -> CoreExpr -> VM CoreExpr
752 fromVect ty expr | Just ty' <- coreView ty = fromVect ty' expr
753 fromVect (FunTy arg_ty res_ty) expr
754   = do
755       arg     <- newLocalVar (fsLit "x") arg_ty
756       varg    <- toVect arg_ty (Var arg)
757       varg_ty <- vectType arg_ty
758       vres_ty <- vectType res_ty
759       apply   <- builtin applyVar
760       body    <- fromVect res_ty
761                $ Var apply `mkTyApps` [varg_ty, vres_ty] `mkApps` [expr, varg]
762       return $ Lam arg body
763 fromVect ty expr
764   = identityConv ty >> return expr
765
766 toVect :: Type -> CoreExpr -> VM CoreExpr
767 toVect ty expr = identityConv ty >> return expr
768
769 identityConv :: Type -> VM ()
770 identityConv ty | Just ty' <- coreView ty = identityConv ty'
771 identityConv (TyConApp tycon tys)
772   = do
773       mapM_ identityConv tys
774       identityConvTyCon tycon
775 identityConv _ = noV
776
777 identityConvTyCon :: TyCon -> VM ()
778 identityConvTyCon tc
779   | isBoxedTupleTyCon tc = return ()
780   | isUnLiftedTyCon tc   = return ()
781   | otherwise            = do
782                              tc' <- maybeV (lookupTyCon tc)
783                              if tc == tc' then return () else noV
784