commit 208019cd069e91b783af8737490251946abe1e48
parent 08fed64dea0812afa3566a21dfdca511d3dac660
Author: Jared Tobin <jared@jtobin.io>
Date: Wed, 11 Feb 2026 20:08:57 +0400
feat: add register kind tracking to prevent provenance laundering (IMPL13)
Introduces a register kind lattice (Ptr/Scalar/Unknown) and restricts
provenance-based upgrades to pointer kinds only. This prevents secret
scalars from being "laundered" through pointer arithmetic to become
valid base registers.
Changes:
- Add RegKind type and joinKind function to Types.hs
- Extend TaintState with tsKind map for kind tracking
- Propagate kinds in transfer function (adr/adrp → Ptr, loads → Scalar)
- Only allow provenance upgrade in checkBase if kind is KindPtr
- Add test for scalar provenance laundering detection
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
6 files changed, 222 insertions(+), 35 deletions(-)
diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs
@@ -28,6 +28,11 @@ import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock,
functionLabels, functionBlocks)
import Audit.AArch64.Taint
import Audit.AArch64.Types
+ ( Reg(..), Instr(..), Line(..), AddrMode(..)
+ , Taint(..), Provenance(..), RegKind(..)
+ , Violation(..), ViolationReason(..)
+ , TaintConfig(..)
+ )
import Control.DeepSeq (NFData)
import qualified Data.IntMap.Strict as IM
import qualified Data.Map.Strict as Map
@@ -138,25 +143,30 @@ checkAddrMode sym ln instr addr st = case addr of
-- | Check that base register is public.
-- If taint is Unknown, check provenance to see if we can upgrade to Public.
+-- Provenance upgrade is only allowed for pointer-kind registers to prevent
+-- laundering scalar indices through provenance.
checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkBase sym ln instr base st =
case getTaint base st of
Public -> []
Secret -> [Violation sym ln instr (SecretBase base)]
- Unknown -> case getProvenance base st of
- ProvPublic -> [] -- Provenance proves public derivation
- ProvUnknown -> [Violation sym ln instr (UnknownBase base)]
+ Unknown ->
+ -- Only upgrade via provenance if register has pointer kind
+ case (getProvenance base st, getKind base st) of
+ (ProvPublic, KindPtr) -> [] -- Pointer provenance proves safety
+ _ -> [Violation sym ln instr (UnknownBase base)]
-- | Check that index register is public.
-- If taint is Unknown, check provenance to see if we can upgrade to Public.
+-- Index registers should never be upgraded via provenance since they are
+-- typically scalars, not pointers.
checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation]
checkIndex sym ln instr idx st =
case getTaint idx st of
Public -> []
Secret -> [Violation sym ln instr (SecretIndex idx)]
- Unknown -> case getProvenance idx st of
- ProvPublic -> [] -- Provenance proves public derivation
- ProvUnknown -> [Violation sym ln instr (UnknownIndex idx)]
+ Unknown -> [Violation sym ln instr (UnknownIndex idx)]
+ -- No provenance upgrade for indices - they are scalar values
-- | Check entire CFG with inter-procedural analysis.
-- Computes function summaries via fixpoint, then checks each function.
diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs
@@ -22,6 +22,7 @@ module Audit.AArch64.Taint (
, getTaint
, setTaint
, getProvenance
+ , getKind
, publicRoots
, joinTaintState
, runDataflow
@@ -49,6 +50,7 @@ import Audit.AArch64.CFG
import Audit.AArch64.Types
( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..)
, Taint(..), joinTaint, Provenance(..), joinProvenance
+ , RegKind(..), joinKind
, TaintConfig(..), ArgPolicy(..)
)
import Data.IntMap.Strict (IntMap)
@@ -58,11 +60,11 @@ import qualified Data.IntSet as IS
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
-import Data.List (foldl')
import Data.Text (Text)
-- | Taint state: maps registers to their publicness, plus stack slots.
--- Also tracks provenance for upgrading Unknown to Public when provable.
+-- Also tracks provenance for upgrading Unknown to Public when provable,
+-- and register kinds (pointer vs scalar) for safe provenance upgrades.
data TaintState = TaintState
{ tsRegs :: !(Map Reg Taint)
-- ^ Register taints
@@ -72,6 +74,8 @@ data TaintState = TaintState
-- ^ Register provenance
, tsStackProv :: !(IntMap Provenance)
-- ^ Stack slot provenance (keyed by SP offset)
+ , tsKind :: !(Map Reg RegKind)
+ -- ^ Register kinds (pointer vs scalar)
} deriving (Eq, Show)
-- | GHC 9.10.3 AArch64 public root registers.
@@ -106,15 +110,18 @@ emptyTaintState = TaintState
, tsStack = IM.empty
, tsProv = Map.empty
, tsStackProv = IM.empty
+ , tsKind = Map.empty
}
-- | Initial taint state with public roots marked.
+-- Public roots are marked as Ptr kind (they are pointers).
initTaintState :: TaintState
initTaintState = TaintState
{ tsRegs = Map.fromList [(r, Public) | r <- publicRoots]
, tsStack = IM.empty
, tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots]
, tsStackProv = IM.empty
+ , tsKind = Map.fromList [(r, KindPtr) | r <- publicRoots]
}
-- | Seed argument registers according to policy.
@@ -145,6 +152,10 @@ getTaint r st = Map.findWithDefault Unknown r (tsRegs st)
getProvenance :: Reg -> TaintState -> Provenance
getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st)
+-- | Get the kind of a register.
+getKind :: Reg -> TaintState -> RegKind
+getKind r st = Map.findWithDefault KindUnknown r (tsKind st)
+
-- | Analyze a single line, updating taint state.
analyzeLine :: Line -> TaintState -> TaintState
analyzeLine l st = case lineInstr l of
@@ -161,33 +172,39 @@ analyzeBlock lns st = foldl' (flip analyzeLine) st lns
-- Also tracks provenance for upgrading Unknown bases.
transfer :: Instr -> TaintState -> TaintState
transfer instr st = case instr of
- -- Move: destination gets source taint and provenance
+ -- Move: destination gets source taint, provenance, and kind
Mov dst op ->
- setTaintProv dst (operandTaint op st) (operandProv op st) st
+ setTaintProvKind dst (operandTaint op st) (operandProv op st)
+ (operandKind op st) st
Movz dst _ _ ->
- setTaintProv dst Public ProvPublic st -- Immediate is public
+ setTaintProvKind dst Public ProvPublic KindScalar st -- Immediate is scalar
Movk _ _ _ -> st -- Keeps existing value, modifies bits
Movn dst _ _ ->
- setTaintProv dst Public ProvPublic st -- Immediate is public
+ setTaintProvKind dst Public ProvPublic KindScalar st -- Immediate is scalar
-- Arithmetic: result is join of operand taints/provenances
-- Clear stack map if SP is modified (offsets become stale)
+ -- Pointer arithmetic (ptr + imm) preserves pointer kind
Add dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in updateWithSpCheckProv dst t p st
+ k = pointerArithKind r1 op st
+ in updateWithSpCheckProvKind dst t p k st
Sub dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in updateWithSpCheckProv dst t p st
+ k = pointerArithKind r1 op st
+ in updateWithSpCheckProvKind dst t p k st
Adds dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in updateWithSpCheckProv dst t p st
+ k = pointerArithKind r1 op st
+ in updateWithSpCheckProvKind dst t p k st
Subs dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
- in updateWithSpCheckProv dst t p st
+ k = pointerArithKind r1 op st
+ in updateWithSpCheckProvKind dst t p k st
Adc dst r1 r2 ->
let t = join2 (getTaint r1 st) (getTaint r2 st)
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
@@ -240,16 +257,17 @@ transfer instr st = case instr of
in setTaintProv dst t p st
-- Logical: result is join of operand taints/provenances
- -- Special case: pointer untagging preserves ProvPublic provenance
+ -- Special case: pointer untagging preserves ProvPublic and KindPtr
And dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
srcProv = getProvenance r1 st
- -- Pointer untagging: preserve ProvPublic if source was ProvPublic
- p' = if isPointerUntagMask op && srcProv == ProvPublic
- then ProvPublic
- else p
- in setTaintProv dst t p' st
+ srcKind = getKind r1 st
+ isUntag = isPointerUntagMask op
+ -- Pointer untagging: preserve provenance and kind if source was Ptr
+ p' = if isUntag && srcProv == ProvPublic then ProvPublic else p
+ k = if isUntag && srcKind == KindPtr then KindPtr else KindScalar
+ in setTaintProvKind dst t p' k st
Orr dst r1 op ->
let t = join2 (getTaint r1 st) (operandTaint op st)
p = provJoin2 (getProvenance r1 st) (operandProv op st)
@@ -298,9 +316,9 @@ transfer instr st = case instr of
p = provJoin2 (getProvenance r1 st) (getProvenance r2 st)
in setTaintProv dst t p st
- -- Address generation: result is public (constant pool / PC-relative)
- Adr dst _ -> setTaintProv dst Public ProvPublic st
- Adrp dst _ -> setTaintProv dst Public ProvPublic st
+ -- Address generation: result is public pointer (constant pool / PC-relative)
+ Adr dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st
+ Adrp dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st
-- Loads: restore from stack slots if [sp, #imm], else Unknown
-- Exception: public roots stay public (GHC spills/restores them)
@@ -374,25 +392,36 @@ setTaintProv r t p st = st
, tsProv = Map.insert r p (tsProv st)
}
+-- | Set taint, provenance, and kind for a register.
+setTaintProvKind :: Reg -> Taint -> Provenance -> RegKind -> TaintState
+ -> TaintState
+setTaintProvKind r t p k st = st
+ { tsRegs = Map.insert r t (tsRegs st)
+ , tsProv = Map.insert r p (tsProv st)
+ , tsKind = Map.insert r k (tsKind st)
+ }
+
-- | Set taint for a loaded value, preserving public roots.
-- Public roots (SP, X19-X21, X28, etc.) stay public even when loaded
-- since GHC spills/restores them from the hardware stack.
+-- Loaded values are KindScalar unless they're public roots (which are Ptr).
setTaintLoad :: Reg -> TaintState -> TaintState
setTaintLoad dst st
- | isPublicRoot dst = setTaintProv dst Public ProvPublic st
- | otherwise = setTaintProv dst Unknown ProvUnknown st
+ | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st
+ | otherwise = setTaintProvKind dst Unknown ProvUnknown KindScalar st
where
isPublicRoot r = r `elem` publicRoots
-- | Set taint for a loaded value from a known stack slot.
-- If we have tracked taint/provenance at this offset, use it; else Unknown.
+-- Loaded values are KindScalar unless they're public roots (which are Ptr).
setTaintLoadStack :: Reg -> Int -> TaintState -> TaintState
setTaintLoadStack dst offset st
- | isPublicRoot dst = setTaintProv dst Public ProvPublic st
+ | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st
| otherwise =
let taint = IM.findWithDefault Unknown offset (tsStack st)
prov = IM.findWithDefault ProvUnknown offset (tsStackProv st)
- in setTaintProv dst taint prov st
+ in setTaintProvKind dst taint prov KindScalar st
where
isPublicRoot r = r `elem` publicRoots
@@ -407,11 +436,12 @@ setStackTaintProv offset t p st = st
clearStackMap :: TaintState -> TaintState
clearStackMap st = st { tsStack = IM.empty, tsStackProv = IM.empty }
--- | Update register taint and provenance, clearing stack map if dst is SP.
-updateWithSpCheckProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState
-updateWithSpCheckProv dst t p st
- | dst == SP = clearStackMap (setTaintProv dst t p st)
- | otherwise = setTaintProv dst t p st
+-- | Update register taint, provenance, and kind, clearing stack map if SP.
+updateWithSpCheckProvKind :: Reg -> Taint -> Provenance -> RegKind
+ -> TaintState -> TaintState
+updateWithSpCheckProvKind dst t p k st
+ | dst == SP = clearStackMap (setTaintProvKind dst t p k st)
+ | otherwise = setTaintProvKind dst t p k st
-- | Track a store to stack if address is [sp, #imm].
-- Pre/post-indexed addressing modifies SP, invalidating the stack map.
@@ -509,6 +539,25 @@ operandProv op st = case op of
OpLabel _ -> ProvPublic
OpAddr _ -> ProvUnknown -- Address operand provenance is complex
+-- | Get kind of an operand.
+operandKind :: Operand -> TaintState -> RegKind
+operandKind op st = case op of
+ OpReg r -> getKind r st
+ OpImm _ -> KindScalar -- Immediates are scalar values
+ OpShiftedReg r _ -> getKind r st
+ OpExtendedReg r _ -> getKind r st
+ OpLabel _ -> KindPtr -- Labels are addresses (pointers)
+ OpAddr _ -> KindUnknown -- Address operand kind is complex
+
+-- | Compute kind for pointer arithmetic (add/sub).
+-- If base is a pointer and operand is immediate, result is pointer.
+-- Otherwise result is scalar.
+pointerArithKind :: Reg -> Operand -> TaintState -> RegKind
+pointerArithKind base op st =
+ case op of
+ OpImm _ | getKind base st == KindPtr -> KindPtr
+ _ -> KindScalar
+
-- | Check if operand is a GHC pointer-untagging mask.
-- GHC uses low 3 bits for pointer tagging; this mask clears them.
-- Recognizing this pattern allows whitelisting heap traversal.
@@ -564,13 +613,14 @@ invalidateCallerSaved st = st
-- | Join two taint states (element-wise join).
-- For registers in both, take the join. For registers in only one, keep.
--- Stack slots and provenance are also joined element-wise.
+-- Stack slots, provenance, and kinds are also joined element-wise.
joinTaintState :: TaintState -> TaintState -> TaintState
joinTaintState a b = TaintState
{ tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b)
, tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b)
, tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b)
, tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b)
+ , tsKind = Map.unionWith joinKind (tsKind a) (tsKind b)
}
-- | Run forward dataflow analysis over a CFG.
@@ -630,6 +680,7 @@ initSummary = FuncSummary $ TaintState
, tsStack = IM.empty
, tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ]
, tsStackProv = IM.empty
+ , tsKind = Map.fromList [ (r, KindUnknown) | r <- callerSavedRegs ]
}
-- | Caller-saved registers per AArch64 ABI.
diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs
@@ -35,6 +35,10 @@ module Audit.AArch64.Types (
, Provenance(..)
, joinProvenance
+ -- * Register kind
+ , RegKind(..)
+ , joinKind
+
-- * Analysis results
, Violation(..)
, ViolationReason(..)
@@ -318,6 +322,23 @@ joinProvenance :: Provenance -> Provenance -> Provenance
joinProvenance ProvPublic ProvPublic = ProvPublic
joinProvenance _ _ = ProvUnknown
+-- | Register kind: distinguishes pointers from scalars.
+-- Used to restrict provenance upgrades to pointer-kind registers only.
+data RegKind
+ = KindPtr -- ^ Known to be a pointer (from adr/adrp or pointer arithmetic)
+ | KindScalar -- ^ Known to be a scalar (from loads or arithmetic)
+ | KindUnknown -- ^ Unknown kind
+ deriving (Eq, Ord, Show, Generic, NFData)
+
+instance ToJSON RegKind
+
+-- | Join register kinds: Ptr only if both are Ptr.
+joinKind :: RegKind -> RegKind -> RegKind
+joinKind KindPtr KindPtr = KindPtr
+joinKind KindScalar _ = KindScalar
+joinKind _ KindScalar = KindScalar
+joinKind _ _ = KindUnknown
+
-- | Reason for a violation.
data ViolationReason
= SecretBase !Reg -- ^ Base register is secret
diff --git a/plans/ARCH13.md b/plans/ARCH13.md
@@ -0,0 +1,43 @@
+# ARCH13: Register Kind Tracking (Pointer vs Scalar)
+
+## Goal
+
+Prevent provenance from "laundering" secret-derived scalars by
+introducing a register kind lattice (pointer vs scalar) and only using
+provenance upgrades for pointer kinds.
+
+## Scope
+
+Stage 1 (registers only):
+- Track kind for registers (`Ptr`/`Scalar`/`Unknown`).
+- Apply pointer-kind checks when upgrading via provenance.
+
+Stage 2 (with spills):
+- Extend kind tracking to stack slots to preserve pointer/scalar intent
+ across spills and reloads.
+
+## Rationale
+
+Provenance is safe for pointer bases but unsafe for scalar indices.
+Kind tracking separates these cases and avoids false negatives in
+secret-indexed memory accesses.
+
+## Kind Propagation Rules (Stage 1)
+
+- `adr/adrp` -> `Ptr`.
+- `mov dst, src` -> copy kind.
+- `add/sub dst, src, #imm` -> `Ptr` if src is `Ptr`, else `Scalar`.
+- `and` with pointer-untag mask -> preserve `Ptr`.
+- Loads -> `Scalar` by default.
+- Other arithmetic/logical ops -> `Scalar`.
+
+## Address Checks
+
+- Base registers: allow provenance upgrade only if kind is `Ptr`.
+- Index registers: never upgrade via provenance (or only if kind is
+ `Ptr`, which should be rare for indices).
+
+## Risks
+
+- Misclassifying pointer-preserving ops may increase false positives.
+- Without spill tracking, kind info can be lost across stack stores.
diff --git a/plans/IMPL13.md b/plans/IMPL13.md
@@ -0,0 +1,41 @@
+# IMPL13: Implement Register Kind Tracking (Stage 1)
+
+## Summary
+
+Add register kind tracking and restrict provenance upgrades to pointer
+kinds, improving detection of secret-indexed memory access.
+
+## Steps
+
+1) Add kind type
+- Introduce `RegKind` (Ptr/Scalar/Unknown) in `Types.hs`.
+- Add ToJSON if required.
+
+2) Extend TaintState
+- Add `tsKind :: Map Reg RegKind`.
+- Initialize public roots with `Ptr` (or a subset if needed).
+
+3) Transfer updates
+- For pointer-preserving ops, propagate `Ptr`.
+- For loads and general arithmetic/logic, set `Scalar`.
+- Keep `Unknown` when no information.
+
+4) Provenance upgrade rules
+- In `checkBase`, only treat `ProvPublic` as safe if `RegKind == Ptr`.
+- In `checkIndex`, disallow provenance upgrades (or require Ptr).
+
+5) Tests
+- Add tests that previously upgraded a scalar index via provenance and
+ now emit a violation.
+
+## Files to Touch
+
+- `lib/Audit/AArch64/Types.hs`
+- `lib/Audit/AArch64/Taint.hs`
+- `lib/Audit/AArch64/Check.hs`
+- `test/`
+
+## Validation
+
+- `cabal test`
+- Run auditor on known vulnerable asm to confirm detection.
diff --git a/test/Main.hs b/test/Main.hs
@@ -523,6 +523,27 @@ provenanceTests = testGroup "Provenance" [
case audit "test" src of
Left e -> assertFailure $ "parse failed: " ++ show e
Right ar -> assertEqual "one violation" 1 (length (arViolations ar))
+
+ , testCase "bad: scalar cannot be laundered via provenance" $ do
+ -- A scalar value added to a public pointer should NOT become a valid
+ -- base register, even though the result has public provenance.
+ -- This is the key case that kind tracking prevents.
+ let src = T.unlines
+ [ "foo:"
+ , " ldr x8, [x20]" -- x8 = Unknown, ProvUnknown, KindScalar
+ , " add x9, x20, x8" -- x9 = Unknown, ProvPublic, KindScalar
+ , " ldr x0, [x9]" -- violation: x9 is scalar, not pointer
+ , " ret"
+ ]
+ case audit "test" src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right ar -> do
+ assertEqual "one violation" 1 (length (arViolations ar))
+ case arViolations ar of
+ [v] -> case vReason v of
+ UnknownBase X9 -> pure ()
+ other -> assertFailure $ "wrong reason: " ++ show other
+ _ -> assertFailure "expected one violation"
]
-- Taint config tests