commit abdc5295496088db413cccd84db7cf35fb64d07c
parent 11d6acaa0021242ebee68efaa5b7204d086f9b7e
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 10 Feb 2026 13:39:06 +0400
feat: implement inter-procedural analysis (IMPL4)
Adds opt-in inter-procedural analysis via --interproc/-p flag. Computes
function summaries by fixpoint iteration over the call graph, then
applies summaries at call sites.
Changes:
- CFG.hs: add function partitioning (isFunctionLabel, functionBlocks,
functionLabels) and call graph construction (callTargets, buildCallGraph)
- Taint.hs: add FuncSummary type, applySummary, runFunctionDataflow,
runInterProc for fixpoint computation
- Check.hs: add checkCFGInterProc using function summaries
- Audit.AArch64: export auditInterProc, auditFileInterProc
- app/Main.hs: add -p/--interproc flag
- test/Main.hs: add 3 inter-procedural tests
The inter-proc mode suppresses false positives when callee sets a
register to a known-public value that the caller then uses.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
7 files changed, 364 insertions(+), 9 deletions(-)
diff --git a/README.md b/README.md
@@ -36,7 +36,8 @@ Memory accesses: 37
Violations: 2
```
-Use `-q` for quiet mode (violations only) or `-j` for JSON output.
+Use `-q` for quiet mode (violations only), `-j` for JSON output, or
+`-p` for inter-procedural analysis (computes function summaries).
## Limitations
diff --git a/app/Main.hs b/app/Main.hs
@@ -2,7 +2,8 @@
module Main where
-import Audit.AArch64
+import Audit.AArch64 (AuditResult(..), Violation(..), ViolationReason(..),
+ auditFile, auditFileInterProc, regName)
import Data.Aeson (encode)
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Text (Text)
@@ -12,9 +13,10 @@ import Options.Applicative
import System.Exit (exitFailure, exitSuccess)
data Options = Options
- { optInput :: !FilePath
- , optJson :: !Bool
- , optQuiet :: !Bool
+ { optInput :: !FilePath
+ , optJson :: !Bool
+ , optQuiet :: !Bool
+ , optInterProc :: !Bool
} deriving (Eq, Show)
optParser :: Parser Options
@@ -35,6 +37,11 @@ optParser = Options
<> short 'q'
<> help "Suppress summary, only show violations"
)
+ <*> switch
+ ( long "interproc"
+ <> short 'p'
+ <> help "Enable inter-procedural analysis"
+ )
optInfo :: ParserInfo Options
optInfo = info (optParser <**> helper)
@@ -46,7 +53,8 @@ optInfo = info (optParser <**> helper)
main :: IO ()
main = do
opts <- execParser optInfo
- result <- auditFile (optInput opts)
+ let auditor = if optInterProc opts then auditFileInterProc else auditFile
+ result <- auditor (optInput opts)
case result of
Left err -> do
TIO.putStrLn $ "Error: " <> err
diff --git a/lib/Audit/AArch64.hs b/lib/Audit/AArch64.hs
@@ -30,7 +30,9 @@
module Audit.AArch64 (
-- * Main API
audit
+ , auditInterProc
, auditFile
+ , auditFileInterProc
-- * Results
, AuditResult(..)
@@ -59,13 +61,29 @@ audit name src = do
let cfg = buildCFG lns
pure (checkCFG name cfg)
+-- | Audit with inter-procedural analysis.
+auditInterProc :: Text -> Text -> Either ParseError AuditResult
+auditInterProc name src = do
+ lns <- parseAsm src
+ let cfg = buildCFG lns
+ pure (checkCFGInterProc name cfg)
+
-- | Audit an assembly file.
auditFile :: FilePath -> IO (Either Text AuditResult)
-auditFile path = do
+auditFile = auditFileWith audit
+
+-- | Audit an assembly file with inter-procedural analysis.
+auditFileInterProc :: FilePath -> IO (Either Text AuditResult)
+auditFileInterProc = auditFileWith auditInterProc
+
+-- | Helper for file auditing.
+auditFileWith :: (Text -> Text -> Either ParseError AuditResult)
+ -> FilePath -> IO (Either Text AuditResult)
+auditFileWith auditor path = do
bs <- BS.readFile path
case decodeUtf8' bs of
Left err -> pure (Left (T.pack (show err)))
Right src ->
- case audit (T.pack path) src of
+ case auditor (T.pack path) src of
Left err -> pure (Left (T.pack (show err)))
Right result -> pure (Right result)
diff --git a/lib/Audit/AArch64/CFG.hs b/lib/Audit/AArch64/CFG.hs
@@ -15,6 +15,12 @@ module Audit.AArch64.CFG (
, buildCFG
, blockLabels
, blockSuccessors
+ -- * Function partitioning
+ , isFunctionLabel
+ , functionBlocks
+ , functionLabels
+ , callTargets
+ , buildCallGraph
) where
import Audit.AArch64.Types
@@ -23,6 +29,7 @@ import qualified Data.Map.Strict as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Text (Text)
+import qualified Data.Text as T
-- | A basic block: a sequence of instructions with a single entry point.
data BasicBlock = BasicBlock
@@ -157,3 +164,60 @@ hasFallthrough bb = case lastInstrOf bb of
lastInstrOf b = case bbLines b of
[] -> Nothing
ls -> lineInstr (last ls)
+
+-- | Check if a label is a function entry (not a local LLVM label).
+-- Function labels start with _ or don't have LBB/Lloh prefixes.
+isFunctionLabel :: Text -> Bool
+isFunctionLabel lbl
+ | T.isPrefixOf "LBB" lbl = False -- LLVM basic block
+ | T.isPrefixOf "Lloh" lbl = False -- LLVM linker hint
+ | T.isPrefixOf "LCPI" lbl = False -- LLVM constant pool
+ | T.isPrefixOf "ltmp" lbl = False -- LLVM temporary
+ | T.isPrefixOf "l_" lbl = False -- Local label
+ | otherwise = True -- Likely a function
+
+-- | Get all function entry labels in the CFG.
+functionLabels :: CFG -> [Text]
+functionLabels cfg =
+ [ lbl | bb <- cfgBlocks cfg
+ , Just lbl <- [bbLabel bb]
+ , isFunctionLabel lbl ]
+
+-- | Get block indices belonging to a function.
+-- A function spans from its entry label to the next function label.
+functionBlocks :: CFG -> Text -> [Int]
+functionBlocks cfg funcLabel =
+ case Map.lookup funcLabel (cfgLabelMap cfg) of
+ Nothing -> []
+ Just startIdx ->
+ let blocks = cfgBlocks cfg
+ n = length blocks
+ isNextFunc i = case bbLabel (blocks !! i) of
+ Nothing -> False
+ Just lbl -> lbl /= funcLabel && isFunctionLabel lbl
+ endIdx = findEnd (startIdx + 1)
+ findEnd i
+ | i >= n = n
+ | isNextFunc i = i
+ | otherwise = findEnd (i + 1)
+ in [startIdx .. endIdx - 1]
+
+-- | Extract call targets from a block's instructions.
+callTargets :: BasicBlock -> [Text]
+callTargets bb =
+ [ target | l <- bbLines bb
+ , Just instr <- [lineInstr l]
+ , target <- getCallTarget instr ]
+ where
+ getCallTarget (Bl target) = [target]
+ getCallTarget _ = []
+
+-- | Build call graph: maps each function to its callees.
+buildCallGraph :: CFG -> Map Text [Text]
+buildCallGraph cfg = Map.fromList
+ [ (func, callees)
+ | func <- functionLabels cfg
+ , let blocks = cfgBlocks cfg
+ indices = functionBlocks cfg func
+ callees = concatMap (callTargets . (blocks !!)) indices
+ ]
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -16,6 +16,7 @@ module Audit.AArch64.Check (
checkLine
, checkBlock
, checkCFG
+ , checkCFGInterProc
, AuditResult(..)
) where
@@ -23,6 +24,7 @@ import Audit.AArch64.CFG
import Audit.AArch64.Taint
import Audit.AArch64.Types
import qualified Data.IntMap.Strict as IM
+import qualified Data.Map.Strict as Map
import Data.Text (Text)
-- | Result of auditing assembly.
@@ -140,3 +142,46 @@ checkIndex sym ln instr idx st =
Public -> []
Secret -> [Violation sym ln instr (SecretIndex idx)]
Unknown -> [Violation sym ln instr (UnknownIndex idx)]
+
+-- | Check entire CFG with inter-procedural analysis.
+-- Computes function summaries via fixpoint, then checks each function.
+checkCFGInterProc :: Text -> CFG -> AuditResult
+checkCFGInterProc sym cfg =
+ let summaries = runInterProc cfg
+ blocks = cfgBlocks cfg
+ funcs = functionLabels cfg
+ in mconcat
+ [ checkFunction cfg sym func summaries
+ | func <- funcs
+ ]
+ where
+ checkFunction c s func summs =
+ let blockIdxs = functionBlocks c func
+ inStates = runFunctionDataflow c blockIdxs summs
+ bs = cfgBlocks c
+ in mconcat
+ [ fst (checkBlockWithSummary s summs inState (bbLines bb))
+ | idx <- blockIdxs
+ , let bb = bs !! idx
+ inState = Map.findWithDefault initTaintState idx
+ (IM.foldlWithKey' toMap Map.empty
+ (runFunctionBlocks c blockIdxs summs))
+ ]
+ toMap m k v = Map.insert k v m
+
+-- | Check a block using summaries for calls.
+checkBlockWithSummary :: Text -> Map.Map Text FuncSummary -> TaintState
+ -> [Line] -> (AuditResult, TaintState)
+checkBlockWithSummary sym summaries st0 lns = go mempty st0 lns
+ where
+ go acc st [] = (acc, st)
+ go acc st (l:ls) =
+ let result = checkLine sym st l
+ st' = case lineInstr l of
+ Nothing -> st
+ Just instr -> case instr of
+ Bl target -> case Map.lookup target summaries of
+ Just summ -> applySummary summ st
+ Nothing -> analyzeLine l st
+ _ -> analyzeLine l st
+ in go (acc <> result) st' ls
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -20,9 +20,18 @@ module Audit.AArch64.Taint (
, analyzeLine
, analyzeBlock
, getTaint
+ , setTaint
, publicRoots
, joinTaintState
, runDataflow
+ -- * Function summaries
+ , FuncSummary(..)
+ , initSummary
+ , joinSummary
+ , applySummary
+ , runFunctionDataflow
+ , runFunctionBlocks
+ , runInterProc
) where
import Audit.AArch64.CFG
@@ -33,6 +42,8 @@ import Data.IntSet (IntSet)
import qualified Data.IntSet as IS
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
+import qualified Data.Set as Set
+import Data.Text (Text)
-- | Taint state: maps registers to their publicness.
type TaintState = Map Reg Taint
@@ -285,3 +296,153 @@ runDataflow cfg
ins' = IM.insert s newIn ins
wl' = if changed then IS.insert s wl else wl
in propagateToSuccs ss out wl' ins'
+
+-- ---------------------------------------------------------------------
+-- Function summaries for inter-procedural analysis
+
+-- | Function summary: taint state of caller-saved registers at return.
+-- This captures what a function leaves in registers when it returns.
+newtype FuncSummary = FuncSummary { summaryState :: TaintState }
+ deriving (Eq, Show)
+
+-- | Initial conservative summary: all caller-saved are Unknown.
+initSummary :: FuncSummary
+initSummary = FuncSummary $ Map.fromList
+ [ (r, Unknown) | r <- callerSavedRegs ]
+
+-- | Caller-saved registers per AArch64 ABI.
+callerSavedRegs :: [Reg]
+callerSavedRegs =
+ [ X0, X1, X2, X3, X4, X5, X6, X7
+ , X8, X9, X10, X11, X12, X13, X14, X15
+ , X16, X17
+ ]
+
+-- | Join two summaries (element-wise join of taint states).
+joinSummary :: FuncSummary -> FuncSummary -> FuncSummary
+joinSummary (FuncSummary a) (FuncSummary b) =
+ FuncSummary (joinTaintState a b)
+
+-- | Apply a function summary at a call site.
+-- Replaces caller-saved register taints with the summary's values.
+applySummary :: FuncSummary -> TaintState -> TaintState
+applySummary (FuncSummary summ) st =
+ foldr applyReg st callerSavedRegs
+ where
+ applyReg r s = Map.insert r (Map.findWithDefault Unknown r summ) s
+
+-- | Run dataflow analysis for a single function (subset of blocks).
+-- Returns the OUT state at return points (joined).
+runFunctionDataflow :: CFG -> [Int] -> Map Text FuncSummary -> TaintState
+runFunctionDataflow cfg blockIndices summaries =
+ let blocks = cfgBlocks cfg
+ -- Run dataflow on just these blocks
+ inStates = runFunctionBlocks cfg blockIndices summaries
+ -- Collect OUT states at return instructions
+ returnOuts = [ analyzeBlock (bbLines (blocks !! i)) inState
+ | i <- blockIndices
+ , let bb = blocks !! i
+ inState = IM.findWithDefault initTaintState i inStates
+ , endsWithRet bb
+ ]
+ in case returnOuts of
+ [] -> initTaintState -- No return found, use init
+ (o:os) -> foldl joinTaintState o os
+
+-- | Check if block ends with a return instruction.
+endsWithRet :: BasicBlock -> Bool
+endsWithRet bb = case bbLines bb of
+ [] -> False
+ ls -> case lineInstr (last ls) of
+ Just (Ret _) -> True
+ _ -> False
+
+-- | Run dataflow on a subset of blocks (one function).
+runFunctionBlocks :: CFG -> [Int] -> Map Text FuncSummary
+ -> IntMap TaintState
+runFunctionBlocks _ [] _ = IM.empty
+runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empty
+ where
+ blocks = cfgBlocks cfg
+ blockSet = IS.fromList (entryIdx:rest)
+
+ initIn = IM.singleton entryIdx initTaintState
+ initWorklist = IS.singleton entryIdx
+
+ go wl inStates outStates
+ | IS.null wl = inStates
+ | otherwise =
+ let (idx, wl') = IS.deleteFindMin wl
+ bb = blocks !! idx
+ inState = IM.findWithDefault initTaintState idx inStates
+ outState = analyzeBlockWithSummaries bb inState summaries
+ oldOut = IM.lookup idx outStates
+ changed = oldOut /= Just outState
+ outStates' = IM.insert idx outState outStates
+ -- Only propagate to successors within this function
+ succs = filter (`IS.member` blockSet) (blockSuccessors cfg idx)
+ (wl'', inStates') = if changed
+ then propagate succs outState wl' inStates
+ else (wl', inStates)
+ in go wl'' inStates' outStates'
+
+ propagate [] _ wl ins = (wl, ins)
+ propagate (s:ss) out wl ins =
+ let oldIn = IM.findWithDefault initTaintState s ins
+ newIn = joinTaintState oldIn out
+ ins' = IM.insert s newIn ins
+ wl' = if oldIn /= newIn then IS.insert s wl else wl
+ in propagate ss out wl' ins'
+
+-- | Analyze a block, applying call summaries at bl instructions.
+analyzeBlockWithSummaries :: BasicBlock -> TaintState -> Map Text FuncSummary
+ -> TaintState
+analyzeBlockWithSummaries bb st0 summaries = foldl go st0 (bbLines bb)
+ where
+ go st l = case lineInstr l of
+ Nothing -> st
+ Just instr -> transferWithSummary instr st summaries
+
+-- | Transfer with summary application for calls.
+transferWithSummary :: Instr -> TaintState -> Map Text FuncSummary -> TaintState
+transferWithSummary instr st summaries = case instr of
+ Bl target ->
+ case Map.lookup target summaries of
+ Just summ -> applySummary summ st
+ Nothing -> invalidateCallerSaved st
+ Blr _ -> invalidateCallerSaved st
+ _ -> transfer instr st
+
+-- | Run inter-procedural fixpoint analysis.
+-- Returns summaries for all functions.
+runInterProc :: CFG -> Map Text FuncSummary
+runInterProc cfg = go initSummaries (Set.fromList funcs)
+ where
+ funcs = functionLabels cfg
+ initSummaries = Map.fromList [(f, initSummary) | f <- funcs]
+
+ go summaries worklist
+ | Set.null worklist = summaries
+ | otherwise =
+ let (func, worklist') = Set.deleteFindMin worklist
+ blockIdxs = functionBlocks cfg func
+ outState = runFunctionDataflow cfg blockIdxs summaries
+ newSumm = FuncSummary outState
+ oldSumm = Map.findWithDefault initSummary func summaries
+ changed = newSumm /= oldSumm
+ summaries' = Map.insert func newSumm summaries
+ -- If changed, re-analyze callers
+ callers = findCallers cfg func
+ worklist'' = if changed
+ then foldr Set.insert worklist' callers
+ else worklist'
+ in go summaries' worklist''
+
+-- | Find functions that call the given target.
+findCallers :: CFG -> Text -> [Text]
+findCallers cfg target =
+ [ caller | caller <- functionLabels cfg
+ , target `elem` callees caller ]
+ where
+ callGraph = buildCallGraph cfg
+ callees f = Map.findWithDefault [] f callGraph
diff --git a/test/Main.hs b/test/Main.hs
@@ -6,7 +6,6 @@ import Audit.AArch64
import Audit.AArch64.Parser
import Audit.AArch64.Taint
import Audit.AArch64.Types
-import Data.Text (Text)
import qualified Data.Text as T
import Test.Tasty
import Test.Tasty.HUnit
@@ -16,6 +15,7 @@ main = defaultMain $ testGroup "ppad-auditor" [
parserTests
, taintTests
, auditTests
+ , interprocTests
]
-- Parser tests
@@ -268,3 +268,61 @@ auditTests = testGroup "Audit" [
Left e -> assertFailure $ "parse failed: " ++ show e
Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
]
+
+-- Inter-procedural tests
+
+interprocTests :: TestTree
+interprocTests = testGroup "InterProc" [
+ testCase "interproc: callee sets x0 public" $ do
+ -- callee sets x0 to a public value; caller uses x0 after call
+ -- With interproc, the summary should show x0 is public
+ let src = T.unlines
+ [ "_callee:"
+ , " adrp x0, _const@PAGE" -- x0 = public
+ , " ret"
+ , "_caller:"
+ , " bl _callee"
+ , " ldr x1, [x0]" -- x0 should be public via summary
+ , " ret"
+ ]
+ -- Default mode: x0 unknown after call (caller-saved invalidated)
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "default: 1 violation" 1 (length (arViolations ar))
+ -- Interproc mode: x0 public via summary
+ case auditInterProc "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "interproc: 0 violations" 0 (length (arViolations ar))
+
+ , testCase "interproc: callee leaves x0 unknown" $ do
+ -- callee loads into x0; caller uses x0 after call
+ let src = T.unlines
+ [ "_callee:"
+ , " ldr x0, [x20]" -- x0 = unknown (loaded)
+ , " ret"
+ , "_caller:"
+ , " bl _callee"
+ , " ldr x1, [x0]" -- x0 unknown in both modes
+ , " ret"
+ ]
+ -- Both modes should report violation
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "default: 1 violation" 1 (length (arViolations ar))
+ case auditInterProc "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "interproc: 1 violation" 1 (length (arViolations ar))
+
+ , testCase "interproc: default mode unchanged" $ do
+ -- Verify default mode behavior is unchanged
+ let src = T.unlines
+ [ "foo:"
+ , " adrp x8, _const@PAGE"
+ , " bl bar"
+ , " ldr x0, [x8]" -- x8 unknown after call (caller-saved)
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "1 violation" 1 (length (arViolations ar))
+ ]