auditor

An aarch64 constant-time memory access auditing tool.
git clone git://git.ppad.tech/auditor.git
Log | Files | Refs | README | LICENSE

AArch64.hs (8428B)


      1 {-# OPTIONS_HADDOCK prune #-}
      2 {-# LANGUAGE OverloadedStrings #-}
      3 
      4 -- |
      5 -- Module: Audit.AArch64
      6 -- Copyright: (c) 2025 Jared Tobin
      7 -- License: MIT
      8 -- Maintainer: jared@ppad.tech
      9 --
     10 -- AArch64 constant-time memory access auditor.
     11 --
     12 -- This module provides static analysis for AArch64 assembly
     13 -- to verify that memory accesses use only public
     14 -- (non-secret-derived) addresses. This helps ensure
     15 -- constant-time properties for cryptographic code.
     16 --
     17 -- Example usage:
     18 --
     19 -- @
     20 -- import Audit.AArch64
     21 -- import qualified Data.Text.IO as TIO
     22 --
     23 -- main = do
     24 --   src <- TIO.readFile "foo.s"
     25 --   case audit ghcRuntime "foo.s" src of
     26 --     Left err -> putStrLn $ "Parse error: " ++ show err
     27 --     Right result -> print result
     28 -- @
     29 
     30 module Audit.AArch64 (
     31     -- * Runtime configuration
     32     RuntimeConfig(..)
     33   , SecondaryStack(..)
     34   , ghcRuntime
     35   , genericRuntime
     36 
     37     -- * Main API
     38   , audit
     39   , auditInterProc
     40   , auditFile
     41   , auditFileInterProc
     42   , auditWithConfig
     43   , auditInterProcWithConfig
     44   , auditFileWithConfig
     45   , auditFileInterProcWithConfig
     46   , parseFile
     47 
     48     -- * NCT scanner
     49   , scanNct
     50   , scanNctFile
     51   , scanNctFileForSymbol
     52   , SymbolScanResult(..)
     53   , NctReason(..)
     54   , NctFinding(..)
     55     -- ** Runtime-aware filtering
     56   , LineMap
     57   , buildLineMap
     58   , filterRuntimePatterns
     59     -- ** GHC runtime classification (re-export)
     60   , isGhcRuntimeFinding
     61     -- ** Call graph
     62   , CallGraph
     63   , buildCallGraph
     64   , allSymbols
     65   , reachableSymbols
     66   , reachingSymbols
     67   , symbolExists
     68 
     69     -- * Results
     70   , AuditResult(..)
     71   , Violation(..)
     72   , ViolationReason(..)
     73   , Reg
     74   , regName
     75 
     76     -- * Taint configuration
     77   , TaintConfig(..)
     78   , ArgPolicy(..)
     79   , emptyArgPolicy
     80   , loadTaintConfig
     81 
     82     -- * Re-exports
     83   , ParseError
     84   , showParseError
     85   ) where
     86 
     87 import Audit.AArch64.CFG hiding (buildCallGraph)
     88 import Audit.AArch64.CallGraph
     89   ( CallGraph
     90   , buildCallGraph
     91   , allSymbols
     92   , reachableSymbols
     93   , reachingSymbols
     94   , symbolExists
     95   )
     96 import qualified Audit.AArch64.CallGraph as CG
     97   (buildCallGraph)
     98 import Audit.AArch64.Check
     99 import Audit.AArch64.NCT
    100   ( scanNct
    101   , buildLineMap, filterRuntimePatterns
    102   )
    103 import Audit.AArch64.Parser
    104 import Audit.AArch64.Runtime
    105   (RuntimeConfig(..), SecondaryStack(..))
    106 import Audit.AArch64.Runtime.GHC
    107   (ghcRuntime, genericRuntime, isGhcRuntimeFinding)
    108 import Audit.AArch64.Types
    109   ( Reg, Violation(..), ViolationReason(..), regName
    110   , TaintConfig(..), ArgPolicy(..), emptyArgPolicy
    111   , NctReason(..), NctFinding(..), LineMap
    112   )
    113 import Data.Aeson (eitherDecodeStrict')
    114 import qualified Data.ByteString as BS
    115 import qualified Data.Map.Strict as Map
    116 import qualified Data.Set as Set
    117 import Data.Text (Text)
    118 import qualified Data.Text as T
    119 import Data.Text.Encoding (decodeUtf8')
    120 
    121 -- | Audit assembly source for memory access violations.
    122 audit :: RuntimeConfig -> Text -> Text
    123       -> Either ParseError AuditResult
    124 audit rt name src = do
    125   lns <- parseAsm src
    126   let cfg = buildCFG rt lns
    127   pure (checkCFG rt name cfg)
    128 
    129 -- | Audit with inter-procedural analysis.
    130 auditInterProc :: RuntimeConfig -> Text -> Text
    131                -> Either ParseError AuditResult
    132 auditInterProc rt name src = do
    133   lns <- parseAsm src
    134   let cfg = buildCFG rt lns
    135   pure (checkCFGInterProc rt name cfg)
    136 
    137 -- | Audit an assembly file.
    138 auditFile :: RuntimeConfig -> FilePath
    139           -> IO (Either Text AuditResult)
    140 auditFile rt = auditFileWith (audit rt)
    141 
    142 -- | Audit an assembly file with inter-procedural analysis.
    143 auditFileInterProc :: RuntimeConfig -> FilePath
    144                    -> IO (Either Text AuditResult)
    145 auditFileInterProc rt = auditFileWith (auditInterProc rt)
    146 
    147 -- | Helper for file auditing.
    148 auditFileWith
    149   :: (Text -> Text -> Either ParseError AuditResult)
    150   -> FilePath -> IO (Either Text AuditResult)
    151 auditFileWith auditor path = do
    152   bs <- BS.readFile path
    153   case decodeUtf8' bs of
    154     Left err -> pure (Left (T.pack (show err)))
    155     Right src ->
    156       case auditor (T.pack path) src of
    157         Left err ->
    158           pure (Left (T.pack (showParseError err)))
    159         Right result -> pure (Right result)
    160 
    161 -- | Parse an assembly file without analysis.
    162 -- Returns line count on success.
    163 parseFile :: FilePath -> IO (Either Text Int)
    164 parseFile path = do
    165   bs <- BS.readFile path
    166   case decodeUtf8' bs of
    167     Left err -> pure (Left (T.pack (show err)))
    168     Right src ->
    169       case parseAsm src of
    170         Left err ->
    171           pure (Left (T.pack (showParseError err)))
    172         Right lns -> pure (Right (length lns))
    173 
    174 -- | Audit assembly source with taint config.
    175 auditWithConfig :: RuntimeConfig -> TaintConfig
    176                 -> Text -> Text
    177                 -> Either ParseError AuditResult
    178 auditWithConfig rt tcfg name src = do
    179   lns <- parseAsm src
    180   let cfg = buildCFG rt lns
    181   pure (checkCFGWithConfig rt tcfg name cfg)
    182 
    183 -- | Audit with inter-procedural analysis and taint config.
    184 auditInterProcWithConfig
    185   :: RuntimeConfig -> TaintConfig -> Text -> Text
    186   -> Either ParseError AuditResult
    187 auditInterProcWithConfig rt tcfg name src = do
    188   lns <- parseAsm src
    189   let cfg = buildCFG rt lns
    190   pure (checkCFGInterProcWithConfig rt tcfg name cfg)
    191 
    192 -- | Audit an assembly file with taint config.
    193 auditFileWithConfig :: RuntimeConfig -> TaintConfig
    194                     -> FilePath
    195                     -> IO (Either Text AuditResult)
    196 auditFileWithConfig rt tcfg =
    197   auditFileWith (auditWithConfig rt tcfg)
    198 
    199 -- | Audit an assembly file with inter-procedural analysis
    200 -- and taint config.
    201 auditFileInterProcWithConfig
    202   :: RuntimeConfig -> TaintConfig -> FilePath
    203   -> IO (Either Text AuditResult)
    204 auditFileInterProcWithConfig rt tcfg =
    205   auditFileWith (auditInterProcWithConfig rt tcfg)
    206 
    207 -- | Load a taint config from a JSON file.
    208 loadTaintConfig :: FilePath -> IO (Either Text TaintConfig)
    209 loadTaintConfig path = do
    210   bs <- BS.readFile path
    211   case eitherDecodeStrict' bs of
    212     Left err -> pure (Left (T.pack err))
    213     Right cfg -> pure (Right cfg)
    214 
    215 -- | Scan an assembly file for non-constant-time
    216 -- instructions. Returns a LineMap (for runtime
    217 -- classification) and the findings.
    218 scanNctFile
    219   :: RuntimeConfig -> FilePath
    220   -> IO (Either Text
    221            (LineMap, Map.Map Text [NctFinding]))
    222 scanNctFile rt path = do
    223   bs <- BS.readFile path
    224   case decodeUtf8' bs of
    225     Left err -> pure (Left (T.pack (show err)))
    226     Right src ->
    227       case parseAsm src of
    228         Left err ->
    229           pure (Left (T.pack (showParseError err)))
    230         Right lns ->
    231           let lineMap = buildLineMap lns
    232               findings = scanNct rt lns
    233           in  pure (Right (lineMap, findings))
    234 
    235 -- | Result of symbol-focused NCT scan.
    236 data SymbolScanResult = SymbolScanResult
    237   { ssrRootSymbol   :: !Text
    238   -- ^ The requested root symbol
    239   , ssrReachable    :: !Int
    240   -- ^ Number of reachable symbols (including root)
    241   , ssrLineMap      :: !LineMap
    242   -- ^ Line map for runtime classification
    243   , ssrFindings     :: !(Map.Map Text [NctFinding])
    244   -- ^ Findings filtered to reachable symbols only
    245   }
    246 
    247 -- | Scan an assembly file for NCT instructions, focused
    248 -- on a specific symbol.
    249 --
    250 -- Uses call graph analysis to find all symbols reachable
    251 -- from the given root symbol and returns findings only for
    252 -- those symbols.
    253 --
    254 -- Returns 'Left' if parsing fails or the symbol doesn't
    255 -- exist.
    256 scanNctFileForSymbol
    257   :: RuntimeConfig
    258   -> Text      -- ^ Root symbol to analyze
    259   -> FilePath  -- ^ Assembly file path
    260   -> IO (Either Text SymbolScanResult)
    261 scanNctFileForSymbol rt rootSym path = do
    262   bs <- BS.readFile path
    263   case decodeUtf8' bs of
    264     Left err -> pure (Left (T.pack (show err)))
    265     Right src ->
    266       case parseAsm src of
    267         Left err ->
    268           pure (Left (T.pack (showParseError err)))
    269         Right lns -> do
    270           let callGraph = CG.buildCallGraph rt lns
    271           if not (symbolExists rootSym callGraph)
    272             then pure (Left
    273               ("symbol not found: " <> rootSym))
    274             else do
    275               let reachable =
    276                     reachableSymbols rootSym callGraph
    277                   lineMap = buildLineMap lns
    278                   allFindings = scanNct rt lns
    279                   filtered = Map.filterWithKey
    280                     (\sym _ -> Set.member sym reachable)
    281                     allFindings
    282               pure $ Right $ SymbolScanResult
    283                 { ssrRootSymbol = rootSym
    284                 , ssrReachable  = Set.size reachable
    285                 , ssrLineMap    = lineMap
    286                 , ssrFindings   = filtered
    287                 }