Main.hs (16935B)
1 {-# LANGUAGE OverloadedStrings #-} 2 3 module Main where 4 5 import Audit.AArch64 6 ( RuntimeConfig(..) 7 , ghcRuntime, genericRuntime 8 , AuditResult(..), Violation(..), ViolationReason(..) 9 , TaintConfig(..) 10 , NctReason(..), NctFinding(..) 11 , nctLine, nctInstr, nctReason 12 , LineMap, buildLineMap 13 , SymbolScanResult(..), scanNct 14 , buildCallGraph, allSymbols 15 , reachableSymbols, reachingSymbols 16 , symbolExists 17 , auditFile, auditFileInterProc 18 , auditFileWithConfig 19 , auditFileInterProcWithConfig 20 , parseFile, regName, loadTaintConfig 21 , scanNctFile 22 ) 23 import Audit.AArch64.Parser (parseAsm) 24 import Audit.AArch64.Types (Instr) 25 import Data.Aeson (encode) 26 import qualified Data.ByteString as BS 27 import qualified Data.ByteString.Lazy.Char8 as BL 28 import qualified Data.Map.Strict as Map 29 import qualified Data.Set as Set 30 import Data.Text (Text) 31 import qualified Data.Text as T 32 import Data.Text.Encoding (decodeUtf8') 33 import qualified Data.Text.IO as TIO 34 import Options.Applicative 35 import System.Exit (exitFailure, exitSuccess) 36 37 data Runtime = Haskell | Generic 38 deriving (Eq, Show) 39 40 data Options = Options 41 { optInput :: !FilePath 42 , optJson :: !Bool 43 , optQuiet :: !Bool 44 , optInterProc :: !Bool 45 , optParseOnly :: !Bool 46 , optTaintConfig :: !(Maybe FilePath) 47 , optDisplayUnknown :: !Bool 48 , optScanNct :: !Bool 49 , optNctDetail :: !Bool 50 , optShowRuntimePat :: !Bool 51 , optSymbol :: !(Maybe Text) 52 , optZSymbol :: !(Maybe Text) 53 , optListSymbols :: !Bool 54 , optSymbolFilter :: !(Maybe Text) 55 , optCallers :: !Bool 56 , optAssumeSecPrivate :: !Bool 57 , optRuntime :: !Runtime 58 } deriving (Eq, Show) 59 60 runtimeReader :: ReadM Runtime 61 runtimeReader = eitherReader $ \s -> case s of 62 "haskell" -> Right Haskell 63 "generic" -> Right Generic 64 _ -> Left $ 65 "unknown runtime: " ++ s 66 ++ " (expected haskell or generic)" 67 68 optParser :: Parser Options 69 optParser = Options 70 <$> strOption 71 ( long "input" 72 <> short 'i' 73 <> metavar "FILE" 74 <> help "Input assembly file (.s)" 75 ) 76 <*> switch 77 ( long "json" 78 <> short 'j' 79 <> help "Output JSON format" 80 ) 81 <*> switch 82 ( long "quiet" 83 <> short 'q' 84 <> help "Suppress summary, only show violations" 85 ) 86 <*> switch 87 ( long "interproc" 88 <> short 'p' 89 <> help "Enable inter-procedural analysis" 90 ) 91 <*> switch 92 ( long "parse" 93 <> help "Parse only, no analysis" 94 ) 95 <*> optional (strOption 96 ( long "taint-config" 97 <> short 't' 98 <> metavar "FILE" 99 <> help "JSON file with per-function taint policies" 100 )) 101 <*> switch 102 ( long "display-unknown" 103 <> short 'u' 104 <> help "Display unknown violations \ 105 \(only secret shown by default)" 106 ) 107 <*> switch 108 ( long "scan-nct" 109 <> help "Scan for non-constant-time instructions \ 110 \(no taint analysis)" 111 ) 112 <*> switch 113 ( long "nct-detail" 114 <> help "Show per-instruction details in NCT \ 115 \scan mode" 116 ) 117 <*> switch 118 ( long "show-runtime-patterns" 119 <> long "show-ghc-runtime" 120 <> help "Show runtime patterns in NCT scan \ 121 \(hidden by default)" 122 ) 123 <*> optional (strOption 124 ( long "symbol" 125 <> short 's' 126 <> metavar "SYMBOL" 127 <> help "Analyze only this symbol and its callees \ 128 \(NCT scan mode)" 129 )) 130 <*> optional (strOption 131 ( long "zsymbol" 132 <> short 'z' 133 <> metavar "SYMBOL" 134 <> help "Human-readable symbol, auto z-encoded \ 135 \with _info$def \ 136 \(e.g., pkg-1.0:Mod.Sub:func)" 137 )) 138 <*> switch 139 ( long "list-symbols" 140 <> short 'l' 141 <> help "List all function symbols in the \ 142 \assembly file" 143 ) 144 <*> optional (strOption 145 ( long "filter" 146 <> short 'f' 147 <> metavar "PATTERN" 148 <> help "Filter symbols containing PATTERN \ 149 \(use with --list-symbols)" 150 )) 151 <*> switch 152 ( long "callers" 153 <> short 'c' 154 <> help "Show callers instead of callees \ 155 \(use with --symbol)" 156 ) 157 <*> switch 158 ( long "assume-secondary-private" 159 <> long "assume-stg-private" 160 <> help "Treat untracked secondary stack \ 161 \slots as private (default: public)" 162 ) 163 <*> option runtimeReader 164 ( long "runtime" 165 <> metavar "RUNTIME" 166 <> value Haskell 167 <> help "Runtime: haskell (default) or generic" 168 ) 169 170 optInfo :: ParserInfo Options 171 optInfo = info (optParser <**> helper) 172 ( fullDesc 173 <> progDesc "Audit AArch64 assembly for \ 174 \constant-time memory access" 175 <> header "auditor - CT memory access auditor \ 176 \for AArch64" 177 ) 178 179 -- | Select runtime configuration from CLI option. 180 selectRuntime :: Runtime -> RuntimeConfig 181 selectRuntime Haskell = ghcRuntime 182 selectRuntime Generic = genericRuntime 183 184 main :: IO () 185 main = do 186 opts <- execParser optInfo 187 let rt = selectRuntime (optRuntime opts) 188 -- Compute effective symbol from --symbol or --zsymbol 189 effSym <- case optZSymbol opts of 190 Just zs -> case rtEncodeSymbol rt of 191 Nothing -> do 192 TIO.putStrLn 193 "Error: --zsymbol requires a runtime \ 194 \with symbol encoding (use --runtime haskell)" 195 exitFailure 196 Just encoder -> case encoder zs of 197 Left err -> do 198 TIO.putStrLn $ "Error: " <> err 199 exitFailure 200 Right encoded -> pure (Just encoded) 201 Nothing -> pure (optSymbol opts) 202 let opts' = opts { optSymbol = effSym } 203 if optListSymbols opts' 204 then listSymbols rt opts' 205 else if optParseOnly opts' 206 then do 207 result <- parseFile (optInput opts') 208 case result of 209 Left err -> do 210 TIO.putStrLn $ "Error: " <> err 211 exitFailure 212 Right n -> do 213 TIO.putStrLn $ "Parsed " 214 <> T.pack (show n) <> " lines" 215 exitSuccess 216 else if optScanNct opts' 217 then case optSymbol opts' of 218 Just sym -> do 219 result <- scanNctForSymbol rt opts' sym 220 case result of 221 Left err -> do 222 TIO.putStrLn $ "Error: " <> err 223 exitFailure 224 Right ssr -> 225 outputNctSymbol rt opts' ssr 226 Nothing -> do 227 result <- scanNctFile rt (optInput opts') 228 case result of 229 Left err -> do 230 TIO.putStrLn $ "Error: " <> err 231 exitFailure 232 Right (lineMap, findings) -> 233 outputNct rt opts' lineMap findings 234 else do 235 -- Load taint config if provided 236 mcfg <- case optTaintConfig opts' of 237 Nothing -> pure (Right emptyConfig) 238 Just path -> loadTaintConfig path 239 case mcfg of 240 Left err -> do 241 TIO.putStrLn $ 242 "Error loading taint config: " <> err 243 exitFailure 244 Right baseCfg -> do 245 let assumeSec = 246 not (optAssumeSecPrivate opts') 247 cfg = baseCfg 248 { tcAssumeStgPublic = assumeSec } 249 auditor = selectAuditor rt opts' cfg 250 result <- auditor (optInput opts') 251 case result of 252 Left err -> do 253 TIO.putStrLn $ "Error: " <> err 254 exitFailure 255 Right ar -> 256 if optJson opts' 257 then outputJson opts' ar 258 else outputText opts' ar 259 where 260 emptyConfig = TaintConfig Map.empty True 261 262 selectAuditor rt opts cfg 263 | needsConfig && optInterProc opts = 264 auditFileInterProcWithConfig rt cfg 265 | needsConfig = 266 auditFileWithConfig rt cfg 267 | optInterProc opts = 268 auditFileInterProc rt 269 | otherwise = 270 auditFile rt 271 where 272 needsConfig = 273 not (Map.null (tcPolicies cfg)) 274 || not (tcAssumeStgPublic cfg) 275 276 outputJson :: Options -> AuditResult -> IO () 277 outputJson opts ar = 278 let vs = filterViolations opts (arViolations ar) 279 in BL.putStrLn (encode vs) 280 281 outputText :: Options -> AuditResult -> IO () 282 outputText opts ar = do 283 let allVs = arViolations ar 284 vs = filterViolations opts allVs 285 mapM_ printViolation vs 286 if optQuiet opts 287 then pure () 288 else do 289 TIO.putStrLn "" 290 TIO.putStrLn $ "Lines checked: " 291 <> T.pack (show (arLinesChecked ar)) 292 TIO.putStrLn $ "Memory accesses: " 293 <> T.pack (show (arMemoryAccesses ar)) 294 TIO.putStrLn $ "Violations: " 295 <> T.pack (show (length vs)) 296 if not (optDisplayUnknown opts) 297 && length vs < length allVs 298 then TIO.putStrLn $ " (hidden): " 299 <> T.pack (show 300 (length allVs - length vs)) 301 <> " unknown (use -u to show)" 302 else pure () 303 if null vs 304 then exitSuccess 305 else exitFailure 306 307 -- | List all function symbols in the assembly file. 308 -- With --symbol, lists callees (or callers with --callers). 309 listSymbols :: RuntimeConfig -> Options -> IO () 310 listSymbols rt opts = do 311 bs <- BS.readFile (optInput opts) 312 case decodeUtf8' bs of 313 Left err -> do 314 TIO.putStrLn $ "Error: " <> T.pack (show err) 315 exitFailure 316 Right src -> 317 case parseAsm src of 318 Left _ -> do 319 TIO.putStrLn "Error: parse failed" 320 exitFailure 321 Right lns -> do 322 let cg = buildCallGraph rt lns 323 case optSymbol opts of 324 Just sym 325 | not (symbolExists sym cg) -> do 326 TIO.putStrLn $ 327 "Error: symbol not found: " <> sym 328 exitFailure 329 | optCallers opts -> do 330 let callers = Set.toAscList 331 (reachingSymbols sym cg) 332 mapM_ TIO.putStrLn callers 333 if optQuiet opts 334 then pure () 335 else TIO.putStrLn $ 336 "\n" <> T.pack 337 (show (length callers)) 338 <> " symbols can reach " <> sym 339 | otherwise -> do 340 let callees = Set.toAscList 341 (reachableSymbols sym cg) 342 mapM_ TIO.putStrLn callees 343 if optQuiet opts 344 then pure () 345 else TIO.putStrLn $ 346 "\n" <> T.pack 347 (show (length callees)) 348 <> " symbols reachable from " 349 <> sym 350 Nothing -> do 351 let syms = Set.toAscList (allSymbols cg) 352 filtered = case optSymbolFilter opts of 353 Nothing -> syms 354 Just pat -> 355 filter (T.isInfixOf pat) syms 356 mapM_ TIO.putStrLn filtered 357 if optQuiet opts 358 then pure () 359 else TIO.putStrLn $ 360 "\n" <> T.pack 361 (show (length filtered)) 362 <> " symbols" 363 364 -- | NCT scan for a symbol (callees or callers based 365 -- on options). 366 scanNctForSymbol 367 :: RuntimeConfig -> Options -> Text 368 -> IO (Either Text SymbolScanResult) 369 scanNctForSymbol rt opts rootSym = do 370 bs <- BS.readFile (optInput opts) 371 case decodeUtf8' bs of 372 Left err -> pure (Left (T.pack (show err))) 373 Right src -> 374 case parseAsm src of 375 Left _ -> pure (Left "parse failed") 376 Right lns -> do 377 let cg = buildCallGraph rt lns 378 if not (symbolExists rootSym cg) 379 then pure (Left 380 ("symbol not found: " <> rootSym)) 381 else do 382 let syms = if optCallers opts 383 then reachingSymbols rootSym cg 384 else reachableSymbols rootSym cg 385 lineMap = buildLineMap lns 386 allFindings = scanNct rt lns 387 filtered = Map.filterWithKey 388 (\sym _ -> Set.member sym syms) 389 allFindings 390 pure $ Right $ SymbolScanResult 391 { ssrRootSymbol = rootSym 392 , ssrReachable = Set.size syms 393 , ssrLineMap = lineMap 394 , ssrFindings = filtered 395 } 396 397 -- | Output NCT scan results for a specific symbol 398 -- and its callees. 399 outputNctSymbol :: RuntimeConfig -> Options 400 -> SymbolScanResult -> IO () 401 outputNctSymbol rt opts ssr = do 402 let lineMap = ssrLineMap ssr 403 findings = ssrFindings ssr 404 showRt = optShowRuntimePat opts 405 filterFn = rtFilterNct rt 406 isRt = filterFn lineMap 407 filterFindings = 408 if showRt then id else filter (not . isRt) 409 syms = [ (sym, filterFindings fs) 410 | (sym, fs) <- Map.toList findings ] 411 realSyms = filter (not . null . snd) syms 412 total = sum (map (length . snd) realSyms) 413 if optNctDetail opts 414 then mapM_ (printNctDetail showRt isRt) realSyms 415 else mapM_ (printNctSummary showRt isRt) realSyms 416 if optQuiet opts 417 then pure () 418 else do 419 TIO.putStrLn "" 420 TIO.putStrLn $ "Root symbol: " 421 <> ssrRootSymbol ssr 422 TIO.putStrLn $ "Reachable symbols: " 423 <> T.pack (show (ssrReachable ssr)) 424 TIO.putStrLn $ "With findings: " 425 <> T.pack (show (length realSyms)) 426 TIO.putStrLn $ "NCT findings: " 427 <> T.pack (show total) 428 if total == 0 429 then exitSuccess 430 else exitFailure 431 432 -- | Output NCT scan results. 433 outputNct :: RuntimeConfig -> Options -> LineMap 434 -> Map.Map Text [NctFinding] -> IO () 435 outputNct rt opts lineMap findings = do 436 let showRt = optShowRuntimePat opts 437 filterFn = rtFilterNct rt 438 isRt = filterFn lineMap 439 filterFindings = 440 if showRt then id else filter (not . isRt) 441 syms = [ (sym, filterFindings fs) 442 | (sym, fs) <- Map.toList findings ] 443 realSyms = filter (not . null . snd) syms 444 total = sum (map (length . snd) realSyms) 445 if optNctDetail opts 446 then mapM_ (printNctDetail showRt isRt) realSyms 447 else mapM_ (printNctSummary showRt isRt) realSyms 448 if optQuiet opts 449 then pure () 450 else do 451 TIO.putStrLn "" 452 TIO.putStrLn $ "Functions scanned: " 453 <> T.pack (show (length realSyms)) 454 TIO.putStrLn $ "NCT findings: " 455 <> T.pack (show total) 456 if total == 0 457 then exitSuccess 458 else exitFailure 459 460 printNctSummary :: Bool -> (NctFinding -> Bool) 461 -> (Text, [NctFinding]) -> IO () 462 printNctSummary showRt isRt (sym, fs) = do 463 TIO.putStrLn $ sym <> ": " 464 <> T.pack (show (length fs)) 465 mapM_ (printFindingIndented showRt isRt) fs 466 467 printFindingIndented :: Bool -> (NctFinding -> Bool) 468 -> NctFinding -> IO () 469 printFindingIndented showRt isRt f = 470 let rtMatch = isRt f 471 content = T.pack (show (nctLine f)) <> ": " 472 <> nctReasonText (nctReason f) <> ": " 473 <> instrText (nctInstr f) 474 line = if showRt && rtMatch 475 then " (runtime) " <> content 476 else " " <> content 477 in TIO.putStrLn line 478 479 printNctDetail :: Bool -> (NctFinding -> Bool) 480 -> (Text, [NctFinding]) -> IO () 481 printNctDetail showRt isRt (sym, fs) = 482 mapM_ (printFinding showRt isRt sym) fs 483 484 printFinding :: Bool -> (NctFinding -> Bool) 485 -> Text -> NctFinding -> IO () 486 printFinding showRt isRt sym f = 487 let rtMatch = isRt f 488 content = sym <> ":" 489 <> T.pack (show (nctLine f)) <> ": " 490 <> nctReasonText (nctReason f) <> ": " 491 <> instrText (nctInstr f) 492 line = if showRt && rtMatch 493 then "(runtime) " <> content 494 else content 495 in TIO.putStrLn line 496 497 nctReasonText :: NctReason -> Text 498 nctReasonText r = case r of 499 CondBranch -> "cond-branch" 500 IndirectBranch -> "indirect-branch" 501 Div -> "div" 502 RegIndexAddr -> "reg-index" 503 504 instrText :: Instr -> Text 505 instrText instr = T.pack (show instr) 506 507 -- | Filter violations based on options. 508 -- By default, only secret violations are shown. 509 filterViolations :: Options -> [Violation] 510 -> [Violation] 511 filterViolations opts 512 | optDisplayUnknown opts = id 513 | otherwise = filter (isSecretViolation . vReason) 514 515 -- | Check if a violation reason is secret (not unknown). 516 isSecretViolation :: ViolationReason -> Bool 517 isSecretViolation r = case r of 518 SecretBase _ -> True 519 SecretIndex _ -> True 520 UnknownBase _ -> False 521 UnknownIndex _ -> False 522 NonConstOffset -> True 523 524 printViolation :: Violation -> IO () 525 printViolation v = TIO.putStrLn $ 526 vSymbol v <> ":" <> T.pack (show (vLine v)) 527 <> ": " <> reasonText (vReason v) 528 529 reasonText :: ViolationReason -> Text 530 reasonText r = case r of 531 SecretBase reg -> 532 "secret base register " <> regName reg 533 SecretIndex reg -> 534 "secret index register " <> regName reg 535 UnknownBase reg -> 536 "unknown base register " <> regName reg 537 UnknownIndex reg -> 538 "unknown index register " <> regName reg 539 NonConstOffset -> 540 "non-constant offset without masking"