commit d08c45dad9c962dedb108935e04abc5a05d0a8b1
parent 2f45e0bc6bda9cab4322ee7355359d6d73992d3d
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 11 Feb 2026 19:15:03 +0400
perf: refactor CFG to indexed blocks with cached metadata (IMPL10)
- Replace list-based cfgBlocks with Data.Primitive.Array for O(1) indexing
- Add bbLastInstr, bbSuccIdxs, bbHasFallthrough to BasicBlock
- Cache cfgFuncBlocks map during CFG construction
- Update blockSuccessors, functionBlocks, functionLabels to use caches
- Update Taint.hs and Check.hs to use indexBlock instead of list indexing
Benchmark results show dramatic improvements in taint analysis:
- taint/intra-small: 305ms -> 18ms (17x faster)
- taint/intra-large: 4.9s -> 84ms (58x faster)
- taint/inter-small: 303ms -> 2.9ms (104x faster)
- taint/inter-large: 5.1s -> 25ms (204x faster)
CFG construction is slightly slower due to caching overhead, but this
one-time cost is offset by massive analysis speedups.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 136 insertions(+), 124 deletions(-)
diff --git a/lib/Audit/AArch64/CFG.hs b/lib/Audit/AArch64/CFG.hs
@@ -17,13 +17,15 @@ module Audit.AArch64.CFG (
, buildCFG
, blockLabels
, blockSuccessors
+ , cfgBlockCount
+ , indexBlock
-- * Function partitioning
, isFunctionLabel
, functionBlocks
, functionLabels
, callTargets
, buildCallGraph
- -- * Cached analysis structures
+ -- * Cached analysis structures (deprecated)
, buildFunctionBlocksMap
, buildCallerMap
) where
@@ -32,6 +34,8 @@ import Audit.AArch64.Types
import Control.DeepSeq (NFData)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
+import Data.Primitive.Array (Array)
+import qualified Data.Primitive.Array as A
import Data.Set (Set)
import qualified Data.Set as Set
import Data.Text (Text)
@@ -40,39 +44,107 @@ import GHC.Generics (Generic)
-- | A basic block: a sequence of instructions with a single entry point.
data BasicBlock = BasicBlock
- { bbLabel :: !(Maybe Text) -- ^ Optional label at block start
- , bbLines :: ![Line] -- ^ Instructions in the block
- , bbSuccs :: ![Text] -- ^ Successor block labels
+ { bbLabel :: !(Maybe Text) -- ^ Optional label at block start
+ , bbLines :: ![Line] -- ^ Instructions in the block
+ , bbSuccs :: ![Text] -- ^ Successor block labels
+ , bbLastInstr :: !(Maybe Instr) -- ^ Last instruction (cached)
+ , bbSuccIdxs :: ![Int] -- ^ Successor block indices (cached)
+ , bbHasFallthrough :: !Bool -- ^ Block falls through to next
} deriving (Eq, Show, Generic, NFData)
-- | Control flow graph.
data CFG = CFG
- { cfgBlocks :: ![BasicBlock] -- ^ All basic blocks
+ { cfgBlocks :: !(Array BasicBlock) -- ^ All basic blocks (O(1) index)
, cfgLabelMap :: !(Map Text Int) -- ^ Label -> block index
, cfgEntry :: !Int -- ^ Entry block index
+ , cfgFuncBlocks :: !(Map Text [Int]) -- ^ Function label -> block indices
} deriving (Eq, Show, Generic, NFData)
-- | Get all labels defined in the CFG.
blockLabels :: CFG -> Set Text
blockLabels cfg = Set.fromList
- [ lbl | bb <- cfgBlocks cfg, Just lbl <- [bbLabel bb] ]
+ [ lbl | bb <- toList (cfgBlocks cfg), Just lbl <- [bbLabel bb] ]
+ where
+ toList arr = [A.indexArray arr i | i <- [0 .. A.sizeofArray arr - 1]]
+
+-- | Get the number of blocks in the CFG.
+cfgBlockCount :: CFG -> Int
+cfgBlockCount cfg = A.sizeofArray (cfgBlocks cfg)
+
+-- | Index into the CFG's block array.
+indexBlock :: CFG -> Int -> BasicBlock
+indexBlock cfg i = A.indexArray (cfgBlocks cfg) i
-- | Build a CFG from parsed assembly lines.
buildCFG :: [Line] -> CFG
-buildCFG lns = CFG
- { cfgBlocks = blocks
- , cfgLabelMap = labelMap
- , cfgEntry = 0
- }
+buildCFG lns = cfg
where
-- Split into basic blocks at labels and control flow instructions
- blocks = buildBlocks lns
+ rawBlocks = buildBlocks lns
+ nBlocks = length rawBlocks
+
+ -- Build label map first
labelMap = Map.fromList
[ (lbl, idx)
- | (idx, bb) <- zip [0..] blocks
+ | (idx, bb) <- zip [0..] rawBlocks
, Just lbl <- [bbLabel bb]
]
+ -- Compute successor indices and fallthrough for each block
+ annotatedBlocks =
+ [ bb { bbSuccIdxs = succIdxs, bbHasFallthrough = fallthrough }
+ | (idx, bb) <- zip [0..] rawBlocks
+ , let jumpTargets = [ i | lbl <- bbSuccs bb
+ , Just i <- [Map.lookup lbl labelMap] ]
+ fallthrough = hasFallthroughInstr (bbLastInstr bb)
+ fallthroughEdge = if fallthrough && idx + 1 < nBlocks
+ then [idx + 1]
+ else []
+ succIdxs = jumpTargets ++ fallthroughEdge
+ ]
+
+ -- Build the block array
+ blocks = A.arrayFromList annotatedBlocks
+
+ -- Build function blocks map in a single pass
+ funcBlocksMap = buildFuncBlocksOnce annotatedBlocks
+
+ cfg = CFG
+ { cfgBlocks = blocks
+ , cfgLabelMap = labelMap
+ , cfgEntry = 0
+ , cfgFuncBlocks = funcBlocksMap
+ }
+
+-- | Check if an instruction falls through to the next block.
+hasFallthroughInstr :: Maybe Instr -> Bool
+hasFallthroughInstr Nothing = True
+hasFallthroughInstr (Just instr) = case instr of
+ B _ -> False -- Unconditional branch
+ Br _ -> False -- Indirect branch
+ Ret _ -> False -- Return
+ _ -> True -- Conditional branches, calls fall through
+
+-- | Build function blocks map in a single pass over blocks.
+buildFuncBlocksOnce :: [BasicBlock] -> Map Text [Int]
+buildFuncBlocksOnce bbs = finalize $ foldl step (Nothing, Map.empty) indexed
+ where
+ nBlocks = length bbs
+ indexed = zip [0..] bbs
+
+ step (mCur, acc) (idx, bb) =
+ case bbLabel bb of
+ Just lbl | isFunctionLabel lbl ->
+ let acc' = closeCurrent idx mCur acc
+ in (Just (lbl, idx), acc')
+ _ -> (mCur, acc)
+
+ closeCurrent _ Nothing acc = acc
+ closeCurrent endIdx (Just (lbl, start)) acc =
+ Map.insert lbl [start .. endIdx - 1] acc
+
+ finalize (mCur, acc) = closeCurrent nBlocks mCur acc
+
-- | Build basic blocks from lines.
buildBlocks :: [Line] -> [BasicBlock]
buildBlocks [] = []
@@ -98,11 +170,16 @@ buildBlocks lns = go [] Nothing lns
-- Regular instruction
_ -> go (l:acc) mLabel ls
- finishBlock mLabel lns' = BasicBlock
- { bbLabel = mLabel
- , bbLines = lns'
- , bbSuccs = successorLabels (lastInstr lns')
- }
+ finishBlock mLabel lns' =
+ let lastI = lastInstr lns'
+ in BasicBlock
+ { bbLabel = mLabel
+ , bbLines = lns'
+ , bbSuccs = successorLabels lastI
+ , bbLastInstr = lastI
+ , bbSuccIdxs = [] -- Filled in by buildCFG
+ , bbHasFallthrough = True -- Filled in by buildCFG
+ }
lastInstr [] = Nothing
lastInstr xs = lineInstr (last xs)
@@ -141,36 +218,11 @@ successorLabels (Just instr) = case instr of
_ -> []
-- | Get successor block indices for a given block.
--- Includes both jump targets (resolved via label map) and fallthrough edges.
+-- Returns cached successor indices (precomputed during CFG construction).
blockSuccessors :: CFG -> Int -> [Int]
blockSuccessors cfg idx
- | idx < 0 || idx >= length (cfgBlocks cfg) = []
- | otherwise =
- let blocks = cfgBlocks cfg
- bb = blocks !! idx
- lmap = cfgLabelMap cfg
- -- Jump targets from labels
- jumpTargets = [ i | lbl <- bbSuccs bb
- , Just i <- [Map.lookup lbl lmap] ]
- -- Fallthrough if this block doesn't end with unconditional transfer
- fallthrough = if hasFallthrough bb && idx + 1 < length blocks
- then [idx + 1]
- else []
- in jumpTargets ++ fallthrough
-
--- | Check if a block falls through to the next block.
-hasFallthrough :: BasicBlock -> Bool
-hasFallthrough bb = case lastInstrOf bb of
- Nothing -> True -- No terminator, falls through
- Just instr -> case instr of
- B _ -> False -- Unconditional branch
- Br _ -> False -- Indirect branch
- Ret _ -> False -- Return
- _ -> True -- Conditional branches, calls fall through
- where
- lastInstrOf b = case bbLines b of
- [] -> Nothing
- ls -> lineInstr (last ls)
+ | idx < 0 || idx >= cfgBlockCount cfg = []
+ | otherwise = bbSuccIdxs (indexBlock cfg idx)
-- | Check if a label is an NCG-internal label (not a function entry).
-- These start with _L followed by a lowercase letter (e.g. _Lblock_info).
@@ -195,30 +247,15 @@ isFunctionLabel lbl
| otherwise = True -- Likely a function
-- | Get all function entry labels in the CFG.
+-- Returns keys from the cached function blocks map.
functionLabels :: CFG -> [Text]
-functionLabels cfg =
- [ lbl | bb <- cfgBlocks cfg
- , Just lbl <- [bbLabel bb]
- , isFunctionLabel lbl ]
+functionLabels cfg = Map.keys (cfgFuncBlocks cfg)
-- | Get block indices belonging to a function.
--- A function spans from its entry label to the next function label.
+-- Returns cached block indices (precomputed during CFG construction).
functionBlocks :: CFG -> Text -> [Int]
functionBlocks cfg funcLabel =
- case Map.lookup funcLabel (cfgLabelMap cfg) of
- Nothing -> []
- Just startIdx ->
- let blocks = cfgBlocks cfg
- n = length blocks
- isNextFunc i = case bbLabel (blocks !! i) of
- Nothing -> False
- Just lbl -> lbl /= funcLabel && isFunctionLabel lbl
- endIdx = findEnd (startIdx + 1)
- findEnd i
- | i >= n = n
- | isNextFunc i = i
- | otherwise = findEnd (i + 1)
- in [startIdx .. endIdx - 1]
+ Map.findWithDefault [] funcLabel (cfgFuncBlocks cfg)
-- | Extract call targets from a block's instructions.
callTargets :: BasicBlock -> [Text]
@@ -231,39 +268,18 @@ callTargets bb =
getCallTarget _ = []
-- | Build call graph: maps each function to its callees.
+-- Uses cached function blocks map for efficient lookup.
buildCallGraph :: CFG -> Map Text [Text]
buildCallGraph cfg = Map.fromList
[ (func, callees)
- | func <- functionLabels cfg
- , let blocks = cfgBlocks cfg
- indices = functionBlocks cfg func
- callees = concatMap (callTargets . (blocks !!)) indices
+ | (func, indices) <- Map.toList (cfgFuncBlocks cfg)
+ , let callees = concatMap (callTargets . indexBlock cfg) indices
]
-- | Build map of function labels to their block index ranges.
--- Single O(N) pass over blocks, detecting function boundaries.
+-- Deprecated: use cfgFuncBlocks directly; this returns the cached map.
buildFunctionBlocksMap :: CFG -> Map Text [Int]
-buildFunctionBlocksMap cfg = finalize $ foldl step (Nothing, Map.empty) indexed
- where
- blocks = cfgBlocks cfg
- nBlocks = length blocks
- indexed = zip [0..] blocks
-
- step (mCur, acc) (idx, bb) =
- case bbLabel bb of
- Just lbl | isFunctionLabel lbl ->
- -- New function starts; close previous at idx
- let acc' = closeCurrent idx mCur acc
- in (Just (lbl, idx), acc')
- _ ->
- -- Continue current function
- (mCur, acc)
-
- closeCurrent _ Nothing acc = acc
- closeCurrent endIdx (Just (lbl, start)) acc =
- Map.insert lbl [start .. endIdx - 1] acc
-
- finalize (mCur, acc) = closeCurrent nBlocks mCur acc
+buildFunctionBlocksMap = cfgFuncBlocks
-- | Build caller map: maps each function to its callers.
-- Takes precomputed function blocks map to avoid rescanning.
@@ -274,8 +290,7 @@ buildCallerMap cfg funcBlocksMap = Map.fromListWith (++)
, callee <- callees
]
where
- blocks = cfgBlocks cfg
callGraph = Map.fromList
- [ (func, concatMap (callTargets . (blocks !!)) idxs)
+ [ (func, concatMap (callTargets . indexBlock cfg) idxs)
| (func, idxs) <- Map.toList funcBlocksMap
]
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -24,7 +24,8 @@ module Audit.AArch64.Check (
, AuditResult(..)
) where
-import Audit.AArch64.CFG
+import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock,
+ functionLabels, functionBlocks)
import Audit.AArch64.Taint
import Audit.AArch64.Types
import Control.DeepSeq (NFData)
@@ -76,11 +77,12 @@ checkBlock sym st0 lns = go mempty st0 lns
checkCFG :: Text -> CFG -> AuditResult
checkCFG sym cfg =
let inStates = runDataflow cfg
- blocks = cfgBlocks cfg
+ nBlocks = cfgBlockCount cfg
in mconcat
[ fst (checkBlock blockSym inState (bbLines bb))
- | (idx, bb) <- zip [0..] blocks
- , let blockSym = maybe sym id (bbLabel bb)
+ | idx <- [0..nBlocks-1]
+ , let bb = indexBlock cfg idx
+ blockSym = maybe sym id (bbLabel bb)
inState = IM.findWithDefault initTaintState idx inStates
]
@@ -169,13 +171,12 @@ checkCFGInterProc sym cfg =
where
checkFunction c s func summs =
let blockIdxs = functionBlocks c func
- bs = cfgBlocks c
inStatesIM = runFunctionBlocks c blockIdxs summs
inStatesMap = IM.foldlWithKey' toMap Map.empty inStatesIM
in mconcat
[ fst (checkBlockWithSummary s summs inState (bbLines bb))
| idx <- blockIdxs
- , let bb = bs !! idx
+ , let bb = indexBlock c idx
inState = Map.findWithDefault initTaintState idx inStatesMap
]
toMap m k v = Map.insert k v m
@@ -201,11 +202,12 @@ checkBlockWithSummary sym summaries st0 lns = go mempty st0 lns
checkCFGWithConfig :: TaintConfig -> Text -> CFG -> AuditResult
checkCFGWithConfig tcfg sym cfg =
let inStates = runDataflowWithConfig tcfg cfg
- blocks = cfgBlocks cfg
+ nBlocks = cfgBlockCount cfg
in mconcat
[ fst (checkBlock blockSym inState (bbLines bb))
- | (idx, bb) <- zip [0..] blocks
- , let blockSym = fromMaybe sym (bbLabel bb)
+ | idx <- [0..nBlocks-1]
+ , let bb = indexBlock cfg idx
+ blockSym = fromMaybe sym (bbLabel bb)
inState = IM.findWithDefault initTaintState idx inStates
]
@@ -221,7 +223,6 @@ checkCFGInterProcWithConfig tcfg sym cfg =
where
checkFunctionWithConfig c tc s func summs =
let blockIdxs = functionBlocks c func
- bs = cfgBlocks c
-- Seed entry state with function policy
baseEntry = initTaintState
entryState = case Map.lookup func (tcPolicies tc) of
@@ -232,7 +233,7 @@ checkCFGInterProcWithConfig tcfg sym cfg =
in mconcat
[ fst (checkBlockWithSummary s summs inState (bbLines bb))
| idx <- blockIdxs
- , let bb = bs !! idx
+ , let bb = indexBlock c idx
inState = Map.findWithDefault entryState idx inStatesMap
]
toMap m k v = Map.insert k v m
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -44,6 +44,7 @@ import Audit.AArch64.CFG
( BasicBlock(..), CFG(..)
, blockSuccessors, functionLabels, functionBlocks
, buildFunctionBlocksMap, buildCallerMap
+ , cfgBlockCount, indexBlock
)
import Audit.AArch64.Types
( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..)
@@ -575,11 +576,10 @@ joinTaintState a b = TaintState
-- Returns the IN taint state for each block (indexed by block number).
runDataflow :: CFG -> IntMap TaintState
runDataflow cfg
- | null (cfgBlocks cfg) = IM.empty
+ | nBlocks == 0 = IM.empty
| otherwise = go initWorklist initIn initOut
where
- blocks = cfgBlocks cfg
- nBlocks = length blocks
+ nBlocks = cfgBlockCount cfg
-- Initial states: all blocks start with public roots (GHC invariant)
initIn = IM.fromList [(i, initTaintState) | i <- [0..nBlocks-1]]
@@ -591,7 +591,7 @@ runDataflow cfg
| IS.null worklist = inStates
| otherwise =
let (idx, worklist') = IS.deleteFindMin worklist
- bb = blocks !! idx
+ bb = indexBlock cfg idx
inState = IM.findWithDefault initTaintState idx inStates
outState = analyzeBlock (bbLines bb) inState
oldOut = IM.lookup idx outStates
@@ -661,13 +661,12 @@ applySummary (FuncSummary summ) st = st
-- Returns the OUT state at return points (joined).
runFunctionDataflow :: CFG -> [Int] -> Map Text FuncSummary -> TaintState
runFunctionDataflow cfg blockIndices summaries =
- let blocks = cfgBlocks cfg
- -- Run dataflow on just these blocks
+ let -- Run dataflow on just these blocks
inStates = runFunctionBlocks cfg blockIndices summaries
-- Collect OUT states at return instructions
returnOuts = [ analyzeBlockWithSummaries bb inState summaries
| i <- blockIndices
- , let bb = blocks !! i
+ , let bb = indexBlock cfg i
inState = IM.findWithDefault initTaintState i inStates
, endsWithRet bb
]
@@ -689,7 +688,6 @@ runFunctionBlocks :: CFG -> [Int] -> Map Text FuncSummary
runFunctionBlocks _ [] _ = IM.empty
runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empty
where
- blocks = cfgBlocks cfg
blockSet = IS.fromList (entryIdx:rest)
initIn = IM.singleton entryIdx initTaintState
@@ -699,7 +697,7 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt
| IS.null wl = inStates
| otherwise =
let (idx, wl') = IS.deleteFindMin wl
- bb = blocks !! idx
+ bb = indexBlock cfg idx
inState = IM.findWithDefault initTaintState idx inStates
outState = analyzeBlockWithSummaries bb inState summaries
oldOut = IM.lookup idx outStates
@@ -779,17 +777,16 @@ runInterProc cfg = go initSummaries (Set.fromList funcs)
-- Entry states are seeded according to function-specific policies.
runDataflowWithConfig :: TaintConfig -> CFG -> IntMap TaintState
runDataflowWithConfig tcfg cfg
- | null (cfgBlocks cfg) = IM.empty
+ | nBlocks == 0 = IM.empty
| otherwise = go initWorklist initIn IM.empty
where
- blocks = cfgBlocks cfg
- nBlocks = length blocks
+ nBlocks = cfgBlockCount cfg
-- Build a map from block index to entry taint state
-- Entry blocks of functions get their policy applied
initIn = IM.fromList
- [ (i, entryState i bb)
- | (i, bb) <- zip [0..] blocks
+ [ (i, entryState i (indexBlock cfg i))
+ | i <- [0..nBlocks-1]
]
entryState idx bb =
@@ -811,7 +808,7 @@ runDataflowWithConfig tcfg cfg
| IS.null worklist = inStates
| otherwise =
let (idx, worklist') = IS.deleteFindMin worklist
- bb = blocks !! idx
+ bb = indexBlock cfg idx
inState = IM.findWithDefault initTaintState idx inStates
outState = analyzeBlock (bbLines bb) inState
oldOut = IM.lookup idx outStates
@@ -867,8 +864,7 @@ runInterProcWithConfig tcfg cfg = go initSummaries (Set.fromList funcs)
runFunctionDataflowWithConfig :: TaintConfig -> CFG -> Text -> [Int]
-> Map Text FuncSummary -> TaintState
runFunctionDataflowWithConfig tcfg cfg funcName blockIndices summaries =
- let blocks = cfgBlocks cfg
- -- Seed entry state with function policy
+ let -- Seed entry state with function policy
baseEntry = initTaintState
entryState = case Map.lookup funcName (tcPolicies tcfg) of
Nothing -> baseEntry
@@ -877,7 +873,7 @@ runFunctionDataflowWithConfig tcfg cfg funcName blockIndices summaries =
returnOuts =
[ analyzeBlockWithSummaries bb inState summaries
| i <- blockIndices
- , let bb = blocks !! i
+ , let bb = indexBlock cfg i
inState = IM.findWithDefault entryState i inStates
, endsWithRet bb
]
@@ -892,7 +888,6 @@ runFunctionBlocksWithEntry _ [] _ _ = IM.empty
runFunctionBlocksWithEntry cfg (entryIdx:rest) summaries entryState =
go initWorklist initIn IM.empty
where
- blocks = cfgBlocks cfg
blockSet = IS.fromList (entryIdx:rest)
initIn = IM.singleton entryIdx entryState
@@ -902,7 +897,7 @@ runFunctionBlocksWithEntry cfg (entryIdx:rest) summaries entryState =
| IS.null wl = inStates
| otherwise =
let (idx, wl') = IS.deleteFindMin wl
- bb = blocks !! idx
+ bb = indexBlock cfg idx
inState = IM.findWithDefault entryState idx inStates
outState = analyzeBlockWithSummaries bb inState summaries
oldOut = IM.lookup idx outStates
diff --git a/ppad-auditor.cabal b/ppad-auditor.cabal
@@ -36,6 +36,7 @@ library
, containers >= 0.6 && < 0.8
, deepseq >= 1.4 && < 1.6
, megaparsec >= 9.0 && < 10
+ , primitive >= 0.9 && < 0.10
, text >= 1.2 && < 2.2
, aeson >= 2.0 && < 2.3