auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit 889aed3a98b5d862ccc80e8b4d66ccf1c3bba710
parent e08c2305e1ff2a5eb0ee18fd17dd9983efdc800d
Author: Jared Tobin <jared@jtobin.io>
Date:   Wed, 11 Feb 2026 19:43:47 +0400

Revert "refactor: use SmallArray for register taint/provenance (IMPL12)"

This reverts commit e08c230.

Reason: Unclear performance win. While intra-procedural analysis improved
(2-5x faster due to O(1) register lookups), inter-procedural analysis
regressed significantly (2-3x slower).

The regression stems from array-based joins iterating over all 161
register slots, whereas Map-based joins only processed populated entries.
For inter-procedural analysis with many join operations across function
summaries, this overhead dominates.

The tradeoff is not clearly favorable without further optimization
(e.g., pointer-equality checks, lazy updates, or hybrid representations).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Audit/AArch64/Check.hs | 19++++++-------------
Mlib/Audit/AArch64/Taint.hs | 191+++++++++++++++++++++++++++----------------------------------------------------
Mlib/Audit/AArch64/Types.hs | 80+++++--------------------------------------------------------------------------
3 files changed, 77 insertions(+), 213 deletions(-)

diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs @@ -27,7 +27,6 @@ module Audit.AArch64.Check ( import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock, functionLabels, functionBlocks) import Audit.AArch64.Taint -import Audit.AArch64.Types (Taint(..)) import Audit.AArch64.Types import Control.DeepSeq (NFData) import qualified Data.IntMap.Strict as IM @@ -138,32 +137,26 @@ checkAddrMode sym ln instr addr st = case addr of [] -- | Check that base register is public. --- If taint is Unknown/Bottom, check provenance to see if we can upgrade. +-- If taint is Unknown, check provenance to see if we can upgrade to Public. checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkBase sym ln instr base st = case getTaint base st of Public -> [] Secret -> [Violation sym ln instr (SecretBase base)] - Bottom -> checkProv -- No information yet, check provenance - Unknown -> checkProv - where - checkProv = case getProvenance base st of + Unknown -> case getProvenance base st of ProvPublic -> [] -- Provenance proves public derivation - _ -> [Violation sym ln instr (UnknownBase base)] + ProvUnknown -> [Violation sym ln instr (UnknownBase base)] -- | Check that index register is public. --- If taint is Unknown/Bottom, check provenance to see if we can upgrade. +-- If taint is Unknown, check provenance to see if we can upgrade to Public. checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkIndex sym ln instr idx st = case getTaint idx st of Public -> [] Secret -> [Violation sym ln instr (SecretIndex idx)] - Bottom -> checkProv -- No information yet, check provenance - Unknown -> checkProv - where - checkProv = case getProvenance idx st of + Unknown -> case getProvenance idx st of ProvPublic -> [] -- Provenance proves public derivation - _ -> [Violation sym ln instr (UnknownIndex idx)] + ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)] -- | Check entire CFG with inter-procedural analysis. -- Computes function summaries via fixpoint, then checks each function. diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs @@ -13,8 +13,6 @@ -- Public registers are those derived from known-safe sources (stack -- pointers, heap pointers, constants). Registers with unknown or -- secret-derived values are flagged when used in memory addressing. --- --- Uses SmallArray-backed register storage for efficient O(1) lookups. module Audit.AArch64.Taint ( TaintState @@ -52,12 +50,6 @@ import Audit.AArch64.Types ( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..) , Taint(..), joinTaint, Provenance(..), joinProvenance , TaintConfig(..), ArgPolicy(..) - , regIndex, regCount - ) -import Control.Monad.ST (runST) -import Data.Primitive.SmallArray - ( SmallArray, indexSmallArray, newSmallArray, writeSmallArray - , freezeSmallArray, thawSmallArray, sizeofSmallArray ) import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IM @@ -70,14 +62,13 @@ import Data.Text (Text) -- | Taint state: maps registers to their publicness, plus stack slots. -- Also tracks provenance for upgrading Unknown to Public when provable. --- Uses SmallArray for O(1) register lookups (indexed by regIndex). data TaintState = TaintState - { tsRegs :: !(SmallArray Taint) - -- ^ Register taints (indexed by regIndex) + { tsRegs :: !(Map Reg Taint) + -- ^ Register taints , tsStack :: !(IntMap Taint) -- ^ Stack slot taints (keyed by SP offset) - , tsProv :: !(SmallArray Provenance) - -- ^ Register provenance (indexed by regIndex) + , tsProv :: !(Map Reg Provenance) + -- ^ Register provenance , tsStackProv :: !(IntMap Provenance) -- ^ Stack slot provenance (keyed by SP offset) } deriving (Eq, Show) @@ -107,45 +98,23 @@ publicRoots = , XZR, WZR -- Zero registers ] --- | Empty taint state (all Bottom/ProvUnknown). --- Bottom is the identity for join, representing "no information yet". +-- | Empty taint state (no known taints). emptyTaintState :: TaintState emptyTaintState = TaintState - { tsRegs = defaultTaintArray + { tsRegs = Map.empty , tsStack = IM.empty - , tsProv = defaultProvArray + , tsProv = Map.empty , tsStackProv = IM.empty } --- | Default array with all Bottom taints (identity for join). -defaultTaintArray :: SmallArray Taint -defaultTaintArray = runST $ do - arr <- newSmallArray regCount Bottom - freezeSmallArray arr 0 regCount - --- | Default array with all ProvBottom provenances (identity for join). -defaultProvArray :: SmallArray Provenance -defaultProvArray = runST $ do - arr <- newSmallArray regCount ProvBottom - freezeSmallArray arr 0 regCount - -- | Initial taint state with public roots marked. --- Non-public-root registers start at Unknown (truly unknown at function entry). initTaintState :: TaintState -initTaintState = runST $ do - regs <- newSmallArray regCount Unknown - prov <- newSmallArray regCount ProvUnknown - -- Mark public roots - mapM_ (\r -> writeSmallArray regs (regIndex r) Public) publicRoots - mapM_ (\r -> writeSmallArray prov (regIndex r) ProvPublic) publicRoots - regsArr <- freezeSmallArray regs 0 regCount - provArr <- freezeSmallArray prov 0 regCount - pure $ TaintState - { tsRegs = regsArr - , tsStack = IM.empty - , tsProv = provArr - , tsStackProv = IM.empty - } +initTaintState = TaintState + { tsRegs = Map.fromList [(r, Public) | r <- publicRoots] + , tsStack = IM.empty + , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots] + , tsStackProv = IM.empty + } -- | Seed argument registers according to policy. -- Secret registers are marked Secret with ProvUnknown. @@ -158,16 +127,22 @@ seedArgs policy st = st2 = Set.foldr markSecret st1 (apSecret policy) in st2 where - markPublic r s = setTaintProv r Public ProvPublic s - markSecret r s = setTaintProv r Secret ProvUnknown s + markPublic r s = s + { tsRegs = Map.insert r Public (tsRegs s) + , tsProv = Map.insert r ProvPublic (tsProv s) + } + markSecret r s = s + { tsRegs = Map.insert r Secret (tsRegs s) + , tsProv = Map.insert r ProvUnknown (tsProv s) + } -- | Get the taint of a register. getTaint :: Reg -> TaintState -> Taint -getTaint r st = indexSmallArray (tsRegs st) (regIndex r) +getTaint r st = Map.findWithDefault Unknown r (tsRegs st) -- | Get the provenance of a register. getProvenance :: Reg -> TaintState -> Provenance -getProvenance r st = indexSmallArray (tsProv st) (regIndex r) +getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st) -- | Analyze a single line, updating taint state. analyzeLine :: Line -> TaintState -> TaintState @@ -177,7 +152,7 @@ analyzeLine l st = case lineInstr l of -- | Analyze a basic block, threading taint state through. analyzeBlock :: [Line] -> TaintState -> TaintState -analyzeBlock lns st = foldl' (flip analyzeLine) st lns +analyzeBlock lns st = foldl (flip analyzeLine) st lns -- | Transfer function for taint analysis. -- @@ -389,23 +364,14 @@ transfer instr st = case instr of -- | Set taint for a register. setTaint :: Reg -> Taint -> TaintState -> TaintState -setTaint r t st = st { tsRegs = updateSmallArray (regIndex r) t (tsRegs st) } +setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) } -- | Set both taint and provenance for a register. setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState setTaintProv r t p st = st - { tsRegs = updateSmallArray idx t (tsRegs st) - , tsProv = updateSmallArray idx p (tsProv st) + { tsRegs = Map.insert r t (tsRegs st) + , tsProv = Map.insert r p (tsProv st) } - where - idx = regIndex r - --- | Update a single element in a SmallArray (copy-on-write). -updateSmallArray :: Int -> a -> SmallArray a -> SmallArray a -updateSmallArray idx val arr = runST $ do - marr <- thawSmallArray arr 0 (sizeofSmallArray arr) - writeSmallArray marr idx val - freezeSmallArray marr 0 (sizeofSmallArray arr) -- | Set taint for a loaded value, preserving public roots. -- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded @@ -584,14 +550,10 @@ provJoin3 a b c = joinProvenance a (joinProvenance b c) -- Per AArch64 ABI, x0-x17 are caller-saved. -- Clears both taint and provenance. invalidateCallerSaved :: TaintState -> TaintState -invalidateCallerSaved st = runST $ do - regs <- thawSmallArray (tsRegs st) 0 regCount - prov <- thawSmallArray (tsProv st) 0 regCount - mapM_ (\r -> writeSmallArray regs (regIndex r) Unknown) callerSaved - mapM_ (\r -> writeSmallArray prov (regIndex r) ProvUnknown) callerSaved - regsArr <- freezeSmallArray regs 0 regCount - provArr <- freezeSmallArray prov 0 regCount - pure $ st { tsRegs = regsArr, tsProv = provArr } +invalidateCallerSaved st = st + { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved + , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved + } where callerSaved = [ X0, X1, X2, X3, X4, X5, X6, X7 @@ -600,32 +562,16 @@ invalidateCallerSaved st = runST $ do ] -- | Join two taint states (element-wise join). --- Register arrays are joined element-wise with joinTaint/joinProvenance. +-- For registers in both, take the join. For registers in only one, keep. -- Stack slots and provenance are also joined element-wise. joinTaintState :: TaintState -> TaintState -> TaintState joinTaintState a b = TaintState - { tsRegs = joinSmallArrayWith joinTaint (tsRegs a) (tsRegs b) + { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b) , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b) - , tsProv = joinSmallArrayWith joinProvenance (tsProv a) (tsProv b) + , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b) , tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b) } --- | Join two SmallArrays element-wise with a combining function. -joinSmallArrayWith :: (a -> a -> a) -> SmallArray a -> SmallArray a - -> SmallArray a -joinSmallArrayWith f arr1 arr2 = runST $ do - let n = sizeofSmallArray arr1 - result <- newSmallArray n (indexSmallArray arr1 0) - let go i - | i >= n = pure () - | otherwise = do - let v1 = indexSmallArray arr1 i - v2 = indexSmallArray arr2 i - writeSmallArray result i (f v1 v2) - go (i + 1) - go 0 - freezeSmallArray result 0 n - -- | Run forward dataflow analysis over a CFG. -- Returns the IN taint state for each block (indexed by block number). runDataflow :: CFG -> IntMap TaintState @@ -635,11 +581,10 @@ runDataflow cfg where nBlocks = cfgBlockCount cfg - -- Only entry block starts with initTaintState; others get emptyTaintState - -- (which has Bottom for all registers, acting as identity for join). - initIn = IM.singleton 0 initTaintState + -- Initial states: all blocks start with public roots (GHC invariant) + initIn = IM.fromList [(i, initTaintState) | i <- [0..nBlocks-1]] initOut = IM.empty - initWorklist = IS.singleton 0 + initWorklist = IS.fromList [0..nBlocks-1] go :: IntSet -> IntMap TaintState -> IntMap TaintState -> IntMap TaintState go worklist inStates outStates @@ -678,9 +623,13 @@ newtype FuncSummary = FuncSummary { summaryState :: TaintState } deriving (Eq, Show) -- | Initial conservative summary: all caller-saved are Unknown. --- Uses the default empty taint state (all Unknown/ProvUnknown). initSummary :: FuncSummary -initSummary = FuncSummary emptyTaintState +initSummary = FuncSummary $ TaintState + { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ] + , tsStack = IM.empty + , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ] + , tsStackProv = IM.empty + } -- | Caller-saved registers per AArch64 ABI. callerSavedRegs :: [Reg] @@ -698,18 +647,15 @@ joinSummary (FuncSummary a) (FuncSummary b) = -- | Apply a function summary at a call site. -- Replaces caller-saved register taints and provenance with summary values. applySummary :: FuncSummary -> TaintState -> TaintState -applySummary (FuncSummary summ) st = runST $ do - regs <- thawSmallArray (tsRegs st) 0 regCount - prov <- thawSmallArray (tsProv st) 0 regCount - let summRegs = tsRegs summ - summProv = tsProv summ - mapM_ (\r -> do - let idx = regIndex r - writeSmallArray regs idx (indexSmallArray summRegs idx) - writeSmallArray prov idx (indexSmallArray summProv idx)) callerSavedRegs - regsArr <- freezeSmallArray regs 0 regCount - provArr <- freezeSmallArray prov 0 regCount - pure $ st { tsRegs = regsArr, tsProv = provArr } +applySummary (FuncSummary summ) st = st + { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs + , tsProv = foldr applyProv (tsProv st) callerSavedRegs + } + where + summRegs = tsRegs summ + summProv = tsProv summ + applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s + applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s -- | Run dataflow analysis for a single function (subset of blocks). -- Returns the OUT state at return points (joined). @@ -775,7 +721,7 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt -- | Analyze a block, applying call summaries at bl instructions. analyzeBlockWithSummaries :: BasicBlock -> TaintState -> Map Text FuncSummary -> TaintState -analyzeBlockWithSummaries bb st0 summaries = foldl' go st0 (bbLines bb) +analyzeBlockWithSummaries bb st0 summaries = foldl go st0 (bbLines bb) where go st l = case lineInstr l of Nothing -> st @@ -836,32 +782,27 @@ runDataflowWithConfig tcfg cfg where nBlocks = cfgBlockCount cfg - -- Only initialize function entry blocks with their entry states. - -- Other blocks get emptyTaintState via findWithDefault during propagation. + -- Build a map from block index to entry taint state + -- Entry blocks of functions get their policy applied initIn = IM.fromList - [ (idx, entryState idx bb) - | idx <- [0..nBlocks-1] - , let bb = indexBlock cfg idx - , isFuncEntry idx bb + [ (i, entryState i (indexBlock cfg i)) + | i <- [0..nBlocks-1] ] - isFuncEntry idx bb = case bbLabel bb of - Nothing -> idx == 0 -- Block 0 is always an entry if no label - Just lbl -> case functionBlocks cfg lbl of - (entry:_) -> entry == idx - [] -> False - entryState idx bb = let base = initTaintState in case bbLabel bb of Nothing -> base Just lbl -> - case Map.lookup lbl (tcPolicies tcfg) of - Nothing -> base - Just policy -> seedArgs policy base - - -- Start worklist with all function entry blocks - initWorklist = IS.fromList (IM.keys initIn) + -- Check if this block is a function entry + case functionBlocks cfg lbl of + (entry:_) | entry == idx -> + case Map.lookup lbl (tcPolicies tcfg) of + Nothing -> base + Just policy -> seedArgs policy base + _ -> base + + initWorklist = IS.fromList [0..nBlocks-1] go worklist inStates outStates | IS.null worklist = inStates diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs @@ -16,8 +16,6 @@ module Audit.AArch64.Types ( Reg(..) , regName , regFromText - , regIndex - , regCount -- * Operands and addressing , Shift(..) @@ -146,64 +144,6 @@ regFromText t = Map.lookup (T.toUpper t) regMap , ("Q28", Q28), ("Q29", Q29), ("Q30", Q30), ("Q31", Q31) ] --- | Total number of registers (for array sizing). -regCount :: Int -regCount = 161 - --- | Map a register to its array index. --- The mapping is: --- X0-X30: 0-30 --- W0-W30: 31-61 --- SP: 62 --- XZR: 63 --- WZR: 64 --- D0-D31: 65-96 --- S0-S31: 97-128 --- Q0-Q31: 129-160 -regIndex :: Reg -> Int -regIndex r = case r of - X0 -> 0; X1 -> 1; X2 -> 2; X3 -> 3 - X4 -> 4; X5 -> 5; X6 -> 6; X7 -> 7 - X8 -> 8; X9 -> 9; X10 -> 10; X11 -> 11 - X12 -> 12; X13 -> 13; X14 -> 14; X15 -> 15 - X16 -> 16; X17 -> 17; X18 -> 18; X19 -> 19 - X20 -> 20; X21 -> 21; X22 -> 22; X23 -> 23 - X24 -> 24; X25 -> 25; X26 -> 26; X27 -> 27 - X28 -> 28; X29 -> 29; X30 -> 30 - W0 -> 31; W1 -> 32; W2 -> 33; W3 -> 34 - W4 -> 35; W5 -> 36; W6 -> 37; W7 -> 38 - W8 -> 39; W9 -> 40; W10 -> 41; W11 -> 42 - W12 -> 43; W13 -> 44; W14 -> 45; W15 -> 46 - W16 -> 47; W17 -> 48; W18 -> 49; W19 -> 50 - W20 -> 51; W21 -> 52; W22 -> 53; W23 -> 54 - W24 -> 55; W25 -> 56; W26 -> 57; W27 -> 58 - W28 -> 59; W29 -> 60; W30 -> 61 - SP -> 62; XZR -> 63; WZR -> 64 - D0 -> 65; D1 -> 66; D2 -> 67; D3 -> 68 - D4 -> 69; D5 -> 70; D6 -> 71; D7 -> 72 - D8 -> 73; D9 -> 74; D10 -> 75; D11 -> 76 - D12 -> 77; D13 -> 78; D14 -> 79; D15 -> 80 - D16 -> 81; D17 -> 82; D18 -> 83; D19 -> 84 - D20 -> 85; D21 -> 86; D22 -> 87; D23 -> 88 - D24 -> 89; D25 -> 90; D26 -> 91; D27 -> 92 - D28 -> 93; D29 -> 94; D30 -> 95; D31 -> 96 - S0 -> 97; S1 -> 98; S2 -> 99; S3 -> 100 - S4 -> 101; S5 -> 102; S6 -> 103; S7 -> 104 - S8 -> 105; S9 -> 106; S10 -> 107; S11 -> 108 - S12 -> 109; S13 -> 110; S14 -> 111; S15 -> 112 - S16 -> 113; S17 -> 114; S18 -> 115; S19 -> 116 - S20 -> 117; S21 -> 118; S22 -> 119; S23 -> 120 - S24 -> 121; S25 -> 122; S26 -> 123; S27 -> 124 - S28 -> 125; S29 -> 126; S30 -> 127; S31 -> 128 - Q0 -> 129; Q1 -> 130; Q2 -> 131; Q3 -> 132 - Q4 -> 133; Q5 -> 134; Q6 -> 135; Q7 -> 136 - Q8 -> 137; Q9 -> 138; Q10 -> 139; Q11 -> 140 - Q12 -> 141; Q13 -> 142; Q14 -> 143; Q15 -> 144 - Q16 -> 145; Q17 -> 146; Q18 -> 147; Q19 -> 148 - Q20 -> 149; Q21 -> 150; Q22 -> 151; Q23 -> 152 - Q24 -> 153; Q25 -> 154; Q26 -> 155; Q27 -> 156 - Q28 -> 157; Q29 -> 158; Q30 -> 159; Q31 -> 160 - -- | Shift operations for indexed addressing. data Shift = LSL !Int -- ^ Logical shift left @@ -347,23 +287,18 @@ data Line = Line instance ToJSON Line -- | Taint lattice for register publicness. --- Bottom is the identity element for joins (no information yet). data Taint - = Bottom -- ^ No information yet (identity for join) - | Public -- ^ Known to be public (derived from stack/heap pointers) + = Public -- ^ Known to be public (derived from stack/heap pointers) | Secret -- ^ Known or assumed to be secret - | Unknown -- ^ Determined but unknown origin + | Unknown -- ^ Not yet determined deriving (Eq, Ord, Show, Generic, NFData) instance ToJSON Taint -- | Join operation for taint lattice (least upper bound). --- Bottom is identity: join Bottom x = x. -- For CT safety: Public only if both are Public. --- Order: Bottom < Public < Unknown < Secret +-- Order: Public < Unknown < Secret joinTaint :: Taint -> Taint -> Taint -joinTaint Bottom x = x -joinTaint x Bottom = x joinTaint Public Public = Public joinTaint Secret _ = Secret joinTaint _ Secret = Secret @@ -371,20 +306,15 @@ joinTaint _ _ = Unknown -- Public+Unknown or Unknown+Unknown -- | Provenance: tracks whether a value derives from known-public sources. -- Used to upgrade Unknown taint to Public when provenance can prove safety. --- ProvBottom is the identity element for joins (no information yet). data Provenance - = ProvBottom -- ^ No information yet (identity for join) - | ProvPublic -- ^ Derived from public root or constant + = ProvPublic -- ^ Derived from public root or constant | ProvUnknown -- ^ Unknown origin (e.g., loaded from memory) deriving (Eq, Ord, Show, Generic, NFData) instance ToJSON Provenance --- | Join provenance: ProvBottom is identity. --- Only ProvPublic if both are ProvPublic. +-- | Join provenance: only ProvPublic if both are ProvPublic. joinProvenance :: Provenance -> Provenance -> Provenance -joinProvenance ProvBottom x = x -joinProvenance x ProvBottom = x joinProvenance ProvPublic ProvPublic = ProvPublic joinProvenance _ _ = ProvUnknown