Fix vectorisation of recursive types
[ghc-hetmet.git] / compiler / vectorise / Vectorise / Type / PADict.hs
index ed6264a..4c786cf 100644 (file)
@@ -6,7 +6,6 @@ import Vectorise.Monad
 import Vectorise.Builtins
 import Vectorise.Type.Repr
 import Vectorise.Type.PRepr
-import Vectorise.Type.PRDict
 import Vectorise.Utils
 
 import BasicTypes
@@ -19,7 +18,8 @@ import TypeRep
 import Id
 import Var
 import Name
-import Outputable
+import FastString
+-- import Outputable
 
 -- debug               = False
 -- dtrace s x  = if debug then pprTrace "Vectoris.Type.PADict" s x else x
@@ -29,38 +29,52 @@ import Outputable
 buildPADict
        :: TyCon        -- ^ tycon of the type being vectorised.
        -> TyCon        -- ^ tycon of the type used for the vectorised representation.
-       -> TyCon        -- 
+       -> TyCon        -- ^ PRepr instance tycon
        -> SumRepr      -- ^ representation used for the type being vectorised.
        -> VM Var       -- ^ name of the top-level dictionary function.
 
 buildPADict vect_tc prepr_tc arr_tc repr
  = polyAbstract tvs $ \args ->
- case args of
-  (_:_) -> pprPanic "Vectorise.Type.PADict.buildPADict" (text "why do we need superclass dicts?")
-  [] -> do
-      -- TODO: I'm forcing args to [] because I'm not sure why we need them.
-      --       class PA has superclass (PR (PRepr a)) but we're not using
-      --       the superclass dictionary to build the PA dictionary.
+   do
+      -- The superclass dictionary is an argument if the tycon is polymorphic
+      let mk_super_ty = do
+                          r <- mkPReprType inst_ty
+                          pr_cls <- builtin prClass
+                          return $ PredTy $ ClassP pr_cls [r]
+      super_tys <- sequence [mk_super_ty | not (null tvs)]
+      super_args <- mapM (newLocalVar (fsLit "pr")) super_tys
+      let args' = super_args ++ args
+
+      -- it is constant otherwise
+      super_consts <- sequence [prDictOfPReprInstTyCon inst_ty prepr_tc []
+                                                | null tvs]
 
       -- Get ids for each of the methods in the dictionary.
-      method_ids <- mapM (method args) paMethods
+      method_ids <- mapM (method args') paMethods
 
       -- Expression to build the dictionary.
       pa_dc  <- builtin paDataCon
-      let dict = mkLams (tvs ++ args)
+      let dict = mkLams (tvs ++ args')
                $ mkConApp pa_dc
-               $ Type inst_ty : map (method_call args) method_ids
+               $ Type inst_ty
+               : map Var super_args ++ super_consts
+                                   -- the superclass dictionary is
+                                   -- either lambda-bound or
+                                   -- constant
+                 ++ map (method_call args') method_ids
 
       -- Build the type of the dictionary function.
-      pa_tc            <- builtin paTyCon
-      let Just pa_cls  = tyConClass_maybe pa_tc
-
+      pa_cls <- builtin paClass
       let dfun_ty      = mkForAllTys tvs
-                       $ mkFunTys (map varType args) (PredTy $ ClassP pa_cls [inst_ty])
+                       $ mkFunTys (map varType args')
+                                   (PredTy $ ClassP pa_cls [inst_ty])
 
       -- Set the unfolding for the inliner.
       raw_dfun <- newExportedVar dfun_name dfun_ty
-      let dfun_unf = mkDFunUnfolding dfun_ty (map (DFunPolyArg . Var) method_ids)
+      let dfun_unf = mkDFunUnfolding dfun_ty
+                   $ map (const $ DFunLamArg 0) super_args
+                     ++ map DFunConstArg super_consts
+                     ++ map (DFunPolyArg . Var) method_ids
           dfun = raw_dfun `setIdUnfolding`  dfun_unf
                           `setInlinePragma` dfunInlinePragma
 
@@ -91,8 +105,7 @@ buildPADict vect_tc prepr_tc arr_tc repr
 
 
 paMethods :: [(String, TyCon -> TyCon -> TyCon -> SumRepr -> VM CoreExpr)]
-paMethods = [("dictPRepr",    buildPRDict),
-             ("toPRepr",      buildToPRepr),
+paMethods = [("toPRepr",      buildToPRepr),
              ("fromPRepr",    buildFromPRepr),
              ("toArrPRepr",   buildToArrPRepr),
              ("fromArrPRepr", buildFromArrPRepr)]