commit b9e0153adc2546f8b29a243149b6f0349c9aef01
parent 50e1c801e8d53f2da47b11b8b858a6978f800104
Author: Jared Tobin <jared@jtobin.io>
Date: Fri, 13 Feb 2026 19:19:10 +0400
feat: add symbol-focused NCT scan with call graph analysis (IMPL24)
Adds --symbol/-s option to filter NCT scan to a specific symbol and
its transitive callees. Uses Data.Graph for reachability computation.
- New CallGraph module with buildCallGraph, reachableSymbols, symbolExists
- scanNctFileForSymbol returns SymbolScanResult with filtered findings
- CLI shows root symbol, reachable count, and findings summary
- 6 new tests for call graph functionality
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat:
5 files changed, 298 insertions(+), 10 deletions(-)
diff --git a/app/Main.hs b/app/Main.hs
@@ -7,10 +7,11 @@ import Audit.AArch64
, TaintConfig(..)
, NctReason(..), NctFinding(..), nctLine, nctInstr, nctReason
, LineMap, isGhcRuntimeFinding
+ , SymbolScanResult(..)
, auditFile, auditFileInterProc
, auditFileWithConfig, auditFileInterProcWithConfig
, parseFile, regName, loadTaintConfig
- , scanNctFile
+ , scanNctFile, scanNctFileForSymbol
)
import Audit.AArch64.Types (Instr)
import Data.Aeson (encode)
@@ -33,6 +34,7 @@ data Options = Options
, optScanNct :: !Bool
, optNctDetail :: !Bool
, optShowGhcRuntime :: !Bool
+ , optSymbol :: !(Maybe Text)
} deriving (Eq, Show)
optParser :: Parser Options
@@ -85,6 +87,12 @@ optParser = Options
( long "show-ghc-runtime"
<> help "Show GHC runtime patterns in NCT scan (hidden by default)"
)
+ <*> optional (strOption
+ ( long "symbol"
+ <> short 's'
+ <> metavar "SYMBOL"
+ <> help "Analyze only this symbol and its callees (NCT scan mode)"
+ ))
optInfo :: ParserInfo Options
optInfo = info (optParser <**> helper)
@@ -107,14 +115,23 @@ main = do
TIO.putStrLn $ "Parsed " <> T.pack (show n) <> " lines"
exitSuccess
else if optScanNct opts
- then do
- result <- scanNctFile (optInput opts)
- case result of
- Left err -> do
- TIO.putStrLn $ "Error: " <> err
- exitFailure
- Right (lineMap, findings) ->
- outputNct opts lineMap findings
+ then case optSymbol opts of
+ Just sym -> do
+ result <- scanNctFileForSymbol sym (optInput opts)
+ case result of
+ Left err -> do
+ TIO.putStrLn $ "Error: " <> err
+ exitFailure
+ Right ssr ->
+ outputNctSymbol opts ssr
+ Nothing -> do
+ result <- scanNctFile (optInput opts)
+ case result of
+ Left err -> do
+ TIO.putStrLn $ "Error: " <> err
+ exitFailure
+ Right (lineMap, findings) ->
+ outputNct opts lineMap findings
else do
-- Load taint config if provided
mcfg <- case optTaintConfig opts of
@@ -172,6 +189,32 @@ outputText opts ar = do
then exitSuccess
else exitFailure
+-- | Output NCT scan results for a specific symbol and its callees.
+outputNctSymbol :: Options -> SymbolScanResult -> IO ()
+outputNctSymbol opts ssr = do
+ let lineMap = ssrLineMap ssr
+ findings = ssrFindings ssr
+ showGhc = optShowGhcRuntime opts
+ isReal = not . isGhcRuntimeFinding lineMap
+ filterFindings = if showGhc then id else filter isReal
+ syms = [(sym, filterFindings fs) | (sym, fs) <- Map.toList findings]
+ realSyms = filter (not . null . snd) syms
+ total = sum (map (length . snd) realSyms)
+ if optNctDetail opts
+ then mapM_ (printNctDetail showGhc lineMap) realSyms
+ else mapM_ (printNctSummary showGhc lineMap) realSyms
+ if optQuiet opts
+ then pure ()
+ else do
+ TIO.putStrLn ""
+ TIO.putStrLn $ "Root symbol: " <> ssrRootSymbol ssr
+ TIO.putStrLn $ "Reachable symbols: " <> T.pack (show (ssrReachable ssr))
+ TIO.putStrLn $ "With findings: " <> T.pack (show (length realSyms))
+ TIO.putStrLn $ "NCT findings: " <> T.pack (show total)
+ if total == 0
+ then exitSuccess
+ else exitFailure
+
-- | Output NCT scan results.
outputNct :: Options -> LineMap -> Map.Map Text [NctFinding] -> IO ()
outputNct opts lineMap findings = do
diff --git a/lib/Audit/AArch64.hs b/lib/Audit/AArch64.hs
@@ -42,6 +42,8 @@ module Audit.AArch64 (
-- * NCT scanner
, scanNct
, scanNctFile
+ , scanNctFileForSymbol
+ , SymbolScanResult(..)
, NctReason(..)
, NctFinding(..)
-- ** GHC runtime classification
@@ -49,6 +51,11 @@ module Audit.AArch64 (
, buildLineMap
, isGhcRuntimeFinding
, filterGhcRuntime
+ -- ** Call graph
+ , CallGraph
+ , buildCallGraph
+ , reachableSymbols
+ , symbolExists
-- * Results
, AuditResult(..)
@@ -68,7 +75,14 @@ module Audit.AArch64 (
, showParseError
) where
-import Audit.AArch64.CFG
+import Audit.AArch64.CFG hiding (buildCallGraph)
+import Audit.AArch64.CallGraph
+ ( CallGraph
+ , buildCallGraph
+ , reachableSymbols
+ , symbolExists
+ )
+import qualified Audit.AArch64.CallGraph as CG (buildCallGraph)
import Audit.AArch64.Check
import Audit.AArch64.NCT
( NctReason(..), NctFinding(..), scanNct
@@ -82,6 +96,7 @@ import Audit.AArch64.Types
import Data.Aeson (eitherDecodeStrict')
import qualified Data.ByteString as BS
import qualified Data.Map.Strict as Map
+import qualified Data.Set as Set
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8')
@@ -179,3 +194,50 @@ scanNctFile path = do
let lineMap = buildLineMap lns
findings = scanNct lns
in pure (Right (lineMap, findings))
+
+-- | Result of symbol-focused NCT scan.
+data SymbolScanResult = SymbolScanResult
+ { ssrRootSymbol :: !Text
+ -- ^ The requested root symbol
+ , ssrReachable :: !Int
+ -- ^ Number of reachable symbols (including root)
+ , ssrLineMap :: !LineMap
+ -- ^ Line map for GHC runtime classification
+ , ssrFindings :: !(Map.Map Text [NctFinding])
+ -- ^ Findings filtered to reachable symbols only
+ }
+
+-- | Scan an assembly file for NCT instructions, focused on a specific symbol.
+--
+-- Uses call graph analysis to find all symbols reachable from the given
+-- root symbol and returns findings only for those symbols.
+--
+-- Returns 'Left' if parsing fails or the symbol doesn't exist.
+scanNctFileForSymbol
+ :: Text -- ^ Root symbol to analyze
+ -> FilePath -- ^ Assembly file path
+ -> IO (Either Text SymbolScanResult)
+scanNctFileForSymbol rootSym path = do
+ bs <- BS.readFile path
+ case decodeUtf8' bs of
+ Left err -> pure (Left (T.pack (show err)))
+ Right src ->
+ case parseAsm src of
+ Left err -> pure (Left (T.pack (showParseError err)))
+ Right lns -> do
+ let callGraph = CG.buildCallGraph lns
+ if not (symbolExists rootSym callGraph)
+ then pure (Left ("symbol not found: " <> rootSym))
+ else do
+ let reachable = reachableSymbols rootSym callGraph
+ lineMap = buildLineMap lns
+ allFindings = scanNct lns
+ -- Filter to only reachable symbols
+ filtered = Map.filterWithKey
+ (\sym _ -> Set.member sym reachable) allFindings
+ pure $ Right $ SymbolScanResult
+ { ssrRootSymbol = rootSym
+ , ssrReachable = Set.size reachable
+ , ssrLineMap = lineMap
+ , ssrFindings = filtered
+ }
diff --git a/lib/Audit/AArch64/CallGraph.hs b/lib/Audit/AArch64/CallGraph.hs
@@ -0,0 +1,91 @@
+{-# OPTIONS_HADDOCK prune #-}
+{-# LANGUAGE OverloadedStrings #-}
+
+-- |
+-- Module: Audit.AArch64.CallGraph
+-- Copyright: (c) 2025 Jared Tobin
+-- License: MIT
+-- Maintainer: jared@ppad.tech
+--
+-- Inter-procedural call graph construction for AArch64 assembly.
+-- Used to find all functions reachable from a given symbol.
+
+module Audit.AArch64.CallGraph (
+ -- * Call graph
+ CallGraph
+ , buildCallGraph
+ -- * Reachability
+ , reachableSymbols
+ , symbolExists
+ ) where
+
+import Audit.AArch64.CFG (isFunctionLabel)
+import Audit.AArch64.Types (Instr(..), Line(..))
+import Data.Graph (Graph, Vertex, graphFromEdges, reachable)
+import Data.Map.Strict (Map)
+import qualified Data.Map.Strict as Map
+import Data.Set (Set)
+import qualified Data.Set as Set
+import Data.Text (Text)
+
+-- | Call graph: maps symbols to the set of symbols they call.
+-- Includes lookup functions for graph traversal.
+data CallGraph = CallGraph
+ { cgGraph :: !Graph
+ , cgNodeFromV :: !(Vertex -> ((), Text, [Text]))
+ , cgVertexFromK :: !(Text -> Maybe Vertex)
+ , cgSymbols :: !(Set Text)
+ }
+
+-- | Build a call graph from parsed assembly lines.
+--
+-- Extracts function symbols and their call targets (bl instructions).
+-- Does not resolve indirect calls (blr).
+buildCallGraph :: [Line] -> CallGraph
+buildCallGraph lns = CallGraph graph nodeFromV vertexFromK allSyms
+ where
+ (graph, nodeFromV, vertexFromK) = graphFromEdges edges
+
+ -- Build map from symbol to its instructions.
+ -- State: (current symbol, accumulated map)
+ symInstrs :: Map Text [Instr]
+ symInstrs = snd $ foldl step ("<unknown>", Map.empty) lns
+ where
+ step (curSym, acc) ln =
+ let -- Update current symbol when we see a function label
+ sym = case lineLabel ln of
+ Just lbl | isFunctionLabel lbl -> lbl
+ _ -> curSym
+ -- Ensure symbol exists in map (even with no instructions)
+ acc' = Map.insertWith (++) sym [] acc
+ in case lineInstr ln of
+ Nothing -> (sym, acc')
+ Just i -> (sym, Map.insertWith (++) sym [i] acc')
+
+ -- Extract call targets from instructions
+ callTargets :: Text -> [Text]
+ callTargets sym = case Map.lookup sym symInstrs of
+ Nothing -> []
+ Just instrs -> [target | Bl target <- instrs]
+
+ -- All symbols in the assembly
+ allSyms :: Set Text
+ allSyms = Map.keysSet symInstrs
+
+ -- Build graph edges: (node data, key, [adjacent keys])
+ edges :: [((), Text, [Text])]
+ edges = [((), sym, callTargets sym) | sym <- Set.toList allSyms]
+
+-- | Get all symbols reachable from a root symbol (including the root).
+--
+-- Returns empty set if the root symbol doesn't exist.
+reachableSymbols :: Text -> CallGraph -> Set Text
+reachableSymbols root cg = case cgVertexFromK cg root of
+ Nothing -> Set.empty
+ Just v -> Set.fromList
+ [sym | v' <- reachable (cgGraph cg) v
+ , let (_, sym, _) = cgNodeFromV cg v']
+
+-- | Check if a symbol exists in the call graph.
+symbolExists :: Text -> CallGraph -> Bool
+symbolExists sym cg = Set.member sym (cgSymbols cg)
diff --git a/ppad-auditor.cabal b/ppad-auditor.cabal
@@ -25,6 +25,7 @@ library
-Wall
exposed-modules:
Audit.AArch64
+ Audit.AArch64.CallGraph
Audit.AArch64.Types
Audit.AArch64.Parser
Audit.AArch64.CFG
diff --git a/test/Main.hs b/test/Main.hs
@@ -3,6 +3,7 @@
module Main where
import Audit.AArch64
+import Audit.AArch64.CallGraph
import Audit.AArch64.Parser
import Audit.AArch64.Taint
import Audit.AArch64.Types
@@ -23,6 +24,7 @@ main = defaultMain $ testGroup "ppad-auditor" [
, provenanceTests
, taintConfigTests
, nctTests
+ , callGraphTests
]
-- Parser tests
@@ -1282,3 +1284,92 @@ nctTests = testGroup "NCT" [
-- L1 is local, findings should stay with _foo
assertEqual "L1 not a key" Nothing (Map.lookup "L1" m)
]
+
+-- Call graph tests
+
+callGraphTests :: TestTree
+callGraphTests = testGroup "CallGraph" [
+ testCase "symbolExists: existing symbol" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ assertEqual "symbol exists" True (symbolExists "_foo" cg)
+
+ , testCase "symbolExists: missing symbol" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ assertEqual "symbol missing" False (symbolExists "_bar" cg)
+
+ , testCase "reachableSymbols: direct call" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " bl _bar"
+ , " ret"
+ , "_bar:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ reachable = reachableSymbols "_foo" cg
+ assertEqual "foo reachable" True (Set.member "_foo" reachable)
+ assertEqual "bar reachable" True (Set.member "_bar" reachable)
+ assertEqual "count" 2 (Set.size reachable)
+
+ , testCase "reachableSymbols: transitive call" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " bl _bar"
+ , " ret"
+ , "_bar:"
+ , " bl _baz"
+ , " ret"
+ , "_baz:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ reachable = reachableSymbols "_foo" cg
+ assertEqual "baz reachable from foo" True (Set.member "_baz" reachable)
+ assertEqual "count" 3 (Set.size reachable)
+
+ , testCase "reachableSymbols: no callees" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " ret"
+ , "_bar:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ reachable = reachableSymbols "_foo" cg
+ assertEqual "only foo reachable" (Set.singleton "_foo") reachable
+
+ , testCase "reachableSymbols: missing root returns empty" $ do
+ let src = T.unlines
+ [ "_foo:"
+ , " ret"
+ ]
+ case parseAsm src of
+ Left e -> assertFailure $ "parse failed: " ++ show e
+ Right lns -> do
+ let cg = buildCallGraph lns
+ reachable = reachableSymbols "_missing" cg
+ assertEqual "empty set" Set.empty reachable
+ ]