auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

Main.hs (16935B)


      1 {-# LANGUAGE OverloadedStrings #-}
      2 
      3 module Main where
      4 
      5 import Audit.AArch64
      6   ( RuntimeConfig(..)
      7   , ghcRuntime, genericRuntime
      8   , AuditResult(..), Violation(..), ViolationReason(..)
      9   , TaintConfig(..)
     10   , NctReason(..), NctFinding(..)
     11   , nctLine, nctInstr, nctReason
     12   , LineMap, buildLineMap
     13   , SymbolScanResult(..), scanNct
     14   , buildCallGraph, allSymbols
     15   , reachableSymbols, reachingSymbols
     16   , symbolExists
     17   , auditFile, auditFileInterProc
     18   , auditFileWithConfig
     19   , auditFileInterProcWithConfig
     20   , parseFile, regName, loadTaintConfig
     21   , scanNctFile
     22   )
     23 import Audit.AArch64.Parser (parseAsm)
     24 import Audit.AArch64.Types (Instr)
     25 import Data.Aeson (encode)
     26 import qualified Data.ByteString as BS
     27 import qualified Data.ByteString.Lazy.Char8 as BL
     28 import qualified Data.Map.Strict as Map
     29 import qualified Data.Set as Set
     30 import Data.Text (Text)
     31 import qualified Data.Text as T
     32 import Data.Text.Encoding (decodeUtf8')
     33 import qualified Data.Text.IO as TIO
     34 import Options.Applicative
     35 import System.Exit (exitFailure, exitSuccess)
     36 
     37 data Runtime = Haskell | Generic
     38   deriving (Eq, Show)
     39 
     40 data Options = Options
     41   { optInput          :: !FilePath
     42   , optJson           :: !Bool
     43   , optQuiet          :: !Bool
     44   , optInterProc      :: !Bool
     45   , optParseOnly      :: !Bool
     46   , optTaintConfig    :: !(Maybe FilePath)
     47   , optDisplayUnknown :: !Bool
     48   , optScanNct        :: !Bool
     49   , optNctDetail      :: !Bool
     50   , optShowRuntimePat :: !Bool
     51   , optSymbol         :: !(Maybe Text)
     52   , optZSymbol        :: !(Maybe Text)
     53   , optListSymbols    :: !Bool
     54   , optSymbolFilter   :: !(Maybe Text)
     55   , optCallers          :: !Bool
     56   , optAssumeSecPrivate :: !Bool
     57   , optRuntime          :: !Runtime
     58   } deriving (Eq, Show)
     59 
     60 runtimeReader :: ReadM Runtime
     61 runtimeReader = eitherReader $ \s -> case s of
     62   "haskell" -> Right Haskell
     63   "generic" -> Right Generic
     64   _         -> Left $
     65     "unknown runtime: " ++ s
     66     ++ " (expected haskell or generic)"
     67 
     68 optParser :: Parser Options
     69 optParser = Options
     70   <$> strOption
     71       ( long "input"
     72      <> short 'i'
     73      <> metavar "FILE"
     74      <> help "Input assembly file (.s)"
     75       )
     76   <*> switch
     77       ( long "json"
     78      <> short 'j'
     79      <> help "Output JSON format"
     80       )
     81   <*> switch
     82       ( long "quiet"
     83      <> short 'q'
     84      <> help "Suppress summary, only show violations"
     85       )
     86   <*> switch
     87       ( long "interproc"
     88      <> short 'p'
     89      <> help "Enable inter-procedural analysis"
     90       )
     91   <*> switch
     92       ( long "parse"
     93      <> help "Parse only, no analysis"
     94       )
     95   <*> optional (strOption
     96       ( long "taint-config"
     97      <> short 't'
     98      <> metavar "FILE"
     99      <> help "JSON file with per-function taint policies"
    100       ))
    101   <*> switch
    102       ( long "display-unknown"
    103      <> short 'u'
    104      <> help "Display unknown violations \
    105              \(only secret shown by default)"
    106       )
    107   <*> switch
    108       ( long "scan-nct"
    109      <> help "Scan for non-constant-time instructions \
    110              \(no taint analysis)"
    111       )
    112   <*> switch
    113       ( long "nct-detail"
    114      <> help "Show per-instruction details in NCT \
    115              \scan mode"
    116       )
    117   <*> switch
    118       ( long "show-runtime-patterns"
    119      <> long "show-ghc-runtime"
    120      <> help "Show runtime patterns in NCT scan \
    121              \(hidden by default)"
    122       )
    123   <*> optional (strOption
    124       ( long "symbol"
    125      <> short 's'
    126      <> metavar "SYMBOL"
    127      <> help "Analyze only this symbol and its callees \
    128              \(NCT scan mode)"
    129       ))
    130   <*> optional (strOption
    131       ( long "zsymbol"
    132      <> short 'z'
    133      <> metavar "SYMBOL"
    134      <> help "Human-readable symbol, auto z-encoded \
    135              \with _info$def \
    136              \(e.g., pkg-1.0:Mod.Sub:func)"
    137       ))
    138   <*> switch
    139       ( long "list-symbols"
    140      <> short 'l'
    141      <> help "List all function symbols in the \
    142              \assembly file"
    143       )
    144   <*> optional (strOption
    145       ( long "filter"
    146      <> short 'f'
    147      <> metavar "PATTERN"
    148      <> help "Filter symbols containing PATTERN \
    149              \(use with --list-symbols)"
    150       ))
    151   <*> switch
    152       ( long "callers"
    153      <> short 'c'
    154      <> help "Show callers instead of callees \
    155              \(use with --symbol)"
    156       )
    157   <*> switch
    158       ( long "assume-secondary-private"
    159      <> long "assume-stg-private"
    160      <> help "Treat untracked secondary stack \
    161              \slots as private (default: public)"
    162       )
    163   <*> option runtimeReader
    164       ( long "runtime"
    165      <> metavar "RUNTIME"
    166      <> value Haskell
    167      <> help "Runtime: haskell (default) or generic"
    168       )
    169 
    170 optInfo :: ParserInfo Options
    171 optInfo = info (optParser <**> helper)
    172   ( fullDesc
    173  <> progDesc "Audit AArch64 assembly for \
    174              \constant-time memory access"
    175  <> header "auditor - CT memory access auditor \
    176            \for AArch64"
    177   )
    178 
    179 -- | Select runtime configuration from CLI option.
    180 selectRuntime :: Runtime -> RuntimeConfig
    181 selectRuntime Haskell = ghcRuntime
    182 selectRuntime Generic = genericRuntime
    183 
    184 main :: IO ()
    185 main = do
    186   opts <- execParser optInfo
    187   let rt = selectRuntime (optRuntime opts)
    188   -- Compute effective symbol from --symbol or --zsymbol
    189   effSym <- case optZSymbol opts of
    190     Just zs -> case rtEncodeSymbol rt of
    191       Nothing -> do
    192         TIO.putStrLn
    193           "Error: --zsymbol requires a runtime \
    194           \with symbol encoding (use --runtime haskell)"
    195         exitFailure
    196       Just encoder -> case encoder zs of
    197         Left err -> do
    198           TIO.putStrLn $ "Error: " <> err
    199           exitFailure
    200         Right encoded -> pure (Just encoded)
    201     Nothing -> pure (optSymbol opts)
    202   let opts' = opts { optSymbol = effSym }
    203   if optListSymbols opts'
    204     then listSymbols rt opts'
    205     else if optParseOnly opts'
    206     then do
    207       result <- parseFile (optInput opts')
    208       case result of
    209         Left err -> do
    210           TIO.putStrLn $ "Error: " <> err
    211           exitFailure
    212         Right n -> do
    213           TIO.putStrLn $ "Parsed "
    214             <> T.pack (show n) <> " lines"
    215           exitSuccess
    216     else if optScanNct opts'
    217     then case optSymbol opts' of
    218       Just sym -> do
    219         result <- scanNctForSymbol rt opts' sym
    220         case result of
    221           Left err -> do
    222             TIO.putStrLn $ "Error: " <> err
    223             exitFailure
    224           Right ssr ->
    225             outputNctSymbol rt opts' ssr
    226       Nothing -> do
    227         result <- scanNctFile rt (optInput opts')
    228         case result of
    229           Left err -> do
    230             TIO.putStrLn $ "Error: " <> err
    231             exitFailure
    232           Right (lineMap, findings) ->
    233             outputNct rt opts' lineMap findings
    234     else do
    235       -- Load taint config if provided
    236       mcfg <- case optTaintConfig opts' of
    237         Nothing -> pure (Right emptyConfig)
    238         Just path -> loadTaintConfig path
    239       case mcfg of
    240         Left err -> do
    241           TIO.putStrLn $
    242             "Error loading taint config: " <> err
    243           exitFailure
    244         Right baseCfg -> do
    245           let assumeSec =
    246                 not (optAssumeSecPrivate opts')
    247               cfg = baseCfg
    248                 { tcAssumeStgPublic = assumeSec }
    249               auditor = selectAuditor rt opts' cfg
    250           result <- auditor (optInput opts')
    251           case result of
    252             Left err -> do
    253               TIO.putStrLn $ "Error: " <> err
    254               exitFailure
    255             Right ar ->
    256               if optJson opts'
    257                 then outputJson opts' ar
    258                 else outputText opts' ar
    259   where
    260     emptyConfig = TaintConfig Map.empty True
    261 
    262     selectAuditor rt opts cfg
    263       | needsConfig && optInterProc opts =
    264           auditFileInterProcWithConfig rt cfg
    265       | needsConfig =
    266           auditFileWithConfig rt cfg
    267       | optInterProc opts =
    268           auditFileInterProc rt
    269       | otherwise =
    270           auditFile rt
    271       where
    272         needsConfig =
    273           not (Map.null (tcPolicies cfg))
    274           || not (tcAssumeStgPublic cfg)
    275 
    276 outputJson :: Options -> AuditResult -> IO ()
    277 outputJson opts ar =
    278   let vs = filterViolations opts (arViolations ar)
    279   in  BL.putStrLn (encode vs)
    280 
    281 outputText :: Options -> AuditResult -> IO ()
    282 outputText opts ar = do
    283   let allVs = arViolations ar
    284       vs = filterViolations opts allVs
    285   mapM_ printViolation vs
    286   if optQuiet opts
    287     then pure ()
    288     else do
    289       TIO.putStrLn ""
    290       TIO.putStrLn $ "Lines checked:    "
    291         <> T.pack (show (arLinesChecked ar))
    292       TIO.putStrLn $ "Memory accesses:  "
    293         <> T.pack (show (arMemoryAccesses ar))
    294       TIO.putStrLn $ "Violations:       "
    295         <> T.pack (show (length vs))
    296       if not (optDisplayUnknown opts)
    297            && length vs < length allVs
    298         then TIO.putStrLn $ "  (hidden):       "
    299                <> T.pack (show
    300                     (length allVs - length vs))
    301                <> " unknown (use -u to show)"
    302         else pure ()
    303   if null vs
    304     then exitSuccess
    305     else exitFailure
    306 
    307 -- | List all function symbols in the assembly file.
    308 -- With --symbol, lists callees (or callers with --callers).
    309 listSymbols :: RuntimeConfig -> Options -> IO ()
    310 listSymbols rt opts = do
    311   bs <- BS.readFile (optInput opts)
    312   case decodeUtf8' bs of
    313     Left err -> do
    314       TIO.putStrLn $ "Error: " <> T.pack (show err)
    315       exitFailure
    316     Right src ->
    317       case parseAsm src of
    318         Left _ -> do
    319           TIO.putStrLn "Error: parse failed"
    320           exitFailure
    321         Right lns -> do
    322           let cg = buildCallGraph rt lns
    323           case optSymbol opts of
    324             Just sym
    325               | not (symbolExists sym cg) -> do
    326                   TIO.putStrLn $
    327                     "Error: symbol not found: " <> sym
    328                   exitFailure
    329               | optCallers opts -> do
    330                   let callers = Set.toAscList
    331                         (reachingSymbols sym cg)
    332                   mapM_ TIO.putStrLn callers
    333                   if optQuiet opts
    334                     then pure ()
    335                     else TIO.putStrLn $
    336                       "\n" <> T.pack
    337                         (show (length callers))
    338                       <> " symbols can reach " <> sym
    339               | otherwise -> do
    340                   let callees = Set.toAscList
    341                         (reachableSymbols sym cg)
    342                   mapM_ TIO.putStrLn callees
    343                   if optQuiet opts
    344                     then pure ()
    345                     else TIO.putStrLn $
    346                       "\n" <> T.pack
    347                         (show (length callees))
    348                       <> " symbols reachable from "
    349                       <> sym
    350             Nothing -> do
    351               let syms = Set.toAscList (allSymbols cg)
    352                   filtered = case optSymbolFilter opts of
    353                     Nothing -> syms
    354                     Just pat ->
    355                       filter (T.isInfixOf pat) syms
    356               mapM_ TIO.putStrLn filtered
    357               if optQuiet opts
    358                 then pure ()
    359                 else TIO.putStrLn $
    360                   "\n" <> T.pack
    361                     (show (length filtered))
    362                   <> " symbols"
    363 
    364 -- | NCT scan for a symbol (callees or callers based
    365 -- on options).
    366 scanNctForSymbol
    367   :: RuntimeConfig -> Options -> Text
    368   -> IO (Either Text SymbolScanResult)
    369 scanNctForSymbol rt opts rootSym = do
    370   bs <- BS.readFile (optInput opts)
    371   case decodeUtf8' bs of
    372     Left err -> pure (Left (T.pack (show err)))
    373     Right src ->
    374       case parseAsm src of
    375         Left _ -> pure (Left "parse failed")
    376         Right lns -> do
    377           let cg = buildCallGraph rt lns
    378           if not (symbolExists rootSym cg)
    379             then pure (Left
    380               ("symbol not found: " <> rootSym))
    381             else do
    382               let syms = if optCallers opts
    383                     then reachingSymbols rootSym cg
    384                     else reachableSymbols rootSym cg
    385                   lineMap = buildLineMap lns
    386                   allFindings = scanNct rt lns
    387                   filtered = Map.filterWithKey
    388                     (\sym _ -> Set.member sym syms)
    389                     allFindings
    390               pure $ Right $ SymbolScanResult
    391                 { ssrRootSymbol = rootSym
    392                 , ssrReachable  = Set.size syms
    393                 , ssrLineMap    = lineMap
    394                 , ssrFindings   = filtered
    395                 }
    396 
    397 -- | Output NCT scan results for a specific symbol
    398 -- and its callees.
    399 outputNctSymbol :: RuntimeConfig -> Options
    400                 -> SymbolScanResult -> IO ()
    401 outputNctSymbol rt opts ssr = do
    402   let lineMap = ssrLineMap ssr
    403       findings = ssrFindings ssr
    404       showRt = optShowRuntimePat opts
    405       filterFn = rtFilterNct rt
    406       isRt = filterFn lineMap
    407       filterFindings =
    408         if showRt then id else filter (not . isRt)
    409       syms = [ (sym, filterFindings fs)
    410              | (sym, fs) <- Map.toList findings ]
    411       realSyms = filter (not . null . snd) syms
    412       total = sum (map (length . snd) realSyms)
    413   if optNctDetail opts
    414     then mapM_ (printNctDetail showRt isRt) realSyms
    415     else mapM_ (printNctSummary showRt isRt) realSyms
    416   if optQuiet opts
    417     then pure ()
    418     else do
    419       TIO.putStrLn ""
    420       TIO.putStrLn $ "Root symbol:       "
    421         <> ssrRootSymbol ssr
    422       TIO.putStrLn $ "Reachable symbols: "
    423         <> T.pack (show (ssrReachable ssr))
    424       TIO.putStrLn $ "With findings:     "
    425         <> T.pack (show (length realSyms))
    426       TIO.putStrLn $ "NCT findings:      "
    427         <> T.pack (show total)
    428   if total == 0
    429     then exitSuccess
    430     else exitFailure
    431 
    432 -- | Output NCT scan results.
    433 outputNct :: RuntimeConfig -> Options -> LineMap
    434           -> Map.Map Text [NctFinding] -> IO ()
    435 outputNct rt opts lineMap findings = do
    436   let showRt = optShowRuntimePat opts
    437       filterFn = rtFilterNct rt
    438       isRt = filterFn lineMap
    439       filterFindings =
    440         if showRt then id else filter (not . isRt)
    441       syms = [ (sym, filterFindings fs)
    442              | (sym, fs) <- Map.toList findings ]
    443       realSyms = filter (not . null . snd) syms
    444       total = sum (map (length . snd) realSyms)
    445   if optNctDetail opts
    446     then mapM_ (printNctDetail showRt isRt) realSyms
    447     else mapM_ (printNctSummary showRt isRt) realSyms
    448   if optQuiet opts
    449     then pure ()
    450     else do
    451       TIO.putStrLn ""
    452       TIO.putStrLn $ "Functions scanned: "
    453         <> T.pack (show (length realSyms))
    454       TIO.putStrLn $ "NCT findings:      "
    455         <> T.pack (show total)
    456   if total == 0
    457     then exitSuccess
    458     else exitFailure
    459 
    460 printNctSummary :: Bool -> (NctFinding -> Bool)
    461                 -> (Text, [NctFinding]) -> IO ()
    462 printNctSummary showRt isRt (sym, fs) = do
    463   TIO.putStrLn $ sym <> ": "
    464     <> T.pack (show (length fs))
    465   mapM_ (printFindingIndented showRt isRt) fs
    466 
    467 printFindingIndented :: Bool -> (NctFinding -> Bool)
    468                      -> NctFinding -> IO ()
    469 printFindingIndented showRt isRt f =
    470   let rtMatch = isRt f
    471       content = T.pack (show (nctLine f)) <> ": "
    472         <> nctReasonText (nctReason f) <> ": "
    473         <> instrText (nctInstr f)
    474       line = if showRt && rtMatch
    475              then "  (runtime) " <> content
    476              else "  " <> content
    477   in  TIO.putStrLn line
    478 
    479 printNctDetail :: Bool -> (NctFinding -> Bool)
    480                -> (Text, [NctFinding]) -> IO ()
    481 printNctDetail showRt isRt (sym, fs) =
    482   mapM_ (printFinding showRt isRt sym) fs
    483 
    484 printFinding :: Bool -> (NctFinding -> Bool)
    485              -> Text -> NctFinding -> IO ()
    486 printFinding showRt isRt sym f =
    487   let rtMatch = isRt f
    488       content = sym <> ":"
    489         <> T.pack (show (nctLine f)) <> ": "
    490         <> nctReasonText (nctReason f) <> ": "
    491         <> instrText (nctInstr f)
    492       line = if showRt && rtMatch
    493              then "(runtime) " <> content
    494              else content
    495   in  TIO.putStrLn line
    496 
    497 nctReasonText :: NctReason -> Text
    498 nctReasonText r = case r of
    499   CondBranch     -> "cond-branch"
    500   IndirectBranch -> "indirect-branch"
    501   Div            -> "div"
    502   RegIndexAddr   -> "reg-index"
    503 
    504 instrText :: Instr -> Text
    505 instrText instr = T.pack (show instr)
    506 
    507 -- | Filter violations based on options.
    508 -- By default, only secret violations are shown.
    509 filterViolations :: Options -> [Violation]
    510                  -> [Violation]
    511 filterViolations opts
    512   | optDisplayUnknown opts = id
    513   | otherwise = filter (isSecretViolation . vReason)
    514 
    515 -- | Check if a violation reason is secret (not unknown).
    516 isSecretViolation :: ViolationReason -> Bool
    517 isSecretViolation r = case r of
    518   SecretBase _   -> True
    519   SecretIndex _  -> True
    520   UnknownBase _  -> False
    521   UnknownIndex _ -> False
    522   NonConstOffset -> True
    523 
    524 printViolation :: Violation -> IO ()
    525 printViolation v = TIO.putStrLn $
    526   vSymbol v <> ":" <> T.pack (show (vLine v))
    527   <> ": " <> reasonText (vReason v)
    528 
    529 reasonText :: ViolationReason -> Text
    530 reasonText r = case r of
    531   SecretBase reg   ->
    532     "secret base register " <> regName reg
    533   SecretIndex reg  ->
    534     "secret index register " <> regName reg
    535   UnknownBase reg  ->
    536     "unknown base register " <> regName reg
    537   UnknownIndex reg ->
    538     "unknown index register " <> regName reg
    539   NonConstOffset   ->
    540     "non-constant offset without masking"