Fix vectorisation of recursive types
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
index 7fdc31a..4c786cf 100644 (file)
@@ -2,13 +2,11 @@
 module Vectorise.Type.PADict
        (buildPADict)
 where
-import VectUtils
 import Vectorise.Monad
 import Vectorise.Builtins
 import Vectorise.Type.Repr
 import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
-import Vectorise.Utils.Hoisting
+import Vectorise.Utils
 
 import BasicTypes
 import CoreSyn
@@ -16,61 +14,98 @@ import CoreUtils
 import CoreUnfold
 import TyCon
 import Type
-import OccName
+import TypeRep
 import Id
 import Var
 import Name
+import FastString
+-- import Outputable
 
+-- debug               = False
+-- dtrace s x  = if debug then pprTrace "Vectoris.Type.PADict" s x else x
 
+-- | Build the PA dictionary for some type and hoist it to top level.
+--   The PA dictionary holds fns that convert values to and from their vectorised representations.
+buildPADict
+       :: TyCon        -- ^ tycon of the type being vectorised.
+       -> TyCon        -- ^ tycon of the type used for the vectorised representation.
+       -> TyCon        -- ^ PRepr instance tycon
+       -> SumRepr      -- ^ representation used for the type being vectorised.
+       -> VM Var       -- ^ name of the top-level dictionary function.
 
-buildPADict :: TyCon -> TyCon -> TyCon -> SumRepr -> VM Var
 buildPADict vect_tc prepr_tc arr_tc repr
-  = polyAbstract tvs $ \args ->
-    do
-      method_ids <- mapM (method args) paMethods
-
-      pa_tc  <- builtin paTyCon
+ = polyAbstract tvs $ \args ->
+   do
+      -- The superclass dictionary is an argument if the tycon is polymorphic
+      let mk_super_ty = do
+                          r <- mkPReprType inst_ty
+                          pr_cls <- builtin prClass
+                          return $ PredTy $ ClassP pr_cls [r]
+      super_tys <- sequence [mk_super_ty | not (null tvs)]
+      super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+      let args' = super_args ++ args
+
+      -- it is constant otherwise
+      super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+                                                | null tvs]
+
+      -- Get ids for each of the methods in the dictionary.
+      method_ids <- mapM (method args') paMethods
+
+      -- Expression to build the dictionary.
       pa_dc  <- builtin paDataCon
-      let dict = mkLams (tvs ++ args)
+      let dict = mkLams (tvs ++ args')
                $ mkConApp pa_dc
-               $ Type inst_ty : map (method_call args) method_ids
-
-          dfun_ty = mkForAllTys tvs
-                  $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
-
+               $ Type inst_ty
+               : map Var super_args ++ super_consts
+                                   -- the superclass dictionary is
+                                   -- either lambda-bound or
+                                   -- constant
+                 ++ map (method_call args') method_ids
+
+      -- Build the type of the dictionary function.
+      pa_cls <- builtin paClass
+      let dfun_ty      = mkForAllTys tvs
+                       $ mkFunTys (map varType args')
+                                   (PredTy $ ClassP pa_cls [inst_ty])
+
+      -- Set the unfolding for the inliner.
       raw_dfun <- newExportedVar dfun_name dfun_ty
-      let dfun = raw_dfun `setIdUnfolding` mkDFunUnfolding dfun_ty (map Var method_ids)
+      let dfun_unf = mkDFunUnfolding dfun_ty
+                   $ map (const $ DFunLamArg 0) super_args
+                     ++ map DFunConstArg super_consts
+                     ++ map (DFunPolyArg . Var) method_ids
+          dfun = raw_dfun `setIdUnfolding`  dfun_unf
                           `setInlinePragma` dfunInlinePragma
 
+      -- Add the new binding to the top-level environment.
       hoistBinding dfun dict
       return dfun
   where
-    tvs = tyConTyVars vect_tc
-    arg_tys = mkTyVarTys tvs
-    inst_ty = mkTyConApp vect_tc arg_tys
+    tvs       = tyConTyVars vect_tc
+    arg_tys   = mkTyVarTys tvs
+    inst_ty   = mkTyConApp vect_tc arg_tys
 
     dfun_name = mkPADFunOcc (getOccName vect_tc)
 
     method args (name, build)
       = localV
       $ do
-          expr <- build vect_tc prepr_tc arr_tc repr
+          expr     <- build vect_tc prepr_tc arr_tc repr
           let body = mkLams (tvs ++ args) expr
-          raw_var <- newExportedVar (method_name name) (exprType body)
-          let var = raw_var
-                      `setIdUnfolding` mkInlineRule body (Just (length args))
+          raw_var  <- newExportedVar (method_name name) (exprType body)
+          let var  = raw_var
+                      `setIdUnfolding` mkInlineUnfolding (Just (length args)) body
                       `setInlinePragma` alwaysInlinePragma
           hoistBinding var body
           return var
 
     method_call args id = mkApps (Var id) (map Type arg_tys ++ map Var args)
-
-    method_name name = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
+    method_name name    = mkVarOcc $ occNameString dfun_name ++ ('$' : name)
 
 
 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr",    buildPRDict),
-             ("toPRepr",      buildToPRepr),
+paMethods = [("toPRepr",      buildToPRepr),
              ("fromPRepr",    buildFromPRepr),
              ("toArrPRepr",   buildToArrPRepr),
              ("fromArrPRepr", buildFromArrPRepr)]