Fix vectorisation of recursive types
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 26 Jan 2011 23:18:43 +0000 (23:18 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Wed, 26 Jan 2011 23:18:43 +0000 (23:18 +0000)
13 files changed:
compiler/ghc.cabal.in
compiler/vectorise/Vectorise.hs
compiler/vectorise/Vectorise/Builtins/Base.hs
compiler/vectorise/Vectorise/Builtins/Initialise.hs
compiler/vectorise/Vectorise/Env.hs
compiler/vectorise/Vectorise/Monad/Base.hs
compiler/vectorise/Vectorise/Type/Env.hs
compiler/vectorise/Vectorise/Type/PADict.hs
compiler/vectorise/Vectorise/Type/PRDict.hs [deleted file]
compiler/vectorise/Vectorise/Type/Repr.hs
compiler/vectorise/Vectorise/Utils.hs
compiler/vectorise/Vectorise/Utils/PADict.hs
compiler/vectorise/Vectorise/Utils/PRDict.hs [deleted file]

index a8c2d54..cc4c562 100644 (file)
@@ -479,7 +479,6 @@ Library
         Vectorise.Utils.Closure
         Vectorise.Utils.Hoisting
         Vectorise.Utils.PADict
-        Vectorise.Utils.PRDict
         Vectorise.Utils.Poly
         Vectorise.Utils
         Vectorise.Type.Env
@@ -487,7 +486,6 @@ Library
         Vectorise.Type.PData
         Vectorise.Type.PRepr
         Vectorise.Type.PADict
-        Vectorise.Type.PRDict
         Vectorise.Type.Type
         Vectorise.Type.TyConDecl
         Vectorise.Type.Classify
index b4b383e..5e45c97 100644 (file)
@@ -18,7 +18,6 @@ import CoreSyn
 import CoreUnfold           ( mkInlineUnfolding )
 import CoreFVs
 import CoreMonad            ( CoreM, getHscEnv )
-import FamInstEnv           ( extendFamInstEnvList )
 import Var
 import Id
 import OccName
@@ -62,9 +61,7 @@ vectModule guts
       -- TODO: What new binds do we get back here?
       (types', fam_insts, tc_binds) <- vectTypeEnv (mg_types guts)
 
-      -- TODO: What is this?
-      let fam_inst_env' = extendFamInstEnvList (mg_fam_inst_env guts) fam_insts
-      updGEnv (setFamInstEnv fam_inst_env')
+      (_, fam_inst_env) <- readGEnv global_fam_inst_env
 
       -- dicts   <- mapM buildPADict pa_insts
       -- workers <- mapM vectDataConWorkers pa_insts
@@ -74,7 +71,7 @@ vectModule guts
 
       return $ guts { mg_types        = types'
                     , mg_binds        = Rec tc_binds : binds'
-                    , mg_fam_inst_env = fam_inst_env'
+                    , mg_fam_inst_env = fam_inst_env
                     , mg_fam_insts    = mg_fam_insts guts ++ fam_insts
                     }
 
index 5e4d47d..69ae84f 100644 (file)
@@ -61,10 +61,12 @@ data Builtins
         , parrayTyCon      :: TyCon                    -- ^ PArray
         , parrayDataCon    :: DataCon                  -- ^ PArray
         , pdataTyCon       :: TyCon                    -- ^ PData
+        , paClass          :: Class                     -- ^ PA
         , paTyCon          :: TyCon                    -- ^ PA
         , paDataCon        :: DataCon                  -- ^ PA
         , paPRSel          :: Var                       -- ^ PA
         , preprTyCon       :: TyCon                    -- ^ PRepr
+        , prClass          :: Class                     -- ^ PR
         , prTyCon          :: TyCon                    -- ^ PR
         , prDataCon        :: DataCon                  -- ^ PR
         , replicatePDVar   :: Var                      -- ^ replicatePD
index d9a1f0d..9e78f11 100644 (file)
@@ -46,14 +46,15 @@ initBuiltins pkg
       let [parrayDataCon] = tyConDataCons parrayTyCon
 
       pdataTyCon       <- externalTyCon        dph_PArray      (fsLit "PData")
-      pa                <- externalClass        dph_PArray      (fsLit "PA")
-      let paTyCon     = classTyCon pa
+      paClass           <- externalClass        dph_PArray      (fsLit "PA")
+      let paTyCon     = classTyCon paClass
           [paDataCon] = tyConDataCons paTyCon
-          paPRSel     = classSCSelId pa 0
+          paPRSel     = classSCSelId paClass 0
 
       preprTyCon       <- externalTyCon        dph_PArray      (fsLit "PRepr")
-      prTyCon          <- externalClassTyCon   dph_PArray      (fsLit "PR")
-      let [prDataCon]  = tyConDataCons prTyCon
+      prClass           <- externalClass        dph_PArray      (fsLit "PR")
+      let prTyCon     = classTyCon prClass
+          [prDataCon] = tyConDataCons prTyCon
 
       closureTyCon     <- externalTyCon dph_Closure            (fsLit ":->")
 
@@ -127,10 +128,12 @@ initBuiltins pkg
                , parrayTyCon      = parrayTyCon
                , parrayDataCon    = parrayDataCon
                , pdataTyCon       = pdataTyCon
+               , paClass          = paClass
                , paTyCon          = paTyCon
                , paDataCon        = paDataCon
                , paPRSel          = paPRSel
                , preprTyCon       = preprTyCon
+               , prClass          = prClass
                , prTyCon          = prTyCon
                , prDataCon        = prDataCon
                , voidTyCon        = voidTyCon
@@ -308,9 +311,3 @@ externalClass :: Module -> FastString -> DsM Class
 externalClass mod fs
   = dsLookupClass =<< lookupOrig mod (mkClsOccFS fs)
 
-
--- | Like `externalClass`, but get the TyCon of of the class.
-externalClassTyCon :: Module -> FastString -> DsM TyCon
-externalClassTyCon mod fs = liftM classTyCon (externalClass mod fs)
-
-
index 30f259b..70ed8c4 100644 (file)
@@ -11,7 +11,8 @@ module Vectorise.Env (
        initGlobalEnv,
        extendImportedVarsEnv,
        extendScalars,
-       setFamInstEnv,
+       setFamEnv,
+        extendFamEnv,
        extendTyConsEnv,
        extendDataConsEnv,
        extendPAFunsEnv,
@@ -142,11 +143,16 @@ extendScalars vs genv
 
 
 -- | Set the list of type family instances in an environment.
-setFamInstEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
-setFamInstEnv l_fam_inst genv
+setFamEnv :: FamInstEnv -> GlobalEnv -> GlobalEnv
+setFamEnv l_fam_inst genv
   = genv { global_fam_inst_env = (g_fam_inst, l_fam_inst) }
   where (g_fam_inst, _) = global_fam_inst_env genv
 
+extendFamEnv :: [FamInst] -> GlobalEnv -> GlobalEnv
+extendFamEnv new genv
+  = genv { global_fam_inst_env = (g_fam_inst, extendFamInstEnvList l_fam_inst new) }
+  where (g_fam_inst, l_fam_inst) = global_fam_inst_env genv
+
 
 -- | Extend the list of type constructors in an environment.
 extendTyConsEnv :: [(Name, TyCon)] -> GlobalEnv -> GlobalEnv
index 98da3fe..c2c314f 100644 (file)
@@ -77,7 +77,6 @@ maybeCantVectoriseM s d p
         Just x  -> return x
         Nothing -> cantVectorise s d
 
-
 -- Control --------------------------------------------------------------------
 -- | Return some result saying we've failed.
 noV :: VM a
index 99c1746..61a52bc 100644 (file)
@@ -82,6 +82,13 @@ vectTypeEnv env
       let vect_tcs  = filter (not . isClassTyCon) 
                     $ keep_tcs ++ new_tcs
 
+      reprs <- mapM tyConRepr vect_tcs
+      repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
+      pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
+      updGEnv $ extendFamEnv
+              $ map mkLocalFamInst
+              $ repr_tcs ++ pdata_tcs
+
       -- Create PRepr and PData instances for the vectorised types.
       -- We get back the binds for the instance functions, 
       -- and some new type constructors for the representation types.
@@ -89,8 +96,6 @@ vectTypeEnv env
         do
           defTyConPAs (zipLazy vect_tcs dfuns')
           reprs     <- mapM tyConRepr vect_tcs
-          repr_tcs  <- zipWith3M buildPReprTyCon orig_tcs vect_tcs reprs
-          pdata_tcs <- zipWith3M buildPDataTyCon orig_tcs vect_tcs reprs
 
           dfuns     <- sequence 
                     $  zipWith5 buildTyConBindings
index ed6264a..4c786cf 100644 (file)
@@ -6,7 +6,6 @@ import Vectorise.Monad
 import Vectorise.Builtins
 import Vectorise.Type.Repr
 import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
 import Vectorise.Utils
 
 import BasicTypes
@@ -19,7 +18,8 @@ import TypeRep
 import Id
 import Var
 import Name
-import Outputable
+import FastString
+-- import Outputable
 
 -- debug               = False
 -- dtrace s x  = if debug then pprTrace "Vectoris.Type.PADict" s x else x
@@ -29,38 +29,52 @@ import Outputable
 buildPADict
        :: TyCon        -- ^ tycon of the type being vectorised.
        -> TyCon        -- ^ tycon of the type used for the vectorised representation.
-       -> TyCon        -- 
+       -> TyCon        -- ^ PRepr instance tycon
        -> SumRepr      -- ^ representation used for the type being vectorised.
        -> VM Var       -- ^ name of the top-level dictionary function.
 
 buildPADict vect_tc prepr_tc arr_tc repr
  = polyAbstract tvs $ \args ->
- case args of
-  (_:_) -> pprPanic "Vectorise.Type.PADict.buildPADict" (text "why do we need superclass dicts?")
-  [] -> do
-      -- TODO: I'm forcing args to [] because I'm not sure why we need them.
-      --       class PA has superclass (PR (PRepr a)) but we're not using
-      --       the superclass dictionary to build the PA dictionary.
+   do
+      -- The superclass dictionary is an argument if the tycon is polymorphic
+      let mk_super_ty = do
+                          r <- mkPReprType inst_ty
+                          pr_cls <- builtin prClass
+                          return $ PredTy $ ClassP pr_cls [r]
+      super_tys <- sequence [mk_super_ty | not (null tvs)]
+      super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+      let args' = super_args ++ args
+
+      -- it is constant otherwise
+      super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+                                                | null tvs]
 
       -- Get ids for each of the methods in the dictionary.
-      method_ids <- mapM (method args) paMethods
+      method_ids <- mapM (method args') paMethods
 
       -- Expression to build the dictionary.
       pa_dc  <- builtin paDataCon
-      let dict = mkLams (tvs ++ args)
+      let dict = mkLams (tvs ++ args')
                $ mkConApp pa_dc
-               $ Type inst_ty : map (method_call args) method_ids
+               $ Type inst_ty
+               : map Var super_args ++ super_consts
+                                   -- the superclass dictionary is
+                                   -- either lambda-bound or
+                                   -- constant
+                 ++ map (method_call args') method_ids
 
       -- Build the type of the dictionary function.
-      pa_tc            <- builtin paTyCon
-      let Just pa_cls  = tyConClass_maybe pa_tc
-
+      pa_cls <- builtin paClass
       let dfun_ty      = mkForAllTys tvs
-                       $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
+                       $ mkFunTys (map varType args')
+                                   (PredTy $ ClassP pa_cls [inst_ty])
 
       -- Set the unfolding for the inliner.
       raw_dfun <- newExportedVar dfun_name dfun_ty
-      let dfun_unf = mkDFunUnfolding dfun_ty (map (DFunPolyArg . Var) method_ids)
+      let dfun_unf = mkDFunUnfolding dfun_ty
+                   $ map (const $ DFunLamArg 0) super_args
+                     ++ map DFunConstArg super_consts
+                     ++ map (DFunPolyArg . Var) method_ids
           dfun = raw_dfun `setIdUnfolding`  dfun_unf
                           `setInlinePragma` dfunInlinePragma
 
@@ -91,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr
 
 
 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr",    buildPRDict),
-             ("toPRepr",      buildToPRepr),
+paMethods = [("toPRepr",      buildToPRepr),
              ("fromPRepr",    buildFromPRepr),
              ("toArrPRepr",   buildToArrPRepr),
              ("fromArrPRepr", buildFromArrPRepr)]
diff --git a/compiler/vectorise/Vectorise/Type/PRDict.hs b/compiler/vectorise/Vectorise/Type/PRDict.hs
deleted file mode 100644 (file)
index 1a55116..0000000
+++ /dev/null
@@ -1,56 +0,0 @@
-
-module Vectorise.Type.PRDict 
-       (buildPRDict)
-where
-import Vectorise.Utils
-import Vectorise.Monad
-import Vectorise.Builtins
-import Vectorise.Type.Repr
-import CoreSyn
-import CoreUtils
-import TyCon
-import Type
-import Coercion
-
-
-
-buildPRDict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr
-buildPRDict vect_tc prepr_tc _ r
-  = do
-      dict <- sum_dict r
-      pr_co <- mkBuiltinCo prTyCon
-      let co = mkAppCoercion pr_co
-             . mkSymCoercion
-             $ mkTyConApp arg_co ty_args
-      return (mkCoerce co dict)
-  where
-    ty_args = mkTyVarTys (tyConTyVars vect_tc)
-    Just arg_co = tyConFamilyCoercion_maybe prepr_tc
-
-    sum_dict EmptySum = prDFunOfTyCon =<< builtin voidTyCon
-    sum_dict (UnarySum r) = con_dict r
-    sum_dict (Sum { repr_sum_tc  = sum_tc
-                  , repr_con_tys = tys
-                  , repr_cons    = cons
-                  })
-      = do
-          dicts <- mapM con_dict cons
-          dfun  <- prDFunOfTyCon sum_tc
-          return $ dfun `mkTyApps` tys `mkApps` dicts
-
-    con_dict (ConRepr _ r) = prod_dict r
-
-    prod_dict EmptyProd = prDFunOfTyCon =<< builtin voidTyCon
-    prod_dict (UnaryProd r) = comp_dict r
-    prod_dict (Prod { repr_tup_tc   = tup_tc
-                    , repr_comp_tys = tys
-                    , repr_comps    = comps })
-      = do
-          dicts <- mapM comp_dict comps
-          dfun <- prDFunOfTyCon tup_tc
-          return $ dfun `mkTyApps` tys `mkApps` dicts
-
-    comp_dict (Keep _ pr) = return pr
-    comp_dict (Wrap ty)   = wrapPR ty
-
-
index 40242ae..bb300ca 100644 (file)
@@ -82,7 +82,7 @@ tyConRepr tc = sum_repr (tyConDataCons tc)
       where
         arity = length tys
     
-    comp_repr ty = liftM (Keep ty) (prDictOfType ty)
+    comp_repr ty = liftM (Keep ty) (prDictOfReprType ty)
                    `orElseV` return (Wrap ty)
 
 sumReprType :: SumRepr -> VM Type
index e701383..1a099e3 100644 (file)
@@ -4,7 +4,6 @@ module Vectorise.Utils (
   module Vectorise.Utils.Closure,
   module Vectorise.Utils.Hoisting,
   module Vectorise.Utils.PADict,
-  module Vectorise.Utils.PRDict,
   module Vectorise.Utils.Poly,
 
   -- * Annotated Exprs
@@ -28,7 +27,6 @@ import Vectorise.Utils.Base
 import Vectorise.Utils.Closure
 import Vectorise.Utils.Hoisting
 import Vectorise.Utils.PADict
-import Vectorise.Utils.PRDict
 import Vectorise.Utils.Poly
 import Vectorise.Monad
 import Vectorise.Builtins
index 93f2297..329cb63 100644 (file)
@@ -2,7 +2,9 @@
 module Vectorise.Utils.PADict (
        paDictArgType,
        paDictOfType,
-       paMethod        
+       paMethod,
+        prDictOfReprType,
+        prDictOfPReprInstTyCon
 )
 where
 import Vectorise.Monad
@@ -42,7 +44,9 @@ paDictArgType tv = go (TyVarTy tv) (tyVarKind tv)
 
     go ty k
       | isLiftedTypeKind k
-      = liftM Just (mkBuiltinTyConApp paTyCon [ty])
+      = do
+          pa_cls <- builtin paClass
+          return $ Just $ PredTy $ ClassP pa_cls [ty]
 
     go _ _ = return Nothing
 
@@ -108,17 +112,36 @@ prDictOfPReprInst :: Type -> VM CoreExpr
 prDictOfPReprInst ty
   = do
       (prepr_tc, prepr_args) <- preprSynTyCon ty
-      case coreView (mkTyConApp prepr_tc prepr_args) of
-        Just rhs -> do
-                      dict <- prDictOfReprType rhs
-                      pr_co <- mkBuiltinCo prTyCon
-                      let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
-                      let co = mkAppCoercion pr_co
-                             $ mkSymCoercion
-                             $ mkTyConApp arg_co prepr_args
-                      return $ mkCoerce co dict
-        Nothing  -> cantVectorise "Invalid PRepr type instance"
-                                  $ ppr ty
+      prDictOfPReprInstTyCon ty prepr_tc prepr_args
+
+-- | Given a type @ty@, its PRepr synonym tycon and its type arguments,
+-- return the PR @PRepr ty@. Suppose we have:
+--
+-- > type instance PRepr (T a1 ... an) = t
+--
+-- which is internally translated into
+--
+-- > type :R:PRepr a1 ... an = t
+--
+-- and the corresponding coercion. Then,
+--
+-- > prDictOfPReprInstTyCon (T a1 ... an) :R:PRepr u1 ... un = PR (T u1 ... un)
+--
+-- Note that @ty@ is only used for error messages
+--
+prDictOfPReprInstTyCon :: Type -> TyCon -> [Type] -> VM CoreExpr
+prDictOfPReprInstTyCon ty prepr_tc prepr_args
+  | Just rhs <- coreView (mkTyConApp prepr_tc prepr_args)
+  = do
+      dict <- prDictOfReprType' rhs
+      pr_co <- mkBuiltinCo prTyCon
+      let Just arg_co = tyConFamilyCoercion_maybe prepr_tc
+      let co = mkAppCoercion pr_co
+             $ mkSymCoercion
+             $ mkTyConApp arg_co prepr_args
+      return $ mkCoerce co dict
+
+  | otherwise = cantVectorise "Invalid PRepr type instance" (ppr ty)
 
 -- | Get the PR dictionary for a type. The argument must be a representation
 -- type.
@@ -129,14 +152,13 @@ prDictOfReprType ty
         prepr <- builtin preprTyCon
         if tycon == prepr
           then do
-                 [ty'] <- return tyargs
-                 prDictOfPReprInst ty'
+                 let [ty'] = tyargs
+                 pa <- paDictOfType ty'
+                 sel <- builtin paPRSel
+                 return $ Var sel `App` Type ty' `App` pa
           else do 
                  -- a representation tycon must have a PR instance
-                 dfun <- maybeCantVectoriseM
-                           "No PR dictionary for type constructor"
-                           (ppr tycon <+> text "in" <+> ppr ty)
-                       $ lookupTyConPR tycon
+                 dfun <- maybeV $ lookupTyConPR tycon
                  prDFunApply dfun tyargs
 
   | otherwise
@@ -153,6 +175,11 @@ prDictOfReprType ty
         prsel <- builtin paPRSel
         return $ Var prsel `mkApps` [Type ty, pa]
 
+prDictOfReprType' :: Type -> VM CoreExpr
+prDictOfReprType' ty = prDictOfReprType ty `orElseV`
+                       cantVectorise "No PR dictionary for representation type"
+                                     (ppr ty)
+
 -- | Apply a tycon's PR dfun to dictionary arguments (PR or PA) corresponding
 -- to the argument types.
 prDFunApply :: Var -> [Type] -> VM CoreExpr
diff --git a/compiler/vectorise/Vectorise/Utils/PRDict.hs b/compiler/vectorise/Vectorise/Utils/PRDict.hs
deleted file mode 100644 (file)
index a5d09df..0000000
+++ /dev/null
@@ -1,43 +0,0 @@
-
-module Vectorise.Utils.PRDict (
-       prDictOfType,
-       wrapPR
-)
-where
-import Vectorise.Monad
-import Vectorise.Builtins
-import Vectorise.Utils.Base
-import Vectorise.Utils.PADict
-
-import CoreSyn
-import Type
-import TypeRep
-import Control.Monad
-
-
-prDictOfType :: Type -> VM CoreExpr
-prDictOfType ty = prDictOfTyApp ty_fn ty_args
-  where
-    (ty_fn, ty_args) = splitAppTys ty
-
-prDictOfTyApp :: Type -> [Type] -> VM CoreExpr
-prDictOfTyApp ty_fn ty_args
-  | Just ty_fn' <- coreView ty_fn = prDictOfTyApp ty_fn' ty_args
-prDictOfTyApp (TyConApp tc _) ty_args
-  = do
-      dfun <- liftM Var $ maybeV (lookupTyConPR tc)
-      prDFunApply dfun ty_args
-prDictOfTyApp _ _ = noV
-
-prDFunApply :: CoreExpr -> [Type] -> VM CoreExpr
-prDFunApply dfun tys
-  = do
-      dicts <- mapM prDictOfType tys
-      return $ mkApps (mkTyApps dfun tys) dicts
-
-wrapPR :: Type -> VM CoreExpr
-wrapPR ty
-  = do
-      pa_dict <- paDictOfType ty
-      pr_dfun <- prDFunOfTyCon =<< builtin wrapTyCon
-      return $ mkApps pr_dfun [Type ty, pa_dict]