commit 278bc2d33847f190c2a5b10c1d0955db70428eb5
parent 208019cd069e91b783af8737490251946abe1e48
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 11 Feb 2026 20:36:59 +0400
fix: set KindScalar for scalar ops and clear kinds at call boundaries
Addresses reviewer comments on IMPL13:
1. Scalar operations (Adc, Neg, Mul, Orr, Lsl, Ubfx, Csel variants, etc.)
now use setTaintProvKind with KindScalar instead of setTaintProv,
preventing stale KindPtr from enabling incorrect provenance upgrades.
2. invalidateCallerSaved now also clears tsKind to KindUnknown for
caller-saved registers (x0-x17).
3. applySummary now also applies kinds from function summaries, not just
taint and provenance.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
1 file changed, 41 insertions(+), 36 deletions(-)
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -208,53 +208,53 @@ transfer instr st = case instr of
Adc dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Adcs dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Sbc dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Neg dst op ->
- setTaintProv dst (operandTaint op st) (operandProv op st) st
+ setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st
Negs dst op ->
- setTaintProv dst (operandTaint op st) (operandProv op st) st
+ setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st
Mul dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Mneg dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Madd dst r1 r2 r3 ->
let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st)
p = provJoin3 (getProvenance r1 st) (getProvenance r2 st)
(getProvenance r3 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Msub dst r1 r2 r3 ->
let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st)
p = provJoin3 (getProvenance r1 st) (getProvenance r2 st)
(getProvenance r3 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Umulh dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Smulh dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Udiv dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Sdiv dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
-- Logical: result is join of operand taints/provenances
-- Special case: pointer untagging preserves ProvPublic and KindPtr
@@ -271,50 +271,50 @@ transfer instr st = case instr of
Orr dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Eor dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Bic dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Mvn dst op ->
- setTaintProv dst (operandTaint op st) (operandProv op st) st
+ setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st
Tst _ _ -> st -- No destination
- -- Shifts: result is join of operand taints/provenances
+ -- Shifts: result is scalar
Lsl dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Lsr dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Asr dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Ror dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
- -- Bit manipulation: preserves provenance from source
+ -- Bit manipulation: result is scalar
Ubfx dst r1 _ _ ->
- setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
+ setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st
Sbfx dst r1 _ _ ->
- setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
+ setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st
Bfi dst r1 _ _ ->
let t = join2 (getTaint dst st) (getTaint r1 st)
p = provJoin2 (getProvenance dst st) (getProvenance r1 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Extr dst r1 r2 _ ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
-- Address generation: result is public pointer (constant pool / PC-relative)
Adr dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st
@@ -338,29 +338,30 @@ transfer instr st = case instr of
Stur src addr -> storeToStack src addr st
Stp src1 src2 addr -> storePairToStack src1 src2 addr st
- -- Conditionals: conservative join with provenance
+ -- Conditionals: result is scalar (conservative)
Cmp _ _ -> st
Cmn _ _ -> st
Csel dst r1 r2 _ ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ k = joinKind (getKind r1 st) (getKind r2 st)
+ in setTaintProvKind dst t p k st
Csinc dst r1 r2 _ ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Csinv dst r1 r2 _ ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Csneg dst r1 r2 _ ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
- in setTaintProv dst t p st
+ in setTaintProvKind dst t p KindScalar st
Cset dst _ ->
- setTaintProv dst Public ProvPublic st -- Condition flag derived
+ setTaintProvKind dst Public ProvPublic KindScalar st -- Condition flag
Cinc dst r1 _ ->
- setTaintProv dst (getTaint r1 st) (getProvenance r1 st) st
+ setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st
-- Branches: no register change
B _ -> st
@@ -598,11 +599,12 @@ provJoin3 a b c = joinProvenance a (joinProvenance b c)
-- | Invalidate caller-saved registers after a call.
-- Per AArch64 ABI, x0-x17 are caller-saved.
--- Clears both taint and provenance.
+-- Clears taint, provenance, and kind.
invalidateCallerSaved :: TaintState -> TaintState
invalidateCallerSaved st = st
{ tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved
, tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved
+ , tsKind = foldr (\r -> Map.insert r KindUnknown) (tsKind st) callerSaved
}
where
callerSaved =
@@ -697,17 +699,20 @@ joinSummary (FuncSummary a) (FuncSummary b) =
FuncSummary (joinTaintState a b)
-- | Apply a function summary at a call site.
--- Replaces caller-saved register taints and provenance with summary values.
+-- Replaces caller-saved register taints, provenance, and kinds with summary.
applySummary :: FuncSummary -> TaintState -> TaintState
applySummary (FuncSummary summ) st = st
{ tsRegs = foldr applyReg (tsRegs st) callerSavedRegs
, tsProv = foldr applyProv (tsProv st) callerSavedRegs
+ , tsKind = foldr applyKind (tsKind st) callerSavedRegs
}
where
summRegs = tsRegs summ
summProv = tsProv summ
+ summKind = tsKind summ
applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s
applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s
+ applyKind r s = Map.insert r (Map.findWithDefault KindUnknown r summKind) s
-- | Run dataflow analysis for a single function (subset of blocks).
-- Returns the OUT state at return points (joined).