auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit 9a4b2fe34c3c30278e85ed7345299ee4e8e5530a
parent 9bb1c0547562a533bf8e7259310cda9115c33fdf
Author: Jared Tobin <jared@jtobin.io>
Date:   Sat, 28 Feb 2026 13:13:14 +0400

feat: parameterise all analysis over RuntimeConfig

Isolate GHC/STG-specific logic behind a RuntimeConfig record so the
auditor can support multiple runtimes. RuntimeConfig captures public
root registers, secondary stack configuration, local label
classification, pointer untagging masks, NCT pattern filtering, and
symbol encoding as data/closures selected once at CLI parse time.

New modules:
- Audit.AArch64.Runtime: RuntimeConfig and SecondaryStack types
- Audit.AArch64.Runtime.GHC: ghcRuntime, genericRuntime, moved
  isGhcRuntimeFinding helpers, zEncodeSymbol, ghcIsLocalLabel

Key changes:
- Types.hs: NctReason, NctFinding, LineMap moved here from NCT.hs
  to break circular dependency with Runtime.hs
- CFG.hs: isFunctionLabel, buildCFG, callTargets, buildCallGraph
  gain RuntimeConfig parameter; LLVM prefixes stay as base checks,
  NCG prefixes delegated to rtIsLocalLabel
- Taint.hs: all hardcoded X20/STG stack logic generalised to
  secondary stack via ssBaseReg; publicRoots, initTaintState,
  analyzeLine, runDataflow and all variants gain RuntimeConfig;
  isPointerUntagMask checks rtUntagMasks
- Check.hs: all check functions gain RuntimeConfig
- NCT.hs: scanNct gains RuntimeConfig; isGhcRuntimeFinding and
  helpers removed (moved to Runtime/GHC.hs); filterGhcRuntime
  replaced by generic filterRuntimePatterns
- CallGraph.hs: buildCallGraph gains RuntimeConfig
- AArch64.hs: all public API functions gain RuntimeConfig; exports
  RuntimeConfig, SecondaryStack, ghcRuntime, genericRuntime
- Main.hs: --runtime flag (haskell|generic); --show-ghc-runtime
  renamed to --show-runtime-patterns; --assume-stg-private renamed
  to --assume-secondary-private; old flags kept as hidden aliases;
  zEncodeSymbol removed in favour of rtEncodeSymbol from config

All 115 tests pass. No behaviour change when using ghcRuntime
(the default).

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Diffstat:
Mapp/Main.hs | 401+++++++++++++++++++++++++++++++++++++++++++++++--------------------------------
Mbench/Main.hs | 39++++++++++++++++++++++++++-------------
Mbench/Weight.hs | 35++++++++++++++++++++++++-----------
Mlib/Audit/AArch64.hs | 168+++++++++++++++++++++++++++++++++++++++++++++++++------------------------------
Mlib/Audit/AArch64/CFG.hs | 67++++++++++++++++++++++++++++++++-----------------------------------
Mlib/Audit/AArch64/CallGraph.hs | 52++++++++++++++++++++++++++++++----------------------
Mlib/Audit/AArch64/Check.hs | 207+++++++++++++++++++++++++++++++++++++++++++++++++------------------------------
Mlib/Audit/AArch64/NCT.hs | 377++++++-------------------------------------------------------------------------
Alib/Audit/AArch64/Runtime.hs | 58++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Alib/Audit/AArch64/Runtime/GHC.hs | 416+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mlib/Audit/AArch64/Taint.hs | 1463+++++++++++++++++++++++++++++++++++++++++++------------------------------------
Mlib/Audit/AArch64/Types.hs | 24++++++++++++++++++++++++
Mppad-auditor.cabal | 2++
Mtest/Main.hs | 1674++++++++++++++++++++++++++++++++++++++++++++++---------------------------------
14 files changed, 2880 insertions(+), 2103 deletions(-)

diff --git a/app/Main.hs b/app/Main.hs @@ -3,15 +3,20 @@ module Main where import Audit.AArch64 - ( AuditResult(..), Violation(..), ViolationReason(..) + ( RuntimeConfig(..) + , ghcRuntime, genericRuntime + , AuditResult(..), Violation(..), ViolationReason(..) , TaintConfig(..) - , NctReason(..), NctFinding(..), nctLine, nctInstr, nctReason - , LineMap, buildLineMap, isGhcRuntimeFinding + , NctReason(..), NctFinding(..) + , nctLine, nctInstr, nctReason + , LineMap, buildLineMap , SymbolScanResult(..), scanNct - , buildCallGraph, allSymbols, reachableSymbols, reachingSymbols + , buildCallGraph, allSymbols + , reachableSymbols, reachingSymbols , symbolExists , auditFile, auditFileInterProc - , auditFileWithConfig, auditFileInterProcWithConfig + , auditFileWithConfig + , auditFileInterProcWithConfig , parseFile, regName, loadTaintConfig , scanNctFile ) @@ -29,6 +34,9 @@ import qualified Data.Text.IO as TIO import Options.Applicative import System.Exit (exitFailure, exitSuccess) +data Runtime = Haskell | Generic + deriving (Eq, Show) + data Options = Options { optInput :: !FilePath , optJson :: !Bool @@ -39,15 +47,24 @@ data Options = Options , optDisplayUnknown :: !Bool , optScanNct :: !Bool , optNctDetail :: !Bool - , optShowGhcRuntime :: !Bool + , optShowRuntimePat :: !Bool , optSymbol :: !(Maybe Text) , optZSymbol :: !(Maybe Text) , optListSymbols :: !Bool , optSymbolFilter :: !(Maybe Text) , optCallers :: !Bool - , optAssumeStgPrivate :: !Bool + , optAssumeSecPrivate :: !Bool + , optRuntime :: !Runtime } deriving (Eq, Show) +runtimeReader :: ReadM Runtime +runtimeReader = eitherReader $ \s -> case s of + "haskell" -> Right Haskell + "generic" -> Right Generic + _ -> Left $ + "unknown runtime: " ++ s + ++ " (expected haskell or generic)" + optParser :: Parser Options optParser = Options <$> strOption @@ -84,75 +101,115 @@ optParser = Options <*> switch ( long "display-unknown" <> short 'u' - <> help "Display unknown violations (only secret shown by default)" + <> help "Display unknown violations \ + \(only secret shown by default)" ) <*> switch ( long "scan-nct" - <> help "Scan for non-constant-time instructions (no taint analysis)" + <> help "Scan for non-constant-time instructions \ + \(no taint analysis)" ) <*> switch ( long "nct-detail" - <> help "Show per-instruction details in NCT scan mode" + <> help "Show per-instruction details in NCT \ + \scan mode" ) - <*> switch - ( long "show-ghc-runtime" - <> help "Show GHC runtime patterns in NCT scan (hidden by default)" + <*> ( switch + ( long "show-runtime-patterns" + <> help "Show runtime patterns in NCT scan \ + \(hidden by default)" + ) + <|> switch + ( long "show-ghc-runtime" + <> internal + ) ) <*> optional (strOption ( long "symbol" <> short 's' <> metavar "SYMBOL" - <> help "Analyze only this symbol and its callees (NCT scan mode)" + <> help "Analyze only this symbol and its callees \ + \(NCT scan mode)" )) <*> optional (strOption ( long "zsymbol" <> short 'z' <> metavar "SYMBOL" - <> help "Human-readable symbol, auto z-encoded with _info$def \ + <> help "Human-readable symbol, auto z-encoded \ + \with _info$def \ \(e.g., pkg-1.0:Mod.Sub:func)" )) <*> switch ( long "list-symbols" <> short 'l' - <> help "List all function symbols in the assembly file" + <> help "List all function symbols in the \ + \assembly file" ) <*> optional (strOption ( long "filter" <> short 'f' <> metavar "PATTERN" - <> help "Filter symbols containing PATTERN (use with --list-symbols)" + <> help "Filter symbols containing PATTERN \ + \(use with --list-symbols)" )) <*> switch ( long "callers" <> short 'c' - <> help "Show callers instead of callees (use with --symbol)" + <> help "Show callers instead of callees \ + \(use with --symbol)" ) - <*> switch - ( long "assume-stg-private" - <> help "Treat untracked STG stack slots as private (default: public)" + <*> ( switch + ( long "assume-secondary-private" + <> help "Treat untracked secondary stack \ + \slots as private (default: public)" + ) + <|> switch + ( long "assume-stg-private" + <> internal + ) + ) + <*> option runtimeReader + ( long "runtime" + <> metavar "RUNTIME" + <> value Haskell + <> help "Runtime: haskell (default) or generic" ) optInfo :: ParserInfo Options optInfo = info (optParser <**> helper) ( fullDesc - <> progDesc "Audit AArch64 assembly for constant-time memory access" - <> header "auditor - CT memory access auditor for AArch64" + <> progDesc "Audit AArch64 assembly for \ + \constant-time memory access" + <> header "auditor - CT memory access auditor \ + \for AArch64" ) +-- | Select runtime configuration from CLI option. +selectRuntime :: Runtime -> RuntimeConfig +selectRuntime Haskell = ghcRuntime +selectRuntime Generic = genericRuntime + main :: IO () main = do opts <- execParser optInfo + let rt = selectRuntime (optRuntime opts) -- Compute effective symbol from --symbol or --zsymbol effSym <- case optZSymbol opts of - Just zs -> case zEncodeSymbol zs of - Left err -> do - TIO.putStrLn $ "Error: " <> err + Just zs -> case rtEncodeSymbol rt of + Nothing -> do + TIO.putStrLn + "Error: --zsymbol requires a runtime \ + \with symbol encoding (use --runtime haskell)" exitFailure - Right encoded -> pure (Just encoded) + Just encoder -> case encoder zs of + Left err -> do + TIO.putStrLn $ "Error: " <> err + exitFailure + Right encoded -> pure (Just encoded) Nothing -> pure (optSymbol opts) let opts' = opts { optSymbol = effSym } if optListSymbols opts' - then listSymbols opts' + then listSymbols rt opts' else if optParseOnly opts' then do result <- parseFile (optInput opts') @@ -161,26 +218,27 @@ main = do TIO.putStrLn $ "Error: " <> err exitFailure Right n -> do - TIO.putStrLn $ "Parsed " <> T.pack (show n) <> " lines" + TIO.putStrLn $ "Parsed " + <> T.pack (show n) <> " lines" exitSuccess else if optScanNct opts' then case optSymbol opts' of Just sym -> do - result <- scanNctForSymbol opts' sym + result <- scanNctForSymbol rt opts' sym case result of Left err -> do TIO.putStrLn $ "Error: " <> err exitFailure Right ssr -> - outputNctSymbol opts' ssr + outputNctSymbol rt opts' ssr Nothing -> do - result <- scanNctFile (optInput opts') + result <- scanNctFile rt (optInput opts') case result of Left err -> do TIO.putStrLn $ "Error: " <> err exitFailure Right (lineMap, findings) -> - outputNct opts' lineMap findings + outputNct rt opts' lineMap findings else do -- Load taint config if provided mcfg <- case optTaintConfig opts' of @@ -188,13 +246,15 @@ main = do Just path -> loadTaintConfig path case mcfg of Left err -> do - TIO.putStrLn $ "Error loading taint config: " <> err + TIO.putStrLn $ + "Error loading taint config: " <> err exitFailure Right baseCfg -> do - -- Apply CLI flag for STG assumption (--assume-stg-private overrides) - let assumeStg = not (optAssumeStgPrivate opts') - cfg = baseCfg { tcAssumeStgPublic = assumeStg } - auditor = selectAuditor opts' cfg + let assumeSec = + not (optAssumeSecPrivate opts') + cfg = baseCfg + { tcAssumeStgPublic = assumeSec } + auditor = selectAuditor rt opts' cfg result <- auditor (optInput opts') case result of Left err -> do @@ -207,15 +267,19 @@ main = do where emptyConfig = TaintConfig Map.empty True - selectAuditor opts cfg - | needsConfig && optInterProc opts = auditFileInterProcWithConfig cfg - | needsConfig = auditFileWithConfig cfg - | optInterProc opts = auditFileInterProc - | otherwise = auditFile + selectAuditor rt opts cfg + | needsConfig && optInterProc opts = + auditFileInterProcWithConfig rt cfg + | needsConfig = + auditFileWithConfig rt cfg + | optInterProc opts = + auditFileInterProc rt + | otherwise = + auditFile rt where - -- Use config path if we have policies OR non-default STG assumption - needsConfig = not (Map.null (tcPolicies cfg)) - || not (tcAssumeStgPublic cfg) + needsConfig = + not (Map.null (tcPolicies cfg)) + || not (tcAssumeStgPublic cfg) outputJson :: Options -> AuditResult -> IO () outputJson opts ar = @@ -231,12 +295,17 @@ outputText opts ar = do then pure () else do TIO.putStrLn "" - TIO.putStrLn $ "Lines checked: " <> T.pack (show (arLinesChecked ar)) - TIO.putStrLn $ "Memory accesses: " <> T.pack (show (arMemoryAccesses ar)) - TIO.putStrLn $ "Violations: " <> T.pack (show (length vs)) - if not (optDisplayUnknown opts) && length vs < length allVs + TIO.putStrLn $ "Lines checked: " + <> T.pack (show (arLinesChecked ar)) + TIO.putStrLn $ "Memory accesses: " + <> T.pack (show (arMemoryAccesses ar)) + TIO.putStrLn $ "Violations: " + <> T.pack (show (length vs)) + if not (optDisplayUnknown opts) + && length vs < length allVs then TIO.putStrLn $ " (hidden): " - <> T.pack (show (length allVs - length vs)) + <> T.pack (show + (length allVs - length vs)) <> " unknown (use -u to show)" else pure () if null vs @@ -245,8 +314,8 @@ outputText opts ar = do -- | List all function symbols in the assembly file. -- With --symbol, lists callees (or callers with --callers). -listSymbols :: Options -> IO () -listSymbols opts = do +listSymbols :: RuntimeConfig -> Options -> IO () +listSymbols rt opts = do bs <- BS.readFile (optInput opts) case decodeUtf8' bs of Left err -> do @@ -258,42 +327,54 @@ listSymbols opts = do TIO.putStrLn "Error: parse failed" exitFailure Right lns -> do - let cg = buildCallGraph lns + let cg = buildCallGraph rt lns case optSymbol opts of Just sym | not (symbolExists sym cg) -> do - TIO.putStrLn $ "Error: symbol not found: " <> sym + TIO.putStrLn $ + "Error: symbol not found: " <> sym exitFailure | optCallers opts -> do - -- Show symbols that can reach this symbol - let callers = Set.toAscList (reachingSymbols sym cg) + let callers = Set.toAscList + (reachingSymbols sym cg) mapM_ TIO.putStrLn callers if optQuiet opts then pure () - else TIO.putStrLn $ "\n" <> T.pack (show (length callers)) - <> " symbols can reach " <> sym + else TIO.putStrLn $ + "\n" <> T.pack + (show (length callers)) + <> " symbols can reach " <> sym | otherwise -> do - -- Show symbols reachable from this symbol - let callees = Set.toAscList (reachableSymbols sym cg) + let callees = Set.toAscList + (reachableSymbols sym cg) mapM_ TIO.putStrLn callees if optQuiet opts then pure () - else TIO.putStrLn $ "\n" <> T.pack (show (length callees)) - <> " symbols reachable from " <> sym + else TIO.putStrLn $ + "\n" <> T.pack + (show (length callees)) + <> " symbols reachable from " + <> sym Nothing -> do let syms = Set.toAscList (allSymbols cg) filtered = case optSymbolFilter opts of Nothing -> syms - Just pat -> filter (T.isInfixOf pat) syms + Just pat -> + filter (T.isInfixOf pat) syms mapM_ TIO.putStrLn filtered if optQuiet opts then pure () - else TIO.putStrLn $ "\n" <> T.pack (show (length filtered)) - <> " symbols" + else TIO.putStrLn $ + "\n" <> T.pack + (show (length filtered)) + <> " symbols" --- | NCT scan for a symbol (callees or callers based on options). -scanNctForSymbol :: Options -> Text -> IO (Either Text SymbolScanResult) -scanNctForSymbol opts rootSym = do +-- | NCT scan for a symbol (callees or callers based +-- on options). +scanNctForSymbol + :: RuntimeConfig -> Options -> Text + -> IO (Either Text SymbolScanResult) +scanNctForSymbol rt opts rootSym = do bs <- BS.readFile (optInput opts) case decodeUtf8' bs of Left err -> pure (Left (T.pack (show err))) @@ -301,17 +382,19 @@ scanNctForSymbol opts rootSym = do case parseAsm src of Left _ -> pure (Left "parse failed") Right lns -> do - let cg = buildCallGraph lns + let cg = buildCallGraph rt lns if not (symbolExists rootSym cg) - then pure (Left ("symbol not found: " <> rootSym)) + then pure (Left + ("symbol not found: " <> rootSym)) else do let syms = if optCallers opts - then reachingSymbols rootSym cg - else reachableSymbols rootSym cg + then reachingSymbols rootSym cg + else reachableSymbols rootSym cg lineMap = buildLineMap lns - allFindings = scanNct lns + allFindings = scanNct rt lns filtered = Map.filterWithKey - (\sym _ -> Set.member sym syms) allFindings + (\sym _ -> Set.member sym syms) + allFindings pure $ Right $ SymbolScanResult { ssrRootSymbol = rootSym , ssrReachable = Set.size syms @@ -319,82 +402,103 @@ scanNctForSymbol opts rootSym = do , ssrFindings = filtered } --- | Output NCT scan results for a specific symbol and its callees. -outputNctSymbol :: Options -> SymbolScanResult -> IO () -outputNctSymbol opts ssr = do +-- | Output NCT scan results for a specific symbol +-- and its callees. +outputNctSymbol :: RuntimeConfig -> Options + -> SymbolScanResult -> IO () +outputNctSymbol rt opts ssr = do let lineMap = ssrLineMap ssr findings = ssrFindings ssr - showGhc = optShowGhcRuntime opts - isReal = not . isGhcRuntimeFinding lineMap - filterFindings = if showGhc then id else filter isReal - syms = [(sym, filterFindings fs) | (sym, fs) <- Map.toList findings] + showRt = optShowRuntimePat opts + filterFn = rtFilterNct rt + isRt = filterFn lineMap + filterFindings = + if showRt then id else filter (not . isRt) + syms = [ (sym, filterFindings fs) + | (sym, fs) <- Map.toList findings ] realSyms = filter (not . null . snd) syms total = sum (map (length . snd) realSyms) if optNctDetail opts - then mapM_ (printNctDetail showGhc lineMap) realSyms - else mapM_ (printNctSummary showGhc lineMap) realSyms + then mapM_ (printNctDetail showRt isRt) realSyms + else mapM_ (printNctSummary showRt isRt) realSyms if optQuiet opts then pure () else do TIO.putStrLn "" - TIO.putStrLn $ "Root symbol: " <> ssrRootSymbol ssr - TIO.putStrLn $ "Reachable symbols: " <> T.pack (show (ssrReachable ssr)) - TIO.putStrLn $ "With findings: " <> T.pack (show (length realSyms)) - TIO.putStrLn $ "NCT findings: " <> T.pack (show total) + TIO.putStrLn $ "Root symbol: " + <> ssrRootSymbol ssr + TIO.putStrLn $ "Reachable symbols: " + <> T.pack (show (ssrReachable ssr)) + TIO.putStrLn $ "With findings: " + <> T.pack (show (length realSyms)) + TIO.putStrLn $ "NCT findings: " + <> T.pack (show total) if total == 0 then exitSuccess else exitFailure -- | Output NCT scan results. -outputNct :: Options -> LineMap -> Map.Map Text [NctFinding] -> IO () -outputNct opts lineMap findings = do - let showGhc = optShowGhcRuntime opts - isReal = not . isGhcRuntimeFinding lineMap - -- Filter findings per symbol - filterFindings = if showGhc then id else filter isReal - syms = [(sym, filterFindings fs) | (sym, fs) <- Map.toList findings] - -- Filter to symbols with at least one finding +outputNct :: RuntimeConfig -> Options -> LineMap + -> Map.Map Text [NctFinding] -> IO () +outputNct rt opts lineMap findings = do + let showRt = optShowRuntimePat opts + filterFn = rtFilterNct rt + isRt = filterFn lineMap + filterFindings = + if showRt then id else filter (not . isRt) + syms = [ (sym, filterFindings fs) + | (sym, fs) <- Map.toList findings ] realSyms = filter (not . null . snd) syms total = sum (map (length . snd) realSyms) if optNctDetail opts - then mapM_ (printNctDetail showGhc lineMap) realSyms - else mapM_ (printNctSummary showGhc lineMap) realSyms + then mapM_ (printNctDetail showRt isRt) realSyms + else mapM_ (printNctSummary showRt isRt) realSyms if optQuiet opts then pure () else do TIO.putStrLn "" - TIO.putStrLn $ "Functions scanned: " <> T.pack (show (length realSyms)) - TIO.putStrLn $ "NCT findings: " <> T.pack (show total) + TIO.putStrLn $ "Functions scanned: " + <> T.pack (show (length realSyms)) + TIO.putStrLn $ "NCT findings: " + <> T.pack (show total) if total == 0 then exitSuccess else exitFailure -printNctSummary :: Bool -> LineMap -> (Text, [NctFinding]) -> IO () -printNctSummary showGhc lineMap (sym, fs) = do - TIO.putStrLn $ sym <> ": " <> T.pack (show (length fs)) - mapM_ (printFindingIndented showGhc lineMap) fs +printNctSummary :: Bool -> (NctFinding -> Bool) + -> (Text, [NctFinding]) -> IO () +printNctSummary showRt isRt (sym, fs) = do + TIO.putStrLn $ sym <> ": " + <> T.pack (show (length fs)) + mapM_ (printFindingIndented showRt isRt) fs -printFindingIndented :: Bool -> LineMap -> NctFinding -> IO () -printFindingIndented showGhc lineMap f = - let isGhc = isGhcRuntimeFinding lineMap f +printFindingIndented :: Bool -> (NctFinding -> Bool) + -> NctFinding -> IO () +printFindingIndented showRt isRt f = + let rtMatch = isRt f content = T.pack (show (nctLine f)) <> ": " - <> nctReasonText (nctReason f) <> ": " <> instrText (nctInstr f) - line = if showGhc && isGhc - then " (ghc runtime) " <> content + <> nctReasonText (nctReason f) <> ": " + <> instrText (nctInstr f) + line = if showRt && rtMatch + then " (runtime) " <> content else " " <> content in TIO.putStrLn line -printNctDetail :: Bool -> LineMap -> (Text, [NctFinding]) -> IO () -printNctDetail showGhc lineMap (sym, fs) = - mapM_ (printFinding showGhc lineMap sym) fs +printNctDetail :: Bool -> (NctFinding -> Bool) + -> (Text, [NctFinding]) -> IO () +printNctDetail showRt isRt (sym, fs) = + mapM_ (printFinding showRt isRt sym) fs -printFinding :: Bool -> LineMap -> Text -> NctFinding -> IO () -printFinding showGhc lineMap sym f = - let isGhc = isGhcRuntimeFinding lineMap f - content = sym <> ":" <> T.pack (show (nctLine f)) <> ": " - <> nctReasonText (nctReason f) <> ": " <> instrText (nctInstr f) - line = if showGhc && isGhc - then "(ghc runtime) " <> content +printFinding :: Bool -> (NctFinding -> Bool) + -> Text -> NctFinding -> IO () +printFinding showRt isRt sym f = + let rtMatch = isRt f + content = sym <> ":" + <> T.pack (show (nctLine f)) <> ": " + <> nctReasonText (nctReason f) <> ": " + <> instrText (nctInstr f) + line = if showRt && rtMatch + then "(runtime) " <> content else content in TIO.putStrLn line @@ -410,10 +514,11 @@ instrText instr = T.pack (show instr) -- | Filter violations based on options. -- By default, only secret violations are shown. -filterViolations :: Options -> [Violation] -> [Violation] +filterViolations :: Options -> [Violation] + -> [Violation] filterViolations opts | optDisplayUnknown opts = id - | otherwise = filter (isSecretViolation . vReason) + | otherwise = filter (isSecretViolation . vReason) -- | Check if a violation reason is secret (not unknown). isSecretViolation :: ViolationReason -> Bool @@ -422,48 +527,22 @@ isSecretViolation r = case r of SecretIndex _ -> True UnknownBase _ -> False UnknownIndex _ -> False - NonConstOffset -> True -- Treat as secret-level concern + NonConstOffset -> True printViolation :: Violation -> IO () printViolation v = TIO.putStrLn $ - vSymbol v <> ":" <> T.pack (show (vLine v)) <> ": " <> reasonText (vReason v) + vSymbol v <> ":" <> T.pack (show (vLine v)) + <> ": " <> reasonText (vReason v) reasonText :: ViolationReason -> Text reasonText r = case r of - SecretBase reg -> "secret base register " <> regName reg - SecretIndex reg -> "secret index register " <> regName reg - UnknownBase reg -> "unknown base register " <> regName reg - UnknownIndex reg -> "unknown index register " <> regName reg - NonConstOffset -> "non-constant offset without masking" - --- | Z-encode a human-readable Haskell symbol for GHC assembly lookup. --- --- Input format: @\<package\>:\<Module.Path\>:\<identifier\>@ --- --- Output: @_\<z-pkg\>_\<z-mod\>_\<z-id\>_info$def@ -zEncodeSymbol :: Text -> Either Text Text -zEncodeSymbol input = - case T.splitOn ":" input of - [pkg, modPath, ident] -> - let encoded = T.intercalate "_" - [zEncodePart pkg, zEncodePart modPath, zEncodePart ident] - in Right ("_" <> encoded <> "_info$def") - parts -> - Left $ "Invalid symbol format: expected <package>:<Module.Path>:<id>, \ - \got " <> T.pack (show (length parts)) <> " parts" - --- | Z-encode a single component (package, module path, or identifier). --- See GHC's compiler/GHC/Utils/Encoding.hs for the full encoding table. -zEncodePart :: Text -> Text -zEncodePart = T.concatMap encodeChar - where - encodeChar c = case c of - '-' -> "zm" - '.' -> "zi" - '_' -> "zu" - 'z' -> "zz" - 'Z' -> "ZZ" - '$' -> "zd" - '\'' -> "zq" - '#' -> "zh" - _ -> T.singleton c + SecretBase reg -> + "secret base register " <> regName reg + SecretIndex reg -> + "secret index register " <> regName reg + UnknownBase reg -> + "unknown base register " <> regName reg + UnknownIndex reg -> + "unknown index register " <> regName reg + NonConstOffset -> + "non-constant offset without masking" diff --git a/bench/Main.hs b/bench/Main.hs @@ -11,8 +11,10 @@ module Main where import Audit.AArch64.CFG (buildCFG) -import Audit.AArch64.Check (checkCFG, checkCFGInterProc) +import Audit.AArch64.Check + (checkCFG, checkCFGInterProc) import Audit.AArch64.Parser (parseAsm) +import Audit.AArch64.Runtime.GHC (ghcRuntime) import Criterion.Main import qualified Data.ByteString as BS import Data.Text (Text) @@ -27,27 +29,38 @@ main = do -- Pre-parse for CFG and analysis benchmarks smallLines <- case parseAsm smallSrc of Right lns -> pure lns - Left e -> error $ "Failed to parse CurveNCG.s: " ++ show e + Left e -> error $ + "Failed to parse CurveNCG.s: " ++ show e largeLines <- case parseAsm largeSrc of Right lns -> pure lns - Left e -> error $ "Failed to parse secp256k1NCG.s: " ++ show e - let smallCFG = buildCFG smallLines - largeCFG = buildCFG largeLines + Left e -> error $ + "Failed to parse secp256k1NCG.s: " ++ show e + let rt = ghcRuntime + smallCFG = buildCFG rt smallLines + largeCFG = buildCFG rt largeLines defaultMain [ bgroup "parse" - [ bench "small (CurveNCG)" $ nf parseAsm smallSrc - , bench "large (secp256k1NCG)" $ nf parseAsm largeSrc + [ bench "small (CurveNCG)" $ + nf parseAsm smallSrc + , bench "large (secp256k1NCG)" $ + nf parseAsm largeSrc ] , bgroup "cfg" - [ bench "small (CurveNCG)" $ nf buildCFG smallLines - , bench "large (secp256k1NCG)" $ nf buildCFG largeLines + [ bench "small (CurveNCG)" $ + nf (buildCFG rt) smallLines + , bench "large (secp256k1NCG)" $ + nf (buildCFG rt) largeLines ] , bgroup "taint" - [ bench "intra-small" $ nf (checkCFG "bench") smallCFG - , bench "intra-large" $ nf (checkCFG "bench") largeCFG - , bench "inter-small" $ nf (checkCFGInterProc "bench") smallCFG - , bench "inter-large" $ nf (checkCFGInterProc "bench") largeCFG + [ bench "intra-small" $ + nf (checkCFG rt "bench") smallCFG + , bench "intra-large" $ + nf (checkCFG rt "bench") largeCFG + , bench "inter-small" $ + nf (checkCFGInterProc rt "bench") smallCFG + , bench "inter-large" $ + nf (checkCFGInterProc rt "bench") largeCFG ] ] diff --git a/bench/Weight.hs b/bench/Weight.hs @@ -11,8 +11,11 @@ module Main where import Audit.AArch64.CFG (buildCFG, CFG) -import Audit.AArch64.Check (checkCFG, checkCFGInterProc) +import Audit.AArch64.Check + (checkCFG, checkCFGInterProc) import Audit.AArch64.Parser (parseAsm) +import Audit.AArch64.Runtime (RuntimeConfig) +import Audit.AArch64.Runtime.GHC (ghcRuntime) import Audit.AArch64.Types (Line) import qualified Data.ByteString as BS import Data.Text (Text) @@ -20,6 +23,9 @@ import Data.Text.Encoding (decodeUtf8) import System.IO.Unsafe (unsafePerformIO) import Weigh +rt :: RuntimeConfig +rt = ghcRuntime + main :: IO () main = mainWith $ do setColumns [Case, Allocated, GCs] @@ -31,26 +37,33 @@ main = mainWith $ do -- CFG benchmarks wgroup "cfg" $ do - func "small (CurveNCG)" buildCFG smallLines - func "large (secp256k1NCG)" buildCFG largeLines + func "small (CurveNCG)" (buildCFG rt) smallLines + func "large (secp256k1NCG)" + (buildCFG rt) largeLines -- Taint benchmarks wgroup "taint-intra" $ do - func "small (CurveNCG)" (checkCFG "bench") smallCFG - func "large (secp256k1NCG)" (checkCFG "bench") largeCFG + func "small (CurveNCG)" + (checkCFG rt "bench") smallCFG + func "large (secp256k1NCG)" + (checkCFG rt "bench") largeCFG wgroup "taint-inter" $ do - func "small (CurveNCG)" (checkCFGInterProc "bench") smallCFG - func "large (secp256k1NCG)" (checkCFGInterProc "bench") largeCFG + func "small (CurveNCG)" + (checkCFGInterProc rt "bench") smallCFG + func "large (secp256k1NCG)" + (checkCFGInterProc rt "bench") largeCFG -- Fixtures loaded at top-level (evaluated once) {-# NOINLINE smallSrc #-} smallSrc :: Text -smallSrc = unsafePerformIO $ loadFixture "etc/CurveNCG.s" +smallSrc = unsafePerformIO $ + loadFixture "etc/CurveNCG.s" {-# NOINLINE largeSrc #-} largeSrc :: Text -largeSrc = unsafePerformIO $ loadFixture "etc/secp256k1NCG.s" +largeSrc = unsafePerformIO $ + loadFixture "etc/secp256k1NCG.s" {-# NOINLINE smallLines #-} smallLines :: [Line] @@ -66,11 +79,11 @@ largeLines = case parseAsm largeSrc of {-# NOINLINE smallCFG #-} smallCFG :: CFG -smallCFG = buildCFG smallLines +smallCFG = buildCFG rt smallLines {-# NOINLINE largeCFG #-} largeCFG :: CFG -largeCFG = buildCFG largeLines +largeCFG = buildCFG rt largeLines loadFixture :: FilePath -> IO Text loadFixture path = decodeUtf8 <$> BS.readFile path diff --git a/lib/Audit/AArch64.hs b/lib/Audit/AArch64.hs @@ -9,10 +9,10 @@ -- -- AArch64 constant-time memory access auditor. -- --- This module provides static analysis for AArch64 GHC+LLVM assembly --- to verify that memory accesses use only public (non-secret-derived) --- addresses. This helps ensure constant-time properties for --- cryptographic code. +-- This module provides static analysis for AArch64 assembly +-- to verify that memory accesses use only public +-- (non-secret-derived) addresses. This helps ensure +-- constant-time properties for cryptographic code. -- -- Example usage: -- @@ -22,14 +22,20 @@ -- -- main = do -- src <- TIO.readFile "foo.s" --- case audit "foo.s" src of +-- case audit ghcRuntime "foo.s" src of -- Left err -> putStrLn $ "Parse error: " ++ show err -- Right result -> print result -- @ module Audit.AArch64 ( + -- * Runtime configuration + RuntimeConfig(..) + , SecondaryStack(..) + , ghcRuntime + , genericRuntime + -- * Main API - audit + , audit , auditInterProc , auditFile , auditFileInterProc @@ -46,11 +52,12 @@ module Audit.AArch64 ( , SymbolScanResult(..) , NctReason(..) , NctFinding(..) - -- ** GHC runtime classification + -- ** Runtime-aware filtering , LineMap , buildLineMap + , filterRuntimePatterns + -- ** GHC runtime classification (re-export) , isGhcRuntimeFinding - , filterGhcRuntime -- ** Call graph , CallGraph , buildCallGraph @@ -86,16 +93,22 @@ import Audit.AArch64.CallGraph , reachingSymbols , symbolExists ) -import qualified Audit.AArch64.CallGraph as CG (buildCallGraph) +import qualified Audit.AArch64.CallGraph as CG + (buildCallGraph) import Audit.AArch64.Check import Audit.AArch64.NCT - ( NctReason(..), NctFinding(..), scanNct - , LineMap, buildLineMap, isGhcRuntimeFinding, filterGhcRuntime + ( scanNct + , buildLineMap, filterRuntimePatterns ) import Audit.AArch64.Parser +import Audit.AArch64.Runtime + (RuntimeConfig(..), SecondaryStack(..)) +import Audit.AArch64.Runtime.GHC + (ghcRuntime, genericRuntime, isGhcRuntimeFinding) import Audit.AArch64.Types ( Reg, Violation(..), ViolationReason(..), regName , TaintConfig(..), ArgPolicy(..), emptyArgPolicy + , NctReason(..), NctFinding(..), LineMap ) import Data.Aeson (eitherDecodeStrict') import qualified Data.ByteString as BS @@ -106,40 +119,47 @@ import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8') -- | Audit assembly source for memory access violations. -audit :: Text -> Text -> Either ParseError AuditResult -audit name src = do +audit :: RuntimeConfig -> Text -> Text + -> Either ParseError AuditResult +audit rt name src = do lns <- parseAsm src - let cfg = buildCFG lns - pure (checkCFG name cfg) + let cfg = buildCFG rt lns + pure (checkCFG rt name cfg) -- | Audit with inter-procedural analysis. -auditInterProc :: Text -> Text -> Either ParseError AuditResult -auditInterProc name src = do +auditInterProc :: RuntimeConfig -> Text -> Text + -> Either ParseError AuditResult +auditInterProc rt name src = do lns <- parseAsm src - let cfg = buildCFG lns - pure (checkCFGInterProc name cfg) + let cfg = buildCFG rt lns + pure (checkCFGInterProc rt name cfg) -- | Audit an assembly file. -auditFile :: FilePath -> IO (Either Text AuditResult) -auditFile = auditFileWith audit +auditFile :: RuntimeConfig -> FilePath + -> IO (Either Text AuditResult) +auditFile rt = auditFileWith (audit rt) -- | Audit an assembly file with inter-procedural analysis. -auditFileInterProc :: FilePath -> IO (Either Text AuditResult) -auditFileInterProc = auditFileWith auditInterProc +auditFileInterProc :: RuntimeConfig -> FilePath + -> IO (Either Text AuditResult) +auditFileInterProc rt = auditFileWith (auditInterProc rt) -- | Helper for file auditing. -auditFileWith :: (Text -> Text -> Either ParseError AuditResult) - -> FilePath -> IO (Either Text AuditResult) +auditFileWith + :: (Text -> Text -> Either ParseError AuditResult) + -> FilePath -> IO (Either Text AuditResult) auditFileWith auditor path = do bs <- BS.readFile path case decodeUtf8' bs of Left err -> pure (Left (T.pack (show err))) Right src -> case auditor (T.pack path) src of - Left err -> pure (Left (T.pack (showParseError err))) + Left err -> + pure (Left (T.pack (showParseError err))) Right result -> pure (Right result) --- | Parse an assembly file without analysis. Returns line count on success. +-- | Parse an assembly file without analysis. +-- Returns line count on success. parseFile :: FilePath -> IO (Either Text Int) parseFile path = do bs <- BS.readFile path @@ -147,32 +167,42 @@ parseFile path = do Left err -> pure (Left (T.pack (show err))) Right src -> case parseAsm src of - Left err -> pure (Left (T.pack (showParseError err))) + Left err -> + pure (Left (T.pack (showParseError err))) Right lns -> pure (Right (length lns)) -- | Audit assembly source with taint config. -auditWithConfig :: TaintConfig -> Text -> Text -> Either ParseError AuditResult -auditWithConfig tcfg name src = do +auditWithConfig :: RuntimeConfig -> TaintConfig + -> Text -> Text + -> Either ParseError AuditResult +auditWithConfig rt tcfg name src = do lns <- parseAsm src - let cfg = buildCFG lns - pure (checkCFGWithConfig tcfg name cfg) + let cfg = buildCFG rt lns + pure (checkCFGWithConfig rt tcfg name cfg) -- | Audit with inter-procedural analysis and taint config. -auditInterProcWithConfig :: TaintConfig -> Text -> Text - -> Either ParseError AuditResult -auditInterProcWithConfig tcfg name src = do +auditInterProcWithConfig + :: RuntimeConfig -> TaintConfig -> Text -> Text + -> Either ParseError AuditResult +auditInterProcWithConfig rt tcfg name src = do lns <- parseAsm src - let cfg = buildCFG lns - pure (checkCFGInterProcWithConfig tcfg name cfg) + let cfg = buildCFG rt lns + pure (checkCFGInterProcWithConfig rt tcfg name cfg) -- | Audit an assembly file with taint config. -auditFileWithConfig :: TaintConfig -> FilePath -> IO (Either Text AuditResult) -auditFileWithConfig tcfg = auditFileWith (auditWithConfig tcfg) +auditFileWithConfig :: RuntimeConfig -> TaintConfig + -> FilePath + -> IO (Either Text AuditResult) +auditFileWithConfig rt tcfg = + auditFileWith (auditWithConfig rt tcfg) --- | Audit an assembly file with inter-procedural analysis and taint config. -auditFileInterProcWithConfig :: TaintConfig -> FilePath - -> IO (Either Text AuditResult) -auditFileInterProcWithConfig tcfg = auditFileWith (auditInterProcWithConfig tcfg) +-- | Audit an assembly file with inter-procedural analysis +-- and taint config. +auditFileInterProcWithConfig + :: RuntimeConfig -> TaintConfig -> FilePath + -> IO (Either Text AuditResult) +auditFileInterProcWithConfig rt tcfg = + auditFileWith (auditInterProcWithConfig rt tcfg) -- | Load a taint config from a JSON file. loadTaintConfig :: FilePath -> IO (Either Text TaintConfig) @@ -182,21 +212,24 @@ loadTaintConfig path = do Left err -> pure (Left (T.pack err)) Right cfg -> pure (Right cfg) --- | Scan an assembly file for non-constant-time instructions. --- Returns a LineMap (for GHC runtime classification) and the findings. +-- | Scan an assembly file for non-constant-time +-- instructions. Returns a LineMap (for runtime +-- classification) and the findings. scanNctFile - :: FilePath - -> IO (Either Text (LineMap, Map.Map Text [NctFinding])) -scanNctFile path = do + :: RuntimeConfig -> FilePath + -> IO (Either Text + (LineMap, Map.Map Text [NctFinding])) +scanNctFile rt path = do bs <- BS.readFile path case decodeUtf8' bs of Left err -> pure (Left (T.pack (show err))) Right src -> case parseAsm src of - Left err -> pure (Left (T.pack (showParseError err))) + Left err -> + pure (Left (T.pack (showParseError err))) Right lns -> let lineMap = buildLineMap lns - findings = scanNct lns + findings = scanNct rt lns in pure (Right (lineMap, findings)) -- | Result of symbol-focused NCT scan. @@ -206,39 +239,46 @@ data SymbolScanResult = SymbolScanResult , ssrReachable :: !Int -- ^ Number of reachable symbols (including root) , ssrLineMap :: !LineMap - -- ^ Line map for GHC runtime classification + -- ^ Line map for runtime classification , ssrFindings :: !(Map.Map Text [NctFinding]) -- ^ Findings filtered to reachable symbols only } --- | Scan an assembly file for NCT instructions, focused on a specific symbol. +-- | Scan an assembly file for NCT instructions, focused +-- on a specific symbol. -- --- Uses call graph analysis to find all symbols reachable from the given --- root symbol and returns findings only for those symbols. +-- Uses call graph analysis to find all symbols reachable +-- from the given root symbol and returns findings only for +-- those symbols. -- --- Returns 'Left' if parsing fails or the symbol doesn't exist. +-- Returns 'Left' if parsing fails or the symbol doesn't +-- exist. scanNctFileForSymbol - :: Text -- ^ Root symbol to analyze + :: RuntimeConfig + -> Text -- ^ Root symbol to analyze -> FilePath -- ^ Assembly file path -> IO (Either Text SymbolScanResult) -scanNctFileForSymbol rootSym path = do +scanNctFileForSymbol rt rootSym path = do bs <- BS.readFile path case decodeUtf8' bs of Left err -> pure (Left (T.pack (show err))) Right src -> case parseAsm src of - Left err -> pure (Left (T.pack (showParseError err))) + Left err -> + pure (Left (T.pack (showParseError err))) Right lns -> do - let callGraph = CG.buildCallGraph lns + let callGraph = CG.buildCallGraph rt lns if not (symbolExists rootSym callGraph) - then pure (Left ("symbol not found: " <> rootSym)) + then pure (Left + ("symbol not found: " <> rootSym)) else do - let reachable = reachableSymbols rootSym callGraph + let reachable = + reachableSymbols rootSym callGraph lineMap = buildLineMap lns - allFindings = scanNct lns - -- Filter to only reachable symbols + allFindings = scanNct rt lns filtered = Map.filterWithKey - (\sym _ -> Set.member sym reachable) allFindings + (\sym _ -> Set.member sym reachable) + allFindings pure $ Right $ SymbolScanResult { ssrRootSymbol = rootSym , ssrReachable = Set.size reachable diff --git a/lib/Audit/AArch64/CFG.hs b/lib/Audit/AArch64/CFG.hs @@ -31,9 +31,9 @@ module Audit.AArch64.CFG ( , buildCallerMap ) where +import Audit.AArch64.Runtime (RuntimeConfig(..)) import Audit.AArch64.Types import Control.DeepSeq (NFData) -import Data.List (foldl') import Data.Map.Strict (Map) import Data.Maybe (listToMaybe) import qualified Data.Map.Strict as Map @@ -86,8 +86,8 @@ blockFunction cfg idx = , idx `elem` idxs ] -- | Build a CFG from parsed assembly lines. -buildCFG :: [Line] -> CFG -buildCFG lns = cfg +buildCFG :: RuntimeConfig -> [Line] -> CFG +buildCFG rt lns = cfg where -- Split into basic blocks at labels and control flow instructions rawBlocks = buildBlocks lns @@ -117,7 +117,7 @@ buildCFG lns = cfg blocks = A.arrayFromList annotatedBlocks -- Build function blocks map in a single pass - funcBlocksMap = buildFuncBlocksOnce annotatedBlocks + funcBlocksMap = buildFuncBlocksOnce rt annotatedBlocks cfg = CFG { cfgBlocks = blocks @@ -136,15 +136,17 @@ hasFallthroughInstr (Just instr) = case instr of _ -> True -- Conditional branches, calls fall through -- | Build function blocks map in a single pass over blocks. -buildFuncBlocksOnce :: [BasicBlock] -> Map Text [Int] -buildFuncBlocksOnce bbs = finalize $ foldl' step (Nothing, Map.empty) indexed +buildFuncBlocksOnce :: RuntimeConfig -> [BasicBlock] + -> Map Text [Int] +buildFuncBlocksOnce rt bbs = + finalize $ foldl' step (Nothing, Map.empty) indexed where nBlocks = length bbs indexed = zip [0..] bbs step (mCur, acc) (idx, bb) = case bbLabel bb of - Just lbl | isFunctionLabel lbl -> + Just lbl | isFunctionLabel rt lbl -> let acc' = closeCurrent idx mCur acc in (Just (lbl, idx), acc') _ -> (mCur, acc) @@ -236,27 +238,18 @@ blockSuccessors cfg idx | idx < 0 || idx >= cfgBlockCount cfg = [] | otherwise = bbSuccIdxs (indexBlock cfg idx) --- | Check if a label is an NCG-internal label (not a function entry). --- These start with _L followed by a lowercase letter (e.g. _Lblock_info). -isNCGInternal :: Text -> Bool -isNCGInternal lbl = case T.unpack lbl of - '_':'L':c:_ -> c >= 'a' && c <= 'z' - _ -> False - -- | Check if a label is a function entry (not a local label). --- Function labels start with _ or don't have LBB/Lloh/Lc prefixes. -isFunctionLabel :: Text -> Bool -isFunctionLabel lbl +-- LLVM local label checks are base; runtime-specific checks +-- are delegated to 'rtIsLocalLabel'. +isFunctionLabel :: RuntimeConfig -> Text -> Bool +isFunctionLabel rt lbl | T.isPrefixOf "LBB" lbl = False -- LLVM basic block | T.isPrefixOf "Lloh" lbl = False -- LLVM linker hint | T.isPrefixOf "LCPI" lbl = False -- LLVM constant pool | T.isPrefixOf "ltmp" lbl = False -- LLVM temporary | T.isPrefixOf "l_" lbl = False -- Local label - | T.isPrefixOf "Lc" lbl = False -- NCG local label (GHC) - | T.isPrefixOf "Ls" lbl = False -- NCG local label (GHC) - | T.isPrefixOf "Lu" lbl = False -- NCG local label (GHC) - | isNCGInternal lbl = False -- NCG internal label (GHC) - | otherwise = True -- Likely a function + | rtIsLocalLabel rt lbl = False -- Runtime-specific + | otherwise = True -- Likely a function -- | Get all function entry labels in the CFG. -- Returns keys from the cached function blocks map. @@ -271,24 +264,25 @@ functionBlocks cfg funcLabel = -- | Extract call targets from a block's instructions. -- Includes both bl (call with link) and b to function labels (tail call). -callTargets :: BasicBlock -> [Text] -callTargets bb = +callTargets :: RuntimeConfig -> BasicBlock -> [Text] +callTargets rt bb = [ target | l <- bbLines bb , Just instr <- [lineInstr l] , target <- getCallTarget instr ] where getCallTarget (Bl target) = [target] getCallTarget (B target) - | isFunctionLabel target = [target] -- Tail call + | isFunctionLabel rt target = [target] -- Tail call getCallTarget _ = [] -- | Build call graph: maps each function to its callees. -- Uses cached function blocks map for efficient lookup. -buildCallGraph :: CFG -> Map Text [Text] -buildCallGraph cfg = Map.fromList +buildCallGraph :: RuntimeConfig -> CFG -> Map Text [Text] +buildCallGraph rt cfg = Map.fromList [ (func, callees) | (func, indices) <- Map.toList (cfgFuncBlocks cfg) - , let callees = concatMap (callTargets . indexBlock cfg) indices + , let callees = concatMap + (callTargets rt . indexBlock cfg) indices ] -- | Build map of function labels to their block index ranges. @@ -298,14 +292,17 @@ buildFunctionBlocksMap = cfgFuncBlocks -- | Build caller map: maps each function to its callers. -- Takes precomputed function blocks map to avoid rescanning. -buildCallerMap :: CFG -> Map Text [Int] -> Map Text [Text] -buildCallerMap cfg funcBlocksMap = Map.fromListWith (++) - [ (callee, [caller]) - | (caller, callees) <- Map.toList callGraph - , callee <- callees - ] +buildCallerMap :: RuntimeConfig -> CFG -> Map Text [Int] + -> Map Text [Text] +buildCallerMap rt cfg funcBlocksMap = + Map.fromListWith (++) + [ (callee, [caller]) + | (caller, callees) <- Map.toList callGraph + , callee <- callees + ] where callGraph = Map.fromList - [ (func, concatMap (callTargets . indexBlock cfg) idxs) + [ (func, concatMap + (callTargets rt . indexBlock cfg) idxs) | (func, idxs) <- Map.toList funcBlocksMap ] diff --git a/lib/Audit/AArch64/CallGraph.hs b/lib/Audit/AArch64/CallGraph.hs @@ -22,9 +22,10 @@ module Audit.AArch64.CallGraph ( ) where import Audit.AArch64.CFG (isFunctionLabel) +import Audit.AArch64.Runtime (RuntimeConfig) import Audit.AArch64.Types (Instr(..), Line(..)) -import Data.Graph (Graph, Vertex, graphFromEdges, reachable, transposeG) -import Data.List (foldl') +import Data.Graph + (Graph, Vertex, graphFromEdges, reachable, transposeG) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Set (Set) @@ -42,12 +43,14 @@ data CallGraph = CallGraph -- | Build a call graph from parsed assembly lines. -- --- Extracts function symbols and their call targets (bl instructions). --- Does not resolve indirect calls (blr). -buildCallGraph :: [Line] -> CallGraph -buildCallGraph lns = CallGraph graph nodeFromV vertexFromK allSyms +-- Extracts function symbols and their call targets +-- (bl instructions). Does not resolve indirect calls (blr). +buildCallGraph :: RuntimeConfig -> [Line] -> CallGraph +buildCallGraph rt lns = + CallGraph graph nodeFromV vertexFromK allSyms where - (graph, nodeFromV, vertexFromK) = graphFromEdges edges + (graph, nodeFromV, vertexFromK) = + graphFromEdges edges -- Build map from symbol to its instructions. -- State: (current symbol, accumulated map) @@ -55,15 +58,15 @@ buildCallGraph lns = CallGraph graph nodeFromV vertexFromK allSyms symInstrs = snd $ foldl' step ("<unknown>", Map.empty) lns where step (curSym, acc) ln = - let -- Update current symbol when we see a function label - sym = case lineLabel ln of - Just lbl | isFunctionLabel lbl -> lbl + let sym = case lineLabel ln of + Just lbl + | isFunctionLabel rt lbl -> lbl _ -> curSym - -- Ensure symbol exists in map (even with no instructions) acc' = Map.insertWith (++) sym [] acc in case lineInstr ln of Nothing -> (sym, acc') - Just i -> (sym, Map.insertWith (++) sym [i] acc') + Just i -> + (sym, Map.insertWith (++) sym [i] acc') -- Extract call targets from instructions callTargets :: Text -> [Text] @@ -77,9 +80,11 @@ buildCallGraph lns = CallGraph graph nodeFromV vertexFromK allSyms -- Build graph edges: (node data, key, [adjacent keys]) edges :: [((), Text, [Text])] - edges = [((), sym, callTargets sym) | sym <- Set.toList allSyms] + edges = + [((), sym, callTargets sym) | sym <- Set.toList allSyms] --- | Get all symbols reachable from a root symbol (including the root). +-- | Get all symbols reachable from a root symbol +-- (including the root). -- -- Returns empty set if the root symbol doesn't exist. reachableSymbols :: Text -> CallGraph -> Set Text @@ -89,16 +94,19 @@ reachableSymbols root cg = case cgVertexFromK cg root of [sym | v' <- reachable (cgGraph cg) v , let (_, sym, _) = cgNodeFromV cg v'] --- | Get all symbols that can reach a target symbol (including the target). +-- | Get all symbols that can reach a target symbol +-- (including the target). -- --- This is the reverse of 'reachableSymbols': finds all potential callers. --- Returns empty set if the target symbol doesn't exist. +-- This is the reverse of 'reachableSymbols': finds all +-- potential callers. Returns empty set if the target +-- symbol doesn't exist. reachingSymbols :: Text -> CallGraph -> Set Text -reachingSymbols target cg = case cgVertexFromK cg target of - Nothing -> Set.empty - Just v -> Set.fromList - [sym | v' <- reachable (transposeG (cgGraph cg)) v - , let (_, sym, _) = cgNodeFromV cg v'] +reachingSymbols target cg = + case cgVertexFromK cg target of + Nothing -> Set.empty + Just v -> Set.fromList + [sym | v' <- reachable (transposeG (cgGraph cg)) v + , let (_, sym, _) = cgNodeFromV cg v'] -- | Get all symbols in the call graph. allSymbols :: CallGraph -> Set Text diff --git a/lib/Audit/AArch64/Check.hs b/lib/Audit/AArch64/Check.hs @@ -12,8 +12,8 @@ -- -- Memory access validation for constant-time properties. -- --- Checks that all memory accesses use public base registers and --- constant (or properly masked) offsets. +-- Checks that all memory accesses use public base registers +-- and constant (or properly masked) offsets. module Audit.AArch64.Check ( checkLine @@ -25,8 +25,11 @@ module Audit.AArch64.Check ( , AuditResult(..) ) where -import Audit.AArch64.CFG (BasicBlock(..), CFG(..), cfgBlockCount, indexBlock, - blockFunction) +import Audit.AArch64.CFG + ( BasicBlock(..), CFG(..) + , cfgBlockCount, indexBlock, blockFunction + ) +import Audit.AArch64.Runtime (RuntimeConfig) import Audit.AArch64.Taint import Audit.AArch64.Types ( Reg(..), Instr(..), Line(..), AddrMode(..) @@ -50,9 +53,12 @@ data AuditResult = AuditResult instance Semigroup AuditResult where a <> b = AuditResult - { arViolations = arViolations a ++ arViolations b - , arLinesChecked = arLinesChecked a + arLinesChecked b - , arMemoryAccesses = arMemoryAccesses a + arMemoryAccesses b + { arViolations = + arViolations a ++ arViolations b + , arLinesChecked = + arLinesChecked a + arLinesChecked b + , arMemoryAccesses = + arMemoryAccesses a + arMemoryAccesses b } instance Monoid AuditResult where @@ -65,41 +71,52 @@ checkLine sym st l = case lineInstr l of Just instr -> case getMemoryAccess instr of Nothing -> AuditResult [] 1 0 Just addr -> - let violations = checkAddrMode sym (lineNum l) instr addr st + let violations = + checkAddrMode sym (lineNum l) instr addr st in AuditResult violations 1 1 -- | Check a basic block, threading taint state. --- Uses strict counters and cons-accumulation for violations. -checkBlock :: Text -> TaintState -> [Line] -> (AuditResult, TaintState) -checkBlock sym st0 lns = go [] 0 0 st0 lns +-- Uses strict counters and cons-accumulation for +-- violations. +checkBlock :: RuntimeConfig -> Text -> TaintState + -> [Line] -> (AuditResult, TaintState) +checkBlock rt sym st0 lns = go [] 0 0 st0 lns where go !vs !lc !ma st [] = (AuditResult (reverse vs) lc ma, st) go !vs !lc !ma st (l:ls) = let (vs', ma') = checkLineStrict sym st l - st' = analyzeLine l st - in go (foldl' (flip (:)) vs vs') (lc + 1) (ma + ma') st' ls + st' = analyzeLine rt l st + in go (foldl' (flip (:)) vs vs') + (lc + 1) (ma + ma') st' ls --- | Check a single line, returning violations and memory access count. -checkLineStrict :: Text -> TaintState -> Line -> ([Violation], Int) +-- | Check a single line, returning violations and +-- memory access count. +checkLineStrict :: Text -> TaintState -> Line + -> ([Violation], Int) checkLineStrict sym st l = case lineInstr l of Nothing -> ([], 0) Just instr -> case getMemoryAccess instr of Nothing -> ([], 0) - Just addr -> (checkAddrMode sym (lineNum l) instr addr st, 1) + Just addr -> + (checkAddrMode sym (lineNum l) instr addr st, 1) -- | Check entire CFG with inter-block dataflow. -- Runs fixpoint dataflow to propagate taint across blocks. -checkCFG :: Text -> CFG -> AuditResult -checkCFG sym cfg = - let inStates = runDataflow cfg +checkCFG :: RuntimeConfig -> Text -> CFG -> AuditResult +checkCFG rt sym cfg = + let inStates = runDataflow rt cfg nBlocks = cfgBlockCount cfg + baseState = initTaintState rt in mconcat - [ fst (checkBlock blockSym inState (bbLines bb)) + [ fst (checkBlock rt blockSym inState + (bbLines bb)) | idx <- [0..nBlocks-1] , let bb = indexBlock cfg idx blockSym = maybe sym id (bbLabel bb) - inState = IM.findWithDefault initTaintState idx inStates + inState = + IM.findWithDefault baseState + idx inStates ] -- | Extract memory address from a load/store instruction. @@ -142,7 +159,8 @@ getMemoryAccess instr = case instr of _ -> Nothing -- | Check an addressing mode for violations. -checkAddrMode :: Text -> Int -> Instr -> AddrMode -> TaintState -> [Violation] +checkAddrMode :: Text -> Int -> Instr -> AddrMode + -> TaintState -> [Violation] checkAddrMode sym ln instr addr st = case addr of BaseImm base _ -> checkBase sym ln instr base st @@ -160,7 +178,8 @@ checkAddrMode sym ln instr addr st = case addr of checkIndex sym ln instr idx st BaseSymbol base _ -> - -- Symbol offset is constant (compile-time), only check base + -- Symbol offset is constant (compile-time), + -- only check base checkBase sym ln instr base st PreIndex base _ -> @@ -174,56 +193,75 @@ checkAddrMode sym ln instr addr st = case addr of [] -- | Check that base register is public. --- If taint is Unknown, check provenance to see if we can upgrade to Public. --- Provenance upgrade is only allowed for pointer-kind registers to prevent --- laundering scalar indices through provenance. -checkBase :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] +-- If taint is Unknown, check provenance to see if we can +-- upgrade to Public. Provenance upgrade is only allowed +-- for pointer-kind registers to prevent laundering scalar +-- indices through provenance. +checkBase :: Text -> Int -> Instr -> Reg + -> TaintState -> [Violation] checkBase sym ln instr base st = case getTaint base st of Public -> [] - Secret -> [Violation sym ln instr (SecretBase base)] + Secret -> + [Violation sym ln instr (SecretBase base)] Unknown -> - -- Only upgrade via provenance if register has pointer kind + -- Only upgrade via provenance if register has + -- pointer kind case (getProvenance base st, getKind base st) of - (ProvPublic, KindPtr) -> [] -- Pointer provenance proves safety - _ -> [Violation sym ln instr (UnknownBase base)] + (ProvPublic, KindPtr) -> [] + _ -> + [Violation sym ln instr (UnknownBase base)] -- | Check that index register is public. --- If taint is Unknown, check provenance to see if we can upgrade to Public. --- Index registers should never be upgraded via provenance since they are --- typically scalars, not pointers. -checkIndex :: Text -> Int -> Instr -> Reg -> TaintState -> [Violation] +-- If taint is Unknown, check provenance to see if we can +-- upgrade to Public. Index registers should never be +-- upgraded via provenance since they are typically scalars, +-- not pointers. +checkIndex :: Text -> Int -> Instr -> Reg + -> TaintState -> [Violation] checkIndex sym ln instr idx st = case getTaint idx st of Public -> [] - Secret -> [Violation sym ln instr (SecretIndex idx)] - Unknown -> [Violation sym ln instr (UnknownIndex idx)] - -- No provenance upgrade for indices - they are scalar values + Secret -> + [Violation sym ln instr (SecretIndex idx)] + Unknown -> + [Violation sym ln instr (UnknownIndex idx)] + -- No provenance upgrade for indices -- | Check entire CFG with inter-procedural analysis. --- Uses whole-file dataflow with summaries to propagate taint across --- tail call boundaries. -checkCFGInterProc :: Text -> CFG -> AuditResult -checkCFGInterProc sym cfg = - let summaries = runInterProc cfg - -- Use empty config with whole-file dataflow +-- Uses whole-file dataflow with summaries to propagate +-- taint across tail call boundaries. +checkCFGInterProc :: RuntimeConfig -> Text -> CFG + -> AuditResult +checkCFGInterProc rt sym cfg = + let summaries = runInterProc rt cfg emptyConfig = TaintConfig Map.empty True - inStates = runDataflowWithConfigAndSummaries emptyConfig summaries cfg + inStates = + runDataflowWithConfigAndSummaries rt emptyConfig + summaries cfg nBlocks = cfgBlockCount cfg + baseState = initTaintState rt in mconcat - [ fst (checkBlockWithSummary blockSym summaries inState (bbLines bb)) + [ fst (checkBlockWithSummary rt blockSym + summaries inState (bbLines bb)) | idx <- [0..nBlocks-1] , let bb = indexBlock cfg idx - -- Use enclosing function label, not block label - blockSym = fromMaybe sym (blockFunction cfg idx) - inState = IM.findWithDefault initTaintState idx inStates + blockSym = + fromMaybe sym (blockFunction cfg idx) + inState = + IM.findWithDefault baseState + idx inStates ] -- | Check a block using summaries for calls. --- Uses strict counters and cons-accumulation for violations. -checkBlockWithSummary :: Text -> Map.Map Text FuncSummary -> TaintState - -> [Line] -> (AuditResult, TaintState) -checkBlockWithSummary sym summaries st0 lns = go [] 0 0 st0 lns +-- Uses strict counters and cons-accumulation for +-- violations. +checkBlockWithSummary + :: RuntimeConfig -> Text -> Map.Map Text FuncSummary + -> TaintState -> [Line] + -> (AuditResult, TaintState) +checkBlockWithSummary rt sym summaries st0 lns = + go [] 0 0 st0 lns where go !vs !lc !ma st [] = (AuditResult (reverse vs) lc ma, st) @@ -232,40 +270,55 @@ checkBlockWithSummary sym summaries st0 lns = go [] 0 0 st0 lns st' = case lineInstr l of Nothing -> st Just instr -> case instr of - Bl target -> case Map.lookup target summaries of - Just summ -> applySummary summ st - Nothing -> analyzeLine l st - _ -> analyzeLine l st - in go (foldl' (flip (:)) vs vs') (lc + 1) (ma + ma') st' ls + Bl target -> + case Map.lookup target summaries of + Just summ -> applySummary summ st + Nothing -> analyzeLine rt l st + _ -> analyzeLine rt l st + in go (foldl' (flip (:)) vs vs') + (lc + 1) (ma + ma') st' ls -- | Check entire CFG with taint config. -checkCFGWithConfig :: TaintConfig -> Text -> CFG -> AuditResult -checkCFGWithConfig tcfg sym cfg = - let inStates = runDataflowWithConfig tcfg cfg +checkCFGWithConfig :: RuntimeConfig -> TaintConfig + -> Text -> CFG -> AuditResult +checkCFGWithConfig rt tcfg sym cfg = + let inStates = runDataflowWithConfig rt tcfg cfg nBlocks = cfgBlockCount cfg + baseState = initTaintState rt in mconcat - [ fst (checkBlock blockSym inState (bbLines bb)) + [ fst (checkBlock rt blockSym inState + (bbLines bb)) | idx <- [0..nBlocks-1] , let bb = indexBlock cfg idx - -- Use enclosing function label, not block label - blockSym = fromMaybe sym (blockFunction cfg idx) - inState = IM.findWithDefault initTaintState idx inStates + blockSym = + fromMaybe sym (blockFunction cfg idx) + inState = + IM.findWithDefault baseState + idx inStates ] --- | Check entire CFG with inter-procedural analysis and taint config. --- Uses whole-file dataflow with summaries to properly propagate taint --- across tail call boundaries. -checkCFGInterProcWithConfig :: TaintConfig -> Text -> CFG -> AuditResult -checkCFGInterProcWithConfig tcfg sym cfg = - let summaries = runInterProcWithConfig tcfg cfg - -- Run whole-file dataflow with config seeding AND summary application - inStates = runDataflowWithConfigAndSummaries tcfg summaries cfg +-- | Check entire CFG with inter-procedural analysis +-- and taint config. Uses whole-file dataflow with +-- summaries to properly propagate taint across tail +-- call boundaries. +checkCFGInterProcWithConfig + :: RuntimeConfig -> TaintConfig -> Text -> CFG + -> AuditResult +checkCFGInterProcWithConfig rt tcfg sym cfg = + let summaries = runInterProcWithConfig rt tcfg cfg + inStates = + runDataflowWithConfigAndSummaries rt tcfg + summaries cfg nBlocks = cfgBlockCount cfg + baseState = initTaintState rt in mconcat - [ fst (checkBlockWithSummary blockSym summaries inState (bbLines bb)) + [ fst (checkBlockWithSummary rt blockSym + summaries inState (bbLines bb)) | idx <- [0..nBlocks-1] , let bb = indexBlock cfg idx - -- Use enclosing function label, not block label - blockSym = fromMaybe sym (blockFunction cfg idx) - inState = IM.findWithDefault initTaintState idx inStates + blockSym = + fromMaybe sym (blockFunction cfg idx) + inState = + IM.findWithDefault baseState + idx inStates ] diff --git a/lib/Audit/AArch64/NCT.hs b/lib/Audit/AArch64/NCT.hs @@ -1,6 +1,4 @@ {-# OPTIONS_HADDOCK prune #-} -{-# LANGUAGE DeriveAnyClass #-} -{-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} -- | @@ -9,58 +7,37 @@ -- License: MIT -- Maintainer: jared@ppad.tech -- --- Static non-constant-time instruction scanner for AArch64 assembly. --- Flags instructions that typically introduce timing variability. +-- Static non-constant-time instruction scanner for AArch64 +-- assembly. Flags instructions that typically introduce +-- timing variability. module Audit.AArch64.NCT ( - -- * Types - NctReason(..) - , NctFinding(..) -- * Scanner - , scanNct - -- * GHC runtime classification - , LineMap + scanNct + -- * Runtime-aware filtering , buildLineMap - , isGhcRuntimeFinding - , filterGhcRuntime + , filterRuntimePatterns ) where import Audit.AArch64.CFG (isFunctionLabel) +import Audit.AArch64.Runtime (RuntimeConfig(..)) import Audit.AArch64.Types -import Control.DeepSeq (NFData) -import Data.IntMap.Strict (IntMap) -import Data.List (foldl') import qualified Data.IntMap.Strict as IntMap import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map -import Data.Text (Text, isInfixOf, isPrefixOf, isSuffixOf) -import GHC.Generics (Generic) - --- | Reason for flagging an instruction as non-constant-time. -data NctReason - = CondBranch -- ^ Conditional branch (b.<cond>, cbz, cbnz, tbz, tbnz) - | IndirectBranch -- ^ Indirect branch (br, blr) - | Div -- ^ Division (udiv, sdiv) - | RegIndexAddr -- ^ Register-indexed memory access - deriving (Eq, Ord, Show, Generic, NFData) - --- | A non-constant-time finding. -data NctFinding = NctFinding - { nctLine :: !Int -- ^ Source line number - , nctInstr :: !Instr -- ^ The flagged instruction - , nctReason :: !NctReason -- ^ Why it was flagged - } deriving (Eq, Show, Generic, NFData) +import Data.Text (Text) -- | Scan parsed lines for non-constant-time instructions. -- Returns findings grouped by function symbol. -scanNct :: [Line] -> Map Text [NctFinding] -scanNct = finalize . foldl' step (unknownSym, Map.empty) +scanNct :: RuntimeConfig -> [Line] + -> Map Text [NctFinding] +scanNct rt = finalize . foldl' step (unknownSym, Map.empty) where unknownSym = "<unknown>" step (curSym, acc) ln = let sym' = case lineLabel ln of - Just lbl | isFunctionLabel lbl -> lbl + Just lbl | isFunctionLabel rt lbl -> lbl _ -> curSym findings = classifyLine ln acc' = case findings of @@ -76,7 +53,8 @@ classifyLine ln = case lineInstr ln of Nothing -> [] Just instr -> case classifyInstr instr of Nothing -> [] - Just reason -> [NctFinding (lineNum ln) instr reason] + Just reason -> + [NctFinding (lineNum ln) instr reason] -- | Classify an instruction for NCT concerns. classifyInstr :: Instr -> Maybe NctReason @@ -127,7 +105,7 @@ classifyInstr instr = case instr of Stxrb _ _ addr -> checkRegIndexAddr addr Stxrh _ _ addr -> checkRegIndexAddr addr - -- Acquire-exclusive loads and release-exclusive stores + -- Acquire-exclusive and release-exclusive Ldaxr _ addr -> checkRegIndexAddr addr Ldaxrb _ addr -> checkRegIndexAddr addr Ldaxrh _ addr -> checkRegIndexAddr addr @@ -140,327 +118,27 @@ classifyInstr instr = case instr of -- | Check if address mode uses register indexing. checkRegIndexAddr :: AddrMode -> Maybe NctReason checkRegIndexAddr addr = case addr of - BaseReg _ _ -> Just RegIndexAddr + BaseReg _ _ -> Just RegIndexAddr BaseRegShift _ _ _ -> Just RegIndexAddr BaseRegExtend _ _ _ -> Just RegIndexAddr _ -> Nothing --- | Line number to Line map for efficient lookup. -type LineMap = IntMap Line - -- | Build a line map from parsed lines for O(1) lookup. buildLineMap :: [Line] -> LineMap -buildLineMap lns = IntMap.fromList [(lineNum l, l) | l <- lns] - --- | Check if a finding is a GHC runtime pattern. --- --- Recognizes: --- --- * Heap checks: @cmp <reg>, x28@ followed by conditional branch --- (x28 is GHC's heap limit register) --- --- * Nursery checks: @cmp <reg>, [x19, #offset]@ followed by conditional --- branch (x19 is BaseReg, offset 856 is typical nursery limit) --- --- * Tag checks: @tst <reg>, #0x7@ followed by conditional branch --- (pointer tagging for evaluation state) --- --- * CAF checks: @cbz x0@ after @bl _newCAF@ (checking CAF init result) --- --- * Closure entry: @ldr/ldur <reg>, [...]@ followed by @blr <reg>@ --- (entering closure via info pointer) --- --- * Typeclass dictionary calls: @ldr <reg>, [<base>, #offset]@ followed --- by @blr <reg>@ (calling through dictionary) --- --- * RTS calls: @adrp <reg>, _stg_*@ (GHC runtime system symbols) --- followed eventually by @blr <reg>@ -isGhcRuntimeFinding :: LineMap -> NctFinding -> Bool -isGhcRuntimeFinding lineMap f = case nctReason f of - CondBranch -> isHeapCheck || isNurseryCheck || isTagCheck - || isCafCheck || isArityCheck - IndirectBranch -> isClosureEntry || isDictCall || isRtsCall - RegIndexAddr -> isClosureTableLookup - _ -> False - where - -- Check if conditional branch is a heap check (prev: cmp <r>, x28) - isHeapCheck :: Bool - isHeapCheck = case prevInstr of - Just (Cmp _ (OpReg X28)) -> True - _ -> False - - -- Check if conditional branch is a nursery check - -- (prev: cmp <r>, <r'> where one was loaded from BaseReg x19) - isNurseryCheck :: Bool - isNurseryCheck = case prevInstr of - Just (Cmp _ (OpReg _)) -> - -- Look for recent ldr from [x19, #offset] (BaseReg) - any isBaseRegLoad recentInstrs5 - _ -> False - - isBaseRegLoad :: Instr -> Bool - isBaseRegLoad instr = case instr of - Ldr _ (BaseImm X19 _) -> True - Ldur _ (BaseImm X19 _) -> True - _ -> False - - -- Check if conditional branch is a tag check - -- Pattern 1: tst <r>, #7 followed by branch (LLVM) - -- Pattern 2: and <r>, <r2>, #7; cmp <r>, #N followed by branch (LLVM) - -- Pattern 3: and <r>, <r2>, #7; cbnz/cbz <r> (NCG) - -- Pattern 4: and <r>, <r2>, #7; movz <r3>, #N; cmp <r>, <r3> (NCG) - isTagCheck :: Bool - isTagCheck = case prevInstr of - Just (Tst _ (OpImm 7)) -> True - Just (Cmp _ (OpImm _)) -> any isTagMask recentInstrs5 - Just (Cmp r1 (OpReg r2)) -> - -- NCG: cmp rA, rB where one was set by and #7 - any (isTagMaskFor r1) recentInstrs5 - || any (isTagMaskFor r2) recentInstrs5 - _ -> case nctInstr f of - -- NCG: cbnz/cbz on register set by and #7 - Cbnz reg _ -> any (isTagMaskFor reg) recentInstrs5 - Cbz reg _ -> any (isTagMaskFor reg) recentInstrs5 - _ -> False - - isTagMask :: Instr -> Bool - isTagMask instr = case instr of - And _ _ (OpImm 7) -> True - _ -> False - - isTagMaskFor :: Reg -> Instr -> Bool - isTagMaskFor reg instr = case instr of - And r _ (OpImm 7) -> r == reg - _ -> False - - -- Check if conditional branch is an arity/closure-type check - -- Pattern 1: ldur <r>, [<r2>, #-N]; cmp <r>, #M (LLVM) - -- Pattern 2: ldr <r>, [<r2>, #-N]; cmp <r>, <r3> (NCG) - -- Info tables are at negative offsets from info pointer - isArityCheck :: Bool - isArityCheck = case prevInstr of - Just (Cmp _ (OpImm _)) -> any isInfoTableLoad recentInstrs5 - Just (Cmp r1 (OpReg r2)) -> - -- NCG: cmp rA, rB where one was loaded from info table - any (isInfoTableLoadFor r1) recentInstrs5 - || any (isInfoTableLoadFor r2) recentInstrs5 - _ -> False - - isInfoTableLoad :: Instr -> Bool - isInfoTableLoad instr = case instr of - -- Negative offset indicates info table field access - Ldur _ (BaseImm _ off) -> off < 0 - Ldr _ (BaseImm _ off) -> off < 0 -- NCG uses ldr, not ldur - _ -> False - - isInfoTableLoadFor :: Reg -> Instr -> Bool - isInfoTableLoadFor reg instr = case instr of - -- Use samePhysicalReg since ldr w10 and cmp x10 are same register - Ldur r (BaseImm _ off) -> samePhysicalReg r reg && off < 0 - Ldr r (BaseImm _ off) -> samePhysicalReg r reg && off < 0 - _ -> False - - -- Check if cbz/cbnz is a CAF check (recent: bl _newCAF or similar) - -- LLVM: cbz x0 after bl _newCAF - -- NCG: mov xN, x0; cbz xN after bl _newCAF - isCafCheck :: Bool - isCafCheck = case nctInstr f of - Cbz X0 _ -> any isCafCall recentInstrs5 - Cbnz X0 _ -> any isCafCall recentInstrs5 - Cbz reg _ -> any isCafCall recentInstrs5 - && any (isCafResultMove reg) recentInstrs5 - Cbnz reg _ -> any isCafCall recentInstrs5 - && any (isCafResultMove reg) recentInstrs5 - _ -> False - - isCafCall :: Instr -> Bool - isCafCall instr = case instr of - Bl label -> "_newCAF" `isSuffixOf` label - _ -> False - - -- NCG moves x0 to another register before testing - isCafResultMove :: Reg -> Instr -> Bool - isCafResultMove reg instr = case instr of - Mov r (OpReg X0) -> r == reg - _ -> False - - -- Check if blr/br is closure entry (ldr/ldur to same reg in BB) - -- LLVM uses blr, NCG uses br - isClosureEntry :: Bool - isClosureEntry = case nctInstr f of - Blr reg -> any (isRegLoad reg) bbInstrs - Br reg -> any (isRegLoad reg) bbInstrs - _ -> False - - -- Check if blr/br is a dictionary/vtable call - -- (ldr from [base, #offset] pattern - calling through object field) - isDictCall :: Bool - isDictCall = case nctInstr f of - Blr reg -> any (isDictLoad reg) bbInstrs - Br reg -> any (isDictLoad reg) bbInstrs - _ -> False - - isRegLoad :: Reg -> Instr -> Bool - isRegLoad reg instr = case instr of - Ldr r (BaseImm _ _) -> r == reg - Ldr r (PreIndex _ _) -> r == reg -- [xN, #imm]! - Ldr r (PostIndex _ _) -> r == reg -- [xN], #imm - Ldur r (BaseImm _ _) -> r == reg - Ldp r1 r2 (BaseImm _ _) -> r1 == reg || r2 == reg - _ -> False - - isDictLoad :: Reg -> Instr -> Bool - isDictLoad reg instr = case instr of - -- ldr reg, [base, #offset] where offset > 0 (field access) - Ldr r (BaseImm _ off) -> r == reg && off > 0 - Ldr r (PreIndex _ off) -> r == reg && off > 0 - Ldr r (PostIndex _ off) -> r == reg && off > 0 - _ -> False - - -- Check if blr/br is calling a GHC RTS function - -- Scan backwards within basic block for adrp <reg>, _stg_* pattern - -- Also check for return continuation via STG stack (x20) or BaseReg (x19) - isRtsCall :: Bool - isRtsCall = case nctInstr f of - Blr reg -> any (isRtsSymbolLoad reg) bbInstrs - || any (isStgStackLoad reg) bbInstrs - Br reg -> any (isRtsSymbolLoad reg) bbInstrs - || any (isStgStackLoad reg) bbInstrs - _ -> False - - isRtsSymbolLoad :: Reg -> Instr -> Bool - isRtsSymbolLoad reg instr = case instr of - -- adrp to _stg_* symbol (GHC RTS) - Adrp r label -> r == reg && "_stg_" `isInfixOf` label - _ -> False - - -- Check if register was loaded from STG stack (x20) or BaseReg (x19) - -- These are return continuations / RTS dispatch - isStgStackLoad :: Reg -> Instr -> Bool - isStgStackLoad reg instr = case instr of - Ldr r (BaseImm X20 _) -> r == reg -- STG stack return - Ldr r (BaseImm X19 _) -> r == reg -- BaseReg dispatch - Ldur r (BaseImm X20 _) -> r == reg - Ldur r (BaseImm X19 _) -> r == reg - _ -> False - - -- Check if register-indexed access is a GHC closure table lookup - -- Pattern: adrp <base>, *_closure_tbl*; ldr <base>, [...]; ldr [base, idx] - -- Used for boxing Bool, Maybe constructors, etc. - isClosureTableLookup :: Bool - isClosureTableLookup = case getBaseReg (nctInstr f) of - Nothing -> False - Just reg -> any (isClosureTableLoad reg) bbInstrs - - -- Extract base register from a reg-indexed load/store - getBaseReg :: Instr -> Maybe Reg - getBaseReg instr = case instr of - Ldr _ addr -> baseOfRegIndex addr - Ldrb _ addr -> baseOfRegIndex addr - Ldrh _ addr -> baseOfRegIndex addr - Str _ addr -> baseOfRegIndex addr - Strb _ addr -> baseOfRegIndex addr - Strh _ addr -> baseOfRegIndex addr - _ -> Nothing - - baseOfRegIndex :: AddrMode -> Maybe Reg - baseOfRegIndex addr = case addr of - BaseReg base _ -> Just base - BaseRegShift base _ _ -> Just base - BaseRegExtend base _ _ -> Just base - _ -> Nothing - - -- Check if instruction loads closure table address into register - isClosureTableLoad :: Reg -> Instr -> Bool - isClosureTableLoad reg instr = case instr of - Adrp r label -> samePhysicalReg r reg && "_closure_tbl" `isInfixOf` label - _ -> False - - -- Get instruction from preceding line - prevInstr :: Maybe Instr - prevInstr = do - prevLine <- IntMap.lookup (nctLine f - 1) lineMap - lineInstr prevLine - - -- Get up to 5 recent instructions before finding - recentInstrs5 :: [Instr] - recentInstrs5 = recentInstrs 5 - - recentInstrs :: Int -> [Instr] - recentInstrs n = - [ instr - | offset <- [1..n] - , Just ln <- [IntMap.lookup (nctLine f - offset) lineMap] - , Just instr <- [lineInstr ln] - ] - - -- Get all instructions in current basic block before finding - -- Scans backwards until hitting a real BB label (LBB*) - -- Ignores linker hints (Lloh*, ltmp*) which aren't BB boundaries - bbInstrs :: [Instr] - bbInstrs = go (nctLine f - 1) - where - go lineNum - | lineNum <= 0 = [] - | otherwise = case IntMap.lookup lineNum lineMap of - Nothing -> [] - Just ln -> case lineLabel ln of - Just lbl | isBBLabel lbl -> [] -- Hit real BB, stop - _ -> case lineInstr ln of - Just instr -> instr : go (lineNum - 1) - Nothing -> go (lineNum - 1) -- Skip non-instruction lines - - -- Real BB labels start with "LBB", linker hints are Lloh*, ltmp*, etc. - isBBLabel :: Text -> Bool - isBBLabel lbl = "LBB" `isPrefixOf` lbl - --- | Filter out GHC runtime patterns from NCT findings. -filterGhcRuntime :: [Line] -> Map Text [NctFinding] -> Map Text [NctFinding] -filterGhcRuntime lns findings = Map.mapMaybe filterFindings findings +buildLineMap lns = + IntMap.fromList [(lineNum l, l) | l <- lns] + +-- | Filter out runtime-specific patterns from NCT findings. +filterRuntimePatterns + :: RuntimeConfig -> [Line] + -> Map Text [NctFinding] -> Map Text [NctFinding] +filterRuntimePatterns rt lns findings = + Map.mapMaybe filterFindings findings where lineMap = buildLineMap lns + filterFn = rtFilterNct rt filterFindings :: [NctFinding] -> Maybe [NctFinding] filterFindings fs = - let fs' = filter (not . isGhcRuntimeFinding lineMap) fs + let fs' = filter (not . filterFn lineMap) fs in if null fs' then Nothing else Just fs' - --- | Check if two registers refer to the same physical register. --- W and X variants of the same number are the same physical register. -samePhysicalReg :: Reg -> Reg -> Bool -samePhysicalReg r1 r2 = r1 == r2 || physNum r1 == physNum r2 - where - physNum :: Reg -> Maybe Int - physNum r = case r of - X0 -> Just 0; W0 -> Just 0 - X1 -> Just 1; W1 -> Just 1 - X2 -> Just 2; W2 -> Just 2 - X3 -> Just 3; W3 -> Just 3 - X4 -> Just 4; W4 -> Just 4 - X5 -> Just 5; W5 -> Just 5 - X6 -> Just 6; W6 -> Just 6 - X7 -> Just 7; W7 -> Just 7 - X8 -> Just 8; W8 -> Just 8 - X9 -> Just 9; W9 -> Just 9 - X10 -> Just 10; W10 -> Just 10 - X11 -> Just 11; W11 -> Just 11 - X12 -> Just 12; W12 -> Just 12 - X13 -> Just 13; W13 -> Just 13 - X14 -> Just 14; W14 -> Just 14 - X15 -> Just 15; W15 -> Just 15 - X16 -> Just 16; W16 -> Just 16 - X17 -> Just 17; W17 -> Just 17 - X18 -> Just 18; W18 -> Just 18 - X19 -> Just 19; W19 -> Just 19 - X20 -> Just 20; W20 -> Just 20 - X21 -> Just 21; W21 -> Just 21 - X22 -> Just 22; W22 -> Just 22 - X23 -> Just 23; W23 -> Just 23 - X24 -> Just 24; W24 -> Just 24 - X25 -> Just 25; W25 -> Just 25 - X26 -> Just 26; W26 -> Just 26 - X27 -> Just 27; W27 -> Just 27 - X28 -> Just 28; W28 -> Just 28 - X29 -> Just 29; W29 -> Just 29 - X30 -> Just 30; W30 -> Just 30 - _ -> Nothing -- SP, XZR, WZR, SIMD regs don't match -\ No newline at end of file diff --git a/lib/Audit/AArch64/Runtime.hs b/lib/Audit/AArch64/Runtime.hs @@ -0,0 +1,58 @@ +{-# OPTIONS_HADDOCK prune #-} + +-- | +-- Module: Audit.AArch64.Runtime +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: jared@ppad.tech +-- +-- Runtime configuration for AArch64 constant-time auditing. +-- +-- Parameterises GHC/STG-specific analysis logic so the auditor +-- can support multiple runtimes (GHC, Rust, Go, C). + +module Audit.AArch64.Runtime ( + RuntimeConfig(..) + , SecondaryStack(..) + ) where + +import Audit.AArch64.Types + (Reg, LineMap, NctFinding) +import Data.Text (Text) + +-- | Runtime-specific configuration for the auditor. +-- +-- Selected once at CLI parse time and threaded through +-- analysis. All runtime-varying behaviour is captured here. +data RuntimeConfig = RuntimeConfig + { rtPublicRoots :: ![Reg] + -- ^ Registers assumed public at function entry + , rtSecondaryStack :: !(Maybe SecondaryStack) + -- ^ Secondary stack configuration (e.g. GHC's STG + -- stack via X20). Nothing for runtimes without one. + , rtIsLocalLabel :: !(Text -> Bool) + -- ^ Runtime-specific local label predicate (e.g. + -- GHC NCG prefixes Lc, Ls, Lu) + , rtUntagMasks :: ![Integer] + -- ^ Pointer untagging masks to whitelist (e.g. + -- GHC's low-3-bit tag clearing) + , rtFilterNct :: !(LineMap -> NctFinding -> Bool) + -- ^ Predicate for runtime-specific NCT patterns + -- that should be filtered out + , rtEncodeSymbol + :: !(Maybe (Text -> Either Text Text)) + -- ^ Optional symbol encoder (e.g. GHC z-encoding) + } + +-- | Secondary stack configuration. +-- +-- Some runtimes maintain a separate stack (e.g. GHC's STG +-- stack pointed to by X20). This record captures the base +-- register and default assumption for untracked slots. +data SecondaryStack = SecondaryStack + { ssBaseReg :: !Reg + -- ^ Register holding the secondary stack pointer + , ssAssumePublic :: !Bool + -- ^ Default assumption for untracked slots (True + -- for GHC: STG stack holds closure pointers) + } diff --git a/lib/Audit/AArch64/Runtime/GHC.hs b/lib/Audit/AArch64/Runtime/GHC.hs @@ -0,0 +1,416 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | +-- Module: Audit.AArch64.Runtime.GHC +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: jared@ppad.tech +-- +-- GHC runtime configuration for the AArch64 auditor. +-- +-- Provides 'ghcRuntime' which reproduces all current +-- GHC-specific behaviour, and 'genericRuntime' as a +-- baseline for C/Rust/Go runtimes. + +module Audit.AArch64.Runtime.GHC ( + -- * Runtime configurations + ghcRuntime + , genericRuntime + -- * GHC-specific utilities + , ghcIsLocalLabel + , isGhcRuntimeFinding + , zEncodeSymbol + , zEncodePart + ) where + +import Audit.AArch64.Runtime + (RuntimeConfig(..), SecondaryStack(..)) +import Audit.AArch64.Types +import qualified Data.IntMap.Strict as IntMap +import Data.Text (Text) +import qualified Data.Text as T + +-- | GHC runtime configuration. +-- +-- Reproduces all current GHC/STG-specific behaviour: +-- +-- * Public roots include GHC's STG registers +-- * Secondary stack via X20 (STG Sp) +-- * NCG local label filtering +-- * Pointer untagging masks +-- * GHC runtime NCT pattern filtering +-- * Z-encoding for symbol lookup +ghcRuntime :: RuntimeConfig +ghcRuntime = RuntimeConfig + { rtPublicRoots = ghcPublicRoots + , rtSecondaryStack = + Just (SecondaryStack X20 True) + , rtIsLocalLabel = ghcIsLocalLabel + , rtUntagMasks = + [ 0xfffffffffffffff8 -- 64-bit mask + , 0xfffffff8 -- 32-bit mask + , -8 -- signed representation + ] + , rtFilterNct = isGhcRuntimeFinding + , rtEncodeSymbol = Just zEncodeSymbol + } + +-- | Generic runtime configuration. +-- +-- Baseline with no secondary stack, no untagging, no NCT +-- filtering. Useful for C/Rust/Go as a starting point. +genericRuntime :: RuntimeConfig +genericRuntime = RuntimeConfig + { rtPublicRoots = + [ SP, X29 -- Hardware stack/frame pointers + , XZR, WZR -- Zero registers + ] + , rtSecondaryStack = Nothing + , rtIsLocalLabel = const False + , rtUntagMasks = [] + , rtFilterNct = \_ _ -> False + , rtEncodeSymbol = Nothing + } + +-- | GHC 9.10.3 AArch64 public root registers. +ghcPublicRoots :: [Reg] +ghcPublicRoots = + [ SP, X29 -- Hardware stack/frame pointers + , X19 -- GHC BaseReg (capabilities/TSO pointer) + , X20 -- GHC Sp (STG stack pointer) + , X21 -- GHC Hp (heap pointer) + , X22 -- GHC HpLim (heap limit) + , X28 -- GHC SpLim (stack limit) + , X18 -- TLS (Darwin platform register) + , XZR, WZR -- Zero registers + ] + +-- | Check if a label is a GHC NCG-specific local label. +-- +-- NCG local labels start with Lc, Ls, Lu, or _L followed +-- by a lowercase letter. +ghcIsLocalLabel :: Text -> Bool +ghcIsLocalLabel lbl + | T.isPrefixOf "Lc" lbl = True + | T.isPrefixOf "Ls" lbl = True + | T.isPrefixOf "Lu" lbl = True + | isNCGInternal lbl = True + | otherwise = False + +-- | Check if a label is an NCG-internal label. +-- These start with _L followed by a lowercase letter. +isNCGInternal :: Text -> Bool +isNCGInternal lbl = case T.unpack lbl of + '_':'L':c:_ -> c >= 'a' && c <= 'z' + _ -> False + +-- | Check if a finding is a GHC runtime pattern. +-- +-- Recognizes heap checks, tag checks, CAF checks, +-- closure entry, dictionary calls, RTS calls, and +-- closure table lookups. +isGhcRuntimeFinding :: LineMap -> NctFinding -> Bool +isGhcRuntimeFinding lineMap f = case nctReason f of + CondBranch -> isHeapCheck || isNurseryCheck + || isTagCheck || isCafCheck + || isArityCheck + IndirectBranch -> isClosureEntry || isDictCall + || isRtsCall + RegIndexAddr -> isClosureTableLookup + _ -> False + where + -- Heap check: cmp <r>, x28 + isHeapCheck :: Bool + isHeapCheck = case prevInstr of + Just (Cmp _ (OpReg X28)) -> True + _ -> False + + -- Nursery check: cmp + recent ldr from [x19, #offset] + isNurseryCheck :: Bool + isNurseryCheck = case prevInstr of + Just (Cmp _ (OpReg _)) -> + any isBaseRegLoad recentInstrs5 + _ -> False + + isBaseRegLoad :: Instr -> Bool + isBaseRegLoad instr = case instr of + Ldr _ (BaseImm X19 _) -> True + Ldur _ (BaseImm X19 _) -> True + _ -> False + + -- Tag check patterns + isTagCheck :: Bool + isTagCheck = case prevInstr of + Just (Tst _ (OpImm 7)) -> True + Just (Cmp _ (OpImm _)) -> + any isTagMask recentInstrs5 + Just (Cmp r1 (OpReg r2)) -> + any (isTagMaskFor r1) recentInstrs5 + || any (isTagMaskFor r2) recentInstrs5 + _ -> case nctInstr f of + Cbnz reg _ -> + any (isTagMaskFor reg) recentInstrs5 + Cbz reg _ -> + any (isTagMaskFor reg) recentInstrs5 + _ -> False + + isTagMask :: Instr -> Bool + isTagMask instr = case instr of + And _ _ (OpImm 7) -> True + _ -> False + + isTagMaskFor :: Reg -> Instr -> Bool + isTagMaskFor reg instr = case instr of + And r _ (OpImm 7) -> r == reg + _ -> False + + -- Arity/closure-type check + isArityCheck :: Bool + isArityCheck = case prevInstr of + Just (Cmp _ (OpImm _)) -> + any isInfoTableLoad recentInstrs5 + Just (Cmp r1 (OpReg r2)) -> + any (isInfoTableLoadFor r1) recentInstrs5 + || any (isInfoTableLoadFor r2) recentInstrs5 + _ -> False + + isInfoTableLoad :: Instr -> Bool + isInfoTableLoad instr = case instr of + Ldur _ (BaseImm _ off) -> off < 0 + Ldr _ (BaseImm _ off) -> off < 0 + _ -> False + + isInfoTableLoadFor :: Reg -> Instr -> Bool + isInfoTableLoadFor reg instr = case instr of + Ldur r (BaseImm _ off) -> + samePhysicalReg r reg && off < 0 + Ldr r (BaseImm _ off) -> + samePhysicalReg r reg && off < 0 + _ -> False + + -- CAF check: cbz/cbnz x0 after bl _newCAF + isCafCheck :: Bool + isCafCheck = case nctInstr f of + Cbz X0 _ -> any isCafCall recentInstrs5 + Cbnz X0 _ -> any isCafCall recentInstrs5 + Cbz reg _ -> any isCafCall recentInstrs5 + && any (isCafResultMove reg) recentInstrs5 + Cbnz reg _ -> any isCafCall recentInstrs5 + && any (isCafResultMove reg) recentInstrs5 + _ -> False + + isCafCall :: Instr -> Bool + isCafCall instr = case instr of + Bl label -> "_newCAF" `T.isSuffixOf` label + _ -> False + + isCafResultMove :: Reg -> Instr -> Bool + isCafResultMove reg instr = case instr of + Mov r (OpReg X0) -> r == reg + _ -> False + + -- Closure entry: ldr/ldur to reg, then blr/br reg + isClosureEntry :: Bool + isClosureEntry = case nctInstr f of + Blr reg -> any (isRegLoad reg) bbInstrs + Br reg -> any (isRegLoad reg) bbInstrs + _ -> False + + -- Dictionary/vtable call + isDictCall :: Bool + isDictCall = case nctInstr f of + Blr reg -> any (isDictLoad reg) bbInstrs + Br reg -> any (isDictLoad reg) bbInstrs + _ -> False + + isRegLoad :: Reg -> Instr -> Bool + isRegLoad reg instr = case instr of + Ldr r (BaseImm _ _) -> r == reg + Ldr r (PreIndex _ _) -> r == reg + Ldr r (PostIndex _ _) -> r == reg + Ldur r (BaseImm _ _) -> r == reg + Ldp r1 r2 (BaseImm _ _) -> + r1 == reg || r2 == reg + _ -> False + + isDictLoad :: Reg -> Instr -> Bool + isDictLoad reg instr = case instr of + Ldr r (BaseImm _ off) -> + r == reg && off > 0 + Ldr r (PreIndex _ off) -> + r == reg && off > 0 + Ldr r (PostIndex _ off) -> + r == reg && off > 0 + _ -> False + + -- RTS call: adrp <reg>, _stg_* or STG stack load + isRtsCall :: Bool + isRtsCall = case nctInstr f of + Blr reg -> any (isRtsSymbolLoad reg) bbInstrs + || any (isStgStackLoad reg) bbInstrs + Br reg -> any (isRtsSymbolLoad reg) bbInstrs + || any (isStgStackLoad reg) bbInstrs + _ -> False + + isRtsSymbolLoad :: Reg -> Instr -> Bool + isRtsSymbolLoad reg instr = case instr of + Adrp r label -> + r == reg && "_stg_" `T.isInfixOf` label + _ -> False + + isStgStackLoad :: Reg -> Instr -> Bool + isStgStackLoad reg instr = case instr of + Ldr r (BaseImm X20 _) -> r == reg + Ldr r (BaseImm X19 _) -> r == reg + Ldur r (BaseImm X20 _) -> r == reg + Ldur r (BaseImm X19 _) -> r == reg + _ -> False + + -- Closure table lookup + isClosureTableLookup :: Bool + isClosureTableLookup = case getBaseReg (nctInstr f) of + Nothing -> False + Just reg -> + any (isClosureTableLoad reg) bbInstrs + + getBaseReg :: Instr -> Maybe Reg + getBaseReg instr = case instr of + Ldr _ addr -> baseOfRegIndex addr + Ldrb _ addr -> baseOfRegIndex addr + Ldrh _ addr -> baseOfRegIndex addr + Str _ addr -> baseOfRegIndex addr + Strb _ addr -> baseOfRegIndex addr + Strh _ addr -> baseOfRegIndex addr + _ -> Nothing + + baseOfRegIndex :: AddrMode -> Maybe Reg + baseOfRegIndex addr = case addr of + BaseReg base _ -> Just base + BaseRegShift base _ _ -> Just base + BaseRegExtend base _ _ -> Just base + _ -> Nothing + + isClosureTableLoad :: Reg -> Instr -> Bool + isClosureTableLoad reg instr = case instr of + Adrp r label -> + samePhysicalReg r reg + && "_closure_tbl" `T.isInfixOf` label + _ -> False + + -- Helpers for line context + prevInstr :: Maybe Instr + prevInstr = do + prevLine <- IntMap.lookup + (nctLine f - 1) lineMap + lineInstr prevLine + + recentInstrs5 :: [Instr] + recentInstrs5 = recentInstrs 5 + + recentInstrs :: Int -> [Instr] + recentInstrs n = + [ instr + | offset <- [1..n] + , Just ln <- [IntMap.lookup + (nctLine f - offset) lineMap] + , Just instr <- [lineInstr ln] + ] + + -- All instructions in current basic block + bbInstrs :: [Instr] + bbInstrs = go (nctLine f - 1) + where + go lnum + | lnum <= 0 = [] + | otherwise = + case IntMap.lookup lnum lineMap of + Nothing -> [] + Just ln -> case lineLabel ln of + Just lbl | isBBLabel lbl -> [] + _ -> case lineInstr ln of + Just instr -> + instr : go (lnum - 1) + Nothing -> go (lnum - 1) + + isBBLabel :: Text -> Bool + isBBLabel lbl = "LBB" `T.isPrefixOf` lbl + +-- | Check if two registers refer to the same physical +-- register. W and X variants of the same number are the +-- same physical register. +samePhysicalReg :: Reg -> Reg -> Bool +samePhysicalReg r1 r2 = + r1 == r2 || physNum r1 == physNum r2 + where + physNum :: Reg -> Maybe Int + physNum r = case r of + X0 -> Just 0; W0 -> Just 0 + X1 -> Just 1; W1 -> Just 1 + X2 -> Just 2; W2 -> Just 2 + X3 -> Just 3; W3 -> Just 3 + X4 -> Just 4; W4 -> Just 4 + X5 -> Just 5; W5 -> Just 5 + X6 -> Just 6; W6 -> Just 6 + X7 -> Just 7; W7 -> Just 7 + X8 -> Just 8; W8 -> Just 8 + X9 -> Just 9; W9 -> Just 9 + X10 -> Just 10; W10 -> Just 10 + X11 -> Just 11; W11 -> Just 11 + X12 -> Just 12; W12 -> Just 12 + X13 -> Just 13; W13 -> Just 13 + X14 -> Just 14; W14 -> Just 14 + X15 -> Just 15; W15 -> Just 15 + X16 -> Just 16; W16 -> Just 16 + X17 -> Just 17; W17 -> Just 17 + X18 -> Just 18; W18 -> Just 18 + X19 -> Just 19; W19 -> Just 19 + X20 -> Just 20; W20 -> Just 20 + X21 -> Just 21; W21 -> Just 21 + X22 -> Just 22; W22 -> Just 22 + X23 -> Just 23; W23 -> Just 23 + X24 -> Just 24; W24 -> Just 24 + X25 -> Just 25; W25 -> Just 25 + X26 -> Just 26; W26 -> Just 26 + X27 -> Just 27; W27 -> Just 27 + X28 -> Just 28; W28 -> Just 28 + X29 -> Just 29; W29 -> Just 29 + X30 -> Just 30; W30 -> Just 30 + _ -> Nothing + +-- | Z-encode a human-readable Haskell symbol for GHC +-- assembly lookup. +-- +-- Input format: @\<package\>:\<Module.Path\>:\<identifier\>@ +-- +-- Output: @_\<z-pkg\>_\<z-mod\>_\<z-id\>_info$def@ +zEncodeSymbol :: Text -> Either Text Text +zEncodeSymbol input = + case T.splitOn ":" input of + [pkg, modPath, ident] -> + let encoded = T.intercalate "_" + [ zEncodePart pkg + , zEncodePart modPath + , zEncodePart ident + ] + in Right ("_" <> encoded <> "_info$def") + parts -> + Left $ "Invalid symbol format: expected " + <> "<package>:<Module.Path>:<id>, got " + <> T.pack (show (length parts)) + <> " parts" + +-- | Z-encode a single component. +zEncodePart :: Text -> Text +zEncodePart = T.concatMap encodeChar + where + encodeChar c = case c of + '-' -> "zm" + '.' -> "zi" + '_' -> "zu" + 'z' -> "zz" + 'Z' -> "ZZ" + '$' -> "zd" + '\'' -> "zq" + '#' -> "zh" + _ -> T.singleton c diff --git a/lib/Audit/AArch64/Taint.hs b/lib/Audit/AArch64/Taint.hs @@ -24,7 +24,6 @@ module Audit.AArch64.Taint ( , setTaint , getProvenance , getKind - , publicRoots , joinTaintState , runDataflow , runDataflowWithConfig @@ -50,6 +49,8 @@ import Audit.AArch64.CFG , cfgBlockCount, indexBlock , isFunctionLabel ) +import Audit.AArch64.Runtime + (RuntimeConfig(..), SecondaryStack(..)) import Audit.AArch64.Types ( Reg(..), Instr(..), Line(..), Operand(..), AddrMode(..) , Taint(..), joinTaint, Provenance(..), joinProvenance @@ -68,7 +69,7 @@ import Data.Text (Text) -- | Taint state: maps registers to their publicness, plus stack slots. -- Also tracks provenance for upgrading Unknown to Public when provable, -- and register kinds (pointer vs scalar) for safe provenance upgrades. --- Tracks both hardware stack (SP) and STG stack (X20) slots separately. +-- Tracks both hardware stack (SP) and secondary stack slots separately. -- A coarse heap bucket captures taint for non-stack memory accesses. -- A refined heap slot map tracks [base, #imm] accesses for public pointers. data TaintState = TaintState @@ -85,11 +86,11 @@ data TaintState = TaintState , tsStackKind :: !(IntMap RegKind) -- ^ Stack slot kinds (keyed by SP offset) , tsStgStack :: !(IntMap Taint) - -- ^ STG stack slot taints (keyed by X20 offset) + -- ^ Secondary stack slot taints (keyed by base reg offset) , tsStgStackProv :: !(IntMap Provenance) - -- ^ STG stack slot provenance (keyed by X20 offset) + -- ^ Secondary stack slot provenance , tsStgStackKind :: !(IntMap RegKind) - -- ^ STG stack slot kinds (keyed by X20 offset) + -- ^ Secondary stack slot kinds , tsHeapTaint :: !Taint -- ^ Coarse heap taint bucket (joined from all non-stack stores) , tsHeapProv :: !Provenance @@ -99,40 +100,20 @@ data TaintState = TaintState , tsHeapSlots :: !(Map (Reg, Int) (Taint, Provenance, RegKind)) -- ^ Refined heap slots keyed by (base register, offset) , tsAssumeStgPublic :: !Bool - -- ^ Assume untracked STG stack slots are public pointers + -- ^ Assume untracked secondary stack slots are public pointers } deriving (Eq, Show) --- | GHC 9.10.3 AArch64 public root registers. --- --- These are known to be public (derived from stack/heap pointers or --- constants): --- --- - SP: hardware stack pointer --- - X29/FP: frame pointer --- - X19: GHC Base register --- - X20: GHC Sp (STG stack pointer) --- - X21: GHC Hp (heap pointer) --- - X28: GHC SpLim (stack limit) --- - X18: TLS/platform register (Darwin) --- - XZR/WZR: zero registers -publicRoots :: [Reg] -publicRoots = - [ SP, X29 -- Hardware stack/frame pointers - , X19 -- GHC BaseReg (capabilities/TSO pointer) - , X20 -- GHC Sp (STG stack pointer) - , X21 -- GHC Hp (heap pointer) - , X22 -- GHC HpLim (heap limit) - , X28 -- GHC SpLim (stack limit) - , X18 -- TLS (Darwin platform register) - , XZR, WZR -- Zero registers - ] +-- | Check if a register is a secondary stack base register. +isSecStackBase :: RuntimeConfig -> Reg -> Bool +isSecStackBase rt r = case rtSecondaryStack rt of + Nothing -> False + Just ss -> ssBaseReg ss == r --- | Empty taint state (no known taints). --- Assumes STG stack slots are public by default. -emptyTaintState :: TaintState -emptyTaintState = emptyTaintStateWith True +-- | Check if a register is in the public roots. +isPublicRoot :: RuntimeConfig -> Reg -> Bool +isPublicRoot rt r = r `elem` rtPublicRoots rt --- | Empty taint state with configurable STG assumption. +-- | Empty taint state with configurable secondary stack assumption. emptyTaintStateWith :: Bool -> TaintState emptyTaintStateWith assumeStgPublic = TaintState { tsRegs = Map.empty @@ -154,44 +135,56 @@ emptyTaintStateWith assumeStgPublic = TaintState -- | Initial taint state with public roots marked. -- Public roots are marked as Ptr kind (they are pointers). -- Heap bucket starts Unknown (conservative for untracked memory). --- Assumes STG stack slots are public by default. -initTaintState :: TaintState -initTaintState = initTaintStateWith True - --- | Initial taint state with configurable STG assumption. -initTaintStateWith :: Bool -> TaintState -initTaintStateWith assumeStgPublic = TaintState - { tsRegs = Map.fromList [(r, Public) | r <- publicRoots] - , tsStack = IM.empty - , tsProv = Map.fromList [(r, ProvPublic) | r <- publicRoots] - , tsStackProv = IM.empty - , tsKind = Map.fromList [(r, KindPtr) | r <- publicRoots] - , tsStackKind = IM.empty - , tsStgStack = IM.empty - , tsStgStackProv = IM.empty - , tsStgStackKind = IM.empty - , tsHeapTaint = Unknown - , tsHeapProv = ProvUnknown - , tsHeapKind = KindUnknown - , tsHeapSlots = Map.empty - , tsAssumeStgPublic = assumeStgPublic - } - --- | Seed argument registers and STG stack slots according to policy. --- Secret registers are marked Secret with ProvUnknown. --- Public registers are marked Public with ProvPublic. --- Secret pointee registers are marked Public/ProvSecretData/KindPtr. --- If a register/slot appears in multiple lists, secret takes precedence. -seedArgs :: ArgPolicy -> TaintState -> TaintState -seedArgs policy st = - let -- Apply public first, then secret pointee, then secret (so secret wins) +initTaintState :: RuntimeConfig -> TaintState +initTaintState rt = initTaintStateWith rt assumeStg + where + assumeStg = case rtSecondaryStack rt of + Nothing -> False + Just ss -> ssAssumePublic ss + +-- | Initial taint state with configurable secondary stack assumption. +initTaintStateWith :: RuntimeConfig -> Bool -> TaintState +initTaintStateWith rt assumeStgPublic = + let roots = rtPublicRoots rt + in TaintState + { tsRegs = Map.fromList + [(r, Public) | r <- roots] + , tsStack = IM.empty + , tsProv = Map.fromList + [(r, ProvPublic) | r <- roots] + , tsStackProv = IM.empty + , tsKind = Map.fromList + [(r, KindPtr) | r <- roots] + , tsStackKind = IM.empty + , tsStgStack = IM.empty + , tsStgStackProv = IM.empty + , tsStgStackKind = IM.empty + , tsHeapTaint = Unknown + , tsHeapProv = ProvUnknown + , tsHeapKind = KindUnknown + , tsHeapSlots = Map.empty + , tsAssumeStgPublic = assumeStgPublic + } + +-- | Seed argument registers and secondary stack slots +-- according to policy. +seedArgs :: RuntimeConfig -> ArgPolicy -> TaintState + -> TaintState +seedArgs rt policy st = + let -- Apply public first, then secret pointee, then secret st1 = Set.foldr markPublic st (apPublic policy) - st2 = Set.foldr markSecretPointee st1 (apSecretPointee policy) + st2 = Set.foldr markSecretPointee st1 + (apSecretPointee policy) st3 = Set.foldr markSecret st2 (apSecret policy) - -- Seed STG stack slots - st4 = Set.foldr markStgPublic st3 (apStgPublic policy) - st5 = Set.foldr markStgSecret st4 (apStgSecret policy) - in st5 + -- Seed secondary stack slots (conditional) + st4 = case rtSecondaryStack rt of + Nothing -> st3 + Just _ -> + let st4' = Set.foldr markStgPublic st3 + (apStgPublic policy) + in Set.foldr markStgSecret st4' + (apStgSecret policy) + in st4 where markPublic r s = s { tsRegs = Map.insert r Public (tsRegs s) @@ -207,14 +200,20 @@ seedArgs policy st = , tsProv = Map.insert r ProvUnknown (tsProv s) } markStgPublic off s = s - { tsStgStack = IM.insert off Public (tsStgStack s) - , tsStgStackProv = IM.insert off ProvPublic (tsStgStackProv s) - , tsStgStackKind = IM.insert off KindScalar (tsStgStackKind s) + { tsStgStack = + IM.insert off Public (tsStgStack s) + , tsStgStackProv = + IM.insert off ProvPublic (tsStgStackProv s) + , tsStgStackKind = + IM.insert off KindScalar (tsStgStackKind s) } markStgSecret off s = s - { tsStgStack = IM.insert off Secret (tsStgStack s) - , tsStgStackProv = IM.insert off ProvUnknown (tsStgStackProv s) - , tsStgStackKind = IM.insert off KindScalar (tsStgStackKind s) + { tsStgStack = + IM.insert off Secret (tsStgStack s) + , tsStgStackProv = + IM.insert off ProvUnknown (tsStgStackProv s) + , tsStgStackKind = + IM.insert off KindScalar (tsStgStackKind s) } -- | Get the taint of a register. @@ -223,266 +222,312 @@ getTaint r st = Map.findWithDefault Unknown r (tsRegs st) -- | Get the provenance of a register. getProvenance :: Reg -> TaintState -> Provenance -getProvenance r st = Map.findWithDefault ProvUnknown r (tsProv st) +getProvenance r st = + Map.findWithDefault ProvUnknown r (tsProv st) -- | Get the kind of a register. getKind :: Reg -> TaintState -> RegKind -getKind r st = Map.findWithDefault KindUnknown r (tsKind st) +getKind r st = + Map.findWithDefault KindUnknown r (tsKind st) -- | Analyze a single line, updating taint state. -analyzeLine :: Line -> TaintState -> TaintState -analyzeLine l st = case lineInstr l of +analyzeLine :: RuntimeConfig -> Line -> TaintState + -> TaintState +analyzeLine rt l st = case lineInstr l of Nothing -> st - Just instr -> transfer instr st + Just instr -> transfer rt instr st -- | Analyze a basic block, threading taint state through. -analyzeBlock :: [Line] -> TaintState -> TaintState -analyzeBlock lns st = foldl' (flip analyzeLine) st lns +analyzeBlock :: RuntimeConfig -> [Line] -> TaintState + -> TaintState +analyzeBlock rt lns st = + foldl' (flip (analyzeLine rt)) st lns -- | Transfer function for taint analysis. --- --- For each instruction, determine how it affects register taints. --- Also tracks provenance for upgrading Unknown bases. -transfer :: Instr -> TaintState -> TaintState -transfer instr st = case instr of +transfer :: RuntimeConfig -> Instr -> TaintState + -> TaintState +transfer rt instr st = case instr of -- Move: destination gets source taint, provenance, and kind Mov dst op -> - setTaintProvKind dst (operandTaint op st) (operandProv op st) - (operandKind op st) st + setTaintProvKind dst (operandTaint op st) + (operandProv op st) (operandKind op st) st Movz dst _ _ -> - setTaintProvKind dst Public ProvPublic KindScalar st -- Immediate is scalar - Movk _ _ _ -> st -- Keeps existing value, modifies bits + setTaintProvKind dst Public ProvPublic KindScalar st + Movk _ _ _ -> st Movn dst _ _ -> - setTaintProvKind dst Public ProvPublic KindScalar st -- Immediate is scalar + setTaintProvKind dst Public ProvPublic KindScalar st - -- Arithmetic: result is join of operand taints/provenances - -- Clear stack map if SP is modified (offsets become stale) - -- For X20 with immediate operand, shift STG stack map instead of clearing - -- Pointer arithmetic (ptr + imm) preserves pointer kind + -- Arithmetic Add dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) k = pointerArithKind r1 op st - in updateWithX20ShiftAdd dst r1 op t p k st + in updateWithSecStackShiftAdd rt dst r1 op t p k st Sub dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) k = pointerArithKind r1 op st - in updateWithX20ShiftSub dst r1 op t p k st + in updateWithSecStackShiftSub rt dst r1 op t p k st Adds dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) k = pointerArithKind r1 op st - in updateWithX20ShiftAdd dst r1 op t p k st + in updateWithSecStackShiftAdd rt dst r1 op t p k st Subs dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) k = pointerArithKind r1 op st - in updateWithX20ShiftSub dst r1 op t p k st + in updateWithSecStackShiftSub rt dst r1 op t p k st Adc dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Adcs dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Sbc dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Neg dst op -> - setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st + setTaintProvKind dst (operandTaint op st) + (operandProv op st) KindScalar st Negs dst op -> - setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st + setTaintProvKind dst (operandTaint op st) + (operandProv op st) KindScalar st Mul dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Mneg dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Madd dst r1 r2 r3 -> - let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) - p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) - (getProvenance r3 st) + let t = join3 (getTaint r1 st) (getTaint r2 st) + (getTaint r3 st) + p = provJoin3 (getProvenance r1 st) + (getProvenance r2 st) + (getProvenance r3 st) in setTaintProvKind dst t p KindScalar st Msub dst r1 r2 r3 -> - let t = join3 (getTaint r1 st) (getTaint r2 st) (getTaint r3 st) - p = provJoin3 (getProvenance r1 st) (getProvenance r2 st) - (getProvenance r3 st) + let t = join3 (getTaint r1 st) (getTaint r2 st) + (getTaint r3 st) + p = provJoin3 (getProvenance r1 st) + (getProvenance r2 st) + (getProvenance r3 st) in setTaintProvKind dst t p KindScalar st Umulh dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Smulh dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Udiv dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Sdiv dst r1 r2 -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st - -- Logical: result is join of operand taints/provenances - -- Special case: pointer untagging preserves ProvPublic and KindPtr + -- Logical: special case for pointer untagging And dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) srcProv = getProvenance r1 st srcKind = getKind r1 st - isUntag = isPointerUntagMask op - -- Pointer untagging: preserve provenance and kind if source was Ptr - p' = if isUntag && srcProv == ProvPublic then ProvPublic else p - k = if isUntag && srcKind == KindPtr then KindPtr else KindScalar + isUntag = isPointerUntagMask rt op + p' = if isUntag && srcProv == ProvPublic + then ProvPublic else p + k = if isUntag && srcKind == KindPtr + then KindPtr else KindScalar in setTaintProvKind dst t p' k st Orr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Eor dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Bic dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Mvn dst op -> - setTaintProvKind dst (operandTaint op st) (operandProv op st) KindScalar st - Tst _ _ -> st -- No destination + setTaintProvKind dst (operandTaint op st) + (operandProv op st) KindScalar st + Tst _ _ -> st -- Shifts: result is scalar Lsl dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Lsr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Asr dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st Ror dst r1 op -> let t = join2 (getTaint r1 st) (operandTaint op st) - p = provJoin2 (getProvenance r1 st) (operandProv op st) + p = provJoin2 (getProvenance r1 st) + (operandProv op st) in setTaintProvKind dst t p KindScalar st -- Bit manipulation: result is scalar Ubfx dst r1 _ _ -> - setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st + setTaintProvKind dst (getTaint r1 st) + (getProvenance r1 st) KindScalar st Sbfx dst r1 _ _ -> - setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st + setTaintProvKind dst (getTaint r1 st) + (getProvenance r1 st) KindScalar st Bfi dst r1 _ _ -> let t = join2 (getTaint dst st) (getTaint r1 st) - p = provJoin2 (getProvenance dst st) (getProvenance r1 st) + p = provJoin2 (getProvenance dst st) + (getProvenance r1 st) in setTaintProvKind dst t p KindScalar st Extr dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st - -- Address generation: result is public pointer (constant pool / PC-relative) - Adr dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st - Adrp dst _ -> setTaintProvKind dst Public ProvPublic KindPtr st - - -- Loads: restore from stack slots if [sp, #imm], else Unknown - -- Exception: public roots stay public (GHC spills/restores them) - Ldr dst addr -> loadFromStack dst addr st - Ldrb dst addr -> loadFromStack dst addr st - Ldrh dst addr -> loadFromStack dst addr st - Ldrsb dst addr -> loadFromStack dst addr st - Ldrsh dst addr -> loadFromStack dst addr st - Ldrsw dst addr -> loadFromStack dst addr st - Ldur dst addr -> loadFromStack dst addr st - Ldp dst1 dst2 addr -> loadPairFromStack dst1 dst2 addr st - - -- Stores: track stack slot taints for [sp, #imm] patterns - Str src addr -> storeToStack src addr st - Strb src addr -> storeToStack src addr st - Strh src addr -> storeToStack src addr st - Stur src addr -> storeToStack src addr st - Stp src1 src2 addr -> storePairToStack src1 src2 addr st - - -- Acquire/release loads (same as regular loads) - Ldar dst addr -> loadFromStack dst addr st - Ldarb dst addr -> loadFromStack dst addr st - Ldarh dst addr -> loadFromStack dst addr st - - -- Release stores (same as regular stores) - Stlr src addr -> storeToStack src addr st - Stlrb src addr -> storeToStack src addr st - Stlrh src addr -> storeToStack src addr st - - -- Exclusive loads (same as regular loads) - Ldxr dst addr -> loadFromStack dst addr st - Ldxrb dst addr -> loadFromStack dst addr st - Ldxrh dst addr -> loadFromStack dst addr st - - -- Exclusive stores: status reg is Public scalar, src participates in store + -- Address generation + Adr dst _ -> + setTaintProvKind dst Public ProvPublic KindPtr st + Adrp dst _ -> + setTaintProvKind dst Public ProvPublic KindPtr st + + -- Loads + Ldr dst addr -> loadFromStack rt dst addr st + Ldrb dst addr -> loadFromStack rt dst addr st + Ldrh dst addr -> loadFromStack rt dst addr st + Ldrsb dst addr -> loadFromStack rt dst addr st + Ldrsh dst addr -> loadFromStack rt dst addr st + Ldrsw dst addr -> loadFromStack rt dst addr st + Ldur dst addr -> loadFromStack rt dst addr st + Ldp dst1 dst2 addr -> + loadPairFromStack rt dst1 dst2 addr st + + -- Stores + Str src addr -> storeToStack rt src addr st + Strb src addr -> storeToStack rt src addr st + Strh src addr -> storeToStack rt src addr st + Stur src addr -> storeToStack rt src addr st + Stp src1 src2 addr -> + storePairToStack rt src1 src2 addr st + + -- Acquire/release loads + Ldar dst addr -> loadFromStack rt dst addr st + Ldarb dst addr -> loadFromStack rt dst addr st + Ldarh dst addr -> loadFromStack rt dst addr st + + -- Release stores + Stlr src addr -> storeToStack rt src addr st + Stlrb src addr -> storeToStack rt src addr st + Stlrh src addr -> storeToStack rt src addr st + + -- Exclusive loads + Ldxr dst addr -> loadFromStack rt dst addr st + Ldxrb dst addr -> loadFromStack rt dst addr st + Ldxrh dst addr -> loadFromStack rt dst addr st + + -- Exclusive stores Stxr status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' Stxrb status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' Stxrh status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' - -- Acquire-exclusive loads (same as regular loads) - Ldaxr dst addr -> loadFromStack dst addr st - Ldaxrb dst addr -> loadFromStack dst addr st - Ldaxrh dst addr -> loadFromStack dst addr st + -- Acquire-exclusive loads + Ldaxr dst addr -> loadFromStack rt dst addr st + Ldaxrb dst addr -> loadFromStack rt dst addr st + Ldaxrh dst addr -> loadFromStack rt dst addr st - -- Release-exclusive stores: status reg is Public scalar + -- Release-exclusive stores Stlxr status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' Stlxrb status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' Stlxrh status src addr -> - let st' = storeToStack src addr st - in setTaintProvKind status Public ProvPublic KindScalar st' + let st' = storeToStack rt src addr st + in setTaintProvKind status Public ProvPublic + KindScalar st' - -- Conditionals: result is scalar (conservative) + -- Conditionals Cmp _ _ -> st Cmn _ _ -> st Csel dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) k = joinKind (getKind r1 st) (getKind r2 st) in setTaintProvKind dst t p k st Csinc dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Csinv dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Csneg dst r1 r2 _ -> let t = join2 (getTaint r1 st) (getTaint r2 st) - p = provJoin2 (getProvenance r1 st) (getProvenance r2 st) + p = provJoin2 (getProvenance r1 st) + (getProvenance r2 st) in setTaintProvKind dst t p KindScalar st Cset dst _ -> - setTaintProvKind dst Public ProvPublic KindScalar st -- Condition flag + setTaintProvKind dst Public ProvPublic KindScalar st Cinc dst r1 _ -> - setTaintProvKind dst (getTaint r1 st) (getProvenance r1 st) KindScalar st + setTaintProvKind dst (getTaint r1 st) + (getProvenance r1 st) KindScalar st - -- Branches: no register change + -- Branches B _ -> st BCond _ _ -> st - Bl _ -> invalidateCallerSaved st -- Call may clobber + Bl _ -> invalidateCallerSaved st Blr _ -> invalidateCallerSaved st Br _ -> st Ret _ -> st @@ -495,7 +540,7 @@ transfer instr st = case instr of Nop -> st Svc _ -> invalidateCallerSaved st - -- Unknown instruction: conservative + -- Unknown instruction Other _ _ -> st -- | Set taint for a register. @@ -503,97 +548,100 @@ setTaint :: Reg -> Taint -> TaintState -> TaintState setTaint r t st = st { tsRegs = Map.insert r t (tsRegs st) } -- | Set both taint and provenance for a register. -setTaintProv :: Reg -> Taint -> Provenance -> TaintState -> TaintState +setTaintProv :: Reg -> Taint -> Provenance -> TaintState + -> TaintState setTaintProv r t p st = st { tsRegs = Map.insert r t (tsRegs st) , tsProv = Map.insert r p (tsProv st) } -- | Set taint, provenance, and kind for a register. -setTaintProvKind :: Reg -> Taint -> Provenance -> RegKind -> TaintState - -> TaintState +setTaintProvKind :: Reg -> Taint -> Provenance -> RegKind + -> TaintState -> TaintState setTaintProvKind r t p k st = st { tsRegs = Map.insert r t (tsRegs st) , tsProv = Map.insert r p (tsProv st) , tsKind = Map.insert r k (tsKind st) - -- Clear heap slots keyed by this register (they refer to old value) - , tsHeapSlots = Map.filterWithKey (\(base, _) _ -> base /= r) (tsHeapSlots st) + , tsHeapSlots = Map.filterWithKey + (\(base, _) _ -> base /= r) (tsHeapSlots st) } -- | Set taint for a loaded value from the heap bucket. --- Public roots always get Public/ProvPublic/KindPtr. --- Other destinations get the coarse heap taint. -setTaintLoadHeap :: Reg -> TaintState -> TaintState -setTaintLoadHeap dst st - | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st - | otherwise = setTaintProvKind dst (tsHeapTaint st) (tsHeapProv st) - (tsHeapKind st) st - where - isPublicRoot r = r `elem` publicRoots - --- | Set taint for a loaded value from a refined heap slot if available. --- If the slot (base, offset) exists in the refined map, use it. --- Otherwise fall back to the coarse heap bucket. --- Public roots always get Public/ProvPublic/KindPtr. -setTaintLoadHeapSlot :: Reg -> Reg -> Int -> TaintState -> TaintState -setTaintLoadHeapSlot dst base off st - | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st +setTaintLoadHeap :: RuntimeConfig -> Reg -> TaintState + -> TaintState +setTaintLoadHeap rt dst st + | isPublicRoot rt dst = + setTaintProvKind dst Public ProvPublic KindPtr st + | otherwise = + setTaintProvKind dst (tsHeapTaint st) (tsHeapProv st) + (tsHeapKind st) st + +-- | Set taint for a loaded value from a refined heap slot. +setTaintLoadHeapSlot :: RuntimeConfig -> Reg -> Reg -> Int + -> TaintState -> TaintState +setTaintLoadHeapSlot rt dst base off st + | isPublicRoot rt dst = + setTaintProvKind dst Public ProvPublic KindPtr st | otherwise = case getHeapSlot base off st of Just (t, p, k) -> setTaintProvKind dst t p k st - Nothing -> setTaintProvKind dst (tsHeapTaint st) (tsHeapProv st) - (tsHeapKind st) st - where - isPublicRoot r = r `elem` publicRoots + Nothing -> setTaintProvKind dst (tsHeapTaint st) + (tsHeapProv st) (tsHeapKind st) st -- | Set taint for a loaded value from a known stack slot. --- If we have tracked taint/provenance/kind at this offset, use it; else Unknown. --- Public roots always get Public/ProvPublic/KindPtr. -setTaintLoadStack :: Reg -> Int -> TaintState -> TaintState -setTaintLoadStack dst offset st - | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st +setTaintLoadStack :: RuntimeConfig -> Reg -> Int + -> TaintState -> TaintState +setTaintLoadStack rt dst offset st + | isPublicRoot rt dst = + setTaintProvKind dst Public ProvPublic KindPtr st | otherwise = - let taint = IM.findWithDefault Unknown offset (tsStack st) - prov = IM.findWithDefault ProvUnknown offset (tsStackProv st) - kind = IM.findWithDefault KindUnknown offset (tsStackKind st) + let taint = IM.findWithDefault Unknown offset + (tsStack st) + prov = IM.findWithDefault ProvUnknown offset + (tsStackProv st) + kind = IM.findWithDefault KindUnknown offset + (tsStackKind st) in setTaintProvKind dst taint prov kind st - where - isPublicRoot r = r `elem` publicRoots -- | Store taint, provenance, and kind to a stack slot. -setStackTaintProvKind :: Int -> Taint -> Provenance -> RegKind -> TaintState - -> TaintState +setStackTaintProvKind :: Int -> Taint -> Provenance + -> RegKind -> TaintState -> TaintState setStackTaintProvKind offset t p k st = st { tsStack = IM.insert offset t (tsStack st) , tsStackProv = IM.insert offset p (tsStackProv st) , tsStackKind = IM.insert offset k (tsStackKind st) } --- | Clear all stack slot taints, provenance, and kinds (when SP is modified). +-- | Clear all stack slot taints (when SP is modified). clearStackMap :: TaintState -> TaintState clearStackMap st = st - { tsStack = IM.empty, tsStackProv = IM.empty, tsStackKind = IM.empty } + { tsStack = IM.empty + , tsStackProv = IM.empty + , tsStackKind = IM.empty + } --- | Store taint, provenance, and kind to an STG stack slot. -setStgStackTaintProvKind :: Int -> Taint -> Provenance -> RegKind -> TaintState +-- | Store taint to a secondary stack slot. +setSecStackTaintProvKind :: Int -> Taint -> Provenance + -> RegKind -> TaintState -> TaintState -setStgStackTaintProvKind offset t p k st = st +setSecStackTaintProvKind offset t p k st = st { tsStgStack = IM.insert offset t (tsStgStack st) - , tsStgStackProv = IM.insert offset p (tsStgStackProv st) - , tsStgStackKind = IM.insert offset k (tsStgStackKind st) + , tsStgStackProv = IM.insert offset p + (tsStgStackProv st) + , tsStgStackKind = IM.insert offset k + (tsStgStackKind st) } --- | Clear all STG stack slot taints, provenance, and kinds (when X20 changes). -clearStgStackMap :: TaintState -> TaintState -clearStgStackMap st = st - { tsStgStack = IM.empty, tsStgStackProv = IM.empty, tsStgStackKind = IM.empty } - --- | Shift all STG stack slot offsets by a signed delta. --- Used when x20 is adjusted by a constant: after `sub x20, x20, #16`, the old --- slot at offset 8 is now at offset 24 relative to the new x20. --- For `add x20, x20, #imm`, delta should be `-imm`. --- For `sub x20, x20, #imm`, delta should be `+imm`. -shiftStgStackMap :: Int -> TaintState -> TaintState -shiftStgStackMap delta st +-- | Clear all secondary stack slot taints. +clearSecStackMap :: TaintState -> TaintState +clearSecStackMap st = st + { tsStgStack = IM.empty + , tsStgStackProv = IM.empty + , tsStgStackKind = IM.empty + } + +-- | Shift all secondary stack slot offsets by a signed delta. +shiftSecStackMap :: Int -> TaintState -> TaintState +shiftSecStackMap delta st | delta == 0 = st | otherwise = st { tsStgStack = shiftIM (tsStgStack st) @@ -601,27 +649,24 @@ shiftStgStackMap delta st , tsStgStackKind = shiftIM (tsStgStackKind st) } where - shiftIM im = IM.fromList [(k + delta, v) | (k, v) <- IM.toList im] + shiftIM im = + IM.fromList [(k + delta, v) | (k, v) <- IM.toList im] --- | Join taint, provenance, and kind into the heap bucket. --- Used for non-stack stores to propagate secrets through heap memory. -joinHeapBucket :: Taint -> Provenance -> RegKind -> TaintState -> TaintState +-- | Join taint into the heap bucket. +joinHeapBucket :: Taint -> Provenance -> RegKind + -> TaintState -> TaintState joinHeapBucket t p k st = st { tsHeapTaint = joinTaint (tsHeapTaint st) t , tsHeapProv = joinProvenance (tsHeapProv st) p , tsHeapKind = joinKind (tsHeapKind st) k } --- | Check if a register qualifies as a public pointer for refined heap tracking. --- Must be KindPtr with ProvPublic provenance. +-- | Check if a register is a public pointer for refined heap. isPublicPointer :: Reg -> TaintState -> Bool isPublicPointer r st = getKind r st == KindPtr && getProvenance r st == ProvPublic -- | Check if a register is a pointer to secret data. --- Loads through such pointers produce Secret values. --- Requires the register to be a Public pointer with ProvSecretData provenance. --- If arithmetic mutates it to a scalar, it's no longer a valid secret pointer. isSecretDataPointer :: Reg -> TaintState -> Bool isSecretDataPointer r st = getTaint r st == Public && @@ -629,63 +674,75 @@ isSecretDataPointer r st = getProvenance r st == ProvSecretData -- | Set a heap slot for a (base, offset) pair. -setHeapSlot :: Reg -> Int -> Taint -> Provenance -> RegKind -> TaintState - -> TaintState +setHeapSlot :: Reg -> Int -> Taint -> Provenance + -> RegKind -> TaintState -> TaintState setHeapSlot base off t p k st = st - { tsHeapSlots = Map.insert (base, off) (t, p, k) (tsHeapSlots st) } - --- | Get a heap slot for a (base, offset) pair, if it exists. -getHeapSlot :: Reg -> Int -> TaintState -> Maybe (Taint, Provenance, RegKind) -getHeapSlot base off st = Map.lookup (base, off) (tsHeapSlots st) - --- | Set taint for a loaded value from a known STG stack slot. --- If we have tracked taint/provenance/kind at this offset, use it. --- Otherwise, use the STG assumption: if tsAssumeStgPublic is True, --- untracked slots default to Public/ProvPublic/KindPtr (closure pointers); --- if False, they default to Unknown/ProvUnknown/KindUnknown. --- Public roots always get Public/ProvPublic/KindPtr. -setTaintLoadStgStack :: Reg -> Int -> TaintState -> TaintState -setTaintLoadStgStack dst offset st - | isPublicRoot dst = setTaintProvKind dst Public ProvPublic KindPtr st + { tsHeapSlots = Map.insert (base, off) (t, p, k) + (tsHeapSlots st) } + +-- | Get a heap slot for a (base, offset) pair. +getHeapSlot :: Reg -> Int -> TaintState + -> Maybe (Taint, Provenance, RegKind) +getHeapSlot base off st = + Map.lookup (base, off) (tsHeapSlots st) + +-- | Set taint for a loaded value from a secondary stack slot. +setTaintLoadSecStack :: RuntimeConfig -> Reg -> Int + -> TaintState -> TaintState +setTaintLoadSecStack rt dst offset st + | isPublicRoot rt dst = + setTaintProvKind dst Public ProvPublic KindPtr st | otherwise = let (defT, defP, defK) = if tsAssumeStgPublic st then (Public, ProvPublic, KindPtr) else (Unknown, ProvUnknown, KindUnknown) - taint = IM.findWithDefault defT offset (tsStgStack st) - prov = IM.findWithDefault defP offset (tsStgStackProv st) - kind = IM.findWithDefault defK offset (tsStgStackKind st) + taint = IM.findWithDefault defT offset + (tsStgStack st) + prov = IM.findWithDefault defP offset + (tsStgStackProv st) + kind = IM.findWithDefault defK offset + (tsStgStackKind st) in setTaintProvKind dst taint prov kind st - where - isPublicRoot r = r `elem` publicRoots --- | Update for Add instruction, shifting STG stack map when x20 += imm. --- After `add x20, x20, #imm`, old offset O maps to new offset O - imm. -updateWithX20ShiftAdd :: Reg -> Reg -> Operand -> Taint -> Provenance - -> RegKind -> TaintState -> TaintState -updateWithX20ShiftAdd dst r1 op t p k st - | dst == SP = clearStackMap (setTaintProvKind dst t p k st) - | dst == X20, r1 == X20, OpImm imm <- op = +-- | Update for Add, shifting secondary stack map when +-- base += imm. +updateWithSecStackShiftAdd + :: RuntimeConfig -> Reg -> Reg -> Operand + -> Taint -> Provenance -> RegKind -> TaintState + -> TaintState +updateWithSecStackShiftAdd rt dst r1 op t p k st + | dst == SP = + clearStackMap (setTaintProvKind dst t p k st) + | isSecStackBase rt dst, r1 == dst, OpImm imm <- op = let delta = negate (fromInteger imm) - in shiftStgStackMap delta (setTaintProvKind dst t p k st) - | dst == X20 = clearStgStackMap (setTaintProvKind dst t p k st) - | otherwise = setTaintProvKind dst t p k st - --- | Update for Sub instruction, shifting STG stack map when x20 -= imm. --- After `sub x20, x20, #imm`, old offset O maps to new offset O + imm. -updateWithX20ShiftSub :: Reg -> Reg -> Operand -> Taint -> Provenance - -> RegKind -> TaintState -> TaintState -updateWithX20ShiftSub dst r1 op t p k st - | dst == SP = clearStackMap (setTaintProvKind dst t p k st) - | dst == X20, r1 == X20, OpImm imm <- op = + in shiftSecStackMap delta + (setTaintProvKind dst t p k st) + | isSecStackBase rt dst = + clearSecStackMap (setTaintProvKind dst t p k st) + | otherwise = setTaintProvKind dst t p k st + +-- | Update for Sub, shifting secondary stack map when +-- base -= imm. +updateWithSecStackShiftSub + :: RuntimeConfig -> Reg -> Reg -> Operand + -> Taint -> Provenance -> RegKind -> TaintState + -> TaintState +updateWithSecStackShiftSub rt dst r1 op t p k st + | dst == SP = + clearStackMap (setTaintProvKind dst t p k st) + | isSecStackBase rt dst, r1 == dst, OpImm imm <- op = let delta = fromInteger imm - in shiftStgStackMap delta (setTaintProvKind dst t p k st) - | dst == X20 = clearStgStackMap (setTaintProvKind dst t p k st) - | otherwise = setTaintProvKind dst t p k st - --- | Track a store to stack if address is [sp, #imm] or [x20, #imm]. --- Pre/post-indexed addressing modifies the base, invalidating the stack map. -storeToStack :: Reg -> AddrMode -> TaintState -> TaintState -storeToStack src addr st = case addr of + in shiftSecStackMap delta + (setTaintProvKind dst t p k st) + | isSecStackBase rt dst = + clearSecStackMap (setTaintProvKind dst t p k st) + | otherwise = setTaintProvKind dst t p k st + +-- | Track a store to stack if address matches SP or +-- secondary stack base. +storeToStack :: RuntimeConfig -> Reg -> AddrMode + -> TaintState -> TaintState +storeToStack rt src addr st = case addr of -- Hardware stack (SP) BaseImm SP imm -> let off = fromInteger imm @@ -694,7 +751,6 @@ storeToStack src addr st = case addr of k = getKind src st in setStackTaintProvKind off t p k st PreIndex SP imm -> - -- Store at sp+imm, then SP changes let off = fromInteger imm t = getTaint src st p = getProvenance src st @@ -702,220 +758,221 @@ storeToStack src addr st = case addr of st' = setStackTaintProvKind off t p k st in clearStackMap st' PostIndex SP _ -> - -- Store at sp, then SP changes (offset 0) let t = getTaint src st p = getProvenance src st k = getKind src st st' = setStackTaintProvKind 0 t p k st in clearStackMap st' - -- STG stack (X20) - BaseImm X20 imm -> + -- Secondary stack + BaseImm base imm | isSecStackBase rt base -> let off = fromInteger imm t = getTaint src st p = getProvenance src st k = getKind src st - in setStgStackTaintProvKind off t p k st - PreIndex X20 imm -> - -- Store at x20+imm, then x20 = x20 + imm + in setSecStackTaintProvKind off t p k st + PreIndex base imm | isSecStackBase rt base -> let off = fromInteger imm t = getTaint src st p = getProvenance src st k = getKind src st delta = negate (fromInteger imm) - st' = setStgStackTaintProvKind off t p k st - in shiftStgStackMap delta st' - PostIndex X20 imm -> - -- Store at x20, then x20 = x20 + imm + st' = setSecStackTaintProvKind off t p k st + in shiftSecStackMap delta st' + PostIndex base imm | isSecStackBase rt base -> let t = getTaint src st p = getProvenance src st k = getKind src st delta = negate (fromInteger imm) - st' = setStgStackTaintProvKind 0 t p k st - in shiftStgStackMap delta st' - -- Refined heap slot: [base, #imm] where base is a public pointer - BaseImm base imm - | isPublicPointer base st -> - let off = fromInteger imm - t = getTaint src st - p = getProvenance src st - k = getKind src st - -- Update both refined slot and coarse bucket for soundness - in setHeapSlot base off t p k (joinHeapBucket t p k st) - -- Other non-stack stores: join source taint into heap bucket only + st' = setSecStackTaintProvKind 0 t p k st + in shiftSecStackMap delta st' + -- Refined heap slot + BaseImm base imm | isPublicPointer base st -> + let off = fromInteger imm + t = getTaint src st + p = getProvenance src st + k = getKind src st + in setHeapSlot base off t p k + (joinHeapBucket t p k st) + -- Other non-stack stores _ -> let t = getTaint src st p = getProvenance src st k = getKind src st in joinHeapBucket t p k st --- | Track a store pair to stack if address is [sp, #imm] or [x20, #imm]. --- Stores src1 at offset and src2 at offset+8. --- Pre/post-indexed addressing modifies the base, invalidating the stack map. -storePairToStack :: Reg -> Reg -> AddrMode -> TaintState -> TaintState -storePairToStack src1 src2 addr st = case addr of +-- | Track a store pair to stack. +storePairToStack :: RuntimeConfig -> Reg -> Reg -> AddrMode + -> TaintState -> TaintState +storePairToStack rt src1 src2 addr st = case addr of -- Hardware stack (SP) BaseImm SP imm -> let off = fromInteger imm - t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st + t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st in setStackTaintProvKind off t1 p1 k1 (setStackTaintProvKind (off + 8) t2 p2 k2 st) PreIndex SP imm -> - -- Store at sp+imm and sp+imm+8, then SP changes let off = fromInteger imm - t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st + t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st st' = setStackTaintProvKind off t1 p1 k1 - (setStackTaintProvKind (off + 8) t2 p2 k2 st) + (setStackTaintProvKind (off+8) t2 p2 k2 st) in clearStackMap st' PostIndex SP _ -> - -- Store at sp and sp+8, then SP changes - let t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st + let t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st st' = setStackTaintProvKind 0 t1 p1 k1 (setStackTaintProvKind 8 t2 p2 k2 st) in clearStackMap st' - -- STG stack (X20) - BaseImm X20 imm -> + -- Secondary stack + BaseImm base imm | isSecStackBase rt base -> let off = fromInteger imm - t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st - in setStgStackTaintProvKind off t1 p1 k1 - (setStgStackTaintProvKind (off + 8) t2 p2 k2 st) - PreIndex X20 imm -> - -- Store at x20+imm and x20+imm+8, then x20 = x20 + imm + t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st + in setSecStackTaintProvKind off t1 p1 k1 + (setSecStackTaintProvKind (off+8) t2 p2 k2 st) + PreIndex base imm | isSecStackBase rt base -> let off = fromInteger imm - t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st + t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st delta = negate (fromInteger imm) - st' = setStgStackTaintProvKind off t1 p1 k1 - (setStgStackTaintProvKind (off + 8) t2 p2 k2 st) - in shiftStgStackMap delta st' - PostIndex X20 imm -> - -- Store at x20 and x20+8, then x20 = x20 + imm - let t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st + st' = setSecStackTaintProvKind off t1 p1 k1 + (setSecStackTaintProvKind (off+8) t2 p2 k2 st) + in shiftSecStackMap delta st' + PostIndex base imm | isSecStackBase rt base -> + let t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st delta = negate (fromInteger imm) - st' = setStgStackTaintProvKind 0 t1 p1 k1 - (setStgStackTaintProvKind 8 t2 p2 k2 st) - in shiftStgStackMap delta st' - -- Refined heap slot: [base, #imm] where base is a public pointer - BaseImm base imm - | isPublicPointer base st -> - let off = fromInteger imm - t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st - -- Update refined slots and coarse bucket - st1 = setHeapSlot base off t1 p1 k1 st - st2 = setHeapSlot base (off + 8) t2 p2 k2 st1 - in joinHeapBucket t1 p1 k1 (joinHeapBucket t2 p2 k2 st2) - -- Other non-stack stores: join both sources into heap bucket only - _ -> let t1 = getTaint src1 st; p1 = getProvenance src1 st; k1 = getKind src1 st - t2 = getTaint src2 st; p2 = getProvenance src2 st; k2 = getKind src2 st - in joinHeapBucket t1 p1 k1 (joinHeapBucket t2 p2 k2 st) - --- | Load from memory, handling special cases: --- - [sp, #imm]: restore tracked hardware stack slot taint and provenance --- - [x20, #imm]: restore tracked STG stack slot taint and provenance --- - [r, symbol@GOTPAGEOFF]: GOT entry load, result is Public (address) --- - Other: Unknown unless dst is a public root --- Pre/post-indexed addressing modifies the base, invalidating the stack map. -loadFromStack :: Reg -> AddrMode -> TaintState -> TaintState -loadFromStack dst addr st = case addr of + st' = setSecStackTaintProvKind 0 t1 p1 k1 + (setSecStackTaintProvKind 8 t2 p2 k2 st) + in shiftSecStackMap delta st' + -- Refined heap slot + BaseImm base imm | isPublicPointer base st -> + let off = fromInteger imm + t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st + st1 = setHeapSlot base off t1 p1 k1 st + st2 = setHeapSlot base (off+8) t2 p2 k2 st1 + in joinHeapBucket t1 p1 k1 + (joinHeapBucket t2 p2 k2 st2) + -- Other non-stack stores + _ -> let t1 = getTaint src1 st; p1 = getProvenance src1 st + k1 = getKind src1 st + t2 = getTaint src2 st; p2 = getProvenance src2 st + k2 = getKind src2 st + in joinHeapBucket t1 p1 k1 + (joinHeapBucket t2 p2 k2 st) + +-- | Load from memory, handling SP, secondary stack, +-- GOT entries, secret pointers, and heap. +loadFromStack :: RuntimeConfig -> Reg -> AddrMode + -> TaintState -> TaintState +loadFromStack rt dst addr st = case addr of -- Hardware stack (SP) - BaseImm SP imm -> setTaintLoadStack dst (fromInteger imm) st + BaseImm SP imm -> + setTaintLoadStack rt dst (fromInteger imm) st PreIndex SP imm -> - -- Load from sp+imm, then sp changes - clearStackMap (setTaintLoadStack dst (fromInteger imm) st) + clearStackMap + (setTaintLoadStack rt dst (fromInteger imm) st) PostIndex SP imm -> - -- Load first, then clear (SP changes after access) - clearStackMap (setTaintLoadStack dst (fromInteger imm) st) - -- STG stack (X20) - BaseImm X20 imm -> setTaintLoadStgStack dst (fromInteger imm) st - PreIndex X20 imm -> - -- Load from x20+imm, then x20 = x20 + imm + clearStackMap + (setTaintLoadStack rt dst (fromInteger imm) st) + -- Secondary stack + BaseImm base imm | isSecStackBase rt base -> + setTaintLoadSecStack rt dst (fromInteger imm) st + PreIndex base imm | isSecStackBase rt base -> let delta = negate (fromInteger imm) - in shiftStgStackMap delta (setTaintLoadStgStack dst (fromInteger imm) st) - PostIndex X20 imm -> - -- Load first, then x20 = x20 + imm + in shiftSecStackMap delta + (setTaintLoadSecStack rt dst (fromInteger imm) st) + PostIndex base imm | isSecStackBase rt base -> let delta = negate (fromInteger imm) - in shiftStgStackMap delta (setTaintLoadStgStack dst (fromInteger imm) st) + in shiftSecStackMap delta + (setTaintLoadSecStack rt dst (fromInteger imm) st) -- Symbol/literal loads BaseSymbol _ _ -> - setTaintProv dst Public ProvPublic st -- GOT/PLT entry -> address + setTaintProv dst Public ProvPublic st Literal _ -> - setTaintProv dst Public ProvPublic st -- PC-relative literal -> address - -- Secret data pointer: [base, #imm] where base points to secret data - BaseImm base _imm - | isSecretDataPointer base st -> - setTaintProvKind dst Secret ProvUnknown KindScalar st - -- Refined heap slot: [base, #imm] where base is a public pointer - BaseImm base imm - | isPublicPointer base st -> - setTaintLoadHeapSlot dst base (fromInteger imm) st - -- Secret data pointer with register offset: [base, xM] variants - -- Note: Register-offset modes don't support refined heap slots because - -- the offset is dynamic - we can't determine which slot is being accessed. - -- These fall through to coarse heap handling if not secret data pointers. - BaseReg base _idx - | isSecretDataPointer base st -> - setTaintProvKind dst Secret ProvUnknown KindScalar st + setTaintProv dst Public ProvPublic st + -- Secret data pointer: [base, #imm] + BaseImm base _imm | isSecretDataPointer base st -> + setTaintProvKind dst Secret ProvUnknown KindScalar st + -- Refined heap slot + BaseImm base imm | isPublicPointer base st -> + setTaintLoadHeapSlot rt dst base (fromInteger imm) st + -- Secret data pointer with register offset + BaseReg base _idx | isSecretDataPointer base st -> + setTaintProvKind dst Secret ProvUnknown KindScalar st BaseRegShift base _idx _shift | isSecretDataPointer base st -> - setTaintProvKind dst Secret ProvUnknown KindScalar st + setTaintProvKind dst Secret ProvUnknown KindScalar st BaseRegExtend base _idx _ext | isSecretDataPointer base st -> - setTaintProvKind dst Secret ProvUnknown KindScalar st - -- Other loads: read from coarse heap bucket - _ -> setTaintLoadHeap dst st - --- | Load pair from stack if address is [sp, #imm] or [x20, #imm]. --- Loads dst1 from offset and dst2 from offset+8. --- Pre/post-indexed addressing modifies the base, invalidating the stack map. -loadPairFromStack :: Reg -> Reg -> AddrMode -> TaintState -> TaintState -loadPairFromStack dst1 dst2 addr st = case addr of + setTaintProvKind dst Secret ProvUnknown KindScalar st + -- Other loads: coarse heap bucket + _ -> setTaintLoadHeap rt dst st + +-- | Load pair from stack. +loadPairFromStack :: RuntimeConfig -> Reg -> Reg + -> AddrMode -> TaintState -> TaintState +loadPairFromStack rt dst1 dst2 addr st = case addr of -- Hardware stack (SP) BaseImm SP imm -> let off = fromInteger imm - in setTaintLoadStack dst1 off (setTaintLoadStack dst2 (off + 8) st) + in setTaintLoadStack rt dst1 off + (setTaintLoadStack rt dst2 (off + 8) st) PreIndex SP imm -> - -- Load from sp+imm and sp+imm+8, then sp changes let off = fromInteger imm - st' = setTaintLoadStack dst1 off (setTaintLoadStack dst2 (off + 8) st) + st' = setTaintLoadStack rt dst1 off + (setTaintLoadStack rt dst2 (off + 8) st) in clearStackMap st' PostIndex SP imm -> - -- Load first, then clear (SP changes after access) let off = fromInteger imm - st' = setTaintLoadStack dst1 off (setTaintLoadStack dst2 (off + 8) st) + st' = setTaintLoadStack rt dst1 off + (setTaintLoadStack rt dst2 (off + 8) st) in clearStackMap st' - -- STG stack (X20) - BaseImm X20 imm -> + -- Secondary stack + BaseImm base imm | isSecStackBase rt base -> let off = fromInteger imm - in setTaintLoadStgStack dst1 off (setTaintLoadStgStack dst2 (off + 8) st) - PreIndex X20 imm -> - -- Load from x20+imm and x20+imm+8, then x20 = x20 + imm + in setTaintLoadSecStack rt dst1 off + (setTaintLoadSecStack rt dst2 (off + 8) st) + PreIndex base imm | isSecStackBase rt base -> let off = fromInteger imm delta = negate (fromInteger imm) - st' = setTaintLoadStgStack dst1 off (setTaintLoadStgStack dst2 (off+8) st) - in shiftStgStackMap delta st' - PostIndex X20 imm -> - -- Load first, then x20 = x20 + imm + st' = setTaintLoadSecStack rt dst1 off + (setTaintLoadSecStack rt dst2 (off+8) st) + in shiftSecStackMap delta st' + PostIndex base imm | isSecStackBase rt base -> let off = fromInteger imm delta = negate (fromInteger imm) - st' = setTaintLoadStgStack dst1 off (setTaintLoadStgStack dst2 (off+8) st) - in shiftStgStackMap delta st' - -- Secret data pointer: [base, #imm] where base points to secret data - BaseImm base _imm - | isSecretDataPointer base st -> - setTaintProvKind dst1 Secret ProvUnknown KindScalar - (setTaintProvKind dst2 Secret ProvUnknown KindScalar st) - -- Refined heap slot: [base, #imm] where base is a public pointer - BaseImm base imm - | isPublicPointer base st -> - let off = fromInteger imm - in setTaintLoadHeapSlot dst1 base off - (setTaintLoadHeapSlot dst2 base (off + 8) st) - -- Other loads: read from coarse heap bucket - _ -> setTaintLoadHeap dst1 (setTaintLoadHeap dst2 st) + st' = setTaintLoadSecStack rt dst1 off + (setTaintLoadSecStack rt dst2 (off+8) st) + in shiftSecStackMap delta st' + -- Secret data pointer + BaseImm base _imm | isSecretDataPointer base st -> + setTaintProvKind dst1 Secret ProvUnknown KindScalar + (setTaintProvKind dst2 Secret ProvUnknown + KindScalar st) + -- Refined heap slot + BaseImm base imm | isPublicPointer base st -> + let off = fromInteger imm + in setTaintLoadHeapSlot rt dst1 base off + (setTaintLoadHeapSlot rt dst2 base (off+8) st) + -- Other loads: coarse heap bucket + _ -> setTaintLoadHeap rt dst1 + (setTaintLoadHeap rt dst2 st) -- | Get taint of an operand. operandTaint :: Operand -> TaintState -> Taint @@ -935,36 +992,30 @@ operandProv op st = case op of OpShiftedReg r _ -> getProvenance r st OpExtendedReg r _ -> getProvenance r st OpLabel _ -> ProvPublic - OpAddr _ -> ProvUnknown -- Address operand provenance is complex + OpAddr _ -> ProvUnknown -- | Get kind of an operand. operandKind :: Operand -> TaintState -> RegKind operandKind op st = case op of OpReg r -> getKind r st - OpImm _ -> KindScalar -- Immediates are scalar values + OpImm _ -> KindScalar OpShiftedReg r _ -> getKind r st OpExtendedReg r _ -> getKind r st - OpLabel _ -> KindPtr -- Labels are addresses (pointers) - OpAddr _ -> KindUnknown -- Address operand kind is complex + OpLabel _ -> KindPtr + OpAddr _ -> KindUnknown --- | Compute kind for pointer arithmetic (add/sub). --- If base is a pointer and operand is immediate, result is pointer. --- Otherwise result is scalar. +-- | Compute kind for pointer arithmetic. pointerArithKind :: Reg -> Operand -> TaintState -> RegKind pointerArithKind base op st = case op of OpImm _ | getKind base st == KindPtr -> KindPtr _ -> KindScalar --- | Check if operand is a GHC pointer-untagging mask. --- GHC uses low 3 bits for pointer tagging; this mask clears them. --- Recognizing this pattern allows whitelisting heap traversal. -isPointerUntagMask :: Operand -> Bool -isPointerUntagMask (OpImm imm) = - imm == 0xfffffffffffffff8 || -- 64-bit mask - imm == 0xfffffff8 || -- 32-bit mask - imm == -8 -- signed representation -isPointerUntagMask _ = False +-- | Check if operand is a pointer-untagging mask. +isPointerUntagMask :: RuntimeConfig -> Operand -> Bool +isPointerUntagMask rt (OpImm imm) = + imm `elem` rtUntagMasks rt +isPointerUntagMask _ _ = False -- | Get taint of address base register. addrBaseTaint :: AddrMode -> TaintState -> Taint @@ -991,17 +1042,19 @@ provJoin2 :: Provenance -> Provenance -> Provenance provJoin2 = joinProvenance -- | Join three provenances. -provJoin3 :: Provenance -> Provenance -> Provenance -> Provenance +provJoin3 :: Provenance -> Provenance -> Provenance + -> Provenance provJoin3 a b c = joinProvenance a (joinProvenance b c) -- | Invalidate caller-saved registers after a call. --- Per AArch64 ABI, x0-x17 are caller-saved. --- Clears taint, provenance, and kind. invalidateCallerSaved :: TaintState -> TaintState invalidateCallerSaved st = st - { tsRegs = foldr (\r -> Map.insert r Unknown) (tsRegs st) callerSaved - , tsProv = foldr (\r -> Map.insert r ProvUnknown) (tsProv st) callerSaved - , tsKind = foldr (\r -> Map.insert r KindUnknown) (tsKind st) callerSaved + { tsRegs = foldr (\r -> Map.insert r Unknown) + (tsRegs st) callerSaved + , tsProv = foldr (\r -> Map.insert r ProvUnknown) + (tsProv st) callerSaved + , tsKind = foldr (\r -> Map.insert r KindUnknown) + (tsKind st) callerSaved } where callerSaved = @@ -1011,96 +1064,112 @@ invalidateCallerSaved st = st ] -- | Join two taint states (element-wise join). --- For registers in both, take the join. For registers in only one, keep. --- Stack slots (SP and STG), provenance, kinds, and heap bucket are also joined. joinTaintState :: TaintState -> TaintState -> TaintState joinTaintState a b = TaintState - { tsRegs = Map.unionWith joinTaint (tsRegs a) (tsRegs b) - , tsStack = IM.unionWith joinTaint (tsStack a) (tsStack b) - , tsProv = Map.unionWith joinProvenance (tsProv a) (tsProv b) - , tsStackProv = IM.unionWith joinProvenance (tsStackProv a) (tsStackProv b) - , tsKind = Map.unionWith joinKind (tsKind a) (tsKind b) - , tsStackKind = IM.unionWith joinKind (tsStackKind a) (tsStackKind b) - , tsStgStack = IM.unionWith joinTaint (tsStgStack a) (tsStgStack b) - , tsStgStackProv = IM.unionWith joinProvenance (tsStgStackProv a) - (tsStgStackProv b) - , tsStgStackKind = IM.unionWith joinKind (tsStgStackKind a) (tsStgStackKind b) - , tsHeapTaint = joinTaint (tsHeapTaint a) (tsHeapTaint b) - , tsHeapProv = joinProvenance (tsHeapProv a) (tsHeapProv b) - , tsHeapKind = joinKind (tsHeapKind a) (tsHeapKind b) - , tsHeapSlots = Map.unionWith joinSlot (tsHeapSlots a) (tsHeapSlots b) + { tsRegs = Map.unionWith joinTaint + (tsRegs a) (tsRegs b) + , tsStack = IM.unionWith joinTaint + (tsStack a) (tsStack b) + , tsProv = Map.unionWith joinProvenance + (tsProv a) (tsProv b) + , tsStackProv = IM.unionWith joinProvenance + (tsStackProv a) (tsStackProv b) + , tsKind = Map.unionWith joinKind + (tsKind a) (tsKind b) + , tsStackKind = IM.unionWith joinKind + (tsStackKind a) (tsStackKind b) + , tsStgStack = IM.unionWith joinTaint + (tsStgStack a) (tsStgStack b) + , tsStgStackProv = IM.unionWith joinProvenance + (tsStgStackProv a) (tsStgStackProv b) + , tsStgStackKind = IM.unionWith joinKind + (tsStgStackKind a) (tsStgStackKind b) + , tsHeapTaint = joinTaint + (tsHeapTaint a) (tsHeapTaint b) + , tsHeapProv = joinProvenance + (tsHeapProv a) (tsHeapProv b) + , tsHeapKind = joinKind + (tsHeapKind a) (tsHeapKind b) + , tsHeapSlots = Map.unionWith joinSlot + (tsHeapSlots a) (tsHeapSlots b) , tsAssumeStgPublic = let l = tsAssumeStgPublic a r = tsAssumeStgPublic b in if l == r then l - else error "joinTaintState: mismatched tsAssumeStgPublic flags" + else error + "joinTaintState: mismatched tsAssumeStgPublic" } where joinSlot (t1, p1, k1) (t2, p2, k2) = - (joinTaint t1 t2, joinProvenance p1 p2, joinKind k1 k2) + (joinTaint t1 t2, joinProvenance p1 p2, + joinKind k1 k2) -- | Run forward dataflow analysis over a CFG. --- Returns the IN taint state for each block (indexed by block number). -runDataflow :: CFG -> IntMap TaintState -runDataflow cfg +runDataflow :: RuntimeConfig -> CFG -> IntMap TaintState +runDataflow rt cfg | nBlocks == 0 = IM.empty | otherwise = go initWorklist initIn initOut where nBlocks = cfgBlockCount cfg + baseState = initTaintState rt + emptyState = emptyTaintStateWith + (tsAssumeStgPublic baseState) - -- Initial states: all blocks start with public roots (GHC invariant) - initIn = IM.fromList [(i, initTaintState) | i <- [0..nBlocks-1]] + initIn = IM.fromList + [(i, baseState) | i <- [0..nBlocks-1]] initOut = IM.empty initWorklist = IS.fromList [0..nBlocks-1] - go :: IntSet -> IntMap TaintState -> IntMap TaintState -> IntMap TaintState + go :: IntSet -> IntMap TaintState + -> IntMap TaintState -> IntMap TaintState go worklist inStates outStates | IS.null worklist = inStates | otherwise = let (idx, worklist') = IS.deleteFindMin worklist bb = indexBlock cfg idx - inState = IM.findWithDefault initTaintState idx inStates - outState = analyzeBlock (bbLines bb) inState + inState = IM.findWithDefault baseState + idx inStates + outState = analyzeBlock rt (bbLines bb) + inState oldOut = IM.lookup idx outStates changed = oldOut /= Just outState - outStates' = IM.insert idx outState outStates + outStates' = IM.insert idx outState + outStates succs = blockSuccessors cfg idx (worklist'', inStates') = if changed - then propagateToSuccs succs outState worklist' inStates + then propagateToSuccs succs outState + worklist' inStates else (worklist', inStates) in go worklist'' inStates' outStates' - propagateToSuccs :: [Int] -> TaintState -> IntSet -> IntMap TaintState - -> (IntSet, IntMap TaintState) propagateToSuccs [] _ wl ins = (wl, ins) propagateToSuccs (s:ss) out wl ins = - let oldIn = IM.findWithDefault emptyTaintState s ins + let oldIn = IM.findWithDefault emptyState s ins newIn = joinTaintState oldIn out changed = oldIn /= newIn ins' = IM.insert s newIn ins wl' = if changed then IS.insert s wl else wl in propagateToSuccs ss out wl' ins' --- --------------------------------------------------------------------- +-- ------------------------------------------------------------- -- Function summaries for inter-procedural analysis --- | Function summary: taint state of caller-saved registers at return. --- This captures what a function leaves in registers when it returns. -newtype FuncSummary = FuncSummary { summaryState :: TaintState } +-- | Function summary: taint state at return. +newtype FuncSummary = FuncSummary + { summaryState :: TaintState } deriving (Eq, Show) --- | Initial conservative summary: all caller-saved are Unknown. -initSummary :: FuncSummary -initSummary = initSummaryWith True - --- | Initial summary with configurable STG assumption. +-- | Initial summary with configurable secondary stack assumption. initSummaryWith :: Bool -> FuncSummary initSummaryWith assumeStgPublic = FuncSummary $ TaintState - { tsRegs = Map.fromList [ (r, Unknown) | r <- callerSavedRegs ] + { tsRegs = Map.fromList + [(r, Unknown) | r <- callerSavedRegs] , tsStack = IM.empty - , tsProv = Map.fromList [ (r, ProvUnknown) | r <- callerSavedRegs ] + , tsProv = Map.fromList + [(r, ProvUnknown) | r <- callerSavedRegs] , tsStackProv = IM.empty - , tsKind = Map.fromList [ (r, KindUnknown) | r <- callerSavedRegs ] + , tsKind = Map.fromList + [(r, KindUnknown) | r <- callerSavedRegs] , tsStackKind = IM.empty , tsStgStack = IM.empty , tsStgStackProv = IM.empty @@ -1112,6 +1181,14 @@ initSummaryWith assumeStgPublic = FuncSummary $ TaintState , tsAssumeStgPublic = assumeStgPublic } +-- | Initial conservative summary. +initSummary :: RuntimeConfig -> FuncSummary +initSummary rt = initSummaryWith assumeStg + where + assumeStg = case rtSecondaryStack rt of + Nothing -> False + Just ss -> ssAssumePublic ss + -- | Caller-saved registers per AArch64 ABI. callerSavedRegs :: [Reg] callerSavedRegs = @@ -1120,13 +1197,12 @@ callerSavedRegs = , X16, X17 ] --- | Join two summaries (element-wise join of taint states). +-- | Join two summaries. joinSummary :: FuncSummary -> FuncSummary -> FuncSummary joinSummary (FuncSummary a) (FuncSummary b) = FuncSummary (joinTaintState a b) -- | Apply a function summary at a call site. --- Replaces caller-saved register taints, provenance, and kinds with summary. applySummary :: FuncSummary -> TaintState -> TaintState applySummary (FuncSummary summ) st = st { tsRegs = foldr applyReg (tsRegs st) callerSavedRegs @@ -1137,28 +1213,37 @@ applySummary (FuncSummary summ) st = st summRegs = tsRegs summ summProv = tsProv summ summKind = tsKind summ - applyReg r s = Map.insert r (Map.findWithDefault Unknown r summRegs) s - applyProv r s = Map.insert r (Map.findWithDefault ProvUnknown r summProv) s - applyKind r s = Map.insert r (Map.findWithDefault KindUnknown r summKind) s - --- | Run dataflow analysis for a single function (subset of blocks). --- Returns the OUT state at return points (joined). -runFunctionDataflow :: CFG -> [Int] -> Map Text FuncSummary -> TaintState -runFunctionDataflow cfg blockIndices summaries = - let -- Run dataflow on just these blocks - inStates = runFunctionBlocks cfg blockIndices summaries - -- Collect OUT states at return instructions - returnOuts = [ analyzeBlockWithSummaries bb inState summaries - | i <- blockIndices - , let bb = indexBlock cfg i - inState = IM.findWithDefault initTaintState i inStates - , endsWithRet bb - ] + applyReg r s = + Map.insert r + (Map.findWithDefault Unknown r summRegs) s + applyProv r s = + Map.insert r + (Map.findWithDefault ProvUnknown r summProv) s + applyKind r s = + Map.insert r + (Map.findWithDefault KindUnknown r summKind) s + +-- | Run dataflow for a single function. +runFunctionDataflow :: RuntimeConfig -> CFG -> [Int] + -> Map Text FuncSummary -> TaintState +runFunctionDataflow rt cfg blockIndices summaries = + let inStates = runFunctionBlocks rt cfg blockIndices + summaries + baseState = initTaintState rt + returnOuts = + [ analyzeBlockWithSummaries rt bb inState + summaries + | i <- blockIndices + , let bb = indexBlock cfg i + inState = IM.findWithDefault baseState i + inStates + , endsWithRet bb + ] in case returnOuts of - [] -> initTaintState -- No return found, use init + [] -> baseState (o:os) -> foldl' joinTaintState o os --- | Check if block ends with a return instruction. +-- | Check if block ends with a return. endsWithRet :: BasicBlock -> Bool endsWithRet bb = case bbLines bb of [] -> False @@ -1167,14 +1252,19 @@ endsWithRet bb = case bbLines bb of _ -> False -- | Run dataflow on a subset of blocks (one function). -runFunctionBlocks :: CFG -> [Int] -> Map Text FuncSummary +runFunctionBlocks :: RuntimeConfig -> CFG -> [Int] + -> Map Text FuncSummary -> IntMap TaintState -runFunctionBlocks _ [] _ = IM.empty -runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empty +runFunctionBlocks _ _ [] _ = IM.empty +runFunctionBlocks rt cfg (entryIdx:rest) summaries = + go initWorklist initIn IM.empty where blockSet = IS.fromList (entryIdx:rest) + baseState = initTaintState rt + emptyState = emptyTaintStateWith + (tsAssumeStgPublic baseState) - initIn = IM.singleton entryIdx initTaintState + initIn = IM.singleton entryIdx baseState initWorklist = IS.singleton entryIdx go wl inStates outStates @@ -1182,13 +1272,16 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt | otherwise = let (idx, wl') = IS.deleteFindMin wl bb = indexBlock cfg idx - inState = IM.findWithDefault initTaintState idx inStates - outState = analyzeBlockWithSummaries bb inState summaries + inState = IM.findWithDefault baseState + idx inStates + outState = analyzeBlockWithSummaries rt bb + inState summaries oldOut = IM.lookup idx outStates changed = oldOut /= Just outState - outStates' = IM.insert idx outState outStates - -- Only propagate to successors within this function - succs = filter (`IS.member` blockSet) (blockSuccessors cfg idx) + outStates' = IM.insert idx outState + outStates + succs = filter (`IS.member` blockSet) + (blockSuccessors cfg idx) (wl'', inStates') = if changed then propagate succs outState wl' inStates else (wl', inStates) @@ -1196,92 +1289,96 @@ runFunctionBlocks cfg (entryIdx:rest) summaries = go initWorklist initIn IM.empt propagate [] _ wl ins = (wl, ins) propagate (s:ss) out wl ins = - let oldIn = IM.findWithDefault emptyTaintState s ins + let oldIn = IM.findWithDefault emptyState s ins newIn = joinTaintState oldIn out ins' = IM.insert s newIn ins - wl' = if oldIn /= newIn then IS.insert s wl else wl + wl' = if oldIn /= newIn + then IS.insert s wl else wl in propagate ss out wl' ins' --- | Analyze a block, applying call summaries at bl instructions. -analyzeBlockWithSummaries :: BasicBlock -> TaintState -> Map Text FuncSummary +-- | Analyze a block with call summaries. +analyzeBlockWithSummaries :: RuntimeConfig -> BasicBlock -> TaintState -analyzeBlockWithSummaries bb st0 summaries = foldl' go st0 (bbLines bb) + -> Map Text FuncSummary + -> TaintState +analyzeBlockWithSummaries rt bb st0 summaries = + foldl' go st0 (bbLines bb) where go st l = case lineInstr l of Nothing -> st - Just instr -> transferWithSummary instr st summaries + Just instr -> + transferWithSummary rt instr st summaries --- | Transfer with summary application for calls and tail calls. --- For in-file tail calls, state flows via CFG edges - no summary application. --- For external calls/jumps, we preserve callee-saved registers per ABI. -transferWithSummary :: Instr -> TaintState -> Map Text FuncSummary -> TaintState -transferWithSummary instr st summaries = case instr of +-- | Transfer with summary application for calls. +transferWithSummary :: RuntimeConfig -> Instr -> TaintState + -> Map Text FuncSummary -> TaintState +transferWithSummary rt instr st summaries = case instr of Bl target -> case Map.lookup target summaries of Just summ -> applySummary summ st Nothing -> invalidateCallerSaved st Blr _ -> invalidateCallerSaved st - -- Tail calls: b to function label B target - | isFunctionLabel target -> + | isFunctionLabel rt target -> case Map.lookup target summaries of - Just _ -> st -- In-file: CFG edge propagates state - Nothing -> invalidateCallerSaved st -- External: ABI preserves X19-X28 - -- Indirect jumps (STG closure evaluation): ABI preserves callee-saved + Just _ -> st -- In-file: CFG edge propagates + Nothing -> invalidateCallerSaved st Br _ -> invalidateCallerSaved st - _ -> transfer instr st + _ -> transfer rt instr st -- | Run inter-procedural fixpoint analysis. --- Returns summaries for all functions. --- Precomputes caches for function blocks and callers. -runInterProc :: CFG -> Map Text FuncSummary -runInterProc cfg = go initSummaries (Set.fromList funcs) +runInterProc :: RuntimeConfig -> CFG + -> Map Text FuncSummary +runInterProc rt cfg = go initSummaries (Set.fromList funcs) where funcs = functionLabels cfg - initSummaries = Map.fromList [(f, initSummary) | f <- funcs] + baseSummary = initSummary rt + initSummaries = Map.fromList + [(f, baseSummary) | f <- funcs] - -- Precompute caches once funcBlocksCache = buildFunctionBlocksMap cfg - callerMapCache = buildCallerMap cfg funcBlocksCache + callerMapCache = buildCallerMap rt cfg funcBlocksCache - lookupBlocks f = Map.findWithDefault [] f funcBlocksCache - lookupCallers f = Map.findWithDefault [] f callerMapCache + lookupBlocks f = + Map.findWithDefault [] f funcBlocksCache + lookupCallers f = + Map.findWithDefault [] f callerMapCache go summaries worklist | Set.null worklist = summaries | otherwise = - let (func, worklist') = Set.deleteFindMin worklist + let (func, worklist') = + Set.deleteFindMin worklist blockIdxs = lookupBlocks func - outState = runFunctionDataflow cfg blockIdxs summaries + outState = runFunctionDataflow rt cfg + blockIdxs summaries newSumm = FuncSummary outState - oldSumm = Map.findWithDefault initSummary func summaries + oldSumm = Map.findWithDefault baseSummary + func summaries changed = newSumm /= oldSumm - summaries' = Map.insert func newSumm summaries - -- If changed, re-analyze callers + summaries' = Map.insert func newSumm + summaries callers = lookupCallers func worklist'' = if changed - then foldr Set.insert worklist' callers - else worklist' + then foldr Set.insert worklist' callers + else worklist' in go summaries' worklist'' --- --------------------------------------------------------------------- +-- ------------------------------------------------------------- -- Config-aware dataflow analysis -- | Run forward dataflow with taint config. --- Entry states are seeded according to function-specific policies. --- Uses tcAssumeStgPublic from config for STG stack assumption. -runDataflowWithConfig :: TaintConfig -> CFG -> IntMap TaintState -runDataflowWithConfig tcfg cfg +runDataflowWithConfig :: RuntimeConfig -> TaintConfig + -> CFG -> IntMap TaintState +runDataflowWithConfig rt tcfg cfg | nBlocks == 0 = IM.empty | otherwise = go initWorklist initIn IM.empty where nBlocks = cfgBlockCount cfg assumeStg = tcAssumeStgPublic tcfg - baseState = initTaintStateWith assumeStg + baseState = initTaintStateWith rt assumeStg emptyState = emptyTaintStateWith assumeStg - -- Build a map from block index to entry taint state - -- Entry blocks of functions get their policy applied initIn = IM.fromList [ (i, entryState i (indexBlock cfg i)) | i <- [0..nBlocks-1] @@ -1291,12 +1388,12 @@ runDataflowWithConfig tcfg cfg case bbLabel bb of Nothing -> baseState Just lbl -> - -- Check if this block is a function entry case functionBlocks cfg lbl of (entry:_) | entry == idx -> case Map.lookup lbl (tcPolicies tcfg) of Nothing -> baseState - Just policy -> seedArgs policy baseState + Just policy -> + seedArgs rt policy baseState _ -> baseState initWorklist = IS.fromList [0..nBlocks-1] @@ -1306,14 +1403,18 @@ runDataflowWithConfig tcfg cfg | otherwise = let (idx, worklist') = IS.deleteFindMin worklist bb = indexBlock cfg idx - inState = IM.findWithDefault baseState idx inStates - outState = analyzeBlock (bbLines bb) inState + inState = IM.findWithDefault baseState + idx inStates + outState = analyzeBlock rt (bbLines bb) + inState oldOut = IM.lookup idx outStates changed = oldOut /= Just outState - outStates' = IM.insert idx outState outStates + outStates' = IM.insert idx outState + outStates succs = blockSuccessors cfg idx (worklist'', inStates') = if changed - then propagate succs outState worklist' inStates + then propagate succs outState + worklist' inStates else (worklist', inStates) in go worklist'' inStates' outStates' @@ -1322,26 +1423,23 @@ runDataflowWithConfig tcfg cfg let oldIn = IM.findWithDefault emptyState s ins newIn = joinTaintState oldIn out ins' = IM.insert s newIn ins - wl' = if oldIn /= newIn then IS.insert s wl else wl + wl' = if oldIn /= newIn + then IS.insert s wl else wl in propagate ss out wl' ins' --- | Run forward dataflow with taint config and function summaries. --- Combines config-based entry seeding with summary application at calls. --- This enables whole-file dataflow that respects both taint configs and --- inter-procedural summaries (including tail call propagation). -runDataflowWithConfigAndSummaries :: TaintConfig -> Map Text FuncSummary - -> CFG -> IntMap TaintState -runDataflowWithConfigAndSummaries tcfg summaries cfg +-- | Run forward dataflow with taint config and summaries. +runDataflowWithConfigAndSummaries + :: RuntimeConfig -> TaintConfig + -> Map Text FuncSummary -> CFG -> IntMap TaintState +runDataflowWithConfigAndSummaries rt tcfg summaries cfg | nBlocks == 0 = IM.empty | otherwise = go initWorklist initIn IM.empty where nBlocks = cfgBlockCount cfg assumeStg = tcAssumeStgPublic tcfg - baseState = initTaintStateWith assumeStg + baseState = initTaintStateWith rt assumeStg emptyState = emptyTaintStateWith assumeStg - -- Build a map from block index to entry taint state - -- Entry blocks of functions get their policy applied initIn = IM.fromList [ (i, entryState i (indexBlock cfg i)) | i <- [0..nBlocks-1] @@ -1351,12 +1449,12 @@ runDataflowWithConfigAndSummaries tcfg summaries cfg case bbLabel bb of Nothing -> baseState Just lbl -> - -- Check if this block is a function entry case functionBlocks cfg lbl of (entry:_) | entry == idx -> case Map.lookup lbl (tcPolicies tcfg) of Nothing -> baseState - Just policy -> seedArgs policy baseState + Just policy -> + seedArgs rt policy baseState _ -> baseState initWorklist = IS.fromList [0..nBlocks-1] @@ -1366,15 +1464,18 @@ runDataflowWithConfigAndSummaries tcfg summaries cfg | otherwise = let (idx, worklist') = IS.deleteFindMin worklist bb = indexBlock cfg idx - inState = IM.findWithDefault baseState idx inStates - -- Use analyzeBlockWithSummaries to apply summaries at calls - outState = analyzeBlockWithSummaries bb inState summaries + inState = IM.findWithDefault baseState + idx inStates + outState = analyzeBlockWithSummaries rt bb + inState summaries oldOut = IM.lookup idx outStates changed = oldOut /= Just outState - outStates' = IM.insert idx outState outStates + outStates' = IM.insert idx outState + outStates succs = blockSuccessors cfg idx (worklist'', inStates') = if changed - then propagateS succs outState worklist' inStates + then propagateS succs outState + worklist' inStates else (worklist', inStates) in go worklist'' inStates' outStates' @@ -1383,61 +1484,73 @@ runDataflowWithConfigAndSummaries tcfg summaries cfg let oldIn = IM.findWithDefault emptyState s ins newIn = joinTaintState oldIn out ins' = IM.insert s newIn ins - wl' = if oldIn /= newIn then IS.insert s wl else wl + wl' = if oldIn /= newIn + then IS.insert s wl else wl in propagateS ss out wl' ins' -- | Run inter-procedural analysis with taint config. --- Precomputes caches for function blocks and callers. --- Uses tcAssumeStgPublic from config for STG stack assumption. -runInterProcWithConfig :: TaintConfig -> CFG -> Map Text FuncSummary -runInterProcWithConfig tcfg cfg = go initSummaries (Set.fromList funcs) +runInterProcWithConfig :: RuntimeConfig -> TaintConfig + -> CFG -> Map Text FuncSummary +runInterProcWithConfig rt tcfg cfg = + go initSummaries (Set.fromList funcs) where funcs = functionLabels cfg assumeStg = tcAssumeStgPublic tcfg baseSummary = initSummaryWith assumeStg - initSummaries = Map.fromList [(f, baseSummary) | f <- funcs] + initSummaries = Map.fromList + [(f, baseSummary) | f <- funcs] - -- Precompute caches once funcBlocksCache = buildFunctionBlocksMap cfg - callerMapCache = buildCallerMap cfg funcBlocksCache + callerMapCache = + buildCallerMap rt cfg funcBlocksCache - lookupBlocks f = Map.findWithDefault [] f funcBlocksCache - lookupCallers f = Map.findWithDefault [] f callerMapCache + lookupBlocks f = + Map.findWithDefault [] f funcBlocksCache + lookupCallers f = + Map.findWithDefault [] f callerMapCache go summaries worklist | Set.null worklist = summaries | otherwise = - let (func, worklist') = Set.deleteFindMin worklist + let (func, worklist') = + Set.deleteFindMin worklist blockIdxs = lookupBlocks func - outState = runFunctionDataflowWithConfig tcfg cfg func - blockIdxs summaries + outState = + runFunctionDataflowWithConfig rt tcfg cfg + func blockIdxs summaries newSumm = FuncSummary outState - oldSumm = Map.findWithDefault baseSummary func summaries + oldSumm = Map.findWithDefault baseSummary + func summaries changed = newSumm /= oldSumm - summaries' = Map.insert func newSumm summaries + summaries' = Map.insert func newSumm + summaries callers = lookupCallers func worklist'' = if changed - then foldr Set.insert worklist' callers - else worklist' + then foldr Set.insert worklist' callers + else worklist' in go summaries' worklist'' -- | Run function dataflow with config-based entry seeding. --- Uses tcAssumeStgPublic from config for STG stack assumption. -runFunctionDataflowWithConfig :: TaintConfig -> CFG -> Text -> [Int] - -> Map Text FuncSummary -> TaintState -runFunctionDataflowWithConfig tcfg cfg funcName blockIndices summaries = +runFunctionDataflowWithConfig + :: RuntimeConfig -> TaintConfig -> CFG -> Text -> [Int] + -> Map Text FuncSummary -> TaintState +runFunctionDataflowWithConfig rt tcfg cfg funcName + blockIndices summaries = let assumeStg = tcAssumeStgPublic tcfg - -- Seed entry state with function policy - baseEntry = initTaintStateWith assumeStg - entryState = case Map.lookup funcName (tcPolicies tcfg) of - Nothing -> baseEntry - Just policy -> seedArgs policy baseEntry - inStates = runFunctionBlocksWithEntry cfg blockIndices summaries entryState + baseEntry = initTaintStateWith rt assumeStg + entryState = + case Map.lookup funcName (tcPolicies tcfg) of + Nothing -> baseEntry + Just policy -> seedArgs rt policy baseEntry + inStates = runFunctionBlocksWithEntry rt cfg + blockIndices summaries entryState returnOuts = - [ analyzeBlockWithSummaries bb inState summaries + [ analyzeBlockWithSummaries rt bb inState + summaries | i <- blockIndices , let bb = indexBlock cfg i - inState = IM.findWithDefault entryState i inStates + inState = IM.findWithDefault entryState i + inStates , endsWithRet bb ] in case returnOuts of @@ -1445,15 +1558,18 @@ runFunctionDataflowWithConfig tcfg cfg funcName blockIndices summaries = (o:os) -> foldl' joinTaintState o os -- | Run function blocks with a custom entry state. --- Inherits STG assumption from entryState. -runFunctionBlocksWithEntry :: CFG -> [Int] -> Map Text FuncSummary - -> TaintState -> IntMap TaintState -runFunctionBlocksWithEntry _ [] _ _ = IM.empty -runFunctionBlocksWithEntry cfg (entryIdx:rest) summaries entryState = +runFunctionBlocksWithEntry :: RuntimeConfig -> CFG -> [Int] + -> Map Text FuncSummary + -> TaintState + -> IntMap TaintState +runFunctionBlocksWithEntry _ _ [] _ _ = IM.empty +runFunctionBlocksWithEntry rt cfg (entryIdx:rest) + summaries entryState = go initWorklist initIn IM.empty where blockSet = IS.fromList (entryIdx:rest) - emptyState = emptyTaintStateWith (tsAssumeStgPublic entryState) + emptyState = emptyTaintStateWith + (tsAssumeStgPublic entryState) initIn = IM.singleton entryIdx entryState initWorklist = IS.singleton entryIdx @@ -1463,12 +1579,16 @@ runFunctionBlocksWithEntry cfg (entryIdx:rest) summaries entryState = | otherwise = let (idx, wl') = IS.deleteFindMin wl bb = indexBlock cfg idx - inState = IM.findWithDefault entryState idx inStates - outState = analyzeBlockWithSummaries bb inState summaries + inState = IM.findWithDefault entryState + idx inStates + outState = analyzeBlockWithSummaries rt bb + inState summaries oldOut = IM.lookup idx outStates changed = oldOut /= Just outState - outStates' = IM.insert idx outState outStates - succs = filter (`IS.member` blockSet) (blockSuccessors cfg idx) + outStates' = IM.insert idx outState + outStates + succs = filter (`IS.member` blockSet) + (blockSuccessors cfg idx) (wl'', inStates') = if changed then propagate succs outState wl' inStates else (wl', inStates) @@ -1479,5 +1599,6 @@ runFunctionBlocksWithEntry cfg (entryIdx:rest) summaries entryState = let oldIn = IM.findWithDefault emptyState s ins newIn = joinTaintState oldIn out ins' = IM.insert s newIn ins - wl' = if oldIn /= newIn then IS.insert s wl else wl + wl' = if oldIn /= newIn + then IS.insert s wl else wl in propagate ss out wl' ins' diff --git a/lib/Audit/AArch64/Types.hs b/lib/Audit/AArch64/Types.hs @@ -43,6 +43,11 @@ module Audit.AArch64.Types ( , Violation(..) , ViolationReason(..) + -- * Non-constant-time findings + , NctReason(..) + , NctFinding(..) + , LineMap + -- * Taint configuration , TaintConfig(..) , defaultTaintConfig @@ -52,6 +57,7 @@ module Audit.AArch64.Types ( import Control.DeepSeq (NFData) import Data.Aeson (ToJSON(..), FromJSON(..), (.=), (.:?), object, withObject) +import Data.IntMap.Strict (IntMap) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map import Data.Set (Set) @@ -404,6 +410,24 @@ instance ToJSON Violation where , "reason" .= vReason v ] +-- | Reason for flagging an instruction as non-constant-time. +data NctReason + = CondBranch -- ^ Conditional branch + | IndirectBranch -- ^ Indirect branch (br, blr) + | Div -- ^ Division (udiv, sdiv) + | RegIndexAddr -- ^ Register-indexed memory access + deriving (Eq, Ord, Show, Generic, NFData) + +-- | A non-constant-time finding. +data NctFinding = NctFinding + { nctLine :: !Int -- ^ Source line number + , nctInstr :: !Instr -- ^ The flagged instruction + , nctReason :: !NctReason -- ^ Why it was flagged + } deriving (Eq, Show, Generic, NFData) + +-- | Line number to Line map for efficient lookup. +type LineMap = IntMap Line + -- | Per-function argument taint policy. -- Specifies which registers and STG stack slots should be seeded. data ArgPolicy = ArgPolicy diff --git a/ppad-auditor.cabal b/ppad-auditor.cabal @@ -29,6 +29,8 @@ library Audit.AArch64.Types Audit.AArch64.Parser Audit.AArch64.CFG + Audit.AArch64.Runtime + Audit.AArch64.Runtime.GHC Audit.AArch64.Taint Audit.AArch64.Check Audit.AArch64.NCT diff --git a/test/Main.hs b/test/Main.hs @@ -5,6 +5,7 @@ module Main where import Audit.AArch64 import Audit.AArch64.CallGraph import Audit.AArch64.Parser +import Audit.AArch64.Runtime.GHC (ghcRuntime) import Audit.AArch64.Taint import Audit.AArch64.Types import Data.Aeson (eitherDecodeStrict') @@ -15,14 +16,19 @@ import qualified Data.Text as T import Test.Tasty import Test.Tasty.HUnit --- | Check if a violation is secret-derived (not unknown or structural). +-- | Runtime config used throughout tests. +rt :: RuntimeConfig +rt = ghcRuntime + +-- | Check if a violation is secret-derived +-- (not unknown or structural). isSecretViolation :: ViolationReason -> Bool isSecretViolation r = case r of SecretBase _ -> True SecretIndex _ -> True UnknownBase _ -> False UnknownIndex _ -> False - NonConstOffset -> False -- Structural violation, not secret-derived + NonConstOffset -> False main :: IO () main = defaultMain $ testGroup "ppad-auditor" [ @@ -47,7 +53,8 @@ parserTests = testGroup "Parser" [ Left _ -> assertFailure "parse failed" Right lns -> do assertEqual "line count" 1 (length lns) - assertEqual "label" (Just "foo") (lineLabel (head lns)) + assertEqual "label" (Just "foo") + (lineLabel (head lns)) , testCase "parse ldr base+imm" $ do let src = "ldr x0, [x20, #8]\n" @@ -55,7 +62,8 @@ parserTests = testGroup "Parser" [ Left _ -> assertFailure "parse failed" Right lns -> case lineInstr (head lns) of Just (Ldr X0 (BaseImm X20 8)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse str base+reg" $ do let src = "str x1, [x21, x2]\n" @@ -63,7 +71,8 @@ parserTests = testGroup "Parser" [ Left _ -> assertFailure "parse failed" Right lns -> case lineInstr (head lns) of Just (Str X1 (BaseReg X21 X2)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse add" $ do let src = "add x0, x1, x2\n" @@ -71,7 +80,8 @@ parserTests = testGroup "Parser" [ Left _ -> assertFailure "parse failed" Right lns -> case lineInstr (head lns) of Just (Add X0 X1 (OpReg X2)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse adrp" $ do let src = "adrp x5, _symbol@PAGE\n" @@ -79,48 +89,67 @@ parserTests = testGroup "Parser" [ Left _ -> assertFailure "parse failed" Right lns -> case lineInstr (head lns) of Just (Adrp X5 _) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + other -> assertFailure $ + "unexpected: " ++ show other , testCase "skip directives" $ do - let src = ".section __TEXT\n.globl _foo\n_foo:\n ret\n" + let src = ".section __TEXT\n\ + \.globl _foo\n_foo:\n ret\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - -- Should have 4 lines, first two are directives (no instr) - let instrs = filter ((/= Nothing) . lineInstr) lns - assertEqual "instruction count" 1 (length instrs) + let instrs = filter + ((/= Nothing) . lineInstr) lns + assertEqual "instruction count" + 1 (length instrs) , testCase "parse ldr base+symbol" $ do let src = "ldr x8, [x8, _foo@GOTPAGEOFF]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Ldr X8 (BaseSymbol X8 "_foo@GOTPAGEOFF")) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Ldr X8 + (BaseSymbol X8 "_foo@GOTPAGEOFF")) + -> pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse shifted register operand" $ do let src = "add x0, x1, x2, lsl #3\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Add X0 X1 (OpShiftedReg X2 (LSL 3))) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Add X0 X1 + (OpShiftedReg X2 (LSL 3))) -> pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse plain register operand" $ do let src = "add x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Add X0 X1 (OpReg X2)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Add X0 X1 (OpReg X2)) -> pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse literal address" $ do let src = "ldr x0, =foo\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Ldr X0 (Literal "foo")) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Ldr X0 (Literal "foo")) -> pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse all register types" $ do let src = T.unlines @@ -136,65 +165,94 @@ parserTests = testGroup "Parser" [ , "mov x0, lr" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> assertEqual "parsed all lines" 10 (length lns) + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + assertEqual "parsed all lines" + 10 (length lns) , testCase "parse label with instruction" $ do let src = "foo: mov x0, x1\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> case safeHead lns of - Line _ (Just "foo") (Just (Mov X0 (OpReg X1))) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Line _ (Just "foo") + (Just (Mov X0 (OpReg X1))) -> pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse instruction without label" $ do let src = " mov x0, x1\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> case safeHead lns of - Line _ Nothing (Just (Mov X0 (OpReg X1))) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Line _ Nothing + (Just (Mov X0 (OpReg X1))) -> pure () + other -> assertFailure $ + "unexpected: " ++ show other - -- Acquire/release and exclusive memory ops (IMPL20) + -- Acquire/release and exclusive memory ops , testCase "parse ldar" $ do let src = "ldar x0, [x1]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Ldar X0 (BaseImm X1 0)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Ldar X0 (BaseImm X1 0)) -> + pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse stlr" $ do let src = "stlr x0, [x1]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Stlr X0 (BaseImm X1 0)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Stlr X0 (BaseImm X1 0)) -> + pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse stxr" $ do let src = "stxr w0, x1, [x2]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Stxr W0 X1 (BaseImm X2 0)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Stxr W0 X1 (BaseImm X2 0)) -> + pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse ldaxr" $ do let src = "ldaxr x0, [x1]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Ldaxr X0 (BaseImm X1 0)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Ldaxr X0 (BaseImm X1 0)) -> + pure () + other -> assertFailure $ + "unexpected: " ++ show other , testCase "parse stlxr" $ do let src = "stlxr w0, x1, [x2]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right lns -> case lineInstr (safeHead lns) of - Just (Stlxr W0 X1 (BaseImm X2 0)) -> pure () - other -> assertFailure $ "unexpected: " ++ show other + Left e -> assertFailure $ + "parse failed: " ++ show e + Right lns -> + case lineInstr (safeHead lns) of + Just (Stlxr W0 X1 (BaseImm X2 0)) -> + pure () + other -> assertFailure $ + "unexpected: " ++ show other ] safeHead :: [a] -> a @@ -206,34 +264,42 @@ safeHead [] = error "safeHead: empty list" taintTests :: TestTree taintTests = testGroup "Taint" [ testCase "public roots" $ do - let st = initTaintState - assertEqual "X20 is public" Public (getTaint X20 st) - assertEqual "X21 is public" Public (getTaint X21 st) - assertEqual "SP is public" Public (getTaint SP st) + let st = initTaintState rt + assertEqual "X20 is public" + Public (getTaint X20 st) + assertEqual "X21 is public" + Public (getTaint X21 st) + assertEqual "SP is public" + Public (getTaint SP st) , testCase "unknown register" $ do - let st = initTaintState - assertEqual "X0 is unknown" Unknown (getTaint X0 st) + let st = initTaintState rt + assertEqual "X0 is unknown" + Unknown (getTaint X0 st) , testCase "mov propagates taint" $ do - let st = initTaintState - l = Line 1 Nothing (Just (Mov X0 (OpReg X20))) - st' = analyzeLine l st - assertEqual "X0 becomes public" Public (getTaint X0 st') + let st = initTaintState rt + l = Line 1 Nothing + (Just (Mov X0 (OpReg X20))) + st' = analyzeLine rt l st + assertEqual "X0 becomes public" + Public (getTaint X0 st') , testCase "add joins taints" $ do - let st = initTaintState - -- X0 = X20 (public) + X1 (unknown) -> unknown - l = Line 1 Nothing (Just (Add X0 X20 (OpReg X1))) - st' = analyzeLine l st - assertEqual "result is unknown" Unknown (getTaint X0 st') + let st = initTaintState rt + l = Line 1 Nothing + (Just (Add X0 X20 (OpReg X1))) + st' = analyzeLine rt l st + assertEqual "result is unknown" + Unknown (getTaint X0 st') , testCase "load makes unknown" $ do - -- Use initTaintStateWith False to test without STG-public assumption - let st = initTaintStateWith False - l = Line 1 Nothing (Just (Ldr X0 (BaseImm X20 0))) - st' = analyzeLine l st - assertEqual "loaded value is unknown" Unknown (getTaint X0 st') + let st = initTaintStateWith rt False + l = Line 1 Nothing + (Just (Ldr X0 (BaseImm X20 0))) + st' = analyzeLine rt l st + assertEqual "loaded value is unknown" + Unknown (getTaint X0 st') ] -- Audit tests @@ -247,9 +313,11 @@ auditTests = testGroup "Audit" [ , " str x0, [x21]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: unknown base" $ do let src = T.unlines @@ -257,32 +325,40 @@ auditTests = testGroup "Audit" [ , " ldr x0, [x0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X0 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "bad: secret-derived index" $ do let src = T.unlines [ "foo:" - , " ldr x0, [x21]" -- x0 = unknown (loaded from heap) - , " ldr x1, [x20, x0]" -- x0 as index -> violation + , " ldr x0, [x21]" + , " ldr x1, [x20, x0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownIndex X0 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "good: adrp+add pattern" $ do let src = T.unlines @@ -292,417 +368,464 @@ auditTests = testGroup "Audit" [ , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: cross-block public taint" $ do - -- x8 is made public in block A, used in block B let src = T.unlines [ "blockA:" , " adrp x8, _const@PAGE" , " cbz x0, blockB" , "blockB:" - , " ldr x1, [x8]" -- x8 should be known public from block A + , " ldr x1, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: cross-block unknown taint" $ do - -- x8 is loaded from heap (unknown) in block A, used as base in block B let src = T.unlines [ "blockA:" - , " ldr x8, [x21]" -- x8 becomes unknown (from heap) + , " ldr x8, [x21]" , " cbz x0, blockB" , "blockB:" - , " ldr x1, [x8]" -- x8 as base should trigger violation + , " ldr x1, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "call: no taint propagation to callee" $ do - -- Taint set before bl should NOT flow into the callee block - -- The callee starts fresh with public roots only let src = T.unlines [ "caller:" - , " adrp x8, _const@PAGE" -- x8 = public + , " adrp x8, _const@PAGE" , " bl callee" , " ret" , "callee:" - , " ldr x0, [x8]" -- x8 unknown here (fresh block) + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - -- Should have 1 violation: x8 unknown in callee - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> do - assertEqual "violation in callee" "callee" (vSymbol v) + assertEqual "violation in callee" + "callee" (vSymbol v) case vReason v of UnknownBase X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "call: caller-saved invalidation" $ do - -- x0 is public before call, unknown after (caller-saved) let src = T.unlines [ "foo:" - , " adrp x0, _const@PAGE" -- x0 = public + , " adrp x0, _const@PAGE" , " bl bar" - , " ldr x1, [x0]" -- x0 unknown after call (caller-saved) + , " ldr x1, [x0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X0 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "call: callee-saved preserved" $ do - -- x19 is callee-saved, stays public across call let src = T.unlines [ "foo:" , " bl bar" - , " ldr x0, [x19]" -- x19 public (callee-saved) + , " ldr x0, [x19]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: stack spill/reload preserves public" $ do - -- Public value stored to stack, then reloaded, should stay public let src = T.unlines [ "foo:" - , " stp x20, x21, [sp]" -- Store public roots to stack - , " ldr x8, [sp]" -- Reload x20's value into x8 - , " ldr x9, [sp, #8]" -- Reload x21's value into x9 - , " ldr x0, [x8]" -- Use x8 as base (should be public) - , " ldr x1, [x9]" -- Use x9 as base (should be public) + , " stp x20, x21, [sp]" + , " ldr x8, [sp]" + , " ldr x9, [sp, #8]" + , " ldr x0, [x8]" + , " ldr x1, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: stack spill/reload preserves unknown" $ do - -- Unknown value stored to stack, then reloaded, should stay unknown let src = T.unlines [ "foo:" - , " ldr x8, [x21]" -- x8 = unknown (loaded from heap) - , " str x8, [sp, #16]" -- Store unknown to stack slot - , " ldr x9, [sp, #16]" -- Reload the unknown value - , " ldr x0, [x9]" -- Use as base - violation! + , " ldr x8, [x21]" + , " str x8, [sp, #16]" + , " ldr x9, [sp, #16]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "good: pointer spill/reload preserves kind" $ do - -- A pointer (from adrp) spilled and reloaded should stay KindPtr - -- and be usable as a base via provenance upgrade let src = T.unlines [ "foo:" - , " adrp x8, _const@PAGE" -- x8 = Public, ProvPublic, KindPtr - , " str x8, [sp, #16]" -- Spill pointer to stack - , " ldr x9, [sp, #16]" -- Reload: Unknown, ProvPublic, KindPtr - , " ldr x0, [x9]" -- Use as base - allowed via provenance + , " adrp x8, _const@PAGE" + , " str x8, [sp, #16]" + , " ldr x9, [sp, #16]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: scalar spill/reload doesn't upgrade kind" $ do - -- A scalar (loaded value) spilled and reloaded should stay KindScalar - -- even with public provenance, should not be usable as base let src = T.unlines [ "foo:" - , " ldr x8, [x21]" -- x8 = Unknown, ProvUnknown (from heap) - , " add x8, x20, x8" -- x8 = Unknown, ProvPublic, KindScalar - , " str x8, [sp, #16]" -- Spill scalar to stack - , " ldr x9, [sp, #16]" -- Reload: Unknown, ProvPublic, KindScalar - , " ldr x0, [x9]" -- Use as base - violation! + , " ldr x8, [x21]" + , " add x8, x20, x8" + , " str x8, [sp, #16]" + , " ldr x9, [sp, #16]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" -- STG stack (x20-relative) spill/reload tests , testCase "good: STG stack spill/reload preserves public" $ do - -- Public value stored to STG stack (x20), then reloaded, stays public let src = T.unlines [ "foo:" - , " adrp x8, _const@PAGE" -- x8 = Public, ProvPublic, KindPtr - , " str x8, [x20, #8]" -- Store to STG stack slot - , " ldr x9, [x20, #8]" -- Reload from STG stack - , " ldr x0, [x9]" -- Use as base - should be public + , " adrp x8, _const@PAGE" + , " str x8, [x20, #8]" + , " ldr x9, [x20, #8]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: STG stack spill/reload preserves unknown" $ do - -- Unknown value stored to STG stack, then reloaded, stays unknown let src = T.unlines [ "foo:" - , " ldr x8, [x21]" -- x8 = unknown (loaded from heap) - , " str x8, [x20, #16]" -- Store unknown to STG stack slot - , " ldr x9, [x20, #16]" -- Reload the unknown value - , " ldr x0, [x9]" -- Use as base - violation! + , " ldr x8, [x21]" + , " str x8, [x20, #16]" + , " ldr x9, [x20, #16]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "good: STG stack pointer spill preserves kind" $ do - -- A pointer spilled to STG stack and reloaded should stay KindPtr let src = T.unlines [ "foo:" - , " adrp x8, _const@PAGE" -- x8 = Public, ProvPublic, KindPtr - , " str x8, [x20, #24]" -- Spill pointer to STG stack - , " ldr x9, [x20, #24]" -- Reload: Unknown, ProvPublic, KindPtr - , " ldr x0, [x9]" -- Use as base - allowed via provenance + , " adrp x8, _const@PAGE" + , " str x8, [x20, #24]" + , " ldr x9, [x20, #24]" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: STG stack pre-indexed load preserves taint" $ do - -- Pre-indexed load from STG stack should restore tracked taint - -- ldr x8, [x20, #16]! loads from x20+16, then updates x20 let src = T.unlines [ "foo:" - , " adrp x8, _const@PAGE" -- x8 = Public, ProvPublic, KindPtr - , " str x8, [x20, #16]" -- Store pointer to STG stack slot 16 - , " ldr x9, [x20, #16]!" -- Pre-indexed load from slot 16 - , " ldr x0, [x9]" -- Use as base - should be public + , " adrp x8, _const@PAGE" + , " str x8, [x20, #16]" + , " ldr x9, [x20, #16]!" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: GOT/PLT pattern" $ do - -- GOTPAGEOFF pattern: adrp + ldr from GOT should produce public let src = T.unlines [ "foo:" , " adrp x8, _sym@GOTPAGE" - , " ldr x8, [x8, _sym@GOTPAGEOFF]" -- GOT entry -> address - , " ldr x0, [x8]" -- Use GOT result as base + , " ldr x8, [x8, _sym@GOTPAGEOFF]" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) -- Heap taint propagation tests , testCase "bad: heap store propagates secret to later load" $ do - -- Secret stored to heap should taint subsequent heap loads - -- Using public roots (x21 = Hp) as base addresses to avoid base violations let src = T.unlines [ "_foo:" - , " str x0, [x21]" -- Store secret x0 to heap (via Hp) - , " ldr x9, [x21, #8]" -- Load from heap -> gets heap taint - , " ldr x1, [x20, x9]" -- Use loaded value as index -> violation + , " str x0, [x21]" + , " ldr x9, [x21, #8]" + , " ldr x1, [x20, x9]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "bad: heap taints join on multiple stores" $ do - -- Multiple stores to heap join; one secret contaminates all loads - -- Using public roots (x21 = Hp) as base addresses let src = T.unlines [ "_foo:" - , " adrp x8, _const@PAGE" -- x8 = public pointer - , " str x8, [x21]" -- Store public to heap - , " str x0, [x21, #8]" -- Store secret x0 to heap - , " ldr x9, [x21, #16]" -- Load from heap -> secret (joined) - , " ldr x1, [x20, x9]" -- Use as index -> violation + , " adrp x8, _const@PAGE" + , " str x8, [x21]" + , " str x0, [x21, #8]" + , " ldr x9, [x21, #16]" + , " ldr x1, [x20, x9]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" - -- Refined heap slot tests (Stage 2) + -- Refined heap slot tests , testCase "good: refined heap slot preserves public at different offset" $ do - -- Store secret to offset 0, public to offset 8 - -- Load from offset 8 should get public (refined tracking) let src = T.unlines [ "_foo:" - , " adrp x8, _const@PAGE" -- x8 = public pointer - , " str x0, [x21]" -- Store secret to offset 0 - , " str x8, [x21, #8]" -- Store public pointer to offset 8 - , " ldr x9, [x21, #8]" -- Load from offset 8 -> should be public - , " ldr x10, [x9]" -- Use as base - should be safe + , " adrp x8, _const@PAGE" + , " str x0, [x21]" + , " str x8, [x21, #8]" + , " ldr x9, [x21, #8]" + , " ldr x10, [x9]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: refined heap slot preserves secret at its offset" $ do - -- Store public to offset 0, secret to offset 8 - -- Load from offset 8 should get secret (refined tracking) let src = T.unlines [ "_foo:" - , " adrp x8, _const@PAGE" -- x8 = public pointer - , " str x8, [x21]" -- Store public to offset 0 - , " str x0, [x21, #8]" -- Store secret to offset 8 - , " ldr x9, [x21, #8]" -- Load from offset 8 -> should be secret - , " ldr x10, [x20, x9]" -- Use as index -> violation + , " adrp x8, _const@PAGE" + , " str x8, [x21]" + , " str x0, [x21, #8]" + , " ldr x9, [x21, #8]" + , " ldr x10, [x20, x9]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "bad: refined heap slot falls back for unknown offset" $ do - -- Store secret at known offset, load from unknown offset - -- Load should get coarse heap taint (secret) let src = T.unlines [ "_foo:" - , " str x0, [x21]" -- Store secret to offset 0 - , " ldr x9, [x21, #16]" -- Load from offset 16 (no slot) -> coarse - , " ldr x10, [x20, x9]" -- Use as index -> violation + , " str x0, [x21]" + , " ldr x9, [x21, #16]" + , " ldr x10, [x20, x9]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" - -- Acquire/release and exclusive memory ops (IMPL20) + -- Acquire/release and exclusive memory ops , testCase "bad: ldar from unknown base" $ do - -- ldar should be checked like ldr let src = T.unlines [ "_foo:" - , " ldar x8, [x0]" -- x0 is unknown -> violation + , " ldar x8, [x0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X0 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "good: stlr to stack preserves taint" $ do - -- stlr should track stack slots like str let src = T.unlines [ "_foo:" - , " adrp x8, _const@PAGE" -- x8 = public - , " stlr x8, [sp]" -- Store public to stack - , " ldr x9, [sp]" -- Reload -> public - , " ldr x10, [x9]" -- Use as base -> safe + , " adrp x8, _const@PAGE" + , " stlr x8, [sp]" + , " ldr x9, [sp]" + , " ldr x10, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: stxr status register is public" $ do - -- stxr status register should be Public, allowing use as index - -- Note: w0 and x0 are separate in our model, so use w0 directly let src = T.unlines [ "_foo:" - , " stxr w0, x1, [x21]" -- w0 = status (public scalar) - , " ldr x8, [x20, w0]" -- Use w0 as index -> safe + , " stxr w0, x1, [x21]" + , " ldr x8, [x20, w0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: stlxr status register is public" $ do - -- stlxr status register should be Public let src = T.unlines [ "_foo:" - , " stlxr w0, x1, [x21]" -- w0 = status (public scalar) - , " ldr x8, [x20, w0]" -- Use w0 as index -> safe + , " stlxr w0, x1, [x21]" + , " ldr x8, [x20, w0]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) ] -- Inter-procedural tests @@ -710,57 +833,60 @@ auditTests = testGroup "Audit" [ interprocTests :: TestTree interprocTests = testGroup "InterProc" [ testCase "interproc: callee sets x0 public" $ do - -- callee sets x0 to a public value; caller uses x0 after call - -- With interproc, the summary should show x0 is public let src = T.unlines [ "_callee:" - , " adrp x0, _const@PAGE" -- x0 = public + , " adrp x0, _const@PAGE" , " ret" , "_caller:" , " bl _callee" - , " ldr x1, [x0]" -- x0 should be public via summary + , " ldr x1, [x0]" , " ret" ] - -- Default mode: x0 unknown after call (caller-saved invalidated) - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "default: 1 violation" 1 (length (arViolations ar)) - -- Interproc mode: x0 public via summary - case auditInterProc "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "interproc: 0 violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "default: 1 violation" + 1 (length (arViolations ar)) + case auditInterProc rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "interproc: 0 violations" + 0 (length (arViolations ar)) , testCase "interproc: callee leaves x0 unknown" $ do - -- callee loads into x0; caller uses x0 after call let src = T.unlines [ "_callee:" - , " ldr x0, [x21]" -- x0 = unknown (loaded from heap) + , " ldr x0, [x21]" , " ret" , "_caller:" , " bl _callee" - , " ldr x1, [x0]" -- x0 unknown in both modes + , " ldr x1, [x0]" , " ret" ] - -- Both modes should report violation - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "default: 1 violation" 1 (length (arViolations ar)) - case auditInterProc "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "interproc: 1 violation" 1 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "default: 1 violation" + 1 (length (arViolations ar)) + case auditInterProc rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "interproc: 1 violation" + 1 (length (arViolations ar)) , testCase "interproc: default mode unchanged" $ do - -- Verify default mode behavior is unchanged let src = T.unlines [ "foo:" , " adrp x8, _const@PAGE" , " bl bar" - , " ldr x0, [x8]" -- x8 unknown after call (caller-saved) + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "1 violation" 1 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "1 violation" + 1 (length (arViolations ar)) ] -- Provenance tests @@ -768,107 +894,116 @@ interprocTests = testGroup "InterProc" [ provenanceTests :: TestTree provenanceTests = testGroup "Provenance" [ testCase "good: mov from public root preserves provenance" $ do - -- mov from public root should preserve public provenance let src = T.unlines [ "foo:" - , " mov x8, x20" -- x8 = copy of public root - , " ldr x0, [x8]" -- x8 should be public via provenance + , " mov x8, x20" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: add #imm preserves provenance" $ do - -- add with immediate preserves provenance from source let src = T.unlines [ "foo:" - , " add x8, x20, #16" -- x8 = public root + offset - , " ldr x0, [x8]" -- x8 should be public via provenance + , " add x8, x20, #16" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: load clears provenance" $ do - -- Loading from memory should clear provenance let src = T.unlines [ "foo:" - , " mov x8, x20" -- x8 = public - , " ldr x8, [x8]" -- x8 = unknown (loaded from memory) - , " ldr x0, [x8]" -- x8 as base should be violation + , " mov x8, x20" + , " ldr x8, [x8]" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "good: orr with xzr preserves provenance" $ do - -- orr with zero register should preserve provenance let src = T.unlines [ "foo:" - , " mov x8, x20" -- x8 = public - , " orr x9, x8, xzr" -- x9 = x8 | 0 = copy with provenance - , " ldr x0, [x9]" -- x9 should be public + , " mov x8, x20" + , " orr x9, x8, xzr" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "good: pointer untagging preserves public provenance" $ do - -- Untagging a register with ProvPublic preserves provenance let src = T.unlines [ "foo:" - , " mov x8, x20" -- x8 = Public, ProvPublic - , " and x8, x8, #0xfffffffffffffff8" -- untag preserves ProvPublic - , " ldr x0, [x8]" -- should be safe + , " mov x8, x20" + , " and x8, x8, #0xfffffffffffffff8" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "bad: pointer untagging doesn't upgrade unknown provenance" $ do - -- Untagging a register with ProvUnknown doesn't make it safe let src = T.unlines [ "foo:" - , " ldr x8, [x21]" -- x8 = Unknown, ProvUnknown (from heap) - , " and x8, x8, #-8" -- untag doesn't upgrade ProvUnknown - , " ldr x0, [x8]" -- violation: x8 is Unknown, ProvUnknown + , " ldr x8, [x21]" + , " and x8, x8, #-8" + , " ldr x0, [x8]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "one violation" 1 (length (arViolations ar)) + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "one violation" + 1 (length (arViolations ar)) , testCase "bad: scalar cannot be laundered via provenance" $ do - -- A scalar value added to a public pointer should NOT become a valid - -- base register, even though the result has public provenance. - -- This is the key case that kind tracking prevents. let src = T.unlines [ "foo:" - , " ldr x8, [x21]" -- x8 = Unknown, ProvUnknown (from heap) - , " add x9, x20, x8" -- x9 = Unknown, ProvPublic, KindScalar - , " ldr x0, [x9]" -- violation: x9 is scalar, not pointer + , " ldr x8, [x21]" + , " add x9, x20, x8" + , " ldr x0, [x9]" , " ret" ] - case audit "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + case audit rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of UnknownBase X9 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" ] -- Taint config tests @@ -876,166 +1011,207 @@ provenanceTests = testGroup "Provenance" [ taintConfigTests :: TestTree taintConfigTests = testGroup "TaintConfig" [ testCase "parse valid config" $ do - let json = B8.pack "{\"foo\": {\"secret\": [\"X0\"], \"public\": [\"X1\"]}}" - case eitherDecodeStrict' json :: Either String TaintConfig of - Left e -> assertFailure $ "parse failed: " ++ e + let json = B8.pack + "{\"foo\": {\"secret\": [\"X0\"], \ + \\"public\": [\"X1\"]}}" + case eitherDecodeStrict' json + :: Either String TaintConfig of + Left e -> assertFailure $ + "parse failed: " ++ e Right cfg -> do - let policy = Map.lookup "foo" (tcPolicies cfg) + let policy = + Map.lookup "foo" (tcPolicies cfg) case policy of - Nothing -> assertFailure "missing policy for 'foo'" + Nothing -> assertFailure + "missing policy for 'foo'" Just p -> do - assertEqual "secret regs" (Set.singleton X0) (apSecret p) - assertEqual "public regs" (Set.singleton X1) (apPublic p) + assertEqual "secret regs" + (Set.singleton X0) (apSecret p) + assertEqual "public regs" + (Set.singleton X1) (apPublic p) , testCase "parse invalid register" $ do - let json = B8.pack "{\"foo\": {\"secret\": [\"INVALID\"]}}" - case eitherDecodeStrict' json :: Either String TaintConfig of - Left _ -> pure () -- Expected to fail - Right _ -> assertFailure "should have failed on invalid register" + let json = B8.pack + "{\"foo\": {\"secret\": [\"INVALID\"]}}" + case eitherDecodeStrict' json + :: Either String TaintConfig of + Left _ -> pure () + Right _ -> assertFailure + "should have failed on invalid register" , testCase "seedArgs marks secret" $ do - let policy = ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty - st = seedArgs policy initTaintState - assertEqual "X0 is secret" Secret (getTaint X0 st) + let policy = ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty + st = seedArgs rt policy (initTaintState rt) + assertEqual "X0 is secret" + Secret (getTaint X0 st) , testCase "seedArgs marks public" $ do - let policy = ArgPolicy Set.empty (Set.singleton X0) Set.empty Set.empty Set.empty - st = seedArgs policy initTaintState - assertEqual "X0 is public" Public (getTaint X0 st) + let policy = ArgPolicy Set.empty + (Set.singleton X0) Set.empty + Set.empty Set.empty + st = seedArgs rt policy (initTaintState rt) + assertEqual "X0 is public" + Public (getTaint X0 st) , testCase "secret takes precedence over public" $ do - let policy = ArgPolicy (Set.singleton X0) (Set.singleton X0) Set.empty Set.empty Set.empty - st = seedArgs policy initTaintState - assertEqual "X0 is secret" Secret (getTaint X0 st) + let policy = ArgPolicy (Set.singleton X0) + (Set.singleton X0) Set.empty + Set.empty Set.empty + st = seedArgs rt policy (initTaintState rt) + assertEqual "X0 is secret" + Secret (getTaint X0 st) , testCase "secret_pointee load with imm offset produces Secret" $ do - -- x0 is pointer to secret data; ldr through it produces Secret let src = T.unlines [ "_foo:" - , " ldr x1, [x0, #8]" -- Load through secret_pointee - , " ldr x2, [x21, x1]" -- Use as index -> violation + , " ldr x1, [x0, #8]" + , " ldr x2, [x21, x1]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty (Set.singleton X0) - Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty + (Set.singleton X0) + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of - [v] -> assertEqual "secret index" (SecretIndex X1) (vReason v) - _ -> assertFailure "expected exactly one violation" + [v] -> assertEqual "secret index" + (SecretIndex X1) (vReason v) + _ -> assertFailure + "expected exactly one violation" , testCase "secret_pointee load with reg offset produces Secret" $ do - -- x0 is pointer to secret data; ldr [x0, x1] produces Secret - -- x1 is unknown, so we get: unknown index x1 + secret index x2 let src = T.unlines [ "_foo:" - , " ldr x2, [x0, x1]" -- Load through secret_pointee with reg idx - , " ldr x3, [x21, x2]" -- Use as index -> violation + , " ldr x2, [x0, x1]" + , " ldr x3, [x21, x2]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty (Set.singleton X0) - Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty + (Set.singleton X0) + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - -- Two violations: unknown x1 index, secret x2 index - assertEqual "two violations" 2 (length (arViolations ar)) + assertEqual "two violations" + 2 (length (arViolations ar)) let reasons = map vReason (arViolations ar) - assertBool "has secret index x2" (SecretIndex X2 `elem` reasons) + assertBool "has secret index x2" + (SecretIndex X2 `elem` reasons) , testCase "secret takes precedence over secret_pointee" $ do - -- If both secret and secret_pointee, secret wins (scalar, not pointer) - let policy = ArgPolicy (Set.singleton X0) Set.empty (Set.singleton X0) - Set.empty Set.empty - st = seedArgs policy initTaintState - assertEqual "X0 is secret" Secret (getTaint X0 st) - -- Secret scalar is not a secret_pointee (provenance is ProvUnknown) - assertEqual "X0 provenance" ProvUnknown (getProvenance X0 st) + let policy = ArgPolicy (Set.singleton X0) + Set.empty (Set.singleton X0) + Set.empty Set.empty + st = seedArgs rt policy (initTaintState rt) + assertEqual "X0 is secret" + Secret (getTaint X0 st) + assertEqual "X0 provenance" + ProvUnknown (getProvenance X0 st) , testCase "non-pointer arithmetic invalidates secret_pointee" $ do - -- mul produces KindScalar, so x0 is no longer a valid pointer - -- Loads through it go to coarse heap bucket (Unknown), not Secret let src = T.unlines [ "_foo:" - , " mul x0, x0, x1" -- Multiply -> KindScalar - , " ldr x2, [x0, #0]" -- Load from heap bucket (not secret) - , " ldr x3, [x21, x2]" -- Use as index -> unknown, not secret + , " mul x0, x0, x1" + , " ldr x2, [x0, #0]" + , " ldr x3, [x21, x2]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty (Set.singleton X0) - Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty + (Set.singleton X0) + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> - -- No secret violations (x2 is Unknown from heap, not Secret) assertEqual "no secret violations" 0 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) , testCase "pointer arithmetic preserves secret_pointee" $ do - -- add with immediate keeps KindPtr, so x0 is still secret_pointee let src = T.unlines [ "_foo:" - , " add x0, x0, #8" -- Pointer arithmetic -> KindPtr - , " ldr x1, [x0, #0]" -- Load through secret_pointee -> Secret - , " ldr x2, [x21, x1]" -- Use as index -> secret violation + , " add x0, x0, #8" + , " ldr x1, [x0, #0]" + , " ldr x2, [x21, x1]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty (Set.singleton X0) - Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> do + (ArgPolicy Set.empty Set.empty + (Set.singleton X0) + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "one secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) , testCase "secret arg causes violation on indexed access" $ do let src = T.unlines [ "_mul_wnaf:" - , " ldr x8, [x20, x0]" -- x0 as index + , " ldr x8, [x20, x0]" , " ret" ] cfg = TaintConfig (Map.singleton "_mul_wnaf" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X0 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "public arg allows indexed access" $ do let src = T.unlines [ "_foo:" - , " ldr x8, [x20, x0]" -- x0 as index + , " ldr x8, [x20, x0]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty (Set.singleton X0) Set.empty Set.empty Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + (ArgPolicy Set.empty (Set.singleton X0) + Set.empty Set.empty Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "interproc with config" $ do let src = T.unlines [ "_mul_wnaf:" - , " ldr x8, [x20, x0]" -- x0 as secret index -> violation + , " ldr x8, [x20, x0]" , " ret" ] cfg = TaintConfig (Map.singleton "_mul_wnaf" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> + assertEqual "one violation" + 1 (length (arViolations ar)) -- STG stack slot seeding tests , testCase "parse stg_secret/stg_public in config" $ do @@ -1048,135 +1224,157 @@ taintConfigTests = testGroup "TaintConfig" [ , " }" , "}" ] - case eitherDecodeStrict' json :: Either String TaintConfig of - Left e -> assertFailure $ "parse failed: " ++ e + case eitherDecodeStrict' json + :: Either String TaintConfig of + Left e -> assertFailure $ + "parse failed: " ++ e Right cfg -> do - let policy = Map.lookup "_foo" (tcPolicies cfg) + let policy = + Map.lookup "_foo" (tcPolicies cfg) case policy of - Nothing -> assertFailure "missing policy for '_foo'" + Nothing -> assertFailure + "missing policy for '_foo'" Just p -> do - assertEqual "stg_secret" (Set.fromList [8, 152]) (apStgSecret p) - assertEqual "stg_public" (Set.singleton 24) (apStgPublic p) + assertEqual "stg_secret" + (Set.fromList [8, 152]) + (apStgSecret p) + assertEqual "stg_public" + (Set.singleton 24) + (apStgPublic p) , testCase "stg_secret seeds STG stack slot as secret" $ do - -- Load from seeded secret STG stack slot should yield secret let src = T.unlines [ "_foo:" - , " ldr x8, [x20, #8]" -- Load from secret STG slot - , " ldr x9, [x21, x8]" -- Use as index -> violation + , " ldr x8, [x20, #8]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty (Set.singleton 8) Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty Set.empty + (Set.singleton 8) Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "stg_public seeds STG stack slot as public" $ do - -- Load from seeded public STG stack slot should yield public let src = T.unlines [ "_foo:" - , " ldr x8, [x20, #16]" -- Load from public STG slot - , " ldr x9, [x21, x8]" -- Use as index -> should be safe + , " ldr x8, [x20, #16]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty Set.empty (Set.singleton 16))) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + (ArgPolicy Set.empty Set.empty Set.empty + Set.empty (Set.singleton 16))) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) , testCase "stg_secret takes precedence over stg_public" $ do - -- If same offset in both lists, secret wins let src = T.unlines [ "_foo:" - , " ldr x8, [x20, #8]" -- Load from slot in both lists - , " ldr x9, [x21, x8]" -- Use as index -> violation (secret wins) + , " ldr x8, [x20, #8]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty - (Set.singleton 8) (Set.singleton 8))) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty Set.empty + (Set.singleton 8) + (Set.singleton 8))) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" - -- STG stack delta tracking tests (IMPL19) + -- STG stack delta tracking tests , testCase "stg_secret preserved after sub x20" $ do - -- Seed secret at offset 8, then sub x20, x20, #16 - -- After sub, old offset 8 maps to new offset 24 let src = T.unlines [ "_foo:" - , " sub x20, x20, #16" -- x20 -= 16, shift slots by +16 - , " ldr x8, [x20, #24]" -- Load from new offset 24 (was 8) - , " ldr x9, [x21, x8]" -- Use as index -> violation + , " sub x20, x20, #16" + , " ldr x8, [x20, #24]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty (Set.singleton 8) Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty Set.empty + (Set.singleton 8) Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "stg_secret preserved after add x20" $ do - -- Seed secret at offset 24, then add x20, x20, #16 - -- After add, old offset 24 maps to new offset 8 let src = T.unlines [ "_foo:" - , " add x20, x20, #16" -- x20 += 16, shift slots by -16 - , " ldr x8, [x20, #8]" -- Load from new offset 8 (was 24) - , " ldr x9, [x21, x8]" -- Use as index -> violation + , " add x20, x20, #16" + , " ldr x8, [x20, #8]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty (Set.singleton 24) Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty Set.empty + (Set.singleton 24) Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> do - assertEqual "one violation" 1 (length (arViolations ar)) + assertEqual "one violation" + 1 (length (arViolations ar)) case arViolations ar of [v] -> case vReason v of SecretIndex X8 -> pure () - other -> assertFailure $ "wrong reason: " ++ show other - _ -> assertFailure "expected one violation" + other -> assertFailure $ + "wrong reason: " ++ show other + _ -> assertFailure + "expected one violation" , testCase "stg slot cleared on non-constant x20 update" $ do - -- Secret at offset 8, then x20 updated non-constantly - -- Should clear slot map. With STG-public assumption, load yields public. - -- Note: x20 itself becomes unknown, causing a base violation let src = T.unlines [ "_foo:" - , " add x20, x20, x1" -- Non-constant update clears map, x20 unknown - , " ldr x8, [x20, #8]" -- Unknown base x20, load yields public (STG) - , " ldr x9, [x21, x8]" -- x8 is public, no index violation + , " add x20, x20, x1" + , " ldr x8, [x20, #8]" + , " ldr x9, [x21, x8]" , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy Set.empty Set.empty Set.empty (Set.singleton 8) Set.empty)) True - case auditWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> do - -- One violation: unknown base x20 - -- (x8 is public due to STG-public assumption after map clear) - assertEqual "one violation" 1 (length (arViolations ar)) + (ArgPolicy Set.empty Set.empty Set.empty + (Set.singleton 8) Set.empty)) True + case auditWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> + assertEqual "one violation" + 1 (length (arViolations ar)) ] -- NCT scanner tests @@ -1192,175 +1390,207 @@ nctTests = testGroup "NCT" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "_foo" m of - Nothing -> assertFailure "no findings for _foo" + Nothing -> assertFailure + "no findings for _foo" Just fs -> do - assertEqual "one finding" 1 (length fs) + assertEqual "one finding" + 1 (length fs) case fs of - [f] -> assertEqual "reason" CondBranch (nctReason f) - _ -> assertFailure "expected one finding" + [f] -> assertEqual "reason" + CondBranch (nctReason f) + _ -> assertFailure + "expected one finding" , testCase "cond branch: cbz" $ do - let src = "foo:\n cbz x0, target\ntarget:\n ret\n" + let src = "foo:\n cbz x0, target\n\ + \target:\n ret\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" CondBranch (nctReason f) - _ -> assertFailure $ "expected 1 finding, got " ++ show (length fs) + [f] -> assertEqual "reason" + CondBranch (nctReason f) + _ -> assertFailure $ + "expected 1 finding, got " + ++ show (length fs) , testCase "cond branch: tbz" $ do - let src = "foo:\n tbz x0, #5, target\ntarget:\n ret\n" + let src = "foo:\n tbz x0, #5, target\n\ + \target:\n ret\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" CondBranch (nctReason f) - _ -> assertFailure $ "expected 1 finding" + [f] -> assertEqual "reason" + CondBranch (nctReason f) + _ -> assertFailure "expected 1 finding" , testCase "indirect branch: br" $ do let src = "foo:\n br x8\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" IndirectBranch (nctReason f) + [f] -> assertEqual "reason" + IndirectBranch (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "indirect branch: blr" $ do let src = "foo:\n blr x8\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" IndirectBranch (nctReason f) + [f] -> assertEqual "reason" + IndirectBranch (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "div: udiv" $ do let src = "foo:\n udiv x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" Div (nctReason f) + [f] -> assertEqual "reason" + Div (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "div: sdiv" $ do let src = "foo:\n sdiv x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" Div (nctReason f) + [f] -> assertEqual "reason" + Div (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "no finding: mul (DIT)" $ do - -- mul is DIT per ARM docs let src = "foo:\n mul x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "no finding: madd (DIT)" $ do - -- madd is DIT per ARM docs let src = "foo:\n madd x0, x1, x2, x3\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "no finding: umulh (DIT)" $ do - -- umulh is DIT per ARM docs let src = "foo:\n umulh x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "no finding: lsl with reg (DIT)" $ do - -- variable shifts are DIT per ARM docs let src = "foo:\n lsl x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "no finding: lsr with reg (DIT)" $ do - -- variable shifts are DIT per ARM docs let src = "foo:\n lsr x0, x1, x2\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "no finding: lsl with imm" $ do - -- lsl x0, x1, #5 should NOT flag (immediate shift) let src = "foo:\n lsl x0, x1, #5\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "reg-index: ldr [xN, xM]" $ do let src = "foo:\n ldr x0, [x1, x2]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" RegIndexAddr (nctReason f) + [f] -> assertEqual "reason" + RegIndexAddr (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "reg-index: str [xN, xM, lsl #3]" $ do let src = "foo:\n str x0, [x1, x2, lsl #3]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns + let m = scanNct rt lns case Map.lookup "foo" m of Nothing -> assertFailure "no findings" Just fs -> case fs of - [f] -> assertEqual "reason" RegIndexAddr (nctReason f) + [f] -> assertEqual "reason" + RegIndexAddr (nctReason f) _ -> assertFailure "expected 1 finding" , testCase "no finding: ldr [xN, #imm]" $ do - -- Immediate offset is safe let src = "foo:\n ldr x0, [x1, #8]\n" case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "no findings" Nothing (Map.lookup "foo" m) + let m = scanNct rt lns + assertEqual "no findings" Nothing + (Map.lookup "foo" m) , testCase "grouping by function symbol" $ do let src = T.unlines @@ -1374,13 +1604,16 @@ nctTests = testGroup "NCT" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let m = scanNct lns - assertEqual "foo: 1 finding" (Just 1) (fmap length (Map.lookup "_foo" m)) - assertEqual "bar: 2 findings" (Just 2) (fmap length (Map.lookup "_bar" m)) - -- L1 is local, findings should stay with _foo - assertEqual "L1 not a key" Nothing (Map.lookup "L1" m) + let m = scanNct rt lns + assertEqual "foo: 1 finding" (Just 1) + (fmap length (Map.lookup "_foo" m)) + assertEqual "bar: 2 findings" (Just 2) + (fmap length (Map.lookup "_bar" m)) + assertEqual "L1 not a key" Nothing + (Map.lookup "L1" m) ] -- Call graph tests @@ -1393,10 +1626,12 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - assertEqual "symbol exists" True (symbolExists "_foo" cg) + let cg = buildCallGraph rt lns + assertEqual "symbol exists" + True (symbolExists "_foo" cg) , testCase "symbolExists: missing symbol" $ do let src = T.unlines @@ -1404,10 +1639,12 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - assertEqual "symbol missing" False (symbolExists "_bar" cg) + let cg = buildCallGraph rt lns + assertEqual "symbol missing" + False (symbolExists "_bar" cg) , testCase "reachableSymbols: direct call" $ do let src = T.unlines @@ -1418,13 +1655,18 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - reachable = reachableSymbols "_foo" cg - assertEqual "foo reachable" True (Set.member "_foo" reachable) - assertEqual "bar reachable" True (Set.member "_bar" reachable) - assertEqual "count" 2 (Set.size reachable) + let cg = buildCallGraph rt lns + reachable = + reachableSymbols "_foo" cg + assertEqual "foo reachable" True + (Set.member "_foo" reachable) + assertEqual "bar reachable" True + (Set.member "_bar" reachable) + assertEqual "count" 2 + (Set.size reachable) , testCase "reachableSymbols: transitive call" $ do let src = T.unlines @@ -1438,12 +1680,16 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - reachable = reachableSymbols "_foo" cg - assertEqual "baz reachable from foo" True (Set.member "_baz" reachable) - assertEqual "count" 3 (Set.size reachable) + let cg = buildCallGraph rt lns + reachable = + reachableSymbols "_foo" cg + assertEqual "baz reachable from foo" True + (Set.member "_baz" reachable) + assertEqual "count" 3 + (Set.size reachable) , testCase "reachableSymbols: no callees" $ do let src = T.unlines @@ -1453,11 +1699,14 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - reachable = reachableSymbols "_foo" cg - assertEqual "only foo reachable" (Set.singleton "_foo") reachable + let cg = buildCallGraph rt lns + reachable = + reachableSymbols "_foo" cg + assertEqual "only foo reachable" + (Set.singleton "_foo") reachable , testCase "reachableSymbols: missing root returns empty" $ do let src = T.unlines @@ -1465,11 +1714,14 @@ callGraphTests = testGroup "CallGraph" [ , " ret" ] case parseAsm src of - Left e -> assertFailure $ "parse failed: " ++ show e + Left e -> assertFailure $ + "parse failed: " ++ show e Right lns -> do - let cg = buildCallGraph lns - reachable = reachableSymbols "_missing" cg - assertEqual "empty set" Set.empty reachable + let cg = buildCallGraph rt lns + reachable = + reachableSymbols "_missing" cg + assertEqual "empty set" + Set.empty reachable ] -- Tail call inter-procedural tests @@ -1477,8 +1729,6 @@ callGraphTests = testGroup "CallGraph" [ tailCallTests :: TestTree tailCallTests = testGroup "TailCall" [ - -- Basic tail call taint propagation - -- Uses x23 (R2 in STG) since x22 (HpLim) is a public root testCase "tail call propagates unknown taint" $ do let src = T.unlines [ "_caller:" @@ -1488,9 +1738,11 @@ tailCallTests = testGroup "TailCall" [ , " ldr x0, [x20, x23]" , " ret" ] - case auditInterProc "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "one violation" 1 (length (arViolations ar)) + case auditInterProc rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "one violation" + 1 (length (arViolations ar)) , testCase "tail call propagates secret taint" $ do let src = T.unlines @@ -1502,15 +1754,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_caller" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Local branches still work , testCase "local branch is intra-procedural" $ do let src = T.unlines [ "_foo:" @@ -1520,9 +1774,11 @@ tailCallTests = testGroup "TailCall" [ , " ldr x2, [x20, x0]" , " ret" ] - case auditInterProc "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "one unknown violation" 1 (length (arViolations ar)) + case auditInterProc rt "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "one unknown violation" + 1 (length (arViolations ar)) , testCase "NCG local branch (Lc prefix) is intra-procedural" $ do let src = T.unlines @@ -1534,15 +1790,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Chain of tail calls , testCase "chain of tail calls propagates taint" $ do let src = T.unlines [ "_a:" @@ -1556,15 +1814,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_a" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Indirect call preserves callee-saved (ABI requirement) , testCase "indirect call preserves callee-saved taint" $ do let src = T.unlines [ "_foo:" @@ -1575,16 +1835,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> - -- x23 is callee-saved, so taint preserved across blr assertEqual "x23 secret preserved" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Tail call to unknown function , testCase "tail call to unknown function is conservative" $ do let src = T.unlines [ "_foo:" @@ -1592,13 +1853,15 @@ tailCallTests = testGroup "TailCall" [ , " b _unknown_info" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e - Right ar -> assertEqual "no violations" 0 (length (arViolations ar)) + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e + Right ar -> assertEqual "no violations" + 0 (length (arViolations ar)) - -- Secret pointee across tail call , testCase "secret_pointee propagates across tail call" $ do let src = T.unlines [ "_caller:" @@ -1610,15 +1873,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_caller" - (ArgPolicy Set.empty Set.empty (Set.singleton X0) - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy Set.empty Set.empty + (Set.singleton X0) + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Mutual tail calls (fixpoint test) , testCase "mutual tail calls reach fixpoint" $ do let src = T.unlines [ "_a_info:" @@ -1631,15 +1896,17 @@ tailCallTests = testGroup "TailCall" [ , " b _a_info" ] cfg = TaintConfig (Map.singleton "_a_info" - (ArgPolicy (Set.singleton X22) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X22) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation in cycle" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Mixed bl and b calls , testCase "bl followed by tail call" $ do let src = T.unlines [ "_caller:" @@ -1653,15 +1920,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_caller" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "secret violation" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- ABI argument registers preserved across in-file tail calls , testCase "tail call preserves x0 argument taint" $ do let src = T.unlines [ "_caller:" @@ -1671,15 +1940,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_caller" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> assertEqual "x0 secret preserved" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + (length $ filter isSecretViolation + (map vReason (arViolations ar))) - -- Callee-saved registers preserved across external bl calls , testCase "external bl preserves callee-saved taint" $ do let src = T.unlines [ "_foo:" @@ -1689,12 +1960,17 @@ tailCallTests = testGroup "TailCall" [ , " ret" ] cfg = TaintConfig (Map.singleton "_foo" - (ArgPolicy (Set.singleton X0) Set.empty Set.empty - Set.empty Set.empty)) True - case auditInterProcWithConfig cfg "test" src of - Left e -> assertFailure $ "parse failed: " ++ show e + (ArgPolicy (Set.singleton X0) + Set.empty Set.empty + Set.empty Set.empty)) True + case auditInterProcWithConfig rt cfg "test" src of + Left e -> assertFailure $ + "parse failed: " ++ show e Right ar -> - assertEqual "x23 secret preserved across external call" 1 - (length $ filter isSecretViolation (map vReason (arViolations ar))) + assertEqual + "x23 secret preserved across external call" + 1 + (length $ filter isSecretViolation + (map vReason (arViolations ar))) ]