LLVM: Stop llvm saving stg caller-save regs across C calls
[ghc-hetmet.git] / compiler / llvmGen / LlvmCodeGen / CodeGen.hs
index 8259716..1313b68 100644 (file)
@@ -11,7 +11,7 @@ import LlvmCodeGen.Base
 import LlvmCodeGen.Regs
 
 import BlockId
-import CgUtils ( activeStgRegs )
+import CgUtils ( activeStgRegs, callerSaves )
 import CLabel
 import Cmm
 import qualified PprCmm
@@ -287,23 +287,44 @@ genCall env target res args ret = do
                 | ret == CmmNeverReturns = unitOL $ Unreachable
                 | otherwise              = nilOL
 
+    {- In LLVM we pass the STG registers around everywhere in function calls.
+       So this means LLVM considers them live across the entire function, when
+       in reality they usually aren't. For Caller save registers across C calls
+       the saving and restoring of them is done by the Cmm code generator,
+       using cmm local vars. So to stop LLVM saving them as well (and saving
+       all of them since it thinks they're always live, we trash them just
+       before the call by assigning the 'undef' value to them. The ones we
+       need are restored from the Cmm local var and the ones we don't need
+       are fine to be trashed.
+    -}
+    let trashStmts = concatOL $ map trashReg activeStgRegs
+            where trashReg r =
+                    let reg   = lmGlobalRegVar r
+                        ty    = (pLower . getVarType) reg
+                        trash = unitOL $ Store (LMLitVar $ LMUndefLit ty) reg
+                    in case callerSaves r of
+                              True  -> trash
+                              False -> nilOL
+
+    let stmts = stmts1 `appOL` stmts2 `appOL` trashStmts
+
     -- make the actual call
     case retTy of
         LMVoid -> do
             let s1 = Expr $ Call ccTy fptr argVars fnAttrs
-            let allStmts = stmts1 `appOL` stmts2 `snocOL` s1 `appOL` retStmt
+            let allStmts = stmts `snocOL` s1 `appOL` retStmt
             return (env2, allStmts, top1 ++ top2)
 
         _ -> do
+            (v1, s1) <- doExpr retTy $ Call ccTy fptr argVars fnAttrs
             let (creg, _) = ret_reg res
             let (env3, vreg, stmts3, top3) = getCmmReg env2 (CmmLocal creg)
-            let allStmts = stmts1 `appOL` stmts2 `appOL` stmts3
-            (v1, s1) <- doExpr retTy $ Call ccTy fptr argVars fnAttrs
+            let allStmts = stmts `snocOL` s1 `appOL` stmts3
             if retTy == pLower (getVarType vreg)
                 then do
                     let s2 = Store v1 vreg
-                    return (env3, allStmts `snocOL` s1 `snocOL` s2
-                            `appOL` retStmt, top1 ++ top2 ++ top3)
+                    return (env3, allStmts `snocOL` s2 `appOL` retStmt,
+                                top1 ++ top2 ++ top3)
                 else do
                     let ty = pLower $ getVarType vreg
                     let op = case ty of
@@ -315,8 +336,8 @@ genCall env target res args ret = do
 
                     (v2, s2) <- doExpr ty $ Cast op v1 ty
                     let s3 = Store v2 vreg
-                    return (env3, allStmts `snocOL` s1 `snocOL` s2 `snocOL` s3
-                            `appOL` retStmt, top1 ++ top2 ++ top3)
+                    return (env3, allStmts `snocOL` s2 `snocOL` s3
+                                `appOL` retStmt, top1 ++ top2 ++ top3)
 
 
 -- | Conversion of call arguments.