auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

commit b9e0153adc2546f8b29a243149b6f0349c9aef01
parent 50e1c801e8d53f2da47b11b8b858a6978f800104
Author: Jared Tobin <jared@jtobin.io>
Date:   Fri, 13 Feb 2026 19:19:10 +0400

feat: add symbol-focused NCT scan with call graph analysis (IMPL24)

Adds --symbol/-s option to filter NCT scan to a specific symbol and
its transitive callees. Uses Data.Graph for reachability computation.

- New CallGraph module with buildCallGraph, reachableSymbols, symbolExists
- scanNctFileForSymbol returns SymbolScanResult with filtered findings
- CLI shows root symbol, reachable count, and findings summary
- 6 new tests for call graph functionality

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>

Diffstat:
Mapp/Main.hs | 61++++++++++++++++++++++++++++++++++++++++++++++++++++---------
Mlib/Audit/AArch64.hs | 64+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-
Alib/Audit/AArch64/CallGraph.hs | 91+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Mppad-auditor.cabal | 1+
Mtest/Main.hs | 91+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
5 files changed, 298 insertions(+), 10 deletions(-)

diff --git a/app/Main.hs b/app/Main.hs @@ -7,10 +7,11 @@ import Audit.AArch64 , TaintConfig(..) , NctReason(..), NctFinding(..), nctLine, nctInstr, nctReason , LineMap, isGhcRuntimeFinding + , SymbolScanResult(..) , auditFile, auditFileInterProc , auditFileWithConfig, auditFileInterProcWithConfig , parseFile, regName, loadTaintConfig - , scanNctFile + , scanNctFile, scanNctFileForSymbol ) import Audit.AArch64.Types (Instr) import Data.Aeson (encode) @@ -33,6 +34,7 @@ data Options = Options , optScanNct :: !Bool , optNctDetail :: !Bool , optShowGhcRuntime :: !Bool + , optSymbol :: !(Maybe Text) } deriving (Eq, Show) optParser :: Parser Options @@ -85,6 +87,12 @@ optParser = Options ( long "show-ghc-runtime" <> help "Show GHC runtime patterns in NCT scan (hidden by default)" ) + <*> optional (strOption + ( long "symbol" + <> short 's' + <> metavar "SYMBOL" + <> help "Analyze only this symbol and its callees (NCT scan mode)" + )) optInfo :: ParserInfo Options optInfo = info (optParser <**> helper) @@ -107,14 +115,23 @@ main = do TIO.putStrLn $ "Parsed " <> T.pack (show n) <> " lines" exitSuccess else if optScanNct opts - then do - result <- scanNctFile (optInput opts) - case result of - Left err -> do - TIO.putStrLn $ "Error: " <> err - exitFailure - Right (lineMap, findings) -> - outputNct opts lineMap findings + then case optSymbol opts of + Just sym -> do + result <- scanNctFileForSymbol sym (optInput opts) + case result of + Left err -> do + TIO.putStrLn $ "Error: " <> err + exitFailure + Right ssr -> + outputNctSymbol opts ssr + Nothing -> do + result <- scanNctFile (optInput opts) + case result of + Left err -> do + TIO.putStrLn $ "Error: " <> err + exitFailure + Right (lineMap, findings) -> + outputNct opts lineMap findings else do -- Load taint config if provided mcfg <- case optTaintConfig opts of @@ -172,6 +189,32 @@ outputText opts ar = do then exitSuccess else exitFailure +-- | Output NCT scan results for a specific symbol and its callees. +outputNctSymbol :: Options -> SymbolScanResult -> IO () +outputNctSymbol opts ssr = do + let lineMap = ssrLineMap ssr + findings = ssrFindings ssr + showGhc = optShowGhcRuntime opts + isReal = not . isGhcRuntimeFinding lineMap + filterFindings = if showGhc then id else filter isReal + syms = [(sym, filterFindings fs) | (sym, fs) <- Map.toList findings] + realSyms = filter (not . null . snd) syms + total = sum (map (length . snd) realSyms) + if optNctDetail opts + then mapM_ (printNctDetail showGhc lineMap) realSyms + else mapM_ (printNctSummary showGhc lineMap) realSyms + if optQuiet opts + then pure () + else do + TIO.putStrLn "" + TIO.putStrLn $ "Root symbol: " <> ssrRootSymbol ssr + TIO.putStrLn $ "Reachable symbols: " <> T.pack (show (ssrReachable ssr)) + TIO.putStrLn $ "With findings: " <> T.pack (show (length realSyms)) + TIO.putStrLn $ "NCT findings: " <> T.pack (show total) + if total == 0 + then exitSuccess + else exitFailure + -- | Output NCT scan results. outputNct :: Options -> LineMap -> Map.Map Text [NctFinding] -> IO () outputNct opts lineMap findings = do diff --git a/lib/Audit/AArch64.hs b/lib/Audit/AArch64.hs @@ -42,6 +42,8 @@ module Audit.AArch64 ( -- * NCT scanner , scanNct , scanNctFile + , scanNctFileForSymbol + , SymbolScanResult(..) , NctReason(..) , NctFinding(..) -- ** GHC runtime classification @@ -49,6 +51,11 @@ module Audit.AArch64 ( , buildLineMap , isGhcRuntimeFinding , filterGhcRuntime + -- ** Call graph + , CallGraph + , buildCallGraph + , reachableSymbols + , symbolExists -- * Results , AuditResult(..) @@ -68,7 +75,14 @@ module Audit.AArch64 ( , showParseError ) where -import Audit.AArch64.CFG +import Audit.AArch64.CFG hiding (buildCallGraph) +import Audit.AArch64.CallGraph + ( CallGraph + , buildCallGraph + , reachableSymbols + , symbolExists + ) +import qualified Audit.AArch64.CallGraph as CG (buildCallGraph) import Audit.AArch64.Check import Audit.AArch64.NCT ( NctReason(..), NctFinding(..), scanNct @@ -82,6 +96,7 @@ import Audit.AArch64.Types import Data.Aeson (eitherDecodeStrict') import qualified Data.ByteString as BS import qualified Data.Map.Strict as Map +import qualified Data.Set as Set import Data.Text (Text) import qualified Data.Text as T import Data.Text.Encoding (decodeUtf8') @@ -179,3 +194,50 @@ scanNctFile path = do let lineMap = buildLineMap lns findings = scanNct lns in pure (Right (lineMap, findings)) + +-- | Result of symbol-focused NCT scan. +data SymbolScanResult = SymbolScanResult + { ssrRootSymbol :: !Text + -- ^ The requested root symbol + , ssrReachable :: !Int + -- ^ Number of reachable symbols (including root) + , ssrLineMap :: !LineMap + -- ^ Line map for GHC runtime classification + , ssrFindings :: !(Map.Map Text [NctFinding]) + -- ^ Findings filtered to reachable symbols only + } + +-- | Scan an assembly file for NCT instructions, focused on a specific symbol. +-- +-- Uses call graph analysis to find all symbols reachable from the given +-- root symbol and returns findings only for those symbols. +-- +-- Returns 'Left' if parsing fails or the symbol doesn't exist. +scanNctFileForSymbol + :: Text -- ^ Root symbol to analyze + -> FilePath -- ^ Assembly file path + -> IO (Either Text SymbolScanResult) +scanNctFileForSymbol rootSym path = do + bs <- BS.readFile path + case decodeUtf8' bs of + Left err -> pure (Left (T.pack (show err))) + Right src -> + case parseAsm src of + Left err -> pure (Left (T.pack (showParseError err))) + Right lns -> do + let callGraph = CG.buildCallGraph lns + if not (symbolExists rootSym callGraph) + then pure (Left ("symbol not found: " <> rootSym)) + else do + let reachable = reachableSymbols rootSym callGraph + lineMap = buildLineMap lns + allFindings = scanNct lns + -- Filter to only reachable symbols + filtered = Map.filterWithKey + (\sym _ -> Set.member sym reachable) allFindings + pure $ Right $ SymbolScanResult + { ssrRootSymbol = rootSym + , ssrReachable = Set.size reachable + , ssrLineMap = lineMap + , ssrFindings = filtered + } diff --git a/lib/Audit/AArch64/CallGraph.hs b/lib/Audit/AArch64/CallGraph.hs @@ -0,0 +1,91 @@ +{-# OPTIONS_HADDOCK prune #-} +{-# LANGUAGE OverloadedStrings #-} + +-- | +-- Module: Audit.AArch64.CallGraph +-- Copyright: (c) 2025 Jared Tobin +-- License: MIT +-- Maintainer: jared@ppad.tech +-- +-- Inter-procedural call graph construction for AArch64 assembly. +-- Used to find all functions reachable from a given symbol. + +module Audit.AArch64.CallGraph ( + -- * Call graph + CallGraph + , buildCallGraph + -- * Reachability + , reachableSymbols + , symbolExists + ) where + +import Audit.AArch64.CFG (isFunctionLabel) +import Audit.AArch64.Types (Instr(..), Line(..)) +import Data.Graph (Graph, Vertex, graphFromEdges, reachable) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.Set (Set) +import qualified Data.Set as Set +import Data.Text (Text) + +-- | Call graph: maps symbols to the set of symbols they call. +-- Includes lookup functions for graph traversal. +data CallGraph = CallGraph + { cgGraph :: !Graph + , cgNodeFromV :: !(Vertex -> ((), Text, [Text])) + , cgVertexFromK :: !(Text -> Maybe Vertex) + , cgSymbols :: !(Set Text) + } + +-- | Build a call graph from parsed assembly lines. +-- +-- Extracts function symbols and their call targets (bl instructions). +-- Does not resolve indirect calls (blr). +buildCallGraph :: [Line] -> CallGraph +buildCallGraph lns = CallGraph graph nodeFromV vertexFromK allSyms + where + (graph, nodeFromV, vertexFromK) = graphFromEdges edges + + -- Build map from symbol to its instructions. + -- State: (current symbol, accumulated map) + symInstrs :: Map Text [Instr] + symInstrs = snd $ foldl step ("<unknown>", Map.empty) lns + where + step (curSym, acc) ln = + let -- Update current symbol when we see a function label + sym = case lineLabel ln of + Just lbl | isFunctionLabel lbl -> lbl + _ -> curSym + -- Ensure symbol exists in map (even with no instructions) + acc' = Map.insertWith (++) sym [] acc + in case lineInstr ln of + Nothing -> (sym, acc') + Just i -> (sym, Map.insertWith (++) sym [i] acc') + + -- Extract call targets from instructions + callTargets :: Text -> [Text] + callTargets sym = case Map.lookup sym symInstrs of + Nothing -> [] + Just instrs -> [target | Bl target <- instrs] + + -- All symbols in the assembly + allSyms :: Set Text + allSyms = Map.keysSet symInstrs + + -- Build graph edges: (node data, key, [adjacent keys]) + edges :: [((), Text, [Text])] + edges = [((), sym, callTargets sym) | sym <- Set.toList allSyms] + +-- | Get all symbols reachable from a root symbol (including the root). +-- +-- Returns empty set if the root symbol doesn't exist. +reachableSymbols :: Text -> CallGraph -> Set Text +reachableSymbols root cg = case cgVertexFromK cg root of + Nothing -> Set.empty + Just v -> Set.fromList + [sym | v' <- reachable (cgGraph cg) v + , let (_, sym, _) = cgNodeFromV cg v'] + +-- | Check if a symbol exists in the call graph. +symbolExists :: Text -> CallGraph -> Bool +symbolExists sym cg = Set.member sym (cgSymbols cg) diff --git a/ppad-auditor.cabal b/ppad-auditor.cabal @@ -25,6 +25,7 @@ library -Wall exposed-modules: Audit.AArch64 + Audit.AArch64.CallGraph Audit.AArch64.Types Audit.AArch64.Parser Audit.AArch64.CFG diff --git a/test/Main.hs b/test/Main.hs @@ -3,6 +3,7 @@ module Main where import Audit.AArch64 +import Audit.AArch64.CallGraph import Audit.AArch64.Parser import Audit.AArch64.Taint import Audit.AArch64.Types @@ -23,6 +24,7 @@ main = defaultMain $ testGroup "ppad-auditor" [ , provenanceTests , taintConfigTests , nctTests + , callGraphTests ] -- Parser tests @@ -1282,3 +1284,92 @@ nctTests = testGroup "NCT" [ -- L1 is local, findings should stay with _foo assertEqual "L1 not a key" Nothing (Map.lookup "L1" m) ] + +-- Call graph tests + +callGraphTests :: TestTree +callGraphTests = testGroup "CallGraph" [ + testCase "symbolExists: existing symbol" $ do + let src = T.unlines + [ "_foo:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + assertEqual "symbol exists" True (symbolExists "_foo" cg) + + , testCase "symbolExists: missing symbol" $ do + let src = T.unlines + [ "_foo:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + assertEqual "symbol missing" False (symbolExists "_bar" cg) + + , testCase "reachableSymbols: direct call" $ do + let src = T.unlines + [ "_foo:" + , " bl _bar" + , " ret" + , "_bar:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + reachable = reachableSymbols "_foo" cg + assertEqual "foo reachable" True (Set.member "_foo" reachable) + assertEqual "bar reachable" True (Set.member "_bar" reachable) + assertEqual "count" 2 (Set.size reachable) + + , testCase "reachableSymbols: transitive call" $ do + let src = T.unlines + [ "_foo:" + , " bl _bar" + , " ret" + , "_bar:" + , " bl _baz" + , " ret" + , "_baz:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + reachable = reachableSymbols "_foo" cg + assertEqual "baz reachable from foo" True (Set.member "_baz" reachable) + assertEqual "count" 3 (Set.size reachable) + + , testCase "reachableSymbols: no callees" $ do + let src = T.unlines + [ "_foo:" + , " ret" + , "_bar:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + reachable = reachableSymbols "_foo" cg + assertEqual "only foo reachable" (Set.singleton "_foo") reachable + + , testCase "reachableSymbols: missing root returns empty" $ do + let src = T.unlines + [ "_foo:" + , " ret" + ] + case parseAsm src of + Left e -> assertFailure $ "parse failed: " ++ show e + Right lns -> do + let cg = buildCallGraph lns + reachable = reachableSymbols "_missing" cg + assertEqual "empty set" Set.empty reachable + ]