auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit e08c2305e1ff2a5eb0ee18fd17dd9983efdc800d
parent d08c45dad9c962dedb108935e04abc5a05d0a8b1
Author: Jared Tobin <jared@jtobin.io>
Date:   Wed, 11 Feb 2026 19:43:30 +0400

refactor: use SmallArray for register taint/provenance (IMPL12)

Replace Map Reg Taint and Map Reg Provenance with SmallArray for O(1)
register lookups. Add Bottom/ProvBottom as identity elements for joins.

Changes:
- Add regIndex/regCount for register-to-index mapping
- TaintState now uses SmallArray instead of Map for registers
- Add joinSmallArrayWith for element-wise array joins
- Use foldl' in analyzeBlock for strictness
- Fix dataflow initialization to only seed entry blocks

Benchmark results (mixed):
- taint/intra-small: 18ms -> 4ms (4.7x faster)
- taint/intra-large: 83ms -> 36ms (2.3x faster)
- taint/inter-small: 2.9ms -> 8ms (2.8x slower)
- taint/inter-large: 25ms -> 66ms (2.7x slower)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Audit/AArch64/Check.hs | 19+++++++++++++------
Mlib/Audit/AArch64/Taint.hs | 191++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------------
Mlib/Audit/AArch64/Types.hs | 80++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
3 files changed, 213 insertions(+), 77 deletions(-)

diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs @@ -27,6 +27,7 @@ module Audit.AArch64.Check ( import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock, functionLabels, functionBlocks) import Audit.AArch64.Taint +import Audit.AArch64.Types (Taint(..)) import Audit.AArch64.Types import Control.DeepSeq (NFData) import qualified Data.IntMap.Strict as IM @@ -137,26 +138,32 @@ checkAddrMode sym ln instr addr st = case addr of [] -- | Check that base register is public. --- If taint is Unknown, check provenance to see if we can upgrade to Public. +-- If taint is Unknown/Bottom, check provenance to see if we can upgrade. checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkBase sym ln instr base st = case getTaint base st of Public -> [] Secret -> [Violation sym ln instr (SecretBase base)] - Unknown -> case getProvenance base st of + Bottom -> checkProv -- No information yet, check provenance + Unknown -> checkProv + where + checkProv = case getProvenance base st of ProvPublic -> [] -- Provenance proves public derivation - ProvUnknown -> [Violation sym ln instr (UnknownBase base)] + _ -> [Violation sym ln instr (UnknownBase base)] -- | Check that index register is public. --- If taint is Unknown, check provenance to see if we can upgrade to Public. +-- If taint is Unknown/Bottom, check provenance to see if we can upgrade. checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkIndex sym ln instr idx st = case getTaint idx st of Public -> [] Secret -> [Violation sym ln instr (SecretIndex idx)] - Unknown -> case getProvenance idx st of + Bottom -> checkProv -- No information yet, check provenance + Unknown -> checkProv + where + checkProv = case getProvenance idx st of ProvPublic -> [] -- Provenance proves public derivation - ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)] + _ -> [Violation sym ln instr (UnknownIndex idx)] -- | Check entire CFG with inter-procedural analysis. -- Computes function summaries via fixpoint, then checks each function. diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs @@ -13,6 +13,8 @@ -- Public registers are those derived from known-safe sources (stack -- pointers, heap pointers, constants). Registers with unknown or -- secret-derived values are flagged when used in memory addressing. +-- +-- Uses SmallArray-backed register storage for efficient O(1) lookups. module Audit.AArch64.Taint ( TaintState @@ -50,6 +52,12 @@ import Audit.AArch64.Types ( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..) , Taint(..), joinTaint, Provenance(..), joinProvenance , TaintConfig(..), ArgPolicy(..) + , regIndex, regCount + ) +import Control.Monad.ST (runST) +import Data.Primitive.SmallArray + ( SmallArray, indexSmallArray, newSmallArray, writeSmallArray + , freezeSmallArray, thawSmallArray, sizeofSmallArray ) import Data.IntMap.Strict (IntMap) import qualified Data.IntMap.Strict as IM @@ -62,13 +70,14 @@ import Data.Text (Text) -- | Taint state: maps registers to their publicness, plus stack slots. -- Also tracks provenance for upgrading Unknown to Public when provable. +-- Uses SmallArray for O(1) register lookups (indexed by regIndex). data TaintState = TaintState - { tsRegs :: !(Map Reg Taint) - -- ^ Register taints + { tsRegs :: !(SmallArray Taint) + -- ^ Register taints (indexed by regIndex) , tsStack :: !(IntMap Taint) -- ^ Stack slot taints (keyed by SP offset) - , tsProv :: !(Map Reg Provenance) - -- ^ Register provenance + , tsProv :: !(SmallArray Provenance) + -- ^ Register provenance (indexed by regIndex) , tsStackProv :: !(IntMap Provenance) -- ^ Stack slot provenance (keyed by SP offset) } deriving (Eq, Show) @@ -98,23 +107,45 @@ publicRoots = , XZR, WZR -- Zero registers ] --- | Empty taint state (no known taints). +-- | Empty taint state (all Bottom/ProvUnknown). +-- Bottom is the identity for join, representing "no information yet". emptyTaintState :: TaintState emptyTaintState = TaintState - { tsRegs = Map.empty + { tsRegs = defaultTaintArray , tsStack = IM.empty - , tsProv = Map.empty + , tsProv = defaultProvArray , tsStackProv = IM.empty } +-- | Default array with all Bottom taints (identity for join). +defaultTaintArray :: SmallArray Taint +defaultTaintArray = runST $ do + arr <- newSmallArray regCount Bottom + freezeSmallArray arr 0 regCount + +-- | Default array with all ProvBottom provenances (identity for join). +defaultProvArray :: SmallArray Provenance +defaultProvArray = runST $ do + arr <- newSmallArray regCount ProvBottom + freezeSmallArray arr 0 regCount + -- | Initial taint state with public roots marked. +-- Non-public-root registers start at Unknown (truly unknown at function entry). initTaintState :: TaintState -initTaintState = TaintState - { tsRegs = Map.fromList [(r, Public) | r <- publicRoots] - , tsStack = IM.empty - , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots] - , tsStackProv = IM.empty - } +initTaintState = runST $ do + regs <- newSmallArray regCount Unknown + prov <- newSmallArray regCount ProvUnknown + -- Mark public roots + mapM_ (\r -> writeSmallArray regs (regIndex r) Public) publicRoots + mapM_ (\r -> writeSmallArray prov (regIndex r) ProvPublic) publicRoots + regsArr <- freezeSmallArray regs 0 regCount + provArr <- freezeSmallArray prov 0 regCount + pure $ TaintState + { tsRegs = regsArr + , tsStack = IM.empty + , tsProv = provArr + , tsStackProv = IM.empty + } -- | Seed argument registers according to policy. -- Secret registers are marked Secret with ProvUnknown. @@ -127,22 +158,16 @@ seedArgs policy st = st2 = Set.foldr markSecret st1 (apSecret policy) in st2 where - markPublic r s = s - { tsRegs = Map.insert r Public (tsRegs s) - , tsProv = Map.insert r ProvPublic (tsProv s) - } - markSecret r s = s - { tsRegs = Map.insert r Secret (tsRegs s) - , tsProv = Map.insert r ProvUnknown (tsProv s) - } + markPublic r s = setTaintProv r Public ProvPublic s + markSecret r s = setTaintProv r Secret ProvUnknown s -- | Get the taint of a register. getTaint :: Reg -> TaintState -> Taint -getTaint r st = Map.findWithDefault Unknown r (tsRegs st) +getTaint r st = indexSmallArray (tsRegs st) (regIndex r) -- | Get the provenance of a register. getProvenance :: Reg -> TaintState -> Provenance -getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st) +getProvenance r st = indexSmallArray (tsProv st) (regIndex r) -- | Analyze a single line, updating taint state. analyzeLine :: Line -> TaintState -> TaintState @@ -152,7 +177,7 @@ analyzeLine l st = case lineInstr l of -- | Analyze a basic block, threading taint state through. analyzeBlock :: [Line] -> TaintState -> TaintState -analyzeBlock lns st = foldl (flip analyzeLine) st lns +analyzeBlock lns st = foldl' (flip analyzeLine) st lns -- | Transfer function for taint analysis. -- @@ -364,14 +389,23 @@ transfer instr st = case instr of -- | Set taint for a register. setTaint :: Reg -> Taint -> TaintState -> TaintState -setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) } +setTaint r t st = st { tsRegs = updateSmallArray (regIndex r) t (tsRegs st) } -- | Set both taint and provenance for a register. setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState setTaintProv r t p st = st - { tsRegs = Map.insert r t (tsRegs st) - , tsProv = Map.insert r p (tsProv st) + { tsRegs = updateSmallArray idx t (tsRegs st) + , tsProv = updateSmallArray idx p (tsProv st) } + where + idx = regIndex r + +-- | Update a single element in a SmallArray (copy-on-write). +updateSmallArray :: Int -> a -> SmallArray a -> SmallArray a +updateSmallArray idx val arr = runST $ do + marr <- thawSmallArray arr 0 (sizeofSmallArray arr) + writeSmallArray marr idx val + freezeSmallArray marr 0 (sizeofSmallArray arr) -- | Set taint for a loaded value, preserving public roots. -- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded @@ -550,10 +584,14 @@ provJoin3 a b c = joinProvenance a (joinProvenance b c) -- Per AArch64 ABI, x0-x17 are caller-saved. -- Clears both taint and provenance. invalidateCallerSaved :: TaintState -> TaintState -invalidateCallerSaved st = st - { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved - , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved - } +invalidateCallerSaved st = runST $ do + regs <- thawSmallArray (tsRegs st) 0 regCount + prov <- thawSmallArray (tsProv st) 0 regCount + mapM_ (\r -> writeSmallArray regs (regIndex r) Unknown) callerSaved + mapM_ (\r -> writeSmallArray prov (regIndex r) ProvUnknown) callerSaved + regsArr <- freezeSmallArray regs 0 regCount + provArr <- freezeSmallArray prov 0 regCount + pure $ st { tsRegs = regsArr, tsProv = provArr } where callerSaved = [ X0, X1, X2, X3, X4, X5, X6, X7 @@ -562,16 +600,32 @@ invalidateCallerSaved st = st ] -- | Join two taint states (element-wise join). --- For registers in both, take the join. For registers in only one, keep. +-- Register arrays are joined element-wise with joinTaint/joinProvenance. -- Stack slots and provenance are also joined element-wise. joinTaintState :: TaintState -> TaintState -> TaintState joinTaintState a b = TaintState - { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b) + { tsRegs = joinSmallArrayWith joinTaint (tsRegs a) (tsRegs b) , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b) - , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b) + , tsProv = joinSmallArrayWith joinProvenance (tsProv a) (tsProv b) , tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b) } +-- | Join two SmallArrays element-wise with a combining function. +joinSmallArrayWith :: (a -> a -> a) -> SmallArray a -> SmallArray a + -> SmallArray a +joinSmallArrayWith f arr1 arr2 = runST $ do + let n = sizeofSmallArray arr1 + result <- newSmallArray n (indexSmallArray arr1 0) + let go i + | i >= n = pure () + | otherwise = do + let v1 = indexSmallArray arr1 i + v2 = indexSmallArray arr2 i + writeSmallArray result i (f v1 v2) + go (i + 1) + go 0 + freezeSmallArray result 0 n + -- | Run forward dataflow analysis over a CFG. -- Returns the IN taint state for each block (indexed by block number). runDataflow :: CFG -> IntMap TaintState @@ -581,10 +635,11 @@ runDataflow cfg where nBlocks = cfgBlockCount cfg - -- Initial states: all blocks start with public roots (GHC invariant) - initIn = IM.fromList [(i, initTaintState) | i <- [0..nBlocks-1]] + -- Only entry block starts with initTaintState; others get emptyTaintState + -- (which has Bottom for all registers, acting as identity for join). + initIn = IM.singleton 0 initTaintState initOut = IM.empty - initWorklist = IS.fromList [0..nBlocks-1] + initWorklist = IS.singleton 0 go :: IntSet -> IntMap TaintState -> IntMap TaintState -> IntMap TaintState go worklist inStates outStates @@ -623,13 +678,9 @@ newtype FuncSummary = FuncSummary { summaryState :: TaintState } deriving (Eq, Show) -- | Initial conservative summary: all caller-saved are Unknown. +-- Uses the default empty taint state (all Unknown/ProvUnknown). initSummary :: FuncSummary -initSummary = FuncSummary $ TaintState - { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ] - , tsStack = IM.empty - , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ] - , tsStackProv = IM.empty - } +initSummary = FuncSummary emptyTaintState -- | Caller-saved registers per AArch64 ABI. callerSavedRegs :: [Reg] @@ -647,15 +698,18 @@ joinSummary (FuncSummary a) (FuncSummary b) = -- | Apply a function summary at a call site. -- Replaces caller-saved register taints and provenance with summary values. applySummary :: FuncSummary -> TaintState -> TaintState -applySummary (FuncSummary summ) st = st - { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs - , tsProv = foldr applyProv (tsProv st) callerSavedRegs - } - where - summRegs = tsRegs summ - summProv = tsProv summ - applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s - applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s +applySummary (FuncSummary summ) st = runST $ do + regs <- thawSmallArray (tsRegs st) 0 regCount + prov <- thawSmallArray (tsProv st) 0 regCount + let summRegs = tsRegs summ + summProv = tsProv summ + mapM_ (\r -> do + let idx = regIndex r + writeSmallArray regs idx (indexSmallArray summRegs idx) + writeSmallArray prov idx (indexSmallArray summProv idx)) callerSavedRegs + regsArr <- freezeSmallArray regs 0 regCount + provArr <- freezeSmallArray prov 0 regCount + pure $ st { tsRegs = regsArr, tsProv = provArr } -- | Run dataflow analysis for a single function (subset of blocks). -- Returns the OUT state at return points (joined). @@ -721,7 +775,7 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt -- | Analyze a block, applying call summaries at bl instructions. analyzeBlockWithSummaries :: BasicBlock -> TaintState -> Map Text FuncSummary -> TaintState -analyzeBlockWithSummaries bb st0 summaries = foldl go st0 (bbLines bb) +analyzeBlockWithSummaries bb st0 summaries = foldl' go st0 (bbLines bb) where go st l = case lineInstr l of Nothing -> st @@ -782,27 +836,32 @@ runDataflowWithConfig tcfg cfg where nBlocks = cfgBlockCount cfg - -- Build a map from block index to entry taint state - -- Entry blocks of functions get their policy applied + -- Only initialize function entry blocks with their entry states. + -- Other blocks get emptyTaintState via findWithDefault during propagation. initIn = IM.fromList - [ (i, entryState i (indexBlock cfg i)) - | i <- [0..nBlocks-1] + [ (idx, entryState idx bb) + | idx <- [0..nBlocks-1] + , let bb = indexBlock cfg idx + , isFuncEntry idx bb ] + isFuncEntry idx bb = case bbLabel bb of + Nothing -> idx == 0 -- Block 0 is always an entry if no label + Just lbl -> case functionBlocks cfg lbl of + (entry:_) -> entry == idx + [] -> False + entryState idx bb = let base = initTaintState in case bbLabel bb of Nothing -> base Just lbl -> - -- Check if this block is a function entry - case functionBlocks cfg lbl of - (entry:_) | entry == idx -> - case Map.lookup lbl (tcPolicies tcfg) of - Nothing -> base - Just policy -> seedArgs policy base - _ -> base - - initWorklist = IS.fromList [0..nBlocks-1] + case Map.lookup lbl (tcPolicies tcfg) of + Nothing -> base + Just policy -> seedArgs policy base + + -- Start worklist with all function entry blocks + initWorklist = IS.fromList (IM.keys initIn) go worklist inStates outStates | IS.null worklist = inStates diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs @@ -16,6 +16,8 @@ module Audit.AArch64.Types ( Reg(..) , regName , regFromText + , regIndex + , regCount -- * Operands and addressing , Shift(..) @@ -144,6 +146,64 @@ regFromText t = Map.lookup (T.toUpper t) regMap , ("Q28", Q28), ("Q29", Q29), ("Q30", Q30), ("Q31", Q31) ] +-- | Total number of registers (for array sizing). +regCount :: Int +regCount = 161 + +-- | Map a register to its array index. +-- The mapping is: +-- X0-X30: 0-30 +-- W0-W30: 31-61 +-- SP: 62 +-- XZR: 63 +-- WZR: 64 +-- D0-D31: 65-96 +-- S0-S31: 97-128 +-- Q0-Q31: 129-160 +regIndex :: Reg -> Int +regIndex r = case r of + X0 -> 0; X1 -> 1; X2 -> 2; X3 -> 3 + X4 -> 4; X5 -> 5; X6 -> 6; X7 -> 7 + X8 -> 8; X9 -> 9; X10 -> 10; X11 -> 11 + X12 -> 12; X13 -> 13; X14 -> 14; X15 -> 15 + X16 -> 16; X17 -> 17; X18 -> 18; X19 -> 19 + X20 -> 20; X21 -> 21; X22 -> 22; X23 -> 23 + X24 -> 24; X25 -> 25; X26 -> 26; X27 -> 27 + X28 -> 28; X29 -> 29; X30 -> 30 + W0 -> 31; W1 -> 32; W2 -> 33; W3 -> 34 + W4 -> 35; W5 -> 36; W6 -> 37; W7 -> 38 + W8 -> 39; W9 -> 40; W10 -> 41; W11 -> 42 + W12 -> 43; W13 -> 44; W14 -> 45; W15 -> 46 + W16 -> 47; W17 -> 48; W18 -> 49; W19 -> 50 + W20 -> 51; W21 -> 52; W22 -> 53; W23 -> 54 + W24 -> 55; W25 -> 56; W26 -> 57; W27 -> 58 + W28 -> 59; W29 -> 60; W30 -> 61 + SP -> 62; XZR -> 63; WZR -> 64 + D0 -> 65; D1 -> 66; D2 -> 67; D3 -> 68 + D4 -> 69; D5 -> 70; D6 -> 71; D7 -> 72 + D8 -> 73; D9 -> 74; D10 -> 75; D11 -> 76 + D12 -> 77; D13 -> 78; D14 -> 79; D15 -> 80 + D16 -> 81; D17 -> 82; D18 -> 83; D19 -> 84 + D20 -> 85; D21 -> 86; D22 -> 87; D23 -> 88 + D24 -> 89; D25 -> 90; D26 -> 91; D27 -> 92 + D28 -> 93; D29 -> 94; D30 -> 95; D31 -> 96 + S0 -> 97; S1 -> 98; S2 -> 99; S3 -> 100 + S4 -> 101; S5 -> 102; S6 -> 103; S7 -> 104 + S8 -> 105; S9 -> 106; S10 -> 107; S11 -> 108 + S12 -> 109; S13 -> 110; S14 -> 111; S15 -> 112 + S16 -> 113; S17 -> 114; S18 -> 115; S19 -> 116 + S20 -> 117; S21 -> 118; S22 -> 119; S23 -> 120 + S24 -> 121; S25 -> 122; S26 -> 123; S27 -> 124 + S28 -> 125; S29 -> 126; S30 -> 127; S31 -> 128 + Q0 -> 129; Q1 -> 130; Q2 -> 131; Q3 -> 132 + Q4 -> 133; Q5 -> 134; Q6 -> 135; Q7 -> 136 + Q8 -> 137; Q9 -> 138; Q10 -> 139; Q11 -> 140 + Q12 -> 141; Q13 -> 142; Q14 -> 143; Q15 -> 144 + Q16 -> 145; Q17 -> 146; Q18 -> 147; Q19 -> 148 + Q20 -> 149; Q21 -> 150; Q22 -> 151; Q23 -> 152 + Q24 -> 153; Q25 -> 154; Q26 -> 155; Q27 -> 156 + Q28 -> 157; Q29 -> 158; Q30 -> 159; Q31 -> 160 + -- | Shift operations for indexed addressing. data Shift = LSL !Int -- ^ Logical shift left @@ -287,18 +347,23 @@ data Line = Line instance ToJSON Line -- | Taint lattice for register publicness. +-- Bottom is the identity element for joins (no information yet). data Taint - = Public -- ^ Known to be public (derived from stack/heap pointers) + = Bottom -- ^ No information yet (identity for join) + | Public -- ^ Known to be public (derived from stack/heap pointers) | Secret -- ^ Known or assumed to be secret - | Unknown -- ^ Not yet determined + | Unknown -- ^ Determined but unknown origin deriving (Eq, Ord, Show, Generic, NFData) instance ToJSON Taint -- | Join operation for taint lattice (least upper bound). +-- Bottom is identity: join Bottom x = x. -- For CT safety: Public only if both are Public. --- Order: Public < Unknown < Secret +-- Order: Bottom < Public < Unknown < Secret joinTaint :: Taint -> Taint -> Taint +joinTaint Bottom x = x +joinTaint x Bottom = x joinTaint Public Public = Public joinTaint Secret _ = Secret joinTaint _ Secret = Secret @@ -306,15 +371,20 @@ joinTaint _ _ = Unknown -- Public+Unknown or Unknown+Unknown -- | Provenance: tracks whether a value derives from known-public sources. -- Used to upgrade Unknown taint to Public when provenance can prove safety. +-- ProvBottom is the identity element for joins (no information yet). data Provenance - = ProvPublic -- ^ Derived from public root or constant + = ProvBottom -- ^ No information yet (identity for join) + | ProvPublic -- ^ Derived from public root or constant | ProvUnknown -- ^ Unknown origin (e.g., loaded from memory) deriving (Eq, Ord, Show, Generic, NFData) instance ToJSON Provenance --- | Join provenance: only ProvPublic if both are ProvPublic. +-- | Join provenance: ProvBottom is identity. +-- Only ProvPublic if both are ProvPublic. joinProvenance :: Provenance -> Provenance -> Provenance +joinProvenance ProvBottom x = x +joinProvenance x ProvBottom = x joinProvenance ProvPublic ProvPublic = ProvPublic joinProvenance _ _ = ProvUnknown