Fix vectorisation of recursive types
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
index 677a7bf..4c786cf 100644 (file)
@@ -6,7 +6,6 @@ import Vectorise.Monad
 import Vectorise.Builtins
 import Vectorise.Type.Repr
 import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
 import Vectorise.Utils
 
 import BasicTypes
@@ -19,59 +18,64 @@ import TypeRep
 import Id
 import Var
 import Name
-import Outputable
-import Class
+import FastString
+-- import Outputable
 
-debug          = False
-dtrace s x     = if debug then pprTrace "Vectoris.Type.PADict" s x else x
+-- debug               = False
+-- dtrace s x  = if debug then pprTrace "Vectoris.Type.PADict" s x else x
 
 -- | Build the PA dictionary for some type and hoist it to top level.
 --   The PA dictionary holds fns that convert values to and from their vectorised representations.
 buildPADict
        :: TyCon        -- ^ tycon of the type being vectorised.
        -> TyCon        -- ^ tycon of the type used for the vectorised representation.
-       -> TyCon        -- 
+       -> TyCon        -- ^ PRepr instance tycon
        -> SumRepr      -- ^ representation used for the type being vectorised.
        -> VM Var       -- ^ name of the top-level dictionary function.
 
 buildPADict vect_tc prepr_tc arr_tc repr
- = dtrace (text "buildPADict" <+> ppr vect_tc <+> ppr prepr_tc <+> ppr arr_tc)
- $ polyAbstract tvs $ \args@[] ->
- do
-      -- TODO: I'm forcing args to [] because I'm not sure why we need them.
-      --       class PA has superclass (PR (PRepr a)) but we're not using
-      --       the superclass dictionary to build the PA dictionary.
+ = polyAbstract tvs $ \args ->
+   do
+      -- The superclass dictionary is an argument if the tycon is polymorphic
+      let mk_super_ty = do
+                          r <- mkPReprType inst_ty
+                          pr_cls <- builtin prClass
+                          return $ PredTy $ ClassP pr_cls [r]
+      super_tys <- sequence [mk_super_ty | not (null tvs)]
+      super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+      let args' = super_args ++ args
+
+      -- it is constant otherwise
+      super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+                                                | null tvs]
 
       -- Get ids for each of the methods in the dictionary.
-      method_ids <- mapM (method args) paMethods
+      method_ids <- mapM (method args') paMethods
 
       -- Expression to build the dictionary.
       pa_dc  <- builtin paDataCon
-      let dict = mkLams (tvs ++ args)
+      let dict = mkLams (tvs ++ args')
                $ mkConApp pa_dc
-               $ Type inst_ty : map (method_call args) method_ids
-
-      dtrace (text "dict    = " <+> ppr dict) $ return ()
+               $ Type inst_ty
+               : map Var super_args ++ super_consts
+                                   -- the superclass dictionary is
+                                   -- either lambda-bound or
+                                   -- constant
+                 ++ map (method_call args') method_ids
 
       -- Build the type of the dictionary function.
-      pa_tc          <- builtin paTyCon
-      let pa_opitems = [(id, NoDefMeth) | id <- method_ids]
-      let pa_cls     = mkClass 
-                       (tyConName pa_tc)
-                       tvs             -- tyvars of class
-                       []              -- fundeps
-                       []              -- superclass predicates
-                       []              -- superclass dict selectors
-                       []              -- associated type families
-                       pa_opitems      -- class op items
-                       pa_tc           -- dictionary type constructor
-                       
-      let dfun_ty = mkForAllTys tvs
-                  $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
+      pa_cls <- builtin paClass
+      let dfun_ty      = mkForAllTys tvs
+                       $ mkFunTys (map varType args')
+                                   (PredTy $ ClassP pa_cls [inst_ty])
 
       -- Set the unfolding for the inliner.
       raw_dfun <- newExportedVar dfun_name dfun_ty
-      let dfun = raw_dfun `setIdUnfolding`  mkDFunUnfolding dfun_ty (map Var method_ids)
+      let dfun_unf = mkDFunUnfolding dfun_ty
+                   $ map (const $ DFunLamArg 0) super_args
+                     ++ map DFunConstArg super_consts
+                     ++ map (DFunPolyArg . Var) method_ids
+          dfun = raw_dfun `setIdUnfolding`  dfun_unf
                           `setInlinePragma` dfunInlinePragma
 
       -- Add the new binding to the top-level environment.
@@ -101,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr
 
 
 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr",    buildPRDict),
-             ("toPRepr",      buildToPRepr),
+paMethods = [("toPRepr",      buildToPRepr),
              ("fromPRepr",    buildFromPRepr),
              ("toArrPRepr",   buildToArrPRepr),
              ("fromArrPRepr", buildFromArrPRepr)]