{-# LANGUAGE FlexibleContexts #-}
module Data.Macaw.Discovery.Classifier.PLT (
  pltStubClassifier
  ) where

import           Control.Lens ( (^.) )
import           Control.Monad ( when, unless )
import qualified Control.Monad.Reader as CMR
import qualified Data.Foldable as F
import           Data.Monoid ( Any(..) )
import           Data.Parameterized.Classes
import qualified Data.Parameterized.Map as MapF
import           Data.Parameterized.Some
import           Data.Parameterized.TraversableF
import           Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import qualified Data.Set as Set

import qualified Data.Macaw.Architecture.Info as Info
import           Data.Macaw.CFG
import qualified Data.Macaw.Discovery.ParsedContents as Parsed

-- | @stripPLTRead assignId prev rest@ looks for a read of @assignId@
-- from the end of @prev@, and if it finds it returns the
-- concatenation of the instruction before the read in @prev@ and
-- @rest@.
--
-- The read may appear before comment and @instructionStart@
-- instructions, but otherwise must be at the end of the instructions
-- in @prev@.
stripPLTRead :: forall arch ids tp
               . ArchConstraints arch
              => AssignId ids tp -- ^ Identifier of write to remove
              -> Seq (Stmt arch ids)
              -> Seq (Stmt arch ids)
              -> Maybe (Seq (Stmt arch ids))
stripPLTRead :: forall arch ids (tp :: Type).
ArchConstraints arch =>
AssignId ids tp
-> Seq (Stmt arch ids)
-> Seq (Stmt arch ids)
-> Maybe (Seq (Stmt arch ids))
stripPLTRead AssignId ids tp
readId Seq (Stmt arch ids)
next Seq (Stmt arch ids)
rest =
  case Seq (Stmt arch ids) -> ViewR (Stmt arch ids)
forall a. Seq a -> ViewR a
Seq.viewr Seq (Stmt arch ids)
next of
    ViewR (Stmt arch ids)
Seq.EmptyR -> Maybe (Seq (Stmt arch ids))
forall a. Maybe a
Nothing
    Seq (Stmt arch ids)
prev Seq.:> Stmt arch ids
lastStmt -> do
      let cont :: Maybe (Seq (Stmt arch ids))
cont = AssignId ids tp
-> Seq (Stmt arch ids)
-> Seq (Stmt arch ids)
-> Maybe (Seq (Stmt arch ids))
forall arch ids (tp :: Type).
ArchConstraints arch =>
AssignId ids tp
-> Seq (Stmt arch ids)
-> Seq (Stmt arch ids)
-> Maybe (Seq (Stmt arch ids))
stripPLTRead AssignId ids tp
readId Seq (Stmt arch ids)
prev (Stmt arch ids
lastStmt Stmt arch ids -> Seq (Stmt arch ids) -> Seq (Stmt arch ids)
forall a. a -> Seq a -> Seq a
Seq.<| Seq (Stmt arch ids)
rest)
      case Stmt arch ids
lastStmt of
        AssignStmt (Assignment AssignId ids tp
stmtId AssignRhs arch (Value arch ids) tp
rhs)
          | Just tp :~: tp
Refl <- AssignId ids tp -> AssignId ids tp -> Maybe (tp :~: tp)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
forall (a :: Type) (b :: Type).
AssignId ids a -> AssignId ids b -> Maybe (a :~: b)
testEquality AssignId ids tp
readId AssignId ids tp
stmtId ->
              Seq (Stmt arch ids) -> Maybe (Seq (Stmt arch ids))
forall a. a -> Maybe a
Just (Seq (Stmt arch ids)
prev Seq (Stmt arch ids) -> Seq (Stmt arch ids) -> Seq (Stmt arch ids)
forall a. Seq a -> Seq a -> Seq a
Seq.>< (Stmt arch ids -> Stmt arch ids)
-> Seq (Stmt arch ids) -> Seq (Stmt arch ids)
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (AssignId ids tp -> Stmt arch ids -> Stmt arch ids
dropRefsTo AssignId ids tp
AssignId ids tp
stmtId) Seq (Stmt arch ids)
rest)
            -- Fail if the read to delete is used in later computations
          | Some (AssignId ids) -> Set (Some (AssignId ids)) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (AssignId ids tp -> Some (AssignId ids)
forall k (f :: k -> Type) (x :: k). f x -> Some f
Some AssignId ids tp
readId) ((forall (x :: Type). Value arch ids x -> Set (Some (AssignId ids)))
-> forall (x :: Type).
   AssignRhs arch (Value arch ids) x -> Set (Some (AssignId ids))
forall k l (t :: (k -> Type) -> l -> Type) (f :: k -> Type) m.
(FoldableFC t, Monoid m) =>
(forall (x :: k). f x -> m) -> forall (x :: l). t f x -> m
forall (f :: Type -> Type) m.
Monoid m =>
(forall (x :: Type). f x -> m)
-> forall (x :: Type). AssignRhs arch f x -> m
foldMapFC Value arch ids x -> Set (Some (AssignId ids))
forall arch ids (tp :: Type).
Value arch ids tp -> Set (Some (AssignId ids))
forall (x :: Type). Value arch ids x -> Set (Some (AssignId ids))
refsInValue AssignRhs arch (Value arch ids) tp
rhs) ->
              Maybe (Seq (Stmt arch ids))
forall a. Maybe a
Nothing
          | Bool
otherwise ->
            case AssignRhs arch (Value arch ids) tp
rhs of
              EvalApp{} -> Maybe (Seq (Stmt arch ids))
cont
              SetUndefined{} -> Maybe (Seq (Stmt arch ids))
cont
              AssignRhs arch (Value arch ids) tp
_ -> Maybe (Seq (Stmt arch ids))
forall a. Maybe a
Nothing
        InstructionStart{} -> Maybe (Seq (Stmt arch ids))
cont
        ArchState{} -> Maybe (Seq (Stmt arch ids))
cont
        Comment{} -> Maybe (Seq (Stmt arch ids))
cont
        Stmt arch ids
_ -> Maybe (Seq (Stmt arch ids))
forall a. Maybe a
Nothing
  where
    -- It is possible for later ArchState updates to reference the AssignId of
    -- the AssignStmt that is dropped, so make sure to prune such updates to
    -- avoid referencing the now out-of-scope AssignId.
    dropRefsTo :: AssignId ids tp -> Stmt arch ids -> Stmt arch ids
    dropRefsTo :: AssignId ids tp -> Stmt arch ids -> Stmt arch ids
dropRefsTo AssignId ids tp
stmtId Stmt arch ids
stmt =
      case Stmt arch ids
stmt of
        ArchState ArchMemAddr arch
addr MapF (ArchReg arch) (Value arch ids)
updates ->
          ArchMemAddr arch
-> MapF (ArchReg arch) (Value arch ids) -> Stmt arch ids
forall arch ids.
ArchMemAddr arch
-> MapF (ArchReg arch) (Value arch ids) -> Stmt arch ids
ArchState ArchMemAddr arch
addr (MapF (ArchReg arch) (Value arch ids) -> Stmt arch ids)
-> MapF (ArchReg arch) (Value arch ids) -> Stmt arch ids
forall a b. (a -> b) -> a -> b
$
          (forall (tp :: Type). Value arch ids tp -> Bool)
-> MapF (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
forall {v} (f :: v -> Type) (k :: v -> Type).
(forall (tp :: v). f tp -> Bool) -> MapF k f -> MapF k f
MapF.filter (\Value arch ids tp
v -> AssignId ids tp -> Some (AssignId ids)
forall k (f :: k -> Type) (x :: k). f x -> Some f
Some AssignId ids tp
stmtId Some (AssignId ids) -> Set (Some (AssignId ids)) -> Bool
forall a. Ord a => a -> Set a -> Bool
`Set.notMember` Value arch ids tp -> Set (Some (AssignId ids))
forall arch ids (tp :: Type).
Value arch ids tp -> Set (Some (AssignId ids))
refsInValue Value arch ids tp
v) MapF (ArchReg arch) (Value arch ids)
updates

        -- These Stmts don't contain any Values.
        InstructionStart{} -> Stmt arch ids
stmt
        Comment{}          -> Stmt arch ids
stmt

        -- stripPLTRead will bail out if it encounters any of these forms of
        -- Stmt, so we don't need to consider them.
        AssignStmt{}   -> Stmt arch ids
stmt
        ExecArchStmt{} -> Stmt arch ids
stmt
        CondWriteMem{} -> Stmt arch ids
stmt
        WriteMem{}     -> Stmt arch ids
stmt

removeUnassignedRegs :: forall arch ids
                     .  RegisterInfo (ArchReg arch)
                     => RegState (ArchReg arch) (Value arch ids)
                        -- ^ Initial register values
                     -> RegState (ArchReg arch) (Value arch ids)
                        -- ^ Final register values
                     -> MapF.MapF (ArchReg arch) (Value arch ids)
removeUnassignedRegs :: forall arch ids.
RegisterInfo (ArchReg arch) =>
RegState (ArchReg arch) (Value arch ids)
-> RegState (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
removeUnassignedRegs RegState (ArchReg arch) (Value arch ids)
initRegs RegState (ArchReg arch) (Value arch ids)
finalRegs =
  let keepReg :: forall tp . ArchReg arch tp -> Value arch ids tp -> Bool
      keepReg :: forall (tp :: Type). ArchReg arch tp -> Value arch ids tp -> Bool
keepReg ArchReg arch tp
r Value arch ids tp
finalVal
         | Just tp :~: BVType (RegAddrWidth (ArchReg arch))
Refl <- ArchReg arch tp
-> ArchReg arch (BVType (RegAddrWidth (ArchReg arch)))
-> Maybe (tp :~: BVType (RegAddrWidth (ArchReg arch)))
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
forall (a :: Type) (b :: Type).
ArchReg arch a -> ArchReg arch b -> Maybe (a :~: b)
testEquality ArchReg arch tp
r ArchReg arch (BVType (RegAddrWidth (ArchReg arch)))
forall (r :: Type -> Type).
RegisterInfo r =>
r (BVType (RegAddrWidth r))
ip_reg = Bool
False
         | Just tp :~: tp
Refl <- Value arch ids tp -> Value arch ids tp -> Maybe (tp :~: tp)
forall {k} (f :: k -> Type) (a :: k) (b :: k).
TestEquality f =>
f a -> f b -> Maybe (a :~: b)
forall (a :: Type) (b :: Type).
Value arch ids a -> Value arch ids b -> Maybe (a :~: b)
testEquality Value arch ids tp
initVal Value arch ids tp
finalVal = Bool
False
         | Bool
otherwise = Bool
True
        where initVal :: Value arch ids tp
initVal = RegState (ArchReg arch) (Value arch ids)
initRegsRegState (ArchReg arch) (Value arch ids)
-> Getting
     (Value arch ids tp)
     (RegState (ArchReg arch) (Value arch ids))
     (Value arch ids tp)
-> Value arch ids tp
forall s a. s -> Getting a s a -> a
^.ArchReg arch tp
-> Lens'
     (RegState (ArchReg arch) (Value arch ids)) (Value arch ids tp)
forall {k} (r :: k -> Type) (f :: k -> Type) (tp :: k).
OrdF r =>
r tp -> Lens' (RegState r f) (f tp)
boundValue ArchReg arch tp
r
   in (forall (tp :: Type). ArchReg arch tp -> Value arch ids tp -> Bool)
-> MapF (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
forall {v} (k :: v -> Type) (f :: v -> Type).
(forall (tp :: v). k tp -> f tp -> Bool) -> MapF k f -> MapF k f
MapF.filterWithKey ArchReg arch tp -> Value arch ids tp -> Bool
forall (tp :: Type). ArchReg arch tp -> Value arch ids tp -> Bool
keepReg (RegState (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
forall {v} (r :: v -> Type) (f :: v -> Type).
RegState r f -> MapF r f
regStateMap RegState (ArchReg arch) (Value arch ids)
finalRegs)

-- | Return true if any value in structure contains the given
-- identifier.
containsAssignId :: forall t arch ids itp
                 .  FoldableF t
                 => AssignId ids itp
                    -- ^ Forbidden assignment -- may not appear in terms.
                 -> t (Value arch ids)
                 -> Bool
containsAssignId :: forall (t :: (Type -> Type) -> Type) arch ids (itp :: Type).
FoldableF t =>
AssignId ids itp -> t (Value arch ids) -> Bool
containsAssignId AssignId ids itp
droppedAssign =
  let hasId :: forall tp . Value arch ids tp -> Any
      hasId :: forall (tp :: Type). Value arch ids tp -> Any
hasId Value arch ids tp
v = Bool -> Any
Any (Some (AssignId ids) -> Set (Some (AssignId ids)) -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member (AssignId ids itp -> Some (AssignId ids)
forall k (f :: k -> Type) (x :: k). f x -> Some f
Some AssignId ids itp
droppedAssign) (Value arch ids tp -> Set (Some (AssignId ids))
forall arch ids (tp :: Type).
Value arch ids tp -> Set (Some (AssignId ids))
refsInValue Value arch ids tp
v))
   in Any -> Bool
getAny (Any -> Bool)
-> (t (Value arch ids) -> Any) -> t (Value arch ids) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (tp :: Type). Value arch ids tp -> Any)
-> t (Value arch ids) -> Any
forall m (e :: Type -> Type).
Monoid m =>
(forall (s :: Type). e s -> m) -> t e -> m
forall k (t :: (k -> Type) -> Type) m (e :: k -> Type).
(FoldableF t, Monoid m) =>
(forall (s :: k). e s -> m) -> t e -> m
foldMapF Value arch ids s -> Any
forall (tp :: Type). Value arch ids tp -> Any
hasId

-- | A classifier that attempts to recognize PLT stubs
pltStubClassifier :: Info.BlockClassifier arch ids
pltStubClassifier :: forall arch ids. BlockClassifier arch ids
pltStubClassifier = String
-> BlockClassifierM arch ids (ParsedContents arch ids)
-> BlockClassifierM arch ids (ParsedContents arch ids)
forall arch ids a.
String
-> BlockClassifierM arch ids a -> BlockClassifierM arch ids a
Info.classifierName String
"PLT stub" (BlockClassifierM arch ids (ParsedContents arch ids)
 -> BlockClassifierM arch ids (ParsedContents arch ids))
-> BlockClassifierM arch ids (ParsedContents arch ids)
-> BlockClassifierM arch ids (ParsedContents arch ids)
forall a b. (a -> b) -> a -> b
$ do
  bcc <- BlockClassifierM arch ids (BlockClassifierContext arch ids)
forall r (m :: Type -> Type). MonadReader r m => m r
CMR.ask
  let ctx = BlockClassifierContext arch ids -> ParseContext arch ids
forall arch ids.
BlockClassifierContext arch ids -> ParseContext arch ids
Info.classifierParseContext BlockClassifierContext arch ids
bcc
  let ainfo = ParseContext arch ids -> ArchitectureInfo arch
forall arch ids. ParseContext arch ids -> ArchitectureInfo arch
Info.pctxArchInfo ParseContext arch ids
ctx
  let mem = ParseContext arch ids -> Memory (ArchAddrWidth arch)
forall arch ids.
ParseContext arch ids -> Memory (ArchAddrWidth arch)
Info.pctxMemory ParseContext arch ids
ctx
  Info.withArchConstraints ainfo $ do

    -- The IP should jump to an address in the .got, so try to compute that.
    AssignedValue (Assignment valId v) <- pure $ Info.classifierFinalRegState bcc ^. boundValue ip_reg
    ReadMem gotVal _repr <- pure $ v
    Just gotSegOff <- pure $ valueAsSegmentOff mem gotVal
    -- The .got contents should point to a relocation to the function
    -- that we will jump to.
    Right chunks <- pure $ segoffContentsAfter gotSegOff
    RelocationRegion r:_ <- pure $ chunks
    -- Check the relocation satisfies all the constraints we expect on PLT strub
    SymbolRelocation sym symVer <- pure $ relocationSym r
    unless (relocationOffset r == 0) $ fail "PLT stub requires 0 offset."
    when (relocationIsRel r) $ fail "PLT stub requires absolute relocation."
    when (toInteger (relocationSize r) /= toInteger (addrWidthReprByteCount (Info.archAddrWidth ainfo))) $ do
      fail $ "PLT stub relocations must match address size."
    when (relocationIsSigned r) $ do
      fail $ "PLT stub relocations must be signed."
    when (relocationEndianness r /= Info.archEndianness ainfo) $ do
      fail $ "PLT relocation endianness must match architecture."
    unless (relocationJumpSlot r) $ do
      fail $ "PLT relocations must be jump slots."
    -- The PLTStub terminator will implicitly read the GOT address, so we remove
    -- it from the list of statements.
    Just strippedStmts <- pure $ stripPLTRead valId (Info.classifierStmts bcc) Seq.empty
    let strippedRegs = RegState (ArchReg arch) (Value arch ids)
-> RegState (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
forall arch ids.
RegisterInfo (ArchReg arch) =>
RegState (ArchReg arch) (Value arch ids)
-> RegState (ArchReg arch) (Value arch ids)
-> MapF (ArchReg arch) (Value arch ids)
removeUnassignedRegs (BlockClassifierContext arch ids
-> RegState (ArchReg arch) (Value arch ids)
forall arch ids.
BlockClassifierContext arch ids
-> RegState (ArchReg arch) (Value arch ids)
Info.classifierInitRegState BlockClassifierContext arch ids
bcc) (BlockClassifierContext arch ids
-> RegState (ArchReg arch) (Value arch ids)
forall arch ids.
BlockClassifierContext arch ids
-> RegState (ArchReg arch) (Value arch ids)
Info.classifierFinalRegState BlockClassifierContext arch ids
bcc)
    when (containsAssignId valId strippedRegs) $ do
      fail $ "PLT IP must be assigned."
    pure $ Parsed.ParsedContents { Parsed.parsedNonterm = F.toList strippedStmts
                              , Parsed.parsedTerm  = Parsed.PLTStub strippedRegs gotSegOff (VerSym sym symVer)
                              , Parsed.writtenCodeAddrs = Info.classifierWrittenAddrs bcc
                              , Parsed.intraJumpTargets = []
                              , Parsed.newFunctionAddrs = []
                              }