Finished support for foreign calls in the CPS pass
authorMichael D. Adams <t-madams@microsoft.com>
Tue, 3 Jul 2007 09:13:20 +0000 (09:13 +0000)
committerMichael D. Adams <t-madams@microsoft.com>
Tue, 3 Jul 2007 09:13:20 +0000 (09:13 +0000)
compiler/cmm/CmmBrokenBlock.hs
compiler/cmm/CmmCPS.hs
compiler/cmm/CmmUtils.hs

index 14259c6..f3c9928 100644 (file)
@@ -14,6 +14,7 @@ module CmmBrokenBlock (
 #include "HsVersions.h"
 
 import Cmm
 #include "HsVersions.h"
 
 import Cmm
+import CmmUtils
 import CLabel
 import MachOp (MachHint(..))
 
 import CLabel
 import MachOp (MachHint(..))
 
@@ -26,6 +27,12 @@ import UniqSupply
 import Unique
 import UniqFM
 
 import Unique
 import UniqFM
 
+import MachRegs (callerSaveVolatileRegs)
+  -- HACK: this is part of the NCG so we shouldn't use this, but we need
+  -- it for now to eliminate the need for saved regs to be in CmmCall.
+  -- The long term solution is to factor callerSaveVolatileRegs
+  -- from nativeGen into codeGen
+
 -- This module takes a 'CmmBasicBlock' which might have 'CmmCall'
 -- statements in it with 'CmmSafe' set and breaks it up at each such call.
 -- It also collects information about the block for later use
 -- This module takes a 'CmmBasicBlock' which might have 'CmmCall'
 -- statements in it with 'CmmSafe' set and breaks it up at each such call.
 -- It also collects information about the block for later use
@@ -230,7 +237,7 @@ breakBlock gc_block_idents uniques (BasicBlock ident stmts) entry =
                                 target results arguments srt
 
             -- Break the block on safe calls (the main job of this function)
                                 target results arguments srt
 
             -- Break the block on safe calls (the main job of this function)
-            (CmmCall target results arguments (CmmSafe srt):stmts) ->
+            (CmmCall target results arguments (CmmSafe srt) : stmts) ->
                 (cont_info : cont_infos, block : blocks)
                 where
                   next_id = BlockId $ head uniques
                 (cont_info : cont_infos, block : blocks)
                 where
                   next_id = BlockId $ head uniques
@@ -242,9 +249,24 @@ breakBlock gc_block_idents uniques (BasicBlock ident stmts) entry =
                   (cont_infos, blocks) = breakBlock' (tail uniques) next_id
                                          ControlEntry [] [] stmts
 
                   (cont_infos, blocks) = breakBlock' (tail uniques) next_id
                                          ControlEntry [] [] stmts
 
+            -- Unsafe calls don't need a continuation
+            -- but they do need to be expanded
+            (CmmCall target results arguments CmmUnsafe : stmts) ->
+                breakBlock' remaining_uniques current_id entry exits
+                            (accum_stmts ++
+                             arg_stmts ++
+                             caller_save ++
+                             [CmmCall target results new_args CmmUnsafe] ++
+                             caller_load)
+                            stmts
+                where
+                  (remaining_uniques, arg_stmts, new_args) =
+                      loadArgsIntoTemps uniques arguments
+                  (caller_save, caller_load) = callerSaveVolatileRegs (Just [])
+
             -- Default case.  Just keep accumulating statements
             -- and branch targets.
             -- Default case.  Just keep accumulating statements
             -- and branch targets.
-            (s:stmts) ->
+            (s : stmts) ->
                 breakBlock' uniques current_id entry
                             (cond_branch_target s++exits)
                             (accum_stmts++[s])
                 breakBlock' uniques current_id entry
                             (cond_branch_target s++exits)
                             (accum_stmts++[s])
index e6d70d4..4394fb5 100644 (file)
@@ -17,6 +17,8 @@ import CmmCallConv
 import CmmInfo
 import CmmUtils
 
 import CmmInfo
 import CmmUtils
 
+import CgProf (curCCS, curCCSAddr)
+import CgUtils (cmmOffsetW)
 import Bitmap
 import ClosureInfo
 import MachOp
 import Bitmap
 import ClosureInfo
 import MachOp
@@ -25,6 +27,7 @@ import CLabel
 import SMRep
 import Constants
 
 import SMRep
 import Constants
 
+import StaticFlags
 import DynFlags
 import ErrUtils
 import Maybes
 import DynFlags
 import ErrUtils
 import Maybes
@@ -38,6 +41,12 @@ import Monad
 import IO
 import Data.List
 
 import IO
 import Data.List
 
+import MachRegs (callerSaveVolatileRegs)
+  -- HACK: this is part of the NCG so we shouldn't use this, but we need
+  -- it for now to eliminate the need for saved regs to be in CmmCall.
+  -- The long term solution is to factor callerSaveVolatileRegs
+  -- from nativeGen into CPS
+
 -----------------------------------------------------------------------------
 -- |Top level driver for the CPS pass
 -----------------------------------------------------------------------------
 -----------------------------------------------------------------------------
 -- |Top level driver for the CPS pass
 -----------------------------------------------------------------------------
@@ -120,7 +129,7 @@ cpsProc uniqSupply (CmmProc info ident params blocks) = info_procs
       uniques :: [[Unique]]
       uniques = map uniqsFromSupply $ listSplitUniqSupply uniqSupply1
       (gc_unique:stack_use_unique:info_uniques):adaptor_uniques:block_uniques = uniques
       uniques :: [[Unique]]
       uniques = map uniqsFromSupply $ listSplitUniqSupply uniqSupply1
       (gc_unique:stack_use_unique:info_uniques):adaptor_uniques:block_uniques = uniques
-      proc_uniques = map uniqsFromSupply $ listSplitUniqSupply uniqSupply2
+      proc_uniques = map (map uniqsFromSupply . listSplitUniqSupply) $ listSplitUniqSupply uniqSupply2
 
       stack_use = CmmLocal (LocalReg stack_use_unique (cmmRegRep spReg) KindPtr)
 
 
       stack_use = CmmLocal (LocalReg stack_use_unique (cmmRegRep spReg) KindPtr)
 
@@ -334,10 +343,13 @@ selectContinuationFormat :: BlockEnv CmmLive
 selectContinuationFormat live continuations =
     map (\c -> (continuationLabel c, selectContinuationFormat' c)) continuations
     where
 selectContinuationFormat live continuations =
     map (\c -> (continuationLabel c, selectContinuationFormat' c)) continuations
     where
+      -- User written continuations
       selectContinuationFormat' (Continuation
                           (Right (CmmInfo _ _ _ (ContInfo format srt)))
                           label formals _ _) =
           (formals, Just label, format)
       selectContinuationFormat' (Continuation
                           (Right (CmmInfo _ _ _ (ContInfo format srt)))
                           label formals _ _) =
           (formals, Just label, format)
+      -- Either user written non-continuation code
+      -- or CPS generated proc-points
       selectContinuationFormat' (Continuation (Right _) _ formals _ _) =
           (formals, Nothing, [])
       -- CPS generated continuations
       selectContinuationFormat' (Continuation (Right _) _ formals _ _) =
           (formals, Nothing, [])
       -- CPS generated continuations
@@ -435,7 +447,7 @@ applyContinuationFormat formats (Continuation
       format = continuation_stack $ maybe unknown_block id $ lookup label formats
       unknown_block = panic "unknown BlockId in applyContinuationFormat"
 
       format = continuation_stack $ maybe unknown_block id $ lookup label formats
       unknown_block = panic "unknown BlockId in applyContinuationFormat"
 
--- User written non-continuation code
+-- Either user written non-continuation code or CPS generated proc-point
 applyContinuationFormat formats (Continuation
                           (Right info) label formals is_gc blocks) =
     Continuation info label formals is_gc blocks
 applyContinuationFormat formats (Continuation
                           (Right info) label formals is_gc blocks) =
     Continuation info label formals is_gc blocks
@@ -457,7 +469,7 @@ applyContinuationFormat formats (Continuation
 -----------------------------------------------------------------------------
 continuationToProc :: (WordOff, [(CLabel, ContinuationFormat)])
                    -> CmmReg
 -----------------------------------------------------------------------------
 continuationToProc :: (WordOff, [(CLabel, ContinuationFormat)])
                    -> CmmReg
-                   -> [Unique]
+                   -> [[Unique]]
                    -> Continuation CmmInfo
                    -> CmmTop
 continuationToProc (max_stack, formats) stack_use uniques
                    -> Continuation CmmInfo
                    -> CmmTop
 continuationToProc (max_stack, formats) stack_use uniques
@@ -484,15 +496,49 @@ continuationToProc (max_stack, formats) stack_use uniques
             CmmNonInfo Nothing ->
                 panic "continuationToProc: missing non-info GC block"
 
             CmmNonInfo Nothing ->
                 panic "continuationToProc: missing non-info GC block"
 
-      continuationToProc' :: Unique -> BrokenBlock -> Bool -> [CmmBasicBlock]
-      continuationToProc' unique (BrokenBlock ident entry stmts _ exit) is_entry =
-          case gc_prefix ++ param_prefix of
-            [] -> [main_block]
-            stmts -> [BasicBlock prefix_id (gc_prefix ++ param_prefix ++ [CmmBranch ident]),
-                      main_block]
+-- At present neither the Cmm parser nor the code generator
+-- produce code that will allow the target of a CmmCondBranch
+-- or a CmmSwitch to become a continuation or a proc-point.
+-- If future revisions, might allow these to happen
+-- then special care will have to be take to allow for that case.
+      continuationToProc' :: [Unique]
+                          -> BrokenBlock
+                          -> Bool
+                          -> [CmmBasicBlock]
+      continuationToProc' uniques (BrokenBlock ident entry stmts _ exit) is_entry =
+          prefix_blocks ++ [main_block]
           where
           where
-            main_block = BasicBlock ident (stmts ++ postfix)
-            prefix_id = BlockId unique
+            prefix_blocks =
+                case gc_prefix ++ param_prefix of
+                  [] -> []
+                  entry_stmts -> [BasicBlock prefix_id
+                                  (entry_stmts ++ [CmmBranch ident])]
+
+            prefix_unique : call_uniques = uniques
+            toCLabel = mkReturnPtLabel . getUnique
+
+            block_for_branch unique next
+                | (Just cont_format) <- lookup (toCLabel next) formats
+                = let
+                    new_next = BlockId unique
+                    cont_stack = continuation_frame_size cont_format
+                    arguments = map formal_to_actual (continuation_formals cont_format)
+                  in (new_next,
+                     [BasicBlock new_next $
+                      pack_continuation False curr_format cont_format ++
+                      tail_call (curr_stack - cont_stack)
+                              (CmmLit $ CmmLabel $ toCLabel next)
+                              arguments])
+                | otherwise
+                = (next, [])
+
+            block_for_branch' :: Unique -> Maybe BlockId -> (Maybe BlockId, [CmmBasicBlock])
+            block_for_branch' _ Nothing = (Nothing, [])
+            block_for_branch' unique (Just next) = (Just new_next, new_blocks)
+              where (new_next, new_blocks) = block_for_branch unique next
+
+            main_block = BasicBlock ident (stmts ++ postfix_stmts)
+            prefix_id = BlockId prefix_unique
             gc_prefix = case entry of
                        FunctionEntry _ _ _ -> gc_stmts
                        ControlEntry -> []
             gc_prefix = case entry of
                        FunctionEntry _ _ _ -> gc_stmts
                        ControlEntry -> []
@@ -500,7 +546,7 @@ continuationToProc (max_stack, formats) stack_use uniques
             param_prefix = if is_entry
                            then param_stmts
                            else []
             param_prefix = if is_entry
                            then param_stmts
                            else []
-            postfix = case exit of
+            postfix_stmts = case exit of
                         FinalBranch next ->
                             if (mkReturnPtLabel $ getUnique next) == label
                             then [CmmBranch next]
                         FinalBranch next ->
                             if (mkReturnPtLabel $ getUnique next) == label
                             then [CmmBranch next]
@@ -514,7 +560,6 @@ continuationToProc (max_stack, formats) stack_use uniques
                                 where
                                   cont_stack = continuation_frame_size cont_format
                                   arguments = map formal_to_actual (continuation_formals cont_format)
                                 where
                                   cont_stack = continuation_frame_size cont_format
                                   arguments = map formal_to_actual (continuation_formals cont_format)
-                                  formal_to_actual reg = (CmmReg (CmmLocal reg), NoHint)
                         FinalSwitch expr targets -> [CmmSwitch expr targets]
                         FinalReturn arguments ->
                             tail_call curr_stack
                         FinalSwitch expr targets -> [CmmSwitch expr targets]
                         FinalReturn arguments ->
                             tail_call curr_stack
@@ -522,6 +567,8 @@ continuationToProc (max_stack, formats) stack_use uniques
                                 arguments
                         FinalJump target arguments ->
                             tail_call curr_stack target arguments
                                 arguments
                         FinalJump target arguments ->
                             tail_call curr_stack target arguments
+
+                        -- A regular Cmm function call
                         FinalCall next (CmmForeignCall target CmmCallConv)
                             results arguments _ _ ->
                                 pack_continuation True curr_format cont_format ++
                         FinalCall next (CmmForeignCall target CmmCallConv)
                             results arguments _ _ ->
                                 pack_continuation True curr_format cont_format ++
@@ -531,7 +578,145 @@ continuationToProc (max_stack, formats) stack_use uniques
                               cont_format = maybe unknown_block id $
                                             lookup (mkReturnPtLabel $ getUnique next) formats
                               cont_stack = continuation_frame_size cont_format
                               cont_format = maybe unknown_block id $
                                             lookup (mkReturnPtLabel $ getUnique next) formats
                               cont_stack = continuation_frame_size cont_format
-                        FinalCall next _ results arguments _ _ -> panic "unimplemented CmmCall"
+
+                        -- A safe foreign call
+                        FinalCall next (CmmForeignCall target conv)
+                            results arguments _ _ ->
+                                target_stmts ++
+                                foreignCall call_uniques' (CmmForeignCall new_target conv)
+                                            results arguments
+                            where
+                              (call_uniques', target_stmts, new_target) =
+                                  maybeAssignTemp call_uniques target
+
+                        -- A safe prim call
+                        FinalCall next (CmmPrim target)
+                            results arguments _ _ ->
+                                foreignCall call_uniques (CmmPrim target)
+                                            results arguments
+
+formal_to_actual reg = (CmmReg (CmmLocal reg), NoHint)
+
+foreignCall :: [Unique] -> CmmCallTarget -> CmmHintFormals -> CmmActuals -> [CmmStmt]
+foreignCall uniques call results arguments =
+    arg_stmts ++
+    saveThreadState ++
+    caller_save ++
+    [CmmCall (CmmForeignCall suspendThread CCallConv)
+                [ (id,PtrHint) ]
+                [ (CmmReg (CmmGlobal BaseReg), PtrHint) ]
+                CmmUnsafe,
+     CmmCall call results new_args CmmUnsafe,
+     CmmCall (CmmForeignCall resumeThread CCallConv)
+                 [ (new_base, PtrHint) ]
+                [ (CmmReg (CmmLocal id), PtrHint) ]
+                CmmUnsafe,
+     -- Assign the result to BaseReg: we
+     -- might now have a different Capability!
+     CmmAssign (CmmGlobal BaseReg) (CmmReg (CmmLocal new_base))] ++
+    caller_load ++
+    loadThreadState tso_unique ++
+    [CmmJump (CmmReg spReg) (map (formal_to_actual . fst) results)]
+    where
+      (_, arg_stmts, new_args) =
+          loadArgsIntoTemps argument_uniques arguments
+      (caller_save, caller_load) =
+          callerSaveVolatileRegs (Just [{-only system regs-}])
+      new_base = LocalReg base_unique (cmmRegRep (CmmGlobal BaseReg)) KindNonPtr
+      id = LocalReg id_unique wordRep KindNonPtr
+      tso_unique : base_unique : id_unique : argument_uniques = uniques
+
+-- -----------------------------------------------------------------------------
+-- Save/restore the thread state in the TSO
+
+suspendThread = CmmLit (CmmLabel (mkRtsCodeLabel SLIT("suspendThread")))
+resumeThread  = CmmLit (CmmLabel (mkRtsCodeLabel SLIT("resumeThread")))
+
+-- This stuff can't be done in suspendThread/resumeThread, because it
+-- refers to global registers which aren't available in the C world.
+
+saveThreadState =
+  -- CurrentTSO->sp = Sp;
+  [CmmStore (cmmOffset stgCurrentTSO tso_SP) stgSp,
+  closeNursery] ++
+  -- and save the current cost centre stack in the TSO when profiling:
+  if opt_SccProfilingOn
+  then [CmmStore (cmmOffset stgCurrentTSO tso_CCCS) curCCS]
+  else []
+
+   -- CurrentNursery->free = Hp+1;
+closeNursery = CmmStore nursery_bdescr_free (cmmOffsetW stgHp 1)
+
+loadThreadState tso_unique =
+  [
+       -- tso = CurrentTSO;
+       CmmAssign (CmmLocal tso) stgCurrentTSO,
+       -- Sp = tso->sp;
+       CmmAssign sp (CmmLoad (cmmOffset (CmmReg (CmmLocal tso)) tso_SP)
+                             wordRep),
+       -- SpLim = tso->stack + RESERVED_STACK_WORDS;
+       CmmAssign spLim (cmmOffsetW (cmmOffset (CmmReg (CmmLocal tso)) tso_STACK)
+                                   rESERVED_STACK_WORDS)
+  ] ++
+  openNursery ++
+  -- and load the current cost centre stack from the TSO when profiling:
+  if opt_SccProfilingOn 
+  then [CmmStore curCCSAddr 
+       (CmmLoad (cmmOffset (CmmReg (CmmLocal tso)) tso_CCCS) wordRep)]
+  else []
+  where tso = LocalReg tso_unique wordRep KindNonPtr -- TODO FIXME NOW
+
+
+openNursery = [
+        -- Hp = CurrentNursery->free - 1;
+       CmmAssign hp (cmmOffsetW (CmmLoad nursery_bdescr_free wordRep) (-1)),
+
+        -- HpLim = CurrentNursery->start + 
+       --              CurrentNursery->blocks*BLOCK_SIZE_W - 1;
+       CmmAssign hpLim
+           (cmmOffsetExpr
+               (CmmLoad nursery_bdescr_start wordRep)
+               (cmmOffset
+                 (CmmMachOp mo_wordMul [
+                   CmmMachOp (MO_S_Conv I32 wordRep)
+                     [CmmLoad nursery_bdescr_blocks I32],
+                   CmmLit (mkIntCLit bLOCK_SIZE)
+                  ])
+                 (-1)
+               )
+           )
+   ]
+
+
+nursery_bdescr_free   = cmmOffset stgCurrentNursery oFFSET_bdescr_free
+nursery_bdescr_start  = cmmOffset stgCurrentNursery oFFSET_bdescr_start
+nursery_bdescr_blocks = cmmOffset stgCurrentNursery oFFSET_bdescr_blocks
+
+tso_SP    = tsoFieldB     oFFSET_StgTSO_sp
+tso_STACK = tsoFieldB     oFFSET_StgTSO_stack
+tso_CCCS  = tsoProfFieldB oFFSET_StgTSO_CCCS
+
+-- The TSO struct has a variable header, and an optional StgTSOProfInfo in
+-- the middle.  The fields we're interested in are after the StgTSOProfInfo.
+tsoFieldB :: ByteOff -> ByteOff
+tsoFieldB off
+  | opt_SccProfilingOn = off + sIZEOF_StgTSOProfInfo + fixedHdrSize * wORD_SIZE
+  | otherwise          = off + fixedHdrSize * wORD_SIZE
+
+tsoProfFieldB :: ByteOff -> ByteOff
+tsoProfFieldB off = off + fixedHdrSize * wORD_SIZE
+
+stgSp            = CmmReg sp
+stgHp            = CmmReg hp
+stgCurrentTSO    = CmmReg currentTSO
+stgCurrentNursery = CmmReg currentNursery
+
+sp               = CmmGlobal Sp
+spLim            = CmmGlobal SpLim
+hp               = CmmGlobal Hp
+hpLim            = CmmGlobal HpLim
+currentTSO       = CmmGlobal CurrentTSO
+currentNursery           = CmmGlobal CurrentNursery
 
 -----------------------------------------------------------------------------
 -- Functions that generate CmmStmt sequences
 
 -----------------------------------------------------------------------------
 -- Functions that generate CmmStmt sequences
@@ -573,7 +758,16 @@ gc_stack_check gc_block max_frame_size
 
 -- TODO: fix branches to proc point
 -- (we have to insert a new block to marshel the continuation)
 
 -- TODO: fix branches to proc point
 -- (we have to insert a new block to marshel the continuation)
-pack_continuation :: Bool -> ContinuationFormat -> ContinuationFormat -> [CmmStmt]
+
+
+pack_continuation :: Bool               -- ^ Whether to set the top/header
+                                        -- of the stack.  We only need to
+                                        -- set it if we are calling down
+                                        -- as opposed to continuation
+                                        -- adaptors.
+                  -> ContinuationFormat -- ^ The current format
+                  -> ContinuationFormat -- ^ The return point format
+                  -> [CmmStmt]
 pack_continuation allow_header_set
                       (ContinuationFormat _ curr_id curr_frame_size _)
                       (ContinuationFormat _ cont_id cont_frame_size live_regs)
 pack_continuation allow_header_set
                       (ContinuationFormat _ curr_id curr_frame_size _)
                       (ContinuationFormat _ cont_id cont_frame_size live_regs)
index 0c5ab0f..a2a2711 100644 (file)
@@ -18,6 +18,8 @@ module CmmUtils(
        mkIntCLit, zeroCLit,
 
        mkLblExpr,
        mkIntCLit, zeroCLit,
 
        mkLblExpr,
+
+        loadArgsIntoTemps, maybeAssignTemp,
   ) where
 
 #include "HsVersions.h"
   ) where
 
 #include "HsVersions.h"
@@ -27,6 +29,7 @@ import Cmm
 import MachOp
 import OrdList
 import Outputable
 import MachOp
 import OrdList
 import Outputable
+import Unique
 
 ---------------------------------------------------
 --
 
 ---------------------------------------------------
 --
@@ -175,3 +178,28 @@ zeroCLit = CmmInt 0 wordRep
 
 mkLblExpr :: CLabel -> CmmExpr
 mkLblExpr lbl = CmmLit (CmmLabel lbl)
 
 mkLblExpr :: CLabel -> CmmExpr
 mkLblExpr lbl = CmmLit (CmmLabel lbl)
+
+---------------------------------------------------
+--
+--     Helpers for foreign call arguments
+--
+---------------------------------------------------
+
+loadArgsIntoTemps :: [Unique]
+                  -> CmmActuals
+                  -> ([Unique], [CmmStmt], CmmActuals)
+loadArgsIntoTemps uniques [] = (uniques, [], [])
+loadArgsIntoTemps uniques ((e, hint):args) =
+    (uniques'',
+     new_stmts ++ remaining_stmts,
+     (new_e, hint) : remaining_e)
+    where
+      (uniques', new_stmts, new_e) = maybeAssignTemp uniques e
+      (uniques'', remaining_stmts, remaining_e) =
+          loadArgsIntoTemps uniques' args
+
+maybeAssignTemp :: [Unique] -> CmmExpr -> ([Unique], [CmmStmt], CmmExpr)
+maybeAssignTemp uniques e
+    | hasNoGlobalRegs e = (uniques, [], e)
+    | otherwise         = (tail uniques, [CmmAssign local e], CmmReg local)
+    where local = CmmLocal (LocalReg (head uniques) (cmmExprRep e) KindNonPtr)