auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit fb15d345eb83db9a270c3f937552e2a1c2ef7ba1
parent f66e98a34416cbf8f5a8bc7c8193de5382b38bc7
Author: Jared Tobin <jared@jtobin.io>
Date:   Tue, 10 Feb 2026 14:18:56 +0400

feat: implement def-use provenance tracking (IMPL6)

Add provenance tracking alongside taint analysis to recover public
status when taint becomes Unknown but derivation chain proves safety.

- Add Provenance type (ProvPublic | ProvUnknown) to Types.hs
- Extend TaintState with tsProv and tsStackProv maps
- Track provenance through all transfer rules:
  - adr/adrp, movz/movn, public roots -> ProvPublic
  - Arithmetic/logical ops -> join provenances
  - Stack stores/loads -> preserve provenance
  - GOT/PLT loads -> ProvPublic
  - Other loads, caller-saved on calls -> ProvUnknown
- Update checkBase/checkIndex to upgrade Unknown to Public when
  provenance is ProvPublic
- Add 4 provenance tests

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Audit/AArch64/Check.hs | 10++++++++--
Mlib/Audit/AArch64/Taint.hs | 351++++++++++++++++++++++++++++++++++++++++++++++++++++++++++---------------------
Mlib/Audit/AArch64/Types.hs | 18++++++++++++++++++
Mtest/Main.hs | 62++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 345 insertions(+), 96 deletions(-)

diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs @@ -128,20 +128,26 @@ checkAddrMode sym ln instr addr st = case addr of [] -- | Check that base register is public. +-- If taint is Unknown, check provenance to see if we can upgrade to Public. checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkBase sym ln instr base st = case getTaint base st of Public -> [] Secret -> [Violation sym ln instr (SecretBase base)] - Unknown -> [Violation sym ln instr (UnknownBase base)] + Unknown -> case getProvenance base st of + ProvPublic -> [] -- Provenance proves public derivation + ProvUnknown -> [Violation sym ln instr (UnknownBase base)] -- | Check that index register is public. +-- If taint is Unknown, check provenance to see if we can upgrade to Public. checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] checkIndex sym ln instr idx st = case getTaint idx st of Public -> [] Secret -> [Violation sym ln instr (SecretIndex idx)] - Unknown -> [Violation sym ln instr (UnknownIndex idx)] + Unknown -> case getProvenance idx st of + ProvPublic -> [] -- Provenance proves public derivation + ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)] -- | Check entire CFG with inter-procedural analysis. -- Computes function summaries via fixpoint, then checks each function. diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs @@ -21,6 +21,7 @@ module Audit.AArch64.Taint ( , analyzeBlock , getTaint , setTaint + , getProvenance , publicRoots , joinTaintState , runDataflow @@ -46,11 +47,16 @@ import qualified Data.Set as Set import Data.Text (Text) -- | Taint state: maps registers to their publicness, plus stack slots. +-- Also tracks provenance for upgrading Unknown to Public when provable. data TaintState = TaintState - { tsRegs :: !(Map Reg Taint) + { tsRegs :: !(Map Reg Taint) -- ^ Register taints - , tsStack :: !(IntMap Taint) + , tsStack :: !(IntMap Taint) -- ^ Stack slot taints (keyed by SP offset) + , tsProv :: !(Map Reg Provenance) + -- ^ Register provenance + , tsStackProv :: !(IntMap Provenance) + -- ^ Stack slot provenance (keyed by SP offset) } deriving (Eq, Show) -- | GHC 9.10.3 AArch64 public root registers. @@ -81,21 +87,29 @@ publicRoots = -- | Empty taint state (no known taints). emptyTaintState :: TaintState emptyTaintState = TaintState - { tsRegs = Map.empty - , tsStack = IM.empty + { tsRegs = Map.empty + , tsStack = IM.empty + , tsProv = Map.empty + , tsStackProv = IM.empty } -- | Initial taint state with public roots marked. initTaintState :: TaintState initTaintState = TaintState - { tsRegs = Map.fromList [(r, Public) | r <- publicRoots] - , tsStack = IM.empty + { tsRegs = Map.fromList [(r, Public) | r <- publicRoots] + , tsStack = IM.empty + , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots] + , tsStackProv = IM.empty } -- | Get the taint of a register. getTaint :: Reg -> TaintState -> Taint getTaint r st = Map.findWithDefault Unknown r (tsRegs st) +-- | Get the provenance of a register. +getProvenance :: Reg -> TaintState -> Provenance +getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st) + -- | Analyze a single line, updating taint state. analyzeLine :: Line -> TaintState -> TaintState analyzeLine l st = case lineInstr l of @@ -109,59 +123,143 @@ analyzeBlock lns st = foldl (flip analyzeLine) st lns -- | Transfer function for taint analysis. -- -- For each instruction, determine how it affects register taints. +-- Also tracks provenance for upgrading Unknown bases. transfer :: Instr -> TaintState -> TaintState transfer instr st = case instr of - -- Move: destination gets source taint - Mov dst op -> setTaint dst (operandTaint op st) st - Movz dst _ _ -> setTaint dst Public st -- Immediate is public + -- Move: destination gets source taint and provenance + Mov dst op -> + setTaintProv dst (operandTaint op st) (operandProv op st) st + Movz dst _ _ -> + setTaintProv dst Public ProvPublic st -- Immediate is public Movk _ _ _ -> st -- Keeps existing value, modifies bits - Movn dst _ _ -> setTaint dst Public st -- Immediate is public + Movn dst _ _ -> + setTaintProv dst Public ProvPublic st -- Immediate is public - -- Arithmetic: result is join of operand taints + -- Arithmetic: result is join of operand taints/provenances -- Clear stack map if SP is modified (offsets become stale) - Add dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st - Sub dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st - Adds dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st - Subs dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st - Adc dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Adcs dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Sbc dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Neg dst op -> setTaint dst (operandTaint op st) st - Negs dst op -> setTaint dst (operandTaint op st) st - Mul dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Mneg dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Madd dst r1 r2 r3 -> setTaint dst (join3 (getTaint r1 st) (getTaint r2 st) - (getTaint r3 st)) st - Msub dst r1 r2 r3 -> setTaint dst (join3 (getTaint r1 st) (getTaint r2 st) - (getTaint r3 st)) st - Umulh dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Smulh dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Udiv dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Sdiv dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - - -- Logical: result is join of operand taints - And dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Orr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Eor dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Bic dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Mvn dst op -> setTaint dst (operandTaint op st) st + Add dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in updateWithSpCheckProv dst t p st + Sub dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in updateWithSpCheckProv dst t p st + Adds dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in updateWithSpCheckProv dst t p st + Subs dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in updateWithSpCheckProv dst t p st + Adc dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Adcs dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Sbc dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Neg dst op -> + setTaintProv dst (operandTaint op st) (operandProv op st) st + Negs dst op -> + setTaintProv dst (operandTaint op st) (operandProv op st) st + Mul dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Mneg dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Madd dst r1 r2 r3 -> + let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) + p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) + (getProvenance r3 st) + in setTaintProv dst t p st + Msub dst r1 r2 r3 -> + let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) + p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) + (getProvenance r3 st) + in setTaintProv dst t p st + Umulh dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Smulh dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Udiv dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Sdiv dst r1 r2 -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + + -- Logical: result is join of operand taints/provenances + And dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Orr dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Eor dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Bic dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Mvn dst op -> + setTaintProv dst (operandTaint op st) (operandProv op st) st Tst _ _ -> st -- No destination - -- Shifts: result is join of operand taints - Lsl dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Lsr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Asr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - Ror dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st - - -- Bit manipulation - Ubfx dst r1 _ _ -> setTaint dst (getTaint r1 st) st - Sbfx dst r1 _ _ -> setTaint dst (getTaint r1 st) st - Bfi dst r1 _ _ -> setTaint dst (join2 (getTaint dst st) (getTaint r1 st)) st - Extr dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st + -- Shifts: result is join of operand taints/provenances + Lsl dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Lsr dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Asr dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + Ror dst r1 op -> + let t = join2 (getTaint r1 st) (operandTaint op st) + p = provJoin2 (getProvenance r1 st) (operandProv op st) + in setTaintProv dst t p st + + -- Bit manipulation: preserves provenance from source + Ubfx dst r1 _ _ -> + setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st + Sbfx dst r1 _ _ -> + setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st + Bfi dst r1 _ _ -> + let t = join2 (getTaint dst st) (getTaint r1 st) + p = provJoin2 (getProvenance dst st) (getProvenance r1 st) + in setTaintProv dst t p st + Extr dst r1 r2 _ -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st -- Address generation: result is public (constant pool / PC-relative) - Adr dst _ -> setTaint dst Public st - Adrp dst _ -> setTaint dst Public st + Adr dst _ -> setTaintProv dst Public ProvPublic st + Adrp dst _ -> setTaintProv dst Public ProvPublic st -- Loads: restore from stack slots if [sp, #imm], else Unknown -- Exception: public roots stay public (GHC spills/restores them) @@ -181,15 +279,29 @@ transfer instr st = case instr of Stur src addr -> storeToStack src addr st Stp src1 src2 addr -> storePairToStack src1 src2 addr st - -- Conditionals: conservative join + -- Conditionals: conservative join with provenance Cmp _ _ -> st Cmn _ _ -> st - Csel dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Csinc dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Csinv dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Csneg dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st - Cset dst _ -> setTaint dst Public st -- Condition flag derived - Cinc dst r1 _ -> setTaint dst (getTaint r1 st) st + Csel dst r1 r2 _ -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Csinc dst r1 r2 _ -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Csinv dst r1 r2 _ -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Csneg dst r1 r2 _ -> + let t = join2 (getTaint r1 st) (getTaint r2 st) + p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + in setTaintProv dst t p st + Cset dst _ -> + setTaintProv dst Public ProvPublic st -- Condition flag derived + Cinc dst r1 _ -> + setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st -- Branches: no register change B _ -> st @@ -214,52 +326,68 @@ transfer instr st = case instr of setTaint :: Reg -> Taint -> TaintState -> TaintState setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) } +-- | Set both taint and provenance for a register. +setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState +setTaintProv r t p st = st + { tsRegs = Map.insert r t (tsRegs st) + , tsProv = Map.insert r p (tsProv st) + } + -- | Set taint for a loaded value, preserving public roots. -- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded -- since GHC spills/restores them from the hardware stack. setTaintLoad :: Reg -> TaintState -> TaintState setTaintLoad dst st - | isPublicRoot dst = setTaint dst Public st - | otherwise = setTaint dst Unknown st + | isPublicRoot dst = setTaintProv dst Public ProvPublic st + | otherwise = setTaintProv dst Unknown ProvUnknown st where isPublicRoot r = r `elem` publicRoots -- | Set taint for a loaded value from a known stack slot. --- If we have tracked taint at this offset, use it; otherwise Unknown. +-- If we have tracked taint/provenance at this offset, use it; else Unknown. setTaintLoadStack :: Reg -> Int -> TaintState -> TaintState setTaintLoadStack dst offset st - | isPublicRoot dst = setTaint dst Public st + | isPublicRoot dst = setTaintProv dst Public ProvPublic st | otherwise = let taint = IM.findWithDefault Unknown offset (tsStack st) - in setTaint dst taint st + prov = IM.findWithDefault ProvUnknown offset (tsStackProv st) + in setTaintProv dst taint prov st where isPublicRoot r = r `elem` publicRoots --- | Store taint to a stack slot. -setStackTaint :: Int -> Taint -> TaintState -> TaintState -setStackTaint offset t st = st { tsStack = IM.insert offset t (tsStack st) } +-- | Store both taint and provenance to a stack slot. +setStackTaintProv :: Int -> Taint -> Provenance -> TaintState -> TaintState +setStackTaintProv offset t p st = st + { tsStack = IM.insert offset t (tsStack st) + , tsStackProv = IM.insert offset p (tsStackProv st) + } --- | Clear all stack slot taints (when SP is modified). +-- | Clear all stack slot taints and provenance (when SP is modified). clearStackMap :: TaintState -> TaintState -clearStackMap st = st { tsStack = IM.empty } +clearStackMap st = st { tsStack = IM.empty, tsStackProv = IM.empty } --- | Update register taint, clearing stack map if dst is SP. -updateWithSpCheck :: Reg -> Taint -> TaintState -> TaintState -updateWithSpCheck dst t st - | dst == SP = clearStackMap (setTaint dst t st) - | otherwise = setTaint dst t st +-- | Update register taint and provenance, clearing stack map if dst is SP. +updateWithSpCheckProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState +updateWithSpCheckProv dst t p st + | dst == SP = clearStackMap (setTaintProv dst t p st) + | otherwise = setTaintProv dst t p st -- | Track a store to stack if address is [sp, #imm]. -- Pre/post-indexed addressing modifies SP, invalidating the stack map. storeToStack :: Reg -> AddrMode -> TaintState -> TaintState storeToStack src addr st = case addr of - BaseImm SP imm -> setStackTaint (fromInteger imm) (getTaint src st) st + BaseImm SP imm -> + let off = fromInteger imm + in setStackTaintProv off (getTaint src st) (getProvenance src st) st PreIndex SP imm -> -- Store at sp+imm, then SP changes - clearStackMap (setStackTaint (fromInteger imm) (getTaint src st) st) + let off = fromInteger imm + st' = setStackTaintProv off (getTaint src st) (getProvenance src st) st + in clearStackMap st' PostIndex SP _ -> -- Store at sp, then SP changes (offset 0) - clearStackMap (setStackTaint 0 (getTaint src st) st) + let st' = setStackTaintProv 0 (getTaint src st) (getProvenance src st) st + in clearStackMap st' _ -> st -- Non-SP stores don't affect stack tracking -- | Track a store pair to stack if address is [sp, #imm]. @@ -269,23 +397,27 @@ storePairToStack :: Reg -> Reg -> AddrMode -> TaintState -> TaintState storePairToStack src1 src2 addr st = case addr of BaseImm SP imm -> let off = fromInteger imm - in setStackTaint off (getTaint src1 st) - (setStackTaint (off + 8) (getTaint src2 st) st) + st' = setStackTaintProv off (getTaint src1 st) (getProvenance src1 st) + (setStackTaintProv (off + 8) (getTaint src2 st) + (getProvenance src2 st) st) + in st' PreIndex SP imm -> -- Store at sp+imm and sp+imm+8, then SP changes let off = fromInteger imm - st' = setStackTaint off (getTaint src1 st) - (setStackTaint (off + 8) (getTaint src2 st) st) + st' = setStackTaintProv off (getTaint src1 st) (getProvenance src1 st) + (setStackTaintProv (off + 8) (getTaint src2 st) + (getProvenance src2 st) st) in clearStackMap st' PostIndex SP _ -> -- Store at sp and sp+8, then SP changes - let st' = setStackTaint 0 (getTaint src1 st) - (setStackTaint 8 (getTaint src2 st) st) + let st' = setStackTaintProv 0 (getTaint src1 st) (getProvenance src1 st) + (setStackTaintProv 8 (getTaint src2 st) + (getProvenance src2 st) st) in clearStackMap st' _ -> st -- | Load from memory, handling special cases: --- - [sp, #imm]: restore tracked stack slot taint +-- - [sp, #imm]: restore tracked stack slot taint and provenance -- - [r, symbol@GOTPAGEOFF]: GOT entry load, result is Public (address) -- - Other: Unknown unless dst is a public root -- Post-indexed addressing modifies SP, invalidating the stack map. @@ -295,8 +427,10 @@ loadFromStack dst addr st = case addr of PostIndex SP imm -> -- Load first, then clear (SP changes after access) clearStackMap (setTaintLoadStack dst (fromInteger imm) st) - BaseSymbol _ _ -> setTaint dst Public st -- GOT/PLT entry -> address - Literal _ -> setTaint dst Public st -- PC-relative literal -> address + BaseSymbol _ _ -> + setTaintProv dst Public ProvPublic st -- GOT/PLT entry -> address + Literal _ -> + setTaintProv dst Public ProvPublic st -- PC-relative literal -> address _ -> setTaintLoad dst st -- Other loads use default behavior -- | Load pair from stack if address is [sp, #imm]. @@ -324,6 +458,16 @@ operandTaint op st = case op of OpLabel _ -> Public OpAddr addr -> addrBaseTaint addr st +-- | Get provenance of an operand. +operandProv :: Operand -> TaintState -> Provenance +operandProv op st = case op of + OpReg r -> getProvenance r st + OpImm _ -> ProvPublic + OpShiftedReg r _ -> getProvenance r st + OpExtendedReg r _ -> getProvenance r st + OpLabel _ -> ProvPublic + OpAddr _ -> ProvUnknown -- Address operand provenance is complex + -- | Get taint of address base register. addrBaseTaint :: AddrMode -> TaintState -> Taint addrBaseTaint addr st = case addr of @@ -344,11 +488,22 @@ join2 = joinTaint join3 :: Taint -> Taint -> Taint -> Taint join3 a b c = joinTaint a (joinTaint b c) +-- | Join two provenances. +provJoin2 :: Provenance -> Provenance -> Provenance +provJoin2 = joinProvenance + +-- | Join three provenances. +provJoin3 :: Provenance -> Provenance -> Provenance -> Provenance +provJoin3 a b c = joinProvenance a (joinProvenance b c) + -- | Invalidate caller-saved registers after a call. -- Per AArch64 ABI, x0-x17 are caller-saved. +-- Clears both taint and provenance. invalidateCallerSaved :: TaintState -> TaintState -invalidateCallerSaved st = - st { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved } +invalidateCallerSaved st = st + { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved + , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved + } where callerSaved = [ X0, X1, X2, X3, X4, X5, X6, X7 @@ -358,11 +513,13 @@ invalidateCallerSaved st = -- | Join two taint states (element-wise join). -- For registers in both, take the join. For registers in only one, keep. --- Stack slots are also joined element-wise. +-- Stack slots and provenance are also joined element-wise. joinTaintState :: TaintState -> TaintState -> TaintState joinTaintState a b = TaintState - { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b) - , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b) + { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b) + , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b) + , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b) + , tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b) } -- | Run forward dataflow analysis over a CFG. @@ -419,8 +576,10 @@ newtype FuncSummary = FuncSummary { summaryState :: TaintState } -- | Initial conservative summary: all caller-saved are Unknown. initSummary :: FuncSummary initSummary = FuncSummary $ TaintState - { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ] - , tsStack = IM.empty + { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ] + , tsStack = IM.empty + , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ] + , tsStackProv = IM.empty } -- | Caller-saved registers per AArch64 ABI. @@ -437,13 +596,17 @@ joinSummary (FuncSummary a) (FuncSummary b) = FuncSummary (joinTaintState a b) -- | Apply a function summary at a call site. --- Replaces caller-saved register taints with the summary's values. +-- Replaces caller-saved register taints and provenance with summary values. applySummary :: FuncSummary -> TaintState -> TaintState -applySummary (FuncSummary summ) st = - st { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs } +applySummary (FuncSummary summ) st = st + { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs + , tsProv = foldr applyProv (tsProv st) callerSavedRegs + } where summRegs = tsRegs summ + summProv = tsProv summ applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s + applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s -- | Run dataflow analysis for a single function (subset of blocks). -- Returns the OUT state at return points (joined). diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs @@ -29,6 +29,10 @@ module Audit.AArch64.Types ( , Taint(..) , joinTaint + -- * Provenance + , Provenance(..) + , joinProvenance + -- * Analysis results , Violation(..) , ViolationReason(..) @@ -239,6 +243,20 @@ joinTaint Secret _ = Secret joinTaint _ Secret = Secret joinTaint _ _ = Unknown -- Public+Unknown or Unknown+Unknown +-- | Provenance: tracks whether a value derives from known-public sources. +-- Used to upgrade Unknown taint to Public when provenance can prove safety. +data Provenance + = ProvPublic -- ^ Derived from public root or constant + | ProvUnknown -- ^ Unknown origin (e.g., loaded from memory) + deriving (Eq, Ord, Show, Generic) + +instance ToJSON Provenance + +-- | Join provenance: only ProvPublic if both are ProvPublic. +joinProvenance :: Provenance -> Provenance -> Provenance +joinProvenance ProvPublic ProvPublic = ProvPublic +joinProvenance _ _ = ProvUnknown + -- | Reason for a violation. data ViolationReason = SecretBase !Reg -- ^ Base register is secret diff --git a/test/Main.hs b/test/Main.hs @@ -16,6 +16,7 @@ main = defaultMain $ testGroup "ppad-auditor" [ , taintTests , auditTests , interprocTests + , provenanceTests ] -- Parser tests @@ -374,3 +375,64 @@ interprocTests = testGroup "InterProc" [ Left e -> assertFailure $ "parse failed: " ++ show e Right ar -> assertEqual "1 violation" 1 (length (arViolations ar)) ] + +-- Provenance tests + +provenanceTests :: TestTree +provenanceTests = testGroup "Provenance" [ + testCase "good: mov from public root preserves provenance" $ do + -- mov from public root should preserve public provenance + let src = T.unlines + [ "foo:" + , " mov x8, x20" -- x8 = copy of public root + , " ldr x0, [x8]" -- x8 should be public via provenance + , " ret" + ] + case audit "test" src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + + , testCase "good: add #imm preserves provenance" $ do + -- add with immediate preserves provenance from source + let src = T.unlines + [ "foo:" + , " add x8, x20, #16" -- x8 = public root + offset + , " ldr x0, [x8]" -- x8 should be public via provenance + , " ret" + ] + case audit "test" src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + + , testCase "bad: load clears provenance" $ do + -- Loading from memory should clear provenance + let src = T.unlines + [ "foo:" + , " mov x8, x20" -- x8 = public + , " ldr x8, [x8]" -- x8 = unknown (loaded from memory) + , " ldr x0, [x8]" -- x8 as base should be violation + , " ret" + ] + case audit "test" src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right ar -> do + assertEqual "one violation" 1 (length (arViolations ar)) + case arViolations ar of + [v] -> case vReason v of + UnknownBase X8 -> pure () + other -> assertFailure $ "wrong reason: " ++ show other + _ -> assertFailure "expected one violation" + + , testCase "good: orr with xzr preserves provenance" $ do + -- orr with zero register should preserve provenance + let src = T.unlines + [ "foo:" + , " mov x8, x20" -- x8 = public + , " orr x9, x8, xzr" -- x9 = x8 | 0 = copy with provenance + , " ldr x0, [x9]" -- x9 should be public + , " ret" + ] + case audit "test" src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + ]