commit 889aed3a98b5d862ccc80e8b4d66ccf1c3bba710
parent e08c2305e1ff2a5eb0ee18fd17dd9983efdc800d
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 11 Feb 2026 19:43:47 +0400
Revert "refactor: use SmallArray for register taint/provenance (IMPL12)"
This reverts commit e08c230.
Reason: Unclear performance win. While intra-procedural analysis improved
(2-5x faster due to O(1) register lookups), inter-procedural analysis
regressed significantly (2-3x slower).
The regression stems from array-based joins iterating over all 161
register slots, whereas Map-based joins only processed populated entries.
For inter-procedural analysis with many join operations across function
summaries, this overhead dominates.
The tradeoff is not clearly favorable without further optimization
(e.g., pointer-equality checks, lazy updates, or hybrid representations).
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
3 files changed, 77 insertions(+), 213 deletions(-)
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -27,7 +27,6 @@ module Audit.AArch64.Check (
import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock,
functionLabels, functionBlocks)
import Audit.AArch64.Taint
-import Audit.AArch64.Types (Taint(..))
import Audit.AArch64.Types
import Control.DeepSeq (NFData)
import qualified Data.IntMap.Strict as IM
@@ -138,32 +137,26 @@ checkAddrMode sym ln instr addr st = case addr of
[]
-- | Check that base register is public.
--- If taint is Unknown/Bottom, check provenance to see if we can upgrade.
+-- If taint is Unknown, check provenance to see if we can upgrade to Public.
checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkBase sym ln instr base st =
case getTaint base st of
Public -> []
Secret -> [Violation sym ln instr (SecretBase base)]
- Bottom -> checkProv -- No information yet, check provenance
- Unknown -> checkProv
- where
- checkProv = case getProvenance base st of
+ Unknown -> case getProvenance base st of
ProvPublic -> [] -- Provenance proves public derivation
- _ -> [Violation sym ln instr (UnknownBase base)]
+ ProvUnknown -> [Violation sym ln instr (UnknownBase base)]
-- | Check that index register is public.
--- If taint is Unknown/Bottom, check provenance to see if we can upgrade.
+-- If taint is Unknown, check provenance to see if we can upgrade to Public.
checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkIndex sym ln instr idx st =
case getTaint idx st of
Public -> []
Secret -> [Violation sym ln instr (SecretIndex idx)]
- Bottom -> checkProv -- No information yet, check provenance
- Unknown -> checkProv
- where
- checkProv = case getProvenance idx st of
+ Unknown -> case getProvenance idx st of
ProvPublic -> [] -- Provenance proves public derivation
- _ -> [Violation sym ln instr (UnknownIndex idx)]
+ ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)]
-- | Check entire CFG with inter-procedural analysis.
-- Computes function summaries via fixpoint, then checks each function.
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -13,8 +13,6 @@
-- Public registers are those derived from known-safe sources (stack
-- pointers, heap pointers, constants). Registers with unknown or
-- secret-derived values are flagged when used in memory addressing.
---
--- Uses SmallArray-backed register storage for efficient O(1) lookups.
module Audit.AArch64.Taint (
TaintState
@@ -52,12 +50,6 @@ import Audit.AArch64.Types
( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..)
, Taint(..), joinTaint, Provenance(..), joinProvenance
, TaintConfig(..), ArgPolicy(..)
- , regIndex, regCount
- )
-import Control.Monad.ST (runST)
-import Data.Primitive.SmallArray
- ( SmallArray, indexSmallArray, newSmallArray, writeSmallArray
- , freezeSmallArray, thawSmallArray, sizeofSmallArray
)
import Data.IntMap.Strict (IntMap)
import qualified Data.IntMap.Strict as IM
@@ -70,14 +62,13 @@ import Data.Text (Text)
-- | Taint state: maps registers to their publicness, plus stack slots.
-- Also tracks provenance for upgrading Unknown to Public when provable.
--- Uses SmallArray for O(1) register lookups (indexed by regIndex).
data TaintState = TaintState
- { tsRegs :: !(SmallArray Taint)
- -- ^ Register taints (indexed by regIndex)
+ { tsRegs :: !(Map Reg Taint)
+ -- ^ Register taints
, tsStack :: !(IntMap Taint)
-- ^ Stack slot taints (keyed by SP offset)
- , tsProv :: !(SmallArray Provenance)
- -- ^ Register provenance (indexed by regIndex)
+ , tsProv :: !(Map Reg Provenance)
+ -- ^ Register provenance
, tsStackProv :: !(IntMap Provenance)
-- ^ Stack slot provenance (keyed by SP offset)
} deriving (Eq, Show)
@@ -107,45 +98,23 @@ publicRoots =
, XZR, WZR -- Zero registers
]
--- | Empty taint state (all Bottom/ProvUnknown).
--- Bottom is the identity for join, representing "no information yet".
+-- | Empty taint state (no known taints).
emptyTaintState :: TaintState
emptyTaintState = TaintState
- { tsRegs = defaultTaintArray
+ { tsRegs = Map.empty
, tsStack = IM.empty
- , tsProv = defaultProvArray
+ , tsProv = Map.empty
, tsStackProv = IM.empty
}
--- | Default array with all Bottom taints (identity for join).
-defaultTaintArray :: SmallArray Taint
-defaultTaintArray = runST $ do
- arr <- newSmallArray regCount Bottom
- freezeSmallArray arr 0 regCount
-
--- | Default array with all ProvBottom provenances (identity for join).
-defaultProvArray :: SmallArray Provenance
-defaultProvArray = runST $ do
- arr <- newSmallArray regCount ProvBottom
- freezeSmallArray arr 0 regCount
-
-- | Initial taint state with public roots marked.
--- Non-public-root registers start at Unknown (truly unknown at function entry).
initTaintState :: TaintState
-initTaintState = runST $ do
- regs <- newSmallArray regCount Unknown
- prov <- newSmallArray regCount ProvUnknown
- -- Mark public roots
- mapM_ (\r -> writeSmallArray regs (regIndex r) Public) publicRoots
- mapM_ (\r -> writeSmallArray prov (regIndex r) ProvPublic) publicRoots
- regsArr <- freezeSmallArray regs 0 regCount
- provArr <- freezeSmallArray prov 0 regCount
- pure $ TaintState
- { tsRegs = regsArr
- , tsStack = IM.empty
- , tsProv = provArr
- , tsStackProv = IM.empty
- }
+initTaintState = TaintState
+ { tsRegs = Map.fromList [(r, Public) | r <- publicRoots]
+ , tsStack = IM.empty
+ , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots]
+ , tsStackProv = IM.empty
+ }
-- | Seed argument registers according to policy.
-- Secret registers are marked Secret with ProvUnknown.
@@ -158,16 +127,22 @@ seedArgs policy st =
st2 = Set.foldr markSecret st1 (apSecret policy)
in st2
where
- markPublic r s = setTaintProv r Public ProvPublic s
- markSecret r s = setTaintProv r Secret ProvUnknown s
+ markPublic r s = s
+ { tsRegs = Map.insert r Public (tsRegs s)
+ , tsProv = Map.insert r ProvPublic (tsProv s)
+ }
+ markSecret r s = s
+ { tsRegs = Map.insert r Secret (tsRegs s)
+ , tsProv = Map.insert r ProvUnknown (tsProv s)
+ }
-- | Get the taint of a register.
getTaint :: Reg -> TaintState -> Taint
-getTaint r st = indexSmallArray (tsRegs st) (regIndex r)
+getTaint r st = Map.findWithDefault Unknown r (tsRegs st)
-- | Get the provenance of a register.
getProvenance :: Reg -> TaintState -> Provenance
-getProvenance r st = indexSmallArray (tsProv st) (regIndex r)
+getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st)
-- | Analyze a single line, updating taint state.
analyzeLine :: Line -> TaintState -> TaintState
@@ -177,7 +152,7 @@ analyzeLine l st = case lineInstr l of
-- | Analyze a basic block, threading taint state through.
analyzeBlock :: [Line] -> TaintState -> TaintState
-analyzeBlock lns st = foldl' (flip analyzeLine) st lns
+analyzeBlock lns st = foldl (flip analyzeLine) st lns
-- | Transfer function for taint analysis.
--
@@ -389,23 +364,14 @@ transfer instr st = case instr of
-- | Set taint for a register.
setTaint :: Reg -> Taint -> TaintState -> TaintState
-setTaint r t st = st { tsRegs = updateSmallArray (regIndex r) t (tsRegs st) }
+setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) }
-- | Set both taint and provenance for a register.
setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState
setTaintProv r t p st = st
- { tsRegs = updateSmallArray idx t (tsRegs st)
- , tsProv = updateSmallArray idx p (tsProv st)
+ { tsRegs = Map.insert r t (tsRegs st)
+ , tsProv = Map.insert r p (tsProv st)
}
- where
- idx = regIndex r
-
--- | Update a single element in a SmallArray (copy-on-write).
-updateSmallArray :: Int -> a -> SmallArray a -> SmallArray a
-updateSmallArray idx val arr = runST $ do
- marr <- thawSmallArray arr 0 (sizeofSmallArray arr)
- writeSmallArray marr idx val
- freezeSmallArray marr 0 (sizeofSmallArray arr)
-- | Set taint for a loaded value, preserving public roots.
-- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded
@@ -584,14 +550,10 @@ provJoin3 a b c = joinProvenance a (joinProvenance b c)
-- Per AArch64 ABI, x0-x17 are caller-saved.
-- Clears both taint and provenance.
invalidateCallerSaved :: TaintState -> TaintState
-invalidateCallerSaved st = runST $ do
- regs <- thawSmallArray (tsRegs st) 0 regCount
- prov <- thawSmallArray (tsProv st) 0 regCount
- mapM_ (\r -> writeSmallArray regs (regIndex r) Unknown) callerSaved
- mapM_ (\r -> writeSmallArray prov (regIndex r) ProvUnknown) callerSaved
- regsArr <- freezeSmallArray regs 0 regCount
- provArr <- freezeSmallArray prov 0 regCount
- pure $ st { tsRegs = regsArr, tsProv = provArr }
+invalidateCallerSaved st = st
+ { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved
+ , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved
+ }
where
callerSaved =
[ X0, X1, X2, X3, X4, X5, X6, X7
@@ -600,32 +562,16 @@ invalidateCallerSaved st = runST $ do
]
-- | Join two taint states (element-wise join).
--- Register arrays are joined element-wise with joinTaint/joinProvenance.
+-- For registers in both, take the join. For registers in only one, keep.
-- Stack slots and provenance are also joined element-wise.
joinTaintState :: TaintState -> TaintState -> TaintState
joinTaintState a b = TaintState
- { tsRegs = joinSmallArrayWith joinTaint (tsRegs a) (tsRegs b)
+ { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b)
, tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b)
- , tsProv = joinSmallArrayWith joinProvenance (tsProv a) (tsProv b)
+ , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b)
, tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b)
}
--- | Join two SmallArrays element-wise with a combining function.
-joinSmallArrayWith :: (a -> a -> a) -> SmallArray a -> SmallArray a
- -> SmallArray a
-joinSmallArrayWith f arr1 arr2 = runST $ do
- let n = sizeofSmallArray arr1
- result <- newSmallArray n (indexSmallArray arr1 0)
- let go i
- | i >= n = pure ()
- | otherwise = do
- let v1 = indexSmallArray arr1 i
- v2 = indexSmallArray arr2 i
- writeSmallArray result i (f v1 v2)
- go (i + 1)
- go 0
- freezeSmallArray result 0 n
-
-- | Run forward dataflow analysis over a CFG.
-- Returns the IN taint state for each block (indexed by block number).
runDataflow :: CFG -> IntMap TaintState
@@ -635,11 +581,10 @@ runDataflow cfg
where
nBlocks = cfgBlockCount cfg
- -- Only entry block starts with initTaintState; others get emptyTaintState
- -- (which has Bottom for all registers, acting as identity for join).
- initIn = IM.singleton 0 initTaintState
+ -- Initial states: all blocks start with public roots (GHC invariant)
+ initIn = IM.fromList [(i, initTaintState) | i <- [0..nBlocks-1]]
initOut = IM.empty
- initWorklist = IS.singleton 0
+ initWorklist = IS.fromList [0..nBlocks-1]
go :: IntSet -> IntMap TaintState -> IntMap TaintState -> IntMap TaintState
go worklist inStates outStates
@@ -678,9 +623,13 @@ newtype FuncSummary = FuncSummary { summaryState :: TaintState }
deriving (Eq, Show)
-- | Initial conservative summary: all caller-saved are Unknown.
--- Uses the default empty taint state (all Unknown/ProvUnknown).
initSummary :: FuncSummary
-initSummary = FuncSummary emptyTaintState
+initSummary = FuncSummary $ TaintState
+ { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ]
+ , tsStack = IM.empty
+ , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ]
+ , tsStackProv = IM.empty
+ }
-- | Caller-saved registers per AArch64 ABI.
callerSavedRegs :: [Reg]
@@ -698,18 +647,15 @@ joinSummary (FuncSummary a) (FuncSummary b) =
-- | Apply a function summary at a call site.
-- Replaces caller-saved register taints and provenance with summary values.
applySummary :: FuncSummary -> TaintState -> TaintState
-applySummary (FuncSummary summ) st = runST $ do
- regs <- thawSmallArray (tsRegs st) 0 regCount
- prov <- thawSmallArray (tsProv st) 0 regCount
- let summRegs = tsRegs summ
- summProv = tsProv summ
- mapM_ (\r -> do
- let idx = regIndex r
- writeSmallArray regs idx (indexSmallArray summRegs idx)
- writeSmallArray prov idx (indexSmallArray summProv idx)) callerSavedRegs
- regsArr <- freezeSmallArray regs 0 regCount
- provArr <- freezeSmallArray prov 0 regCount
- pure $ st { tsRegs = regsArr, tsProv = provArr }
+applySummary (FuncSummary summ) st = st
+ { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs
+ , tsProv = foldr applyProv (tsProv st) callerSavedRegs
+ }
+ where
+ summRegs = tsRegs summ
+ summProv = tsProv summ
+ applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s
+ applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s
-- | Run dataflow analysis for a single function (subset of blocks).
-- Returns the OUT state at return points (joined).
@@ -775,7 +721,7 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt
-- | Analyze a block, applying call summaries at bl instructions.
analyzeBlockWithSummaries :: BasicBlock -> TaintState -> Map Text FuncSummary
-> TaintState
-analyzeBlockWithSummaries bb st0 summaries = foldl' go st0 (bbLines bb)
+analyzeBlockWithSummaries bb st0 summaries = foldl go st0 (bbLines bb)
where
go st l = case lineInstr l of
Nothing -> st
@@ -836,32 +782,27 @@ runDataflowWithConfig tcfg cfg
where
nBlocks = cfgBlockCount cfg
- -- Only initialize function entry blocks with their entry states.
- -- Other blocks get emptyTaintState via findWithDefault during propagation.
+ -- Build a map from block index to entry taint state
+ -- Entry blocks of functions get their policy applied
initIn = IM.fromList
- [ (idx, entryState idx bb)
- | idx <- [0..nBlocks-1]
- , let bb = indexBlock cfg idx
- , isFuncEntry idx bb
+ [ (i, entryState i (indexBlock cfg i))
+ | i <- [0..nBlocks-1]
]
- isFuncEntry idx bb = case bbLabel bb of
- Nothing -> idx == 0 -- Block 0 is always an entry if no label
- Just lbl -> case functionBlocks cfg lbl of
- (entry:_) -> entry == idx
- [] -> False
-
entryState idx bb =
let base = initTaintState
in case bbLabel bb of
Nothing -> base
Just lbl ->
- case Map.lookup lbl (tcPolicies tcfg) of
- Nothing -> base
- Just policy -> seedArgs policy base
-
- -- Start worklist with all function entry blocks
- initWorklist = IS.fromList (IM.keys initIn)
+ -- Check if this block is a function entry
+ case functionBlocks cfg lbl of
+ (entry:_) | entry == idx ->
+ case Map.lookup lbl (tcPolicies tcfg) of
+ Nothing -> base
+ Just policy -> seedArgs policy base
+ _ -> base
+
+ initWorklist = IS.fromList [0..nBlocks-1]
go worklist inStates outStates
| IS.null worklist = inStates
diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs
@@ -16,8 +16,6 @@ module Audit.AArch64.Types (
Reg(..)
, regName
, regFromText
- , regIndex
- , regCount
-- * Operands and addressing
, Shift(..)
@@ -146,64 +144,6 @@ regFromText t = Map.lookup (T.toUpper t) regMap
, ("Q28", Q28), ("Q29", Q29), ("Q30", Q30), ("Q31", Q31)
]
--- | Total number of registers (for array sizing).
-regCount :: Int
-regCount = 161
-
--- | Map a register to its array index.
--- The mapping is:
--- X0-X30: 0-30
--- W0-W30: 31-61
--- SP: 62
--- XZR: 63
--- WZR: 64
--- D0-D31: 65-96
--- S0-S31: 97-128
--- Q0-Q31: 129-160
-regIndex :: Reg -> Int
-regIndex r = case r of
- X0 -> 0; X1 -> 1; X2 -> 2; X3 -> 3
- X4 -> 4; X5 -> 5; X6 -> 6; X7 -> 7
- X8 -> 8; X9 -> 9; X10 -> 10; X11 -> 11
- X12 -> 12; X13 -> 13; X14 -> 14; X15 -> 15
- X16 -> 16; X17 -> 17; X18 -> 18; X19 -> 19
- X20 -> 20; X21 -> 21; X22 -> 22; X23 -> 23
- X24 -> 24; X25 -> 25; X26 -> 26; X27 -> 27
- X28 -> 28; X29 -> 29; X30 -> 30
- W0 -> 31; W1 -> 32; W2 -> 33; W3 -> 34
- W4 -> 35; W5 -> 36; W6 -> 37; W7 -> 38
- W8 -> 39; W9 -> 40; W10 -> 41; W11 -> 42
- W12 -> 43; W13 -> 44; W14 -> 45; W15 -> 46
- W16 -> 47; W17 -> 48; W18 -> 49; W19 -> 50
- W20 -> 51; W21 -> 52; W22 -> 53; W23 -> 54
- W24 -> 55; W25 -> 56; W26 -> 57; W27 -> 58
- W28 -> 59; W29 -> 60; W30 -> 61
- SP -> 62; XZR -> 63; WZR -> 64
- D0 -> 65; D1 -> 66; D2 -> 67; D3 -> 68
- D4 -> 69; D5 -> 70; D6 -> 71; D7 -> 72
- D8 -> 73; D9 -> 74; D10 -> 75; D11 -> 76
- D12 -> 77; D13 -> 78; D14 -> 79; D15 -> 80
- D16 -> 81; D17 -> 82; D18 -> 83; D19 -> 84
- D20 -> 85; D21 -> 86; D22 -> 87; D23 -> 88
- D24 -> 89; D25 -> 90; D26 -> 91; D27 -> 92
- D28 -> 93; D29 -> 94; D30 -> 95; D31 -> 96
- S0 -> 97; S1 -> 98; S2 -> 99; S3 -> 100
- S4 -> 101; S5 -> 102; S6 -> 103; S7 -> 104
- S8 -> 105; S9 -> 106; S10 -> 107; S11 -> 108
- S12 -> 109; S13 -> 110; S14 -> 111; S15 -> 112
- S16 -> 113; S17 -> 114; S18 -> 115; S19 -> 116
- S20 -> 117; S21 -> 118; S22 -> 119; S23 -> 120
- S24 -> 121; S25 -> 122; S26 -> 123; S27 -> 124
- S28 -> 125; S29 -> 126; S30 -> 127; S31 -> 128
- Q0 -> 129; Q1 -> 130; Q2 -> 131; Q3 -> 132
- Q4 -> 133; Q5 -> 134; Q6 -> 135; Q7 -> 136
- Q8 -> 137; Q9 -> 138; Q10 -> 139; Q11 -> 140
- Q12 -> 141; Q13 -> 142; Q14 -> 143; Q15 -> 144
- Q16 -> 145; Q17 -> 146; Q18 -> 147; Q19 -> 148
- Q20 -> 149; Q21 -> 150; Q22 -> 151; Q23 -> 152
- Q24 -> 153; Q25 -> 154; Q26 -> 155; Q27 -> 156
- Q28 -> 157; Q29 -> 158; Q30 -> 159; Q31 -> 160
-
-- | Shift operations for indexed addressing.
data Shift
= LSL !Int -- ^ Logical shift left
@@ -347,23 +287,18 @@ data Line = Line
instance ToJSON Line
-- | Taint lattice for register publicness.
--- Bottom is the identity element for joins (no information yet).
data Taint
- = Bottom -- ^ No information yet (identity for join)
- | Public -- ^ Known to be public (derived from stack/heap pointers)
+ = Public -- ^ Known to be public (derived from stack/heap pointers)
| Secret -- ^ Known or assumed to be secret
- | Unknown -- ^ Determined but unknown origin
+ | Unknown -- ^ Not yet determined
deriving (Eq, Ord, Show, Generic, NFData)
instance ToJSON Taint
-- | Join operation for taint lattice (least upper bound).
--- Bottom is identity: join Bottom x = x.
-- For CT safety: Public only if both are Public.
--- Order: Bottom < Public < Unknown < Secret
+-- Order: Public < Unknown < Secret
joinTaint :: Taint -> Taint -> Taint
-joinTaint Bottom x = x
-joinTaint x Bottom = x
joinTaint Public Public = Public
joinTaint Secret _ = Secret
joinTaint _ Secret = Secret
@@ -371,20 +306,15 @@ joinTaint _ _ = Unknown -- Public+Unknown or Unknown+Unknown
-- | Provenance: tracks whether a value derives from known-public sources.
-- Used to upgrade Unknown taint to Public when provenance can prove safety.
--- ProvBottom is the identity element for joins (no information yet).
data Provenance
- = ProvBottom -- ^ No information yet (identity for join)
- | ProvPublic -- ^ Derived from public root or constant
+ = ProvPublic -- ^ Derived from public root or constant
| ProvUnknown -- ^ Unknown origin (e.g., loaded from memory)
deriving (Eq, Ord, Show, Generic, NFData)
instance ToJSON Provenance
--- | Join provenance: ProvBottom is identity.
--- Only ProvPublic if both are ProvPublic.
+-- | Join provenance: only ProvPublic if both are ProvPublic.
joinProvenance :: Provenance -> Provenance -> Provenance
-joinProvenance ProvBottom x = x
-joinProvenance x ProvBottom = x
joinProvenance ProvPublic ProvPublic = ProvPublic
joinProvenance _ _ = ProvUnknown