auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit 278bc2d33847f190c2a5b10c1d0955db70428eb5
parent 208019cd069e91b783af8737490251946abe1e48
Author: Jared Tobin <jared@jtobin.io>
Date:   Wed, 11 Feb 2026 20:36:59 +0400

fix: set KindScalar for scalar ops and clear kinds at call boundaries

Addresses reviewer comments on IMPL13:

1. Scalar operations (Adc, Neg, Mul, Orr, Lsl, Ubfx, Csel variants, etc.)
   now use setTaintProvKind with KindScalar instead of setTaintProv,
   preventing stale KindPtr from enabling incorrect provenance upgrades.

2. invalidateCallerSaved now also clears tsKind to KindUnknown for
   caller-saved registers (x0-x17).

3. applySummary now also applies kinds from function summaries, not just
   taint and provenance.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mlib/Audit/AArch64/Taint.hs | 77+++++++++++++++++++++++++++++++++++++++++------------------------------------
1 file changed, 41 insertions(+), 36 deletions(-)

diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs @@ -208,53 +208,53 @@ transfer instr st = case instr of Adc dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Adcs dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Sbc dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Neg dst op -> - setTaintProv dst (operandTaint op st) (operandProv op st) st + setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st Negs dst op -> - setTaintProv dst (operandTaint op st) (operandProv op st) st + setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st Mul dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Mneg dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Madd dst r1 r2 r3 -> let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) (getProvenance r3 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Msub dst r1 r2 r3 -> let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) (getProvenance r3 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Umulh dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Smulh dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Udiv dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Sdiv dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st -- Logical: result is join of operand taints/provenances -- Special case: pointer untagging preserves ProvPublic and KindPtr @@ -271,50 +271,50 @@ transfer instr st = case instr of Orr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Eor dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Bic dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Mvn dst op -> - setTaintProv dst (operandTaint op st) (operandProv op st) st + setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st Tst _ _ -> st -- No destination - -- Shifts: result is join of operand taints/provenances + -- Shifts: result is scalar Lsl dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Lsr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Asr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Ror dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) p = provJoin2 (getProvenance r1 st) (operandProv op st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st - -- Bit manipulation: preserves provenance from source + -- Bit manipulation: result is scalar Ubfx dst r1 _ _ -> - setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st + setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st Sbfx dst r1 _ _ -> - setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st + setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st Bfi dst r1 _ _ -> let t = join2 (getTaint dst st) (getTaint r1 st) p = provJoin2 (getProvenance dst st) (getProvenance r1 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Extr dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st -- Address generation: result is public pointer (constant pool / PC-relative) Adr dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st @@ -338,29 +338,30 @@ transfer instr st = case instr of Stur src addr -> storeToStack src addr st Stp src1 src2 addr -> storePairToStack src1 src2 addr st - -- Conditionals: conservative join with provenance + -- Conditionals: result is scalar (conservative) Cmp _ _ -> st Cmn _ _ -> st Csel dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + k = joinKind (getKind r1 st) (getKind r2 st) + in setTaintProvKind dst t p k st Csinc dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Csinv dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Csneg dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) - in setTaintProv dst t p st + in setTaintProvKind dst t p KindScalar st Cset dst _ -> - setTaintProv dst Public ProvPublic st -- Condition flag derived + setTaintProvKind dst Public ProvPublic KindScalar st -- Condition flag Cinc dst r1 _ -> - setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st + setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st -- Branches: no register change B _ -> st @@ -598,11 +599,12 @@ provJoin3 a b c = joinProvenance a (joinProvenance b c) -- | Invalidate caller-saved registers after a call. -- Per AArch64 ABI, x0-x17 are caller-saved. --- Clears both taint and provenance. +-- Clears taint, provenance, and kind. invalidateCallerSaved :: TaintState -> TaintState invalidateCallerSaved st = st { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved + , tsKind = foldr (\r -> Map.insert r KindUnknown) (tsKind st) callerSaved } where callerSaved = @@ -697,17 +699,20 @@ joinSummary (FuncSummary a) (FuncSummary b) = FuncSummary (joinTaintState a b) -- | Apply a function summary at a call site. --- Replaces caller-saved register taints and provenance with summary values. +-- Replaces caller-saved register taints, provenance, and kinds with summary. applySummary :: FuncSummary -> TaintState -> TaintState applySummary (FuncSummary summ) st = st { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs , tsProv = foldr applyProv (tsProv st) callerSavedRegs + , tsKind = foldr applyKind (tsKind st) callerSavedRegs } where summRegs = tsRegs summ summProv = tsProv summ + summKind = tsKind summ applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s + applyKind r s = Map.insert r (Map.findWithDefault KindUnknown r summKind) s -- | Run dataflow analysis for a single function (subset of blocks). -- Returns the OUT state at return points (joined).