commit 9945a6a7a5d020dfd44f8ba4e830ec96e877ed4f
parent b9de77630868522fb609abfb0d418b98bae84677
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 27 Feb 2026 15:02:53 +0400
feat: add inter-procedural tail call taint propagation
STG code uses `b _function_info` (tail calls) and `br xN` (indirect
jumps) for control flow rather than `bl`. This change enables taint
propagation across these boundaries within the same assembly file.
Changes:
- CFG.hs: extend callTargets to include B to function labels as tail
call edges
- Taint.hs: handle B/Br in transferWithSummary, add invalidateStgArgRegs
for conservative handling of unknown targets, add
runDataflowWithConfigAndSummaries for whole-file dataflow
- Check.hs: update checkCFGInterProc and checkCFGInterProcWithConfig to
use whole-file dataflow with summaries
- Main.hs: add tailCallTests group with 10 test cases
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 299 insertions(+), 35 deletions(-)
diff --git a/lib/Audit/AArch64/CFG.hs b/lib/Audit/AArch64/CFG.hs
@@ -203,10 +203,12 @@ isTerminator instr = case instr of
-- | Extract successor labels from a terminating instruction.
-- Note: Bl/Blr are NOT included as successors; calls are interprocedural
-- boundaries. Fallthrough after call is handled by hasFallthrough.
+-- B to function labels (tail calls) ARE included - they create CFG edges
+-- for intra-file dataflow propagation.
successorLabels :: Maybe Instr -> [Text]
successorLabels Nothing = []
successorLabels (Just instr) = case instr of
- B lbl -> [lbl]
+ B lbl -> [lbl] -- Includes both local branches and tail calls
BCond _ lbl -> [lbl] -- Plus fallthrough
Bl _ -> [] -- Call; fallthrough handled separately
Blr _ -> [] -- Indirect call
@@ -259,6 +261,7 @@ functionBlocks cfg funcLabel =
Map.findWithDefault [] funcLabel (cfgFuncBlocks cfg)
-- | Extract call targets from a block's instructions.
+-- Includes both bl (call with link) and b to function labels (tail call).
callTargets :: BasicBlock -> [Text]
callTargets bb =
[ target | l <- bbLines bb
@@ -266,7 +269,9 @@ callTargets bb =
, target <- getCallTarget instr ]
where
getCallTarget (Bl target) = [target]
- getCallTarget _ = []
+ getCallTarget (B target)
+ | isFunctionLabel target = [target] -- Tail call
+ getCallTarget _ = []
-- | Build call graph: maps each function to its callees.
-- Uses cached function blocks map for efficient lookup.
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -201,25 +201,22 @@ checkIndex sym ln instr idx st =
-- No provenance upgrade for indices - they are scalar values
-- | Check entire CFG with inter-procedural analysis.
--- Computes function summaries via fixpoint, then checks each function.
+-- Uses whole-file dataflow with summaries to propagate taint across
+-- tail call boundaries.
checkCFGInterProc :: Text -> CFG -> AuditResult
checkCFGInterProc sym cfg =
let summaries = runInterProc cfg
- funcs = functionLabels cfg
+ -- Use empty config with whole-file dataflow
+ emptyConfig = TaintConfig Map.empty True
+ inStates = runDataflowWithConfigAndSummaries emptyConfig summaries cfg
+ nBlocks = cfgBlockCount cfg
in mconcat
- [ checkFunction cfg sym func summaries
- | func <- funcs
+ [ fst (checkBlockWithSummary blockSym summaries inState (bbLines bb))
+ | idx <- [0..nBlocks-1]
+ , let bb = indexBlock cfg idx
+ blockSym = fromMaybe sym (bbLabel bb)
+ inState = IM.findWithDefault initTaintState idx inStates
]
- where
- checkFunction c s func summs =
- let blockIdxs = functionBlocks c func
- inStates = runFunctionBlocks c blockIdxs summs
- in mconcat
- [ fst (checkBlockWithSummary s summs inState (bbLines bb))
- | idx <- blockIdxs
- , let bb = indexBlock c idx
- inState = IM.findWithDefault initTaintState idx inStates
- ]
-- | Check a block using summaries for calls.
-- Uses strict counters and cons-accumulation for violations.
@@ -254,26 +251,18 @@ checkCFGWithConfig tcfg sym cfg =
]
-- | Check entire CFG with inter-procedural analysis and taint config.
+-- Uses whole-file dataflow with summaries to properly propagate taint
+-- across tail call boundaries.
checkCFGInterProcWithConfig :: TaintConfig -> Text -> CFG -> AuditResult
checkCFGInterProcWithConfig tcfg sym cfg =
let summaries = runInterProcWithConfig tcfg cfg
- funcs = functionLabels cfg
+ -- Run whole-file dataflow with config seeding AND summary application
+ inStates = runDataflowWithConfigAndSummaries tcfg summaries cfg
+ nBlocks = cfgBlockCount cfg
in mconcat
- [ checkFunctionWithConfig cfg tcfg sym func summaries
- | func <- funcs
+ [ fst (checkBlockWithSummary blockSym summaries inState (bbLines bb))
+ | idx <- [0..nBlocks-1]
+ , let bb = indexBlock cfg idx
+ blockSym = fromMaybe sym (bbLabel bb)
+ inState = IM.findWithDefault initTaintState idx inStates
]
- where
- checkFunctionWithConfig c tc s func summs =
- let blockIdxs = functionBlocks c func
- -- Seed entry state with function policy
- baseEntry = initTaintState
- entryState = case Map.lookup func (tcPolicies tc) of
- Nothing -> baseEntry
- Just policy -> seedArgs policy baseEntry
- inStates = runFunctionBlocksWithEntry c blockIdxs summs entryState
- in mconcat
- [ fst (checkBlockWithSummary s summs inState (bbLines bb))
- | idx <- blockIdxs
- , let bb = indexBlock c idx
- inState = IM.findWithDefault entryState idx inStates
- ]
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -28,6 +28,7 @@ module Audit.AArch64.Taint (
, joinTaintState
, runDataflow
, runDataflowWithConfig
+ , runDataflowWithConfigAndSummaries
-- * Function summaries
, FuncSummary(..)
, initSummary
@@ -47,6 +48,7 @@ import Audit.AArch64.CFG
, blockSuccessors, functionLabels, functionBlocks
, buildFunctionBlocksMap, buildCallerMap
, cfgBlockCount, indexBlock
+ , isFunctionLabel
)
import Audit.AArch64.Types
( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..)
@@ -1008,6 +1010,18 @@ invalidateCallerSaved st = st
, X16, X17
]
+-- | Invalidate STG argument registers for unknown tail call targets.
+-- STG calling convention: X22=R1 (closure), X23-X27=R2-R6 (args).
+-- Used when we can't determine the tail call target (br xN or unknown symbol).
+invalidateStgArgRegs :: TaintState -> TaintState
+invalidateStgArgRegs st = st
+ { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) stgArgRegs
+ , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) stgArgRegs
+ , tsKind = foldr (\r -> Map.insert r KindUnknown) (tsKind st) stgArgRegs
+ }
+ where
+ stgArgRegs = [X22, X23, X24, X25, X26, X27]
+
-- | Join two taint states (element-wise join).
-- For registers in both, take the join. For registers in only one, keep.
-- Stack slots (SP and STG), provenance, kinds, and heap bucket are also joined.
@@ -1209,7 +1223,7 @@ analyzeBlockWithSummaries bb st0 summaries = foldl' go st0 (bbLines bb)
Nothing -> st
Just instr -> transferWithSummary instr st summaries
--- | Transfer with summary application for calls.
+-- | Transfer with summary application for calls and tail calls.
transferWithSummary :: Instr -> TaintState -> Map Text FuncSummary -> TaintState
transferWithSummary instr st summaries = case instr of
Bl target ->
@@ -1217,6 +1231,14 @@ transferWithSummary instr st summaries = case instr of
Just summ -> applySummary summ st
Nothing -> invalidateCallerSaved st
Blr _ -> invalidateCallerSaved st
+ -- Tail calls: b to function label
+ B target
+ | isFunctionLabel target ->
+ case Map.lookup target summaries of
+ Just summ -> applySummary summ st
+ Nothing -> invalidateStgArgRegs st
+ -- Indirect jumps: conservative treatment
+ Br _ -> invalidateStgArgRegs st
_ -> transfer instr st
-- | Run inter-procedural fixpoint analysis.
@@ -1313,6 +1335,67 @@ runDataflowWithConfig tcfg cfg
wl' = if oldIn /= newIn then IS.insert s wl else wl
in propagate ss out wl' ins'
+-- | Run forward dataflow with taint config and function summaries.
+-- Combines config-based entry seeding with summary application at calls.
+-- This enables whole-file dataflow that respects both taint configs and
+-- inter-procedural summaries (including tail call propagation).
+runDataflowWithConfigAndSummaries :: TaintConfig -> Map Text FuncSummary
+ -> CFG -> IntMap TaintState
+runDataflowWithConfigAndSummaries tcfg summaries cfg
+ | nBlocks == 0 = IM.empty
+ | otherwise = go initWorklist initIn IM.empty
+ where
+ nBlocks = cfgBlockCount cfg
+ assumeStg = tcAssumeStgPublic tcfg
+ baseState = initTaintStateWith assumeStg
+ emptyState = emptyTaintStateWith assumeStg
+
+ -- Build a map from block index to entry taint state
+ -- Entry blocks of functions get their policy applied
+ initIn = IM.fromList
+ [ (i, entryState i (indexBlock cfg i))
+ | i <- [0..nBlocks-1]
+ ]
+
+ entryState idx bb =
+ case bbLabel bb of
+ Nothing -> baseState
+ Just lbl ->
+ -- Check if this block is a function entry
+ case functionBlocks cfg lbl of
+ (entry:_) | entry == idx ->
+ case Map.lookup lbl (tcPolicies tcfg) of
+ Nothing -> baseState
+ Just policy -> seedArgs policy baseState
+ _ -> baseState
+
+ initWorklist = IS.fromList [0..nBlocks-1]
+
+ go worklist inStates outStates
+ | IS.null worklist = inStates
+ | otherwise =
+ let (idx, worklist') = IS.deleteFindMin worklist
+ bb = indexBlock cfg idx
+ inState = IM.findWithDefault baseState idx inStates
+ -- Use analyzeBlockWithSummaries to apply summaries at calls
+ outState = analyzeBlockWithSummaries bb inState summaries
+ oldOut = IM.lookup idx outStates
+ changed = oldOut /= Just outState
+ outStates' = IM.insert idx outState outStates
+ succs = blockSuccessors cfg idx
+ (worklist'', inStates') = if changed
+ then propagateS succs outState worklist' inStates
+ else (worklist', inStates)
+ in go worklist'' inStates' outStates'
+
+ propagateS [] _ wl ins = (wl, ins)
+ propagateS (s:ss) out wl ins =
+ let oldIn = IM.findWithDefault emptyState s ins
+ newIn = joinTaintState oldIn out
+ ins' = IM.insert s newIn ins
+ wl' = if oldIn /= newIn then IS.insert s wl else wl
+ in propagateS ss out wl' ins'
+
-- | Run inter-procedural analysis with taint config.
-- Precomputes caches for function blocks and callers.
-- Uses tcAssumeStgPublic from config for STG stack assumption.
diff --git a/test/Main.hs b/test/Main.hs
@@ -34,6 +34,7 @@ main = defaultMain $ testGroup "ppad-auditor" [
, taintConfigTests
, nctTests
, callGraphTests
+ , tailCallTests
]
-- Parser tests
@@ -1470,3 +1471,189 @@ callGraphTests = testGroup "CallGraph" [
reachable = reachableSymbols "_missing" cg
assertEqual "empty set" Set.empty reachable
]
+
+-- Tail call inter-procedural tests
+
+tailCallTests :: TestTree
+tailCallTests = testGroup "TailCall" [
+
+ -- Basic tail call taint propagation
+ -- Uses x23 (R2 in STG) since x22 (HpLim) is a public root
+ testCase "tail call propagates unknown taint" $ do
+ let src = T.unlines
+ [ "_caller:"
+ , " ldr x23, [x21]"
+ , " b _callee_info"
+ , "_callee_info:"
+ , " ldr x0, [x20, x23]"
+ , " ret"
+ ]
+ case auditInterProc "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "one violation" 1 (length (arViolations ar))
+
+ , testCase "tail call propagates secret taint" $ do
+ let src = T.unlines
+ [ "_caller:"
+ , " mov x22, x0"
+ , " b _callee_info"
+ , "_callee_info:"
+ , " ldr x1, [x20, x22]"
+ , " ret"
+ ]
+ cfg = TaintConfig (Map.singleton "_caller"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ -- Local branches still work
+ , testCase "local branch is intra-procedural" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " mov x0, x1"
+ , " b Llocal"
+ , "Llocal:"
+ , " ldr x2, [x20, x0]"
+ , " ret"
+ ]
+ case auditInterProc "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "one unknown violation" 1 (length (arViolations ar))
+
+ , testCase "NCG local branch (Lc prefix) is intra-procedural" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " mov x22, x0"
+ , " b LcaB"
+ , "LcaB:"
+ , " ldr x1, [x20, x22]"
+ , " ret"
+ ]
+ cfg = TaintConfig (Map.singleton "_foo"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ -- Chain of tail calls
+ , testCase "chain of tail calls propagates taint" $ do
+ let src = T.unlines
+ [ "_a:"
+ , " mov x22, x0"
+ , " b _b_info"
+ , "_b_info:"
+ , " mov x23, x22"
+ , " b _c_info"
+ , "_c_info:"
+ , " ldr x1, [x20, x23]"
+ , " ret"
+ ]
+ cfg = TaintConfig (Map.singleton "_a"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ -- Indirect jump handling
+ , testCase "indirect jump invalidates STG regs" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " mov x22, x0"
+ , " ldr x8, [x19]"
+ , " br x8"
+ ]
+ cfg = TaintConfig (Map.singleton "_foo"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
+
+ -- Tail call to unknown function
+ , testCase "tail call to unknown function is conservative" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " mov x22, x0"
+ , " b _unknown_info"
+ ]
+ cfg = TaintConfig (Map.singleton "_foo"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
+
+ -- Secret pointee across tail call
+ , testCase "secret_pointee propagates across tail call" $ do
+ let src = T.unlines
+ [ "_caller:"
+ , " mov x22, x0"
+ , " b _callee_info"
+ , "_callee_info:"
+ , " ldr x1, [x22, #0]"
+ , " ldr x2, [x20, x1]"
+ , " ret"
+ ]
+ cfg = TaintConfig (Map.singleton "_caller"
+ (ArgPolicy Set.empty Set.empty (Set.singleton X0)
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ -- Mutual tail calls (fixpoint test)
+ , testCase "mutual tail calls reach fixpoint" $ do
+ let src = T.unlines
+ [ "_a_info:"
+ , " cbnz x0, Lskip"
+ , " b _b_info"
+ , "Lskip:"
+ , " ret"
+ , "_b_info:"
+ , " ldr x1, [x20, x22]"
+ , " b _a_info"
+ ]
+ cfg = TaintConfig (Map.singleton "_a_info"
+ (ArgPolicy (Set.singleton X22) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation in cycle" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ -- Mixed bl and b calls
+ , testCase "bl followed by tail call" $ do
+ let src = T.unlines
+ [ "_caller:"
+ , " mov x22, x0"
+ , " bl _helper"
+ , " b _callee_info"
+ , "_helper:"
+ , " ret"
+ , "_callee_info:"
+ , " ldr x1, [x20, x22]"
+ , " ret"
+ ]
+ cfg = TaintConfig (Map.singleton "_caller"
+ (ArgPolicy (Set.singleton X0) Set.empty Set.empty
+ Set.empty Set.empty)) True
+ case auditInterProcWithConfig cfg "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar ->
+ assertEqual "secret violation" 1
+ (length $ filter isSecretViolation (map vReason (arViolations ar)))
+
+ ]