commit fb15d345eb83db9a270c3f937552e2a1c2ef7ba1
parent f66e98a34416cbf8f5a8bc7c8193de5382b38bc7
Author: Jared Tobin <jared@jtobin.io>
Date: Tue, 10 Feb 2026 14:18:56 +0400
feat: implement def-use provenance tracking (IMPL6)
Add provenance tracking alongside taint analysis to recover public
status when taint becomes Unknown but derivation chain proves safety.
- Add Provenance type (ProvPublic | ProvUnknown) to Types.hs
- Extend TaintState with tsProv and tsStackProv maps
- Track provenance through all transfer rules:
- adr/adrp, movz/movn, public roots -> ProvPublic
- Arithmetic/logical ops -> join provenances
- Stack stores/loads -> preserve provenance
- GOT/PLT loads -> ProvPublic
- Other loads, caller-saved on calls -> ProvUnknown
- Update checkBase/checkIndex to upgrade Unknown to Public when
provenance is ProvPublic
- Add 4 provenance tests
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
4 files changed, 345 insertions(+), 96 deletions(-)
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -128,20 +128,26 @@ checkAddrMode sym ln instr addr st = case addr of
[]
-- | Check that base register is public.
+-- If taint is Unknown, check provenance to see if we can upgrade to Public.
checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkBase sym ln instr base st =
case getTaint base st of
Public -> []
Secret -> [Violation sym ln instr (SecretBase base)]
- Unknown -> [Violation sym ln instr (UnknownBase base)]
+ Unknown -> case getProvenance base st of
+ ProvPublic -> [] -- Provenance proves public derivation
+ ProvUnknown -> [Violation sym ln instr (UnknownBase base)]
-- | Check that index register is public.
+-- If taint is Unknown, check provenance to see if we can upgrade to Public.
checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkIndex sym ln instr idx st =
case getTaint idx st of
Public -> []
Secret -> [Violation sym ln instr (SecretIndex idx)]
- Unknown -> [Violation sym ln instr (UnknownIndex idx)]
+ Unknown -> case getProvenance idx st of
+ ProvPublic -> [] -- Provenance proves public derivation
+ ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)]
-- | Check entire CFG with inter-procedural analysis.
-- Computes function summaries via fixpoint, then checks each function.
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -21,6 +21,7 @@ module Audit.AArch64.Taint (
, analyzeBlock
, getTaint
, setTaint
+ , getProvenance
, publicRoots
, joinTaintState
, runDataflow
@@ -46,11 +47,16 @@ import qualified Data.Set as Set
import Data.Text (Text)
-- | Taint state: maps registers to their publicness, plus stack slots.
+-- Also tracks provenance for upgrading Unknown to Public when provable.
data TaintState = TaintState
- { tsRegs :: !(Map Reg Taint)
+ { tsRegs :: !(Map Reg Taint)
-- ^ Register taints
- , tsStack :: !(IntMap Taint)
+ , tsStack :: !(IntMap Taint)
-- ^ Stack slot taints (keyed by SP offset)
+ , tsProv :: !(Map Reg Provenance)
+ -- ^ Register provenance
+ , tsStackProv :: !(IntMap Provenance)
+ -- ^ Stack slot provenance (keyed by SP offset)
} deriving (Eq, Show)
-- | GHC 9.10.3 AArch64 public root registers.
@@ -81,21 +87,29 @@ publicRoots =
-- | Empty taint state (no known taints).
emptyTaintState :: TaintState
emptyTaintState = TaintState
- { tsRegs = Map.empty
- , tsStack = IM.empty
+ { tsRegs = Map.empty
+ , tsStack = IM.empty
+ , tsProv = Map.empty
+ , tsStackProv = IM.empty
}
-- | Initial taint state with public roots marked.
initTaintState :: TaintState
initTaintState = TaintState
- { tsRegs = Map.fromList [(r, Public) | r <- publicRoots]
- , tsStack = IM.empty
+ { tsRegs = Map.fromList [(r, Public) | r <- publicRoots]
+ , tsStack = IM.empty
+ , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots]
+ , tsStackProv = IM.empty
}
-- | Get the taint of a register.
getTaint :: Reg -> TaintState -> Taint
getTaint r st = Map.findWithDefault Unknown r (tsRegs st)
+-- | Get the provenance of a register.
+getProvenance :: Reg -> TaintState -> Provenance
+getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st)
+
-- | Analyze a single line, updating taint state.
analyzeLine :: Line -> TaintState -> TaintState
analyzeLine l st = case lineInstr l of
@@ -109,59 +123,143 @@ analyzeBlock lns st = foldl (flip analyzeLine) st lns
-- | Transfer function for taint analysis.
--
-- For each instruction, determine how it affects register taints.
+-- Also tracks provenance for upgrading Unknown bases.
transfer :: Instr -> TaintState -> TaintState
transfer instr st = case instr of
- -- Move: destination gets source taint
- Mov dst op -> setTaint dst (operandTaint op st) st
- Movz dst _ _ -> setTaint dst Public st -- Immediate is public
+ -- Move: destination gets source taint and provenance
+ Mov dst op ->
+ setTaintProv dst (operandTaint op st) (operandProv op st) st
+ Movz dst _ _ ->
+ setTaintProv dst Public ProvPublic st -- Immediate is public
Movk _ _ _ -> st -- Keeps existing value, modifies bits
- Movn dst _ _ -> setTaint dst Public st -- Immediate is public
+ Movn dst _ _ ->
+ setTaintProv dst Public ProvPublic st -- Immediate is public
- -- Arithmetic: result is join of operand taints
+ -- Arithmetic: result is join of operand taints/provenances
-- Clear stack map if SP is modified (offsets become stale)
- Add dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Sub dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Adds dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Subs dst r1 op -> updateWithSpCheck dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Adc dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Adcs dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Sbc dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Neg dst op -> setTaint dst (operandTaint op st) st
- Negs dst op -> setTaint dst (operandTaint op st) st
- Mul dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Mneg dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Madd dst r1 r2 r3 -> setTaint dst (join3 (getTaint r1 st) (getTaint r2 st)
- (getTaint r3 st)) st
- Msub dst r1 r2 r3 -> setTaint dst (join3 (getTaint r1 st) (getTaint r2 st)
- (getTaint r3 st)) st
- Umulh dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Smulh dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Udiv dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Sdiv dst r1 r2 -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
-
- -- Logical: result is join of operand taints
- And dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Orr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Eor dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Bic dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Mvn dst op -> setTaint dst (operandTaint op st) st
+ Add dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in updateWithSpCheckProv dst t p st
+ Sub dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in updateWithSpCheckProv dst t p st
+ Adds dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in updateWithSpCheckProv dst t p st
+ Subs dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in updateWithSpCheckProv dst t p st
+ Adc dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Adcs dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Sbc dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Neg dst op ->
+ setTaintProv dst (operandTaint op st) (operandProv op st) st
+ Negs dst op ->
+ setTaintProv dst (operandTaint op st) (operandProv op st) st
+ Mul dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Mneg dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Madd dst r1 r2 r3 ->
+ let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st)
+ p = provJoin3 (getProvenance r1 st) (getProvenance r2 st)
+ (getProvenance r3 st)
+ in setTaintProv dst t p st
+ Msub dst r1 r2 r3 ->
+ let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st)
+ p = provJoin3 (getProvenance r1 st) (getProvenance r2 st)
+ (getProvenance r3 st)
+ in setTaintProv dst t p st
+ Umulh dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Smulh dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Udiv dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Sdiv dst r1 r2 ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+
+ -- Logical: result is join of operand taints/provenances
+ And dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Orr dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Eor dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Bic dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Mvn dst op ->
+ setTaintProv dst (operandTaint op st) (operandProv op st) st
Tst _ _ -> st -- No destination
- -- Shifts: result is join of operand taints
- Lsl dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Lsr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Asr dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
- Ror dst r1 op -> setTaint dst (join2 (getTaint r1 st) (operandTaint op st)) st
-
- -- Bit manipulation
- Ubfx dst r1 _ _ -> setTaint dst (getTaint r1 st) st
- Sbfx dst r1 _ _ -> setTaint dst (getTaint r1 st) st
- Bfi dst r1 _ _ -> setTaint dst (join2 (getTaint dst st) (getTaint r1 st)) st
- Extr dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
+ -- Shifts: result is join of operand taints/provenances
+ Lsl dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Lsr dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Asr dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+ Ror dst r1 op ->
+ let t = join2 (getTaint r1 st) (operandTaint op st)
+ p = provJoin2 (getProvenance r1 st) (operandProv op st)
+ in setTaintProv dst t p st
+
+ -- Bit manipulation: preserves provenance from source
+ Ubfx dst r1 _ _ ->
+ setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
+ Sbfx dst r1 _ _ ->
+ setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
+ Bfi dst r1 _ _ ->
+ let t = join2 (getTaint dst st) (getTaint r1 st)
+ p = provJoin2 (getProvenance dst st) (getProvenance r1 st)
+ in setTaintProv dst t p st
+ Extr dst r1 r2 _ ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
-- Address generation: result is public (constant pool / PC-relative)
- Adr dst _ -> setTaint dst Public st
- Adrp dst _ -> setTaint dst Public st
+ Adr dst _ -> setTaintProv dst Public ProvPublic st
+ Adrp dst _ -> setTaintProv dst Public ProvPublic st
-- Loads: restore from stack slots if [sp, #imm], else Unknown
-- Exception: public roots stay public (GHC spills/restores them)
@@ -181,15 +279,29 @@ transfer instr st = case instr of
Stur src addr -> storeToStack src addr st
Stp src1 src2 addr -> storePairToStack src1 src2 addr st
- -- Conditionals: conservative join
+ -- Conditionals: conservative join with provenance
Cmp _ _ -> st
Cmn _ _ -> st
- Csel dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Csinc dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Csinv dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Csneg dst r1 r2 _ -> setTaint dst (join2 (getTaint r1 st) (getTaint r2 st)) st
- Cset dst _ -> setTaint dst Public st -- Condition flag derived
- Cinc dst r1 _ -> setTaint dst (getTaint r1 st) st
+ Csel dst r1 r2 _ ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Csinc dst r1 r2 _ ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Csinv dst r1 r2 _ ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Csneg dst r1 r2 _ ->
+ let t = join2 (getTaint r1 st) (getTaint r2 st)
+ p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
+ in setTaintProv dst t p st
+ Cset dst _ ->
+ setTaintProv dst Public ProvPublic st -- Condition flag derived
+ Cinc dst r1 _ ->
+ setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
-- Branches: no register change
B _ -> st
@@ -214,52 +326,68 @@ transfer instr st = case instr of
setTaint :: Reg -> Taint -> TaintState -> TaintState
setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) }
+-- | Set both taint and provenance for a register.
+setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState
+setTaintProv r t p st = st
+ { tsRegs = Map.insert r t (tsRegs st)
+ , tsProv = Map.insert r p (tsProv st)
+ }
+
-- | Set taint for a loaded value, preserving public roots.
-- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded
-- since GHC spills/restores them from the hardware stack.
setTaintLoad :: Reg -> TaintState -> TaintState
setTaintLoad dst st
- | isPublicRoot dst = setTaint dst Public st
- | otherwise = setTaint dst Unknown st
+ | isPublicRoot dst = setTaintProv dst Public ProvPublic st
+ | otherwise = setTaintProv dst Unknown ProvUnknown st
where
isPublicRoot r = r `elem` publicRoots
-- | Set taint for a loaded value from a known stack slot.
--- If we have tracked taint at this offset, use it; otherwise Unknown.
+-- If we have tracked taint/provenance at this offset, use it; else Unknown.
setTaintLoadStack :: Reg -> Int -> TaintState -> TaintState
setTaintLoadStack dst offset st
- | isPublicRoot dst = setTaint dst Public st
+ | isPublicRoot dst = setTaintProv dst Public ProvPublic st
| otherwise =
let taint = IM.findWithDefault Unknown offset (tsStack st)
- in setTaint dst taint st
+ prov = IM.findWithDefault ProvUnknown offset (tsStackProv st)
+ in setTaintProv dst taint prov st
where
isPublicRoot r = r `elem` publicRoots
--- | Store taint to a stack slot.
-setStackTaint :: Int -> Taint -> TaintState -> TaintState
-setStackTaint offset t st = st { tsStack = IM.insert offset t (tsStack st) }
+-- | Store both taint and provenance to a stack slot.
+setStackTaintProv :: Int -> Taint -> Provenance -> TaintState -> TaintState
+setStackTaintProv offset t p st = st
+ { tsStack = IM.insert offset t (tsStack st)
+ , tsStackProv = IM.insert offset p (tsStackProv st)
+ }
--- | Clear all stack slot taints (when SP is modified).
+-- | Clear all stack slot taints and provenance (when SP is modified).
clearStackMap :: TaintState -> TaintState
-clearStackMap st = st { tsStack = IM.empty }
+clearStackMap st = st { tsStack = IM.empty, tsStackProv = IM.empty }
--- | Update register taint, clearing stack map if dst is SP.
-updateWithSpCheck :: Reg -> Taint -> TaintState -> TaintState
-updateWithSpCheck dst t st
- | dst == SP = clearStackMap (setTaint dst t st)
- | otherwise = setTaint dst t st
+-- | Update register taint and provenance, clearing stack map if dst is SP.
+updateWithSpCheckProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState
+updateWithSpCheckProv dst t p st
+ | dst == SP = clearStackMap (setTaintProv dst t p st)
+ | otherwise = setTaintProv dst t p st
-- | Track a store to stack if address is [sp, #imm].
-- Pre/post-indexed addressing modifies SP, invalidating the stack map.
storeToStack :: Reg -> AddrMode -> TaintState -> TaintState
storeToStack src addr st = case addr of
- BaseImm SP imm -> setStackTaint (fromInteger imm) (getTaint src st) st
+ BaseImm SP imm ->
+ let off = fromInteger imm
+ in setStackTaintProv off (getTaint src st) (getProvenance src st) st
PreIndex SP imm ->
-- Store at sp+imm, then SP changes
- clearStackMap (setStackTaint (fromInteger imm) (getTaint src st) st)
+ let off = fromInteger imm
+ st' = setStackTaintProv off (getTaint src st) (getProvenance src st) st
+ in clearStackMap st'
PostIndex SP _ ->
-- Store at sp, then SP changes (offset 0)
- clearStackMap (setStackTaint 0 (getTaint src st) st)
+ let st' = setStackTaintProv 0 (getTaint src st) (getProvenance src st) st
+ in clearStackMap st'
_ -> st -- Non-SP stores don't affect stack tracking
-- | Track a store pair to stack if address is [sp, #imm].
@@ -269,23 +397,27 @@ storePairToStack :: Reg -> Reg -> AddrMode -> TaintState -> TaintState
storePairToStack src1 src2 addr st = case addr of
BaseImm SP imm ->
let off = fromInteger imm
- in setStackTaint off (getTaint src1 st)
- (setStackTaint (off + 8) (getTaint src2 st) st)
+ st' = setStackTaintProv off (getTaint src1 st) (getProvenance src1 st)
+ (setStackTaintProv (off + 8) (getTaint src2 st)
+ (getProvenance src2 st) st)
+ in st'
PreIndex SP imm ->
-- Store at sp+imm and sp+imm+8, then SP changes
let off = fromInteger imm
- st' = setStackTaint off (getTaint src1 st)
- (setStackTaint (off + 8) (getTaint src2 st) st)
+ st' = setStackTaintProv off (getTaint src1 st) (getProvenance src1 st)
+ (setStackTaintProv (off + 8) (getTaint src2 st)
+ (getProvenance src2 st) st)
in clearStackMap st'
PostIndex SP _ ->
-- Store at sp and sp+8, then SP changes
- let st' = setStackTaint 0 (getTaint src1 st)
- (setStackTaint 8 (getTaint src2 st) st)
+ let st' = setStackTaintProv 0 (getTaint src1 st) (getProvenance src1 st)
+ (setStackTaintProv 8 (getTaint src2 st)
+ (getProvenance src2 st) st)
in clearStackMap st'
_ -> st
-- | Load from memory, handling special cases:
--- - [sp, #imm]: restore tracked stack slot taint
+-- - [sp, #imm]: restore tracked stack slot taint and provenance
-- - [r, symbol@GOTPAGEOFF]: GOT entry load, result is Public (address)
-- - Other: Unknown unless dst is a public root
-- Post-indexed addressing modifies SP, invalidating the stack map.
@@ -295,8 +427,10 @@ loadFromStack dst addr st = case addr of
PostIndex SP imm ->
-- Load first, then clear (SP changes after access)
clearStackMap (setTaintLoadStack dst (fromInteger imm) st)
- BaseSymbol _ _ -> setTaint dst Public st -- GOT/PLT entry -> address
- Literal _ -> setTaint dst Public st -- PC-relative literal -> address
+ BaseSymbol _ _ ->
+ setTaintProv dst Public ProvPublic st -- GOT/PLT entry -> address
+ Literal _ ->
+ setTaintProv dst Public ProvPublic st -- PC-relative literal -> address
_ -> setTaintLoad dst st -- Other loads use default behavior
-- | Load pair from stack if address is [sp, #imm].
@@ -324,6 +458,16 @@ operandTaint op st = case op of
OpLabel _ -> Public
OpAddr addr -> addrBaseTaint addr st
+-- | Get provenance of an operand.
+operandProv :: Operand -> TaintState -> Provenance
+operandProv op st = case op of
+ OpReg r -> getProvenance r st
+ OpImm _ -> ProvPublic
+ OpShiftedReg r _ -> getProvenance r st
+ OpExtendedReg r _ -> getProvenance r st
+ OpLabel _ -> ProvPublic
+ OpAddr _ -> ProvUnknown -- Address operand provenance is complex
+
-- | Get taint of address base register.
addrBaseTaint :: AddrMode -> TaintState -> Taint
addrBaseTaint addr st = case addr of
@@ -344,11 +488,22 @@ join2 = joinTaint
join3 :: Taint -> Taint -> Taint -> Taint
join3 a b c = joinTaint a (joinTaint b c)
+-- | Join two provenances.
+provJoin2 :: Provenance -> Provenance -> Provenance
+provJoin2 = joinProvenance
+
+-- | Join three provenances.
+provJoin3 :: Provenance -> Provenance -> Provenance -> Provenance
+provJoin3 a b c = joinProvenance a (joinProvenance b c)
+
-- | Invalidate caller-saved registers after a call.
-- Per AArch64 ABI, x0-x17 are caller-saved.
+-- Clears both taint and provenance.
invalidateCallerSaved :: TaintState -> TaintState
-invalidateCallerSaved st =
- st { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved }
+invalidateCallerSaved st = st
+ { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved
+ , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved
+ }
where
callerSaved =
[ X0, X1, X2, X3, X4, X5, X6, X7
@@ -358,11 +513,13 @@ invalidateCallerSaved st =
-- | Join two taint states (element-wise join).
-- For registers in both, take the join. For registers in only one, keep.
--- Stack slots are also joined element-wise.
+-- Stack slots and provenance are also joined element-wise.
joinTaintState :: TaintState -> TaintState -> TaintState
joinTaintState a b = TaintState
- { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b)
- , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b)
+ { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b)
+ , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b)
+ , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b)
+ , tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b)
}
-- | Run forward dataflow analysis over a CFG.
@@ -419,8 +576,10 @@ newtype FuncSummary = FuncSummary { summaryState :: TaintState }
-- | Initial conservative summary: all caller-saved are Unknown.
initSummary :: FuncSummary
initSummary = FuncSummary $ TaintState
- { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ]
- , tsStack = IM.empty
+ { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ]
+ , tsStack = IM.empty
+ , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ]
+ , tsStackProv = IM.empty
}
-- | Caller-saved registers per AArch64 ABI.
@@ -437,13 +596,17 @@ joinSummary (FuncSummary a) (FuncSummary b) =
FuncSummary (joinTaintState a b)
-- | Apply a function summary at a call site.
--- Replaces caller-saved register taints with the summary's values.
+-- Replaces caller-saved register taints and provenance with summary values.
applySummary :: FuncSummary -> TaintState -> TaintState
-applySummary (FuncSummary summ) st =
- st { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs }
+applySummary (FuncSummary summ) st = st
+ { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs
+ , tsProv = foldr applyProv (tsProv st) callerSavedRegs
+ }
where
summRegs = tsRegs summ
+ summProv = tsProv summ
applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s
+ applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s
-- | Run dataflow analysis for a single function (subset of blocks).
-- Returns the OUT state at return points (joined).
diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs
@@ -29,6 +29,10 @@ module Audit.AArch64.Types (
, Taint(..)
, joinTaint
+ -- * Provenance
+ , Provenance(..)
+ , joinProvenance
+
-- * Analysis results
, Violation(..)
, ViolationReason(..)
@@ -239,6 +243,20 @@ joinTaint Secret _ = Secret
joinTaint _ Secret = Secret
joinTaint _ _ = Unknown -- Public+Unknown or Unknown+Unknown
+-- | Provenance: tracks whether a value derives from known-public sources.
+-- Used to upgrade Unknown taint to Public when provenance can prove safety.
+data Provenance
+ = ProvPublic -- ^ Derived from public root or constant
+ | ProvUnknown -- ^ Unknown origin (e.g., loaded from memory)
+ deriving (Eq, Ord, Show, Generic)
+
+instance ToJSON Provenance
+
+-- | Join provenance: only ProvPublic if both are ProvPublic.
+joinProvenance :: Provenance -> Provenance -> Provenance
+joinProvenance ProvPublic ProvPublic = ProvPublic
+joinProvenance _ _ = ProvUnknown
+
-- | Reason for a violation.
data ViolationReason
= SecretBase !Reg -- ^ Base register is secret
diff --git a/test/Main.hs b/test/Main.hs
@@ -16,6 +16,7 @@ main = defaultMain $ testGroup "ppad-auditor" [
, taintTests
, auditTests
, interprocTests
+ , provenanceTests
]
-- Parser tests
@@ -374,3 +375,64 @@ interprocTests = testGroup "InterProc" [
Left e -> assertFailure $ "parse failed: " ++ show e
Right ar -> assertEqual "1 violation" 1 (length (arViolations ar))
]
+
+-- Provenance tests
+
+provenanceTests :: TestTree
+provenanceTests = testGroup "Provenance" [
+ testCase "good: mov from public root preserves provenance" $ do
+ -- mov from public root should preserve public provenance
+ let src = T.unlines
+ [ "foo:"
+ , " mov x8, x20" -- x8 = copy of public root
+ , " ldr x0, [x8]" -- x8 should be public via provenance
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
+
+ , testCase "good: add #imm preserves provenance" $ do
+ -- add with immediate preserves provenance from source
+ let src = T.unlines
+ [ "foo:"
+ , " add x8, x20, #16" -- x8 = public root + offset
+ , " ldr x0, [x8]" -- x8 should be public via provenance
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
+
+ , testCase "bad: load clears provenance" $ do
+ -- Loading from memory should clear provenance
+ let src = T.unlines
+ [ "foo:"
+ , " mov x8, x20" -- x8 = public
+ , " ldr x8, [x8]" -- x8 = unknown (loaded from memory)
+ , " ldr x0, [x8]" -- x8 as base should be violation
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> do
+ assertEqual "one violation" 1 (length (arViolations ar))
+ case arViolations ar of
+ [v] -> case vReason v of
+ UnknownBase X8 -> pure ()
+ other -> assertFailure $ "wrong reason: " ++ show other
+ _ -> assertFailure "expected one violation"
+
+ , testCase "good: orr with xzr preserves provenance" $ do
+ -- orr with zero register should preserve provenance
+ let src = T.unlines
+ [ "foo:"
+ , " mov x8, x20" -- x8 = public
+ , " orr x9, x8, xzr" -- x9 = x8 | 0 = copy with provenance
+ , " ldr x0, [x9]" -- x9 should be public
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> assertEqual "no violations" 0 (length (arViolations ar))
+ ]