Vectorise polymorphic let bindings
authorRoman Leshchinskiy <rl@cse.unsw.edu.au>
Sun, 4 May 2008 05:40:06 +0000 (05:40 +0000)
committerRoman Leshchinskiy <rl@cse.unsw.edu.au>
Sun, 4 May 2008 05:40:06 +0000 (05:40 +0000)
compiler/vectorise/VectType.hs
compiler/vectorise/Vectorise.hs

index 79e37fc..0e942ca 100644 (file)
@@ -5,7 +5,7 @@
 --     http://hackage.haskell.org/trac/ghc/wiki/Commentary/CodingStyle#Warnings
 -- for details
 
-module VectType ( vectTyCon, vectType, vectTypeEnv,
+module VectType ( vectTyCon, vectAndLiftType, vectType, vectTypeEnv,
                   mkRepr, arrShapeTys, arrShapeVars, arrSelector,
                   PAInstance, buildPADict,
                   fromVect )
@@ -29,7 +29,7 @@ import InstEnv           ( Instance, mkLocalInstance, instanceDFunId )
 import OccName
 import MkId
 import BasicTypes        ( StrictnessMark(..), OverlapFlag(..), boolToRecFlag )
-import Var               ( Var )
+import Var               ( Var, TyVar )
 import Id                ( mkWildId )
 import Name              ( Name, getOccName )
 import NameEnv
@@ -64,6 +64,20 @@ vectTyCon tc
                     -- FIXME: just for now
                     Nothing  -> pprTrace "ccTyCon:" (ppr tc) $ return tc
 
+vectAndLiftType :: Type -> VM (Type, Type)
+vectAndLiftType ty | Just ty' <- coreView ty = vectAndLiftType ty'
+vectAndLiftType ty
+  = do
+      mdicts   <- mapM paDictArgType tyvars
+      let dicts = [dict | Just dict <- mdicts]
+      vmono_ty <- vectType mono_ty
+      lmono_ty <- mkPArrayType vmono_ty
+      return (abstractType tyvars dicts vmono_ty,
+              abstractType tyvars dicts lmono_ty)
+  where
+    (tyvars, mono_ty) = splitForAllTys ty
+
+
 vectType :: Type -> VM Type
 vectType ty | Just ty' <- coreView ty = vectType ty'
 vectType (TyVarTy tv) = return $ TyVarTy tv
@@ -75,7 +89,7 @@ vectType ty@(ForAllTy _ _)
   = do
       mdicts   <- mapM paDictArgType tyvars
       mono_ty' <- vectType mono_ty
-      return $ tyvars `mkForAllTys` ([dict | Just dict <- mdicts] `mkFunTys` mono_ty')
+      return $ abstractType tyvars [dict | Just dict <- mdicts] mono_ty'
   where
     (tyvars, mono_ty) = splitForAllTys ty
 
@@ -84,6 +98,9 @@ vectType ty = pprPanic "vectType:" (ppr ty)
 vectAndBoxType :: Type -> VM Type
 vectAndBoxType ty = vectType ty >>= boxType
 
+abstractType :: [TyVar] -> [Type] -> Type -> Type
+abstractType tyvars dicts = mkForAllTys tyvars . mkFunTys dicts
+
 -- ----------------------------------------------------------------------------
 -- Boxing
 
index 1c185bd..024ae45 100644 (file)
@@ -129,8 +129,7 @@ tryConvert var vect_var rhs
 vectBndr :: Var -> VM VVar
 vectBndr v
   = do
-      vty <- vectType (idType v)
-      lty <- mkPArrayType vty
+      (vty, lty) <- vectAndLiftType (idType v)
       let vv = v `Id.setIdType` vty
           lv = v `Id.setIdType` lty
       updLEnv (mapTo vv lv)
@@ -342,26 +341,23 @@ type CoreAltWithFVs = AnnAlt Id VarSet
 -- FIXME: this is too lazy
 vectAlgCase tycon ty_args scrut bndr ty [(DEFAULT, [], body)]
   = do
-      vscrut <- vectExpr scrut
-      vty    <- vectType ty
-      lty    <- mkPArrayType vty
+      vscrut         <- vectExpr scrut
+      (vty, lty)     <- vectAndLiftType ty
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
 
 vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, [], body)]
   = do
-      vscrut <- vectExpr scrut
-      vty    <- vectType ty
-      lty    <- mkPArrayType vty
+      vscrut         <- vectExpr scrut
+      (vty, lty)     <- vectAndLiftType ty
       (vbndr, vbody) <- vectBndrIn bndr (vectExpr body)
       return $ vCaseDEFAULT vscrut vbndr vty lty vbody
 
 vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
   = do
-      vect_tc <- maybeV (lookupTyCon tycon)
-      vty <- vectType ty
-      lty <- mkPArrayType vty
-      vexpr <- vectExpr scrut
+      vect_tc    <- maybeV (lookupTyCon tycon)
+      (vty, lty) <- vectAndLiftType ty
+      vexpr      <- vectExpr scrut
       (vbndr, (vbndrs, vbody)) <- vect_scrut_bndr
                                 . vectBndrsIn bndrs
                                 $ vectExpr body
@@ -379,10 +375,8 @@ vectAlgCase tycon ty_args scrut bndr ty [(DataAlt dc, bndrs, body)]
 
 vectAlgCase tycon ty_args scrut bndr ty alts
   = do
-      vect_tc <- maybeV (lookupTyCon tycon)
-      vty               <- vectType ty
-      lty               <- mkPArrayType vty
-
+      vect_tc     <- maybeV (lookupTyCon tycon)
+      (vty, lty)  <- vectAndLiftType ty
       repr        <- mkRepr vect_tc
       shape_bndrs <- arrShapeVars repr
       (len, sel, indices) <- arrSelector repr (map Var shape_bndrs)