Fix vectorisation of recursive types
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
index 5feeb2a..4c786cf 100644 (file)
@@ -6,7 +6,6 @@ import Vectorise.Monad
 import Vectorise.Builtins
 import Vectorise.Type.Repr
 import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
 import Vectorise.Utils
 
 import BasicTypes
@@ -15,37 +14,68 @@ import CoreUtils
 import CoreUnfold
 import TyCon
 import Type
+import TypeRep
 import Id
 import Var
 import Name
+import FastString
+-- import Outputable
 
+-- debug               = False
+-- dtrace s x  = if debug then pprTrace "Vectoris.Type.PADict" s x else x
 
 -- | Build the PA dictionary for some type and hoist it to top level.
 --   The PA dictionary holds fns that convert values to and from their vectorised representations.
 buildPADict
        :: TyCon        -- ^ tycon of the type being vectorised.
        -> TyCon        -- ^ tycon of the type used for the vectorised representation.
-       -> TyCon        -- 
+       -> TyCon        -- ^ PRepr instance tycon
        -> SumRepr      -- ^ representation used for the type being vectorised.
        -> VM Var       -- ^ name of the top-level dictionary function.
 
 buildPADict vect_tc prepr_tc arr_tc repr
-  = polyAbstract tvs $ \args ->
-    do
-      method_ids <- mapM (method args) paMethods
+ = polyAbstract tvs $ \args ->
+   do
+      -- The superclass dictionary is an argument if the tycon is polymorphic
+      let mk_super_ty = do
+                          r <- mkPReprType inst_ty
+                          pr_cls <- builtin prClass
+                          return $ PredTy $ ClassP pr_cls [r]
+      super_tys <- sequence [mk_super_ty | not (null tvs)]
+      super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+      let args' = super_args ++ args
 
-      pa_tc  <- builtin paTyCon
+      -- it is constant otherwise
+      super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+                                                | null tvs]
+
+      -- Get ids for each of the methods in the dictionary.
+      method_ids <- mapM (method args') paMethods
+
+      -- Expression to build the dictionary.
       pa_dc  <- builtin paDataCon
-      let dict = mkLams (tvs ++ args)
+      let dict = mkLams (tvs ++ args')
                $ mkConApp pa_dc
-               $ Type inst_ty : map (method_call args) method_ids
+               $ Type inst_ty
+               : map Var super_args ++ super_consts
+                                   -- the superclass dictionary is
+                                   -- either lambda-bound or
+                                   -- constant
+                 ++ map (method_call args') method_ids
 
-          dfun_ty = mkForAllTys tvs
-                  $ mkFunTys (map varType args) (mkTyConApp pa_tc [inst_ty])
+      -- Build the type of the dictionary function.
+      pa_cls <- builtin paClass
+      let dfun_ty      = mkForAllTys tvs
+                       $ mkFunTys (map varType args')
+                                   (PredTy $ ClassP pa_cls [inst_ty])
 
       -- Set the unfolding for the inliner.
       raw_dfun <- newExportedVar dfun_name dfun_ty
-      let dfun = raw_dfun `setIdUnfolding`  mkDFunUnfolding dfun_ty (map Var method_ids)
+      let dfun_unf = mkDFunUnfolding dfun_ty
+                   $ map (const $ DFunLamArg 0) super_args
+                     ++ map DFunConstArg super_consts
+                     ++ map (DFunPolyArg . Var) method_ids
+          dfun = raw_dfun `setIdUnfolding`  dfun_unf
                           `setInlinePragma` dfunInlinePragma
 
       -- Add the new binding to the top-level environment.
@@ -75,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr
 
 
 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr",    buildPRDict),
-             ("toPRepr",      buildToPRepr),
+paMethods = [("toPRepr",      buildToPRepr),
              ("fromPRepr",    buildFromPRepr),
              ("toArrPRepr",   buildToArrPRepr),
              ("fromArrPRepr", buildFromArrPRepr)]