Minor refactoring of placeHolderPunRhs
[ghc-hetmet.git] / compiler / vectorise / VectBuiltIn.hs
index 16b23ab..11538d5 100644 (file)
@@ -1,6 +1,6 @@
 module VectBuiltIn (
   Builtins(..), sumTyCon, prodTyCon, prodDataCon,
-  selTy, selReplicate, selPick, selElements,
+  selTy, selReplicate, selPick, selTags, selElements,
   combinePDVar, scalarZip, closureCtrFun,
   initBuiltins, initBuiltinVars, initBuiltinTyCons, initBuiltinDataCons,
   initBuiltinPAs, initBuiltinPRs,
@@ -11,11 +11,12 @@ module VectBuiltIn (
 
 import DsMonad
 import IfaceEnv        ( lookupOrig )
+import InstEnv
 
 import Module
 import DataCon         ( DataCon, dataConName, dataConWorkId )
 import TyCon           ( TyCon, tyConName, tyConDataCons )
-import Class           ( Class )
+import Class           ( Class, classTyCon )
 import CoreSyn         ( CoreExpr, Expr(..) )
 import Var             ( Var )
 import Id              ( mkSysLocal )
@@ -26,11 +27,11 @@ import OccName
 import TypeRep         ( funTyCon )
 import Type            ( Type, mkTyConApp )
 import TysPrim
-import TysWiredIn      ( unitTyCon, unitDataCon,
+import TysWiredIn      ( unitDataCon,
                          tupleTyCon, tupleCon,
-                         intTyCon, intTyConName,
-                         doubleTyCon, doubleTyConName,
-                         boolTyCon, boolTyConName, trueDataCon, falseDataCon,
+                         intTyCon,
+                         doubleTyCon,
+                         boolTyCon, trueDataCon, falseDataCon,
                          parrTyConName )
 import PrelNames       ( word8TyConName, gHC_PARR )
 import BasicTypes      ( Boxity(..) )
@@ -92,6 +93,8 @@ dph_Modules pkg = Modules {
   where
     mk = mkModule pkg . mkModuleNameFS
 
+dph_Orphans :: [Modules -> Module]
+dph_Orphans = [dph_Repr, dph_Instances]
 
 data Builtins = Builtins {
                   dphModules       :: Modules
@@ -108,20 +111,21 @@ data Builtins = Builtins {
                 , selTys           :: Array Int Type
                 , selReplicates    :: Array Int CoreExpr
                 , selPicks         :: Array Int CoreExpr
+                , selTagss         :: Array Int CoreExpr
                 , selEls           :: Array (Int, Int) CoreExpr
                 , sumTyCons        :: Array Int TyCon
                 , closureTyCon     :: TyCon
                 , voidVar          :: Var
                 , pvoidVar         :: Var
+                , fromVoidVar      :: Var
                 , punitVar         :: Var
-                , mkPRVar          :: Var
                 , closureVar       :: Var
                 , applyVar         :: Var
                 , liftedClosureVar :: Var
                 , liftedApplyVar   :: Var
                 , replicatePDVar   :: Var
                 , emptyPDVar       :: Var
-                , packPDVar        :: Var
+                , packByTagPDVar   :: Var
                 , combinePDVars    :: Array Int Var
                 , scalarClass      :: Class
                 , scalarZips       :: Array Int Var
@@ -146,6 +150,9 @@ selReplicate = indexBuiltin "selReplicate" selReplicates
 selPick :: Int -> Builtins -> CoreExpr
 selPick = indexBuiltin "selPick" selPicks
 
+selTags :: Int -> Builtins -> CoreExpr
+selTags = indexBuiltin "selTags" selTagss
+
 selElements :: Int -> Int -> Builtins -> CoreExpr
 selElements i j = indexBuiltin "selElements" selEls (i,j)
 
@@ -153,14 +160,14 @@ sumTyCon :: Int -> Builtins -> TyCon
 sumTyCon = indexBuiltin "sumTyCon" sumTyCons
 
 prodTyCon :: Int -> Builtins -> TyCon
-prodTyCon n bi
-  | n == 1                      = wrapTyCon bi
-  | n >= 0 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
+prodTyCon n _
+  | n >= 2 && n <= mAX_DPH_PROD = tupleTyCon Boxed n
   | otherwise = pprPanic "prodTyCon" (ppr n)
 
 prodDataCon :: Int -> Builtins -> DataCon
 prodDataCon n bi = case tyConDataCons (prodTyCon n bi) of
                      [con] -> con
+                     _     -> pprPanic "prodDataCon" (ppr n)
 
 combinePDVar :: Int -> Builtins -> Var
 combinePDVar = indexBuiltin "combinePDVar" combinePDVars
@@ -174,13 +181,14 @@ closureCtrFun = indexBuiltin "closureCtrFun" closureCtrFuns
 initBuiltins :: PackageId -> DsM Builtins
 initBuiltins pkg
   = do
+      mapM_ load dph_Orphans
       parrayTyCon  <- externalTyCon dph_PArray (fsLit "PArray")
       let [parrayDataCon] = tyConDataCons parrayTyCon
       pdataTyCon   <- externalTyCon dph_PArray (fsLit "PData")
-      paTyCon      <- externalTyCon dph_PArray (fsLit "PA")
+      paTyCon      <- externalClassTyCon dph_PArray (fsLit "PA")
       let [paDataCon] = tyConDataCons paTyCon
       preprTyCon   <- externalTyCon dph_PArray (fsLit "PRepr")
-      prTyCon      <- externalTyCon dph_PArray (fsLit "PR")
+      prTyCon      <- externalClassTyCon dph_PArray (fsLit "PR")
       let [prDataCon] = tyConDataCons prTyCon
       closureTyCon <- externalTyCon dph_Closure (fsLit ":->")
 
@@ -192,6 +200,8 @@ initBuiltins pkg
                              (numbered "replicate" 2 mAX_DPH_SUM)
       sel_picks    <- mapM (externalFun dph_Selector)
                            (numbered "pick" 2 mAX_DPH_SUM)
+      sel_tags     <- mapM (externalFun dph_Selector)
+                           (numbered "tagsSel" 2 mAX_DPH_SUM)
       sel_els      <- mapM mk_elements
                            [(i,j) | i <- [2..mAX_DPH_SUM], j <- [0..i-1]]
       sum_tcs      <- mapM (externalTyCon dph_Repr)
@@ -200,27 +210,28 @@ initBuiltins pkg
       let selTys        = listArray (2, mAX_DPH_SUM) sel_tys
           selReplicates = listArray (2, mAX_DPH_SUM) sel_replicates
           selPicks      = listArray (2, mAX_DPH_SUM) sel_picks
+          selTagss      = listArray (2, mAX_DPH_SUM) sel_tags
           selEls        = array ((2,0), (mAX_DPH_SUM, mAX_DPH_SUM)) sel_els
           sumTyCons     = listArray (2, mAX_DPH_SUM) sum_tcs
 
       voidVar          <- externalVar dph_Repr (fsLit "void")
       pvoidVar         <- externalVar dph_Repr (fsLit "pvoid")
+      fromVoidVar      <- externalVar dph_Repr (fsLit "fromVoid")
       punitVar         <- externalVar dph_Repr (fsLit "punit")
-      mkPRVar          <- externalVar dph_PArray (fsLit "mkPR")
       closureVar       <- externalVar dph_Closure (fsLit "closure")
       applyVar         <- externalVar dph_Closure (fsLit "$:")
       liftedClosureVar <- externalVar dph_Closure (fsLit "liftedClosure")
       liftedApplyVar   <- externalVar dph_Closure (fsLit "liftedApply")
       replicatePDVar   <- externalVar dph_PArray (fsLit "replicatePD")
       emptyPDVar       <- externalVar dph_PArray (fsLit "emptyPD")
-      packPDVar        <- externalVar dph_PArray (fsLit "packPD")
+      packByTagPDVar   <- externalVar dph_PArray (fsLit "packByTagPD")
 
       combines <- mapM (externalVar dph_PArray)
                        [mkFastString ("combine" ++ show i ++ "PD")
                           | i <- [2..mAX_DPH_COMBINE]]
       let combinePDVars = listArray (2, mAX_DPH_COMBINE) combines
 
-      scalarClass <- externalClass dph_Scalar (fsLit "Scalar")
+      scalarClass <- externalClass dph_PArray (fsLit "Scalar")
       scalar_map <- externalVar dph_Scalar (fsLit "scalar_map")
       scalar_zip2 <- externalVar dph_Scalar (fsLit "scalar_zipWith")
       scalar_zips <- mapM (externalVar dph_Scalar)
@@ -249,20 +260,21 @@ initBuiltins pkg
                , selTys           = selTys
                , selReplicates    = selReplicates
                , selPicks         = selPicks
+               , selTagss         = selTagss
                , selEls           = selEls
                , sumTyCons        = sumTyCons
                , closureTyCon     = closureTyCon
                , voidVar          = voidVar
                , pvoidVar         = pvoidVar
+               , fromVoidVar      = fromVoidVar
                , punitVar         = punitVar
-               , mkPRVar          = mkPRVar
                , closureVar       = closureVar
                , applyVar         = applyVar
                , liftedClosureVar = liftedClosureVar
                , liftedApplyVar   = liftedApplyVar
                , replicatePDVar   = replicatePDVar
                , emptyPDVar       = emptyPDVar
-               , packPDVar        = packPDVar
+               , packByTagPDVar   = packByTagPDVar
                , combinePDVars    = combinePDVars
                , scalarClass      = scalarClass
                , scalarZips       = scalarZips
@@ -275,11 +287,15 @@ initBuiltins pkg
              , dph_Repr           = dph_Repr
              , dph_Closure        = dph_Closure
              , dph_Selector       = dph_Selector
-             , dph_Unboxed        = dph_Unboxed
              , dph_Scalar         = dph_Scalar
              })
       = dph_Modules pkg
 
+    load get_mod = dsLoadModule doc mod
+      where
+        mod = get_mod modules 
+        doc = ppr mod <+> ptext (sLit "is a DPH module")
+
     numbered :: String -> Int -> Int -> [FastString]
     numbered pfx m n = [mkFastString (pfx ++ show i) | i <- [m..n]]
 
@@ -454,66 +470,19 @@ initBuiltinDataCons _ = [(dataConName dc, dc)| dc <- defaultDataCons]
 defaultDataCons :: [DataCon]
 defaultDataCons = [trueDataCon, falseDataCon, unitDataCon]
 
-initBuiltinDicts :: [(Name, Module, FastString)] -> DsM [(Name, Var)]
-initBuiltinDicts ps
-  = do
-      dicts <- zipWithM externalVar mods fss
-      return $ zip tcs dicts
-  where
-    (tcs, mods, fss) = unzip3 ps
+initBuiltinPAs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
+initBuiltinPAs (Builtins { dphModules = mods }) insts
+  = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PA"))
 
-initBuiltinPAs :: Builtins -> DsM [(Name, Var)]
-initBuiltinPAs = initBuiltinDicts . builtinPAs
+initBuiltinPRs :: Builtins -> (InstEnv, InstEnv) -> DsM [(Name, Var)]
+initBuiltinPRs (Builtins { dphModules = mods }) insts
+  = liftM (initBuiltinDicts insts) (externalClass (dph_PArray mods) (fsLit "PR"))
 
-builtinPAs :: Builtins -> [(Name, Module, FastString)]
-builtinPAs bi@(Builtins { dphModules = mods })
-  = [
-      mk (tyConName $ closureTyCon bi)  (dph_Closure   mods) (fsLit "dPA_Clo")
-    , mk (tyConName $ voidTyCon bi)     (dph_Repr      mods) (fsLit "dPA_Void")
-    , mk (tyConName $ parrayTyCon bi)   (dph_Instances mods) (fsLit "dPA_PArray")
-    , mk unitTyConName                  (dph_Instances mods) (fsLit "dPA_Unit")
-
-    , mk intTyConName                   (dph_Instances mods) (fsLit "dPA_Int")
-    , mk word8TyConName                 (dph_Instances mods) (fsLit "dPA_Word8")
-    , mk doubleTyConName                (dph_Instances mods) (fsLit "dPA_Double")
-    , mk boolTyConName                  (dph_Instances mods) (fsLit "dPA_Bool")
-    ]
-    ++ tups
+initBuiltinDicts :: (InstEnv, InstEnv) -> Class -> [(Name, Var)]
+initBuiltinDicts insts cls = map find $ classInstances insts cls
   where
-    mk name mod fs = (name, mod, fs)
-
-    tups = map mk_tup [2..mAX_DPH_PROD]
-    mk_tup n = mk (tyConName $ tupleTyCon Boxed n)
-                  (dph_Instances mods)
-                  (mkFastString $ "dPA_" ++ show n)
-
-initBuiltinPRs :: Builtins -> DsM [(Name, Var)]
-initBuiltinPRs = initBuiltinDicts . builtinPRs
-
-builtinPRs :: Builtins -> [(Name, Module, FastString)]
-builtinPRs bi@(Builtins { dphModules = mods }) =
-  [
-    mk (tyConName   unitTyCon)           (dph_Repr mods)    (fsLit "dPR_Unit")
-  , mk (tyConName $ voidTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Void")
-  , mk (tyConName $ wrapTyCon        bi) (dph_Repr mods)    (fsLit "dPR_Wrap")
-  , mk (tyConName $ closureTyCon     bi) (dph_Closure mods) (fsLit "dPR_Clo")
-
-    -- temporary
-  , mk intTyConName          (dph_Instances mods) (fsLit "dPR_Int")
-  , mk word8TyConName        (dph_Instances mods) (fsLit "dPR_Word8")
-  , mk doubleTyConName       (dph_Instances mods) (fsLit "dPR_Double")
-  ]
-
-  ++ map mk_sum  [2..mAX_DPH_SUM]
-  ++ map mk_prod [2..mAX_DPH_PROD]
-  where
-    mk name mod fs = (name, mod, fs)
-
-    mk_sum n = (tyConName $ sumTyCon n bi, dph_Repr mods,
-                mkFastString ("dPR_Sum" ++ show n))
-
-    mk_prod n = (tyConName $ prodTyCon n bi, dph_Repr mods,
-                 mkFastString ("dPR_" ++ show n))
+    find i | [Just tc] <- instanceRoughTcs i = (tc, instanceDFunId i)
+           | otherwise = pprPanic "Invalid DPH instance" (ppr i)
 
 initBuiltinBoxedTyCons :: Builtins -> DsM [(Name, TyCon)]
 initBuiltinBoxedTyCons = return . builtinBoxedTyCons
@@ -621,6 +590,9 @@ externalTyCon :: Module -> FastString -> DsM TyCon
 externalTyCon mod fs
   = dsLookupTyCon =<< lookupOrig mod (mkTcOccFS fs)
 
+externalClassTyCon :: Module -> FastString -> DsM TyCon
+externalClassTyCon mod fs = liftM classTyCon (externalClass mod fs)
+
 externalType :: Module -> FastString -> DsM Type
 externalType mod fs
   = do
@@ -629,11 +601,7 @@ externalType mod fs
 
 externalClass :: Module -> FastString -> DsM Class
 externalClass mod fs
-  = dsLookupClass =<< lookupOrig mod (mkTcOccFS fs)
-
-unitTyConName :: Name
-unitTyConName = tyConName unitTyCon
-
+  = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
 
 primMethod :: TyCon -> String -> Builtins -> DsM (Maybe Var)
 primMethod  tycon method (Builtins { dphModules = mods })