{-# LANGUAGE OverloadedStrings #-}

-- | The Oughta Lua API
module Oughta.LuaApi
  ( check
  ) where

import Control.Exception qualified as X
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import Data.ByteString qualified as BS
import Data.IORef (IORef)
import Data.IORef qualified as IORef
import Data.Text (Text)
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Oughta.Exception (Exception)
import Oughta.Exception qualified as OE
import Oughta.Extract (LuaProgram, SourceMap, lookupSourceMap, programText, sourceMap, sourceMapFile)
import Oughta.Hooks qualified as OH
import Oughta.Lua qualified as OL
import Oughta.Pos qualified as OP
import Oughta.Result (Progress, Result)
import Oughta.Result qualified as OR
import Oughta.Traceback qualified as OT
import HsLua qualified as Lua

-- | Name of the @text@ global variable. Not exported.
text :: Lua.Name
text :: Name
text = ByteString -> Name
Lua.Name ByteString
"text"

-- | Set the @text@ global. Not exported.
setText :: ByteString -> Lua.LuaE Exception ()
setText :: ByteString -> LuaE Exception ()
setText ByteString
txt = do
  ByteString -> LuaE Exception ()
forall e. ByteString -> LuaE e ()
Lua.pushstring ByteString
txt
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal Name
text

-- | Helper, not exported.
withProgress :: IORef Progress -> (Progress -> Lua.LuaE Exception Progress) -> Lua.LuaE Exception ()
withProgress :: IORef Progress
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
withProgress IORef Progress
stateRef Progress -> LuaE Exception Progress
f = do
  p <- IO Progress -> LuaE Exception Progress
forall a. IO a -> LuaE Exception a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef Progress -> IO Progress
forall a. IORef a -> IO a
IORef.readIORef IORef Progress
stateRef)
  p' <- f p
  setText (OR.progressRemainder p')
  liftIO (IORef.writeIORef stateRef p')
  pure ()

-- | Implementation of @col@. Not exported.
col :: IORef Progress -> Lua.LuaE Exception Int
col :: IORef Progress -> LuaE Exception Int
col IORef Progress
stateRef = do
  p <- IO Progress -> LuaE Exception Progress
forall a. IO a -> LuaE Exception a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef Progress -> IO Progress
forall a. IORef a -> IO a
IORef.readIORef IORef Progress
stateRef)
  pure (OP.col (OP.pos (OR.progressLoc p)))

-- | Implementation of @fail@. Not exported.
fail_ :: SourceMap -> IORef Progress -> Lua.LuaE Exception ()
fail_ :: SourceMap -> IORef Progress -> LuaE Exception ()
fail_ SourceMap
sm IORef Progress
stateRef =
  IORef Progress
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
withProgress IORef Progress
stateRef ((Progress -> LuaE Exception Progress) -> LuaE Exception ())
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
forall a b. (a -> b) -> a -> b
$ \Progress
p -> do
    tb <- SourceMap -> LuaE Exception Traceback
forall e. LuaError e => SourceMap -> LuaE e Traceback
OT.getTraceback SourceMap
sm
    OE.throwNoMatch (OR.Failure p tb)

-- | Implementation of @file@. Not exported.
file :: SourceMap -> Lua.LuaE Exception Text
file :: SourceMap -> LuaE Exception Text
file SourceMap
sm = Text -> LuaE Exception Text
forall a. a -> LuaE Exception a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SourceMap -> Text
sourceMapFile SourceMap
sm)

-- | Implementation of @line@. Not exported.
line :: IORef Progress -> Lua.LuaE Exception Int
line :: IORef Progress -> LuaE Exception Int
line IORef Progress
stateRef = do
  p <- IO Progress -> LuaE Exception Progress
forall a. IO a -> LuaE Exception a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef Progress -> IO Progress
forall a. IORef a -> IO a
IORef.readIORef IORef Progress
stateRef)
  pure (OP.line (OP.pos (OR.progressLoc p)))

-- | Implementation of @match@. Not exported.
match :: SourceMap -> IORef Progress -> Int -> Lua.LuaE Exception ()
match :: SourceMap -> IORef Progress -> Int -> LuaE Exception ()
match SourceMap
sm IORef Progress
stateRef Int
n =
  IORef Progress
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
withProgress IORef Progress
stateRef ((Progress -> LuaE Exception Progress) -> LuaE Exception ())
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
forall a b. (a -> b) -> a -> b
$ \Progress
p -> do
    tb <- SourceMap -> LuaE Exception Traceback
forall e. LuaError e => SourceMap -> LuaE e Traceback
OT.getTraceback SourceMap
sm
    let txt = Progress -> ByteString
OR.progressRemainder Progress
p
    let (matched, remainder) = BS.splitAt n txt
    let loc = Progress -> Loc
OR.progressLoc Progress
p
    let start = Loc -> Pos
OP.pos Loc
loc
    let end = Pos -> Text -> Pos
OP.incPos (Loc -> Pos
OP.pos Loc
loc) (ByteString -> Text
Text.decodeUtf8Lenient ByteString
matched)
    let m =
          OR.Match
          { matchRemainder :: ByteString
OR.matchRemainder = ByteString
remainder
          , matchSpan :: Span
OR.matchSpan = Maybe String -> Pos -> Pos -> Span
OP.Span (Loc -> Maybe String
OP.path Loc
loc) Pos
start Pos
end
          , matchText :: ByteString
OR.matchText = ByteString
matched
          , matchTraceback :: Traceback
OR.matchTraceback = Traceback
tb
          }
    pure (OR.updateProgress m p)

-- | Implementation of @reset@. Not exported.
reset :: IORef Progress -> String -> ByteString -> Lua.LuaE Exception ()
reset :: IORef Progress -> String -> ByteString -> LuaE Exception ()
reset IORef Progress
stateRef String
name ByteString
txt = do
  ByteString -> LuaE Exception ()
setText ByteString
txt
  let p0 :: Progress
p0 = String -> ByteString -> Progress
OR.newProgress String
name ByteString
txt
  IO () -> LuaE Exception ()
forall a. IO a -> LuaE Exception a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IORef Progress -> Progress -> IO ()
forall a. IORef a -> a -> IO ()
IORef.writeIORef IORef Progress
stateRef Progress
p0)

-- | Implementation of @seek@. Not exported.
seek :: IORef Progress -> Int -> Lua.LuaE Exception ()
seek :: IORef Progress -> Int -> LuaE Exception ()
seek IORef Progress
stateRef Int
chars =
  IORef Progress
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
withProgress IORef Progress
stateRef ((Progress -> LuaE Exception Progress) -> LuaE Exception ())
-> (Progress -> LuaE Exception Progress) -> LuaE Exception ()
forall a b. (a -> b) -> a -> b
$ \Progress
p -> do
    let loc :: Loc
loc = Progress -> Loc
OR.progressLoc Progress
p
    let txt :: ByteString
txt = Progress -> ByteString
OR.progressRemainder Progress
p
    let (ByteString
before, ByteString
after) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
chars ByteString
txt
    let pos' :: Pos
pos' = Pos -> Text -> Pos
OP.incPos (Loc -> Pos
OP.pos Loc
loc) (ByteString -> Text
Text.decodeUtf8Lenient ByteString
before)
    let p' :: Progress
p' =
          Progress
p
          { OR.progressLoc = loc { OP.pos = pos' }
          , OR.progressRemainder = after
          }
    Progress -> LuaE Exception Progress
forall a. a -> LuaE Exception a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Progress
p'

-- | Implementation of @src_line@. Not exported.
srcLine :: SourceMap -> Int -> Lua.LuaE Exception Int
srcLine :: SourceMap -> Int -> LuaE Exception Int
srcLine SourceMap
sm Int
level = do
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.getglobal' Name
"debug.getinfo"
  -- Empirically, there are 3 levels of functions on the Lua stack between this
  -- function and user Lua code.
  Integer -> LuaE Exception ()
forall e. Integer -> LuaE e ()
Lua.pushinteger (Int64 -> Integer
Lua.Integer (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
level Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Int64
3))
  ByteString -> LuaE Exception ()
forall e. ByteString -> LuaE e ()
Lua.pushstring ByteString
"lnS"
  NumArgs -> NumResults -> LuaE Exception ()
forall e. LuaError e => NumArgs -> NumResults -> LuaE e ()
Lua.call NumArgs
2 NumResults
1

  _ty <- StackIndex -> Name -> LuaE Exception Type
forall e. LuaError e => StackIndex -> Name -> LuaE e Type
Lua.getfield StackIndex
Lua.top Name
"currentline"
  l0 <- Lua.peek @Int Lua.top
  Lua.pop 1

  _ty <- Lua.getfield Lua.top "short_src"
  src0 <- Lua.peek @Text Lua.top
  Lua.pop 1
  let src = Int -> Text -> Text
Text.drop (Text -> Int
Text.length Text
"[string \"") (Int -> Text -> Text
Text.dropEnd (Text -> Int
Text.length Text
"\"]") Text
src0)

  pure (lookupSourceMap src l0 sm)

-- | Load user and Oughta Lua code. Helper, not exported.
luaSetup ::
  OH.Hooks ->
  IORef Progress ->
  -- | User code
  LuaProgram ->
  -- | Initial content of @text@ global
  ByteString ->
  Lua.LuaE Exception ()
luaSetup :: Hooks
-> IORef Progress -> LuaProgram -> ByteString -> LuaE Exception ()
luaSetup Hooks
hooks IORef Progress
stateRef LuaProgram
prog ByteString
txt = do
  LuaE Exception ()
forall e. LuaE e ()
Lua.openlibs
  ByteString -> LuaE Exception ()
setText ByteString
txt

  let sm :: SourceMap
sm = LuaProgram -> SourceMap
sourceMap LuaProgram
prog

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction (LuaE Exception Int -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (IORef Progress -> LuaE Exception Int
col IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"col_no")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction (LuaE Exception () -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (SourceMap -> IORef Progress -> LuaE Exception ()
fail_ SourceMap
sm IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"fail")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction (LuaE Exception Text -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (SourceMap -> LuaE Exception Text
file SourceMap
sm))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"file")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction (LuaE Exception Int -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (IORef Progress -> LuaE Exception Int
line IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"line")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction ((Int -> LuaE Exception ()) -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (SourceMap -> IORef Progress -> Int -> LuaE Exception ()
match SourceMap
sm IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"match")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction ((String -> ByteString -> LuaE Exception ())
-> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (IORef Progress -> String -> ByteString -> LuaE Exception ()
reset IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"reset")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction ((Int -> LuaE Exception ()) -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (IORef Progress -> Int -> LuaE Exception ()
seek IORef Progress
stateRef))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"seek")

  HaskellFunction Exception -> LuaE Exception ()
forall e. LuaError e => HaskellFunction e -> LuaE e ()
Lua.pushHaskellFunction ((Int -> LuaE Exception Int) -> HaskellFunction Exception
forall e a. Exposable e a => a -> HaskellFunction e
Lua.toHaskellFunction (SourceMap -> Int -> LuaE Exception Int
srcLine SourceMap
sm))
  Name -> LuaE Exception ()
forall e. LuaError e => Name -> LuaE e ()
Lua.setglobal (ByteString -> Name
Lua.Name ByteString
"src_line")

  _ <- ByteString -> Name -> LuaE Exception Status
forall e. ByteString -> Name -> LuaE e Status
Lua.loadbuffer ByteString
OL.luaCode (ByteString -> Name
Lua.Name ByteString
"oughta.lua")
  Lua.call 0 0

  Lua.changeErrorType (OH.preHook hooks)

  let nm = ByteString -> Name
Lua.Name (Text -> ByteString
Text.encodeUtf8 (SourceMap -> Text
sourceMapFile SourceMap
sm))
  _ <- Lua.loadbuffer (Text.encodeUtf8 (programText prog)) nm
  Lua.call 0 0

  Lua.changeErrorType (OH.postHook hooks)

-- | Check some text against a Lua program using the API.
check ::
  OH.Hooks ->
  LuaProgram ->
  -- | Text to check
  ByteString ->
  IO Result
check :: Hooks -> LuaProgram -> ByteString -> IO Result
check Hooks
hooks LuaProgram
prog ByteString
txt = do
  let p0 :: Progress
p0 = String -> ByteString -> Progress
OR.newProgress String
"<out>" ByteString
txt
  stateRef <- Progress -> IO (IORef Progress)
forall a. a -> IO (IORef a)
IORef.newIORef Progress
p0
  result <- Lua.run (Lua.try (luaSetup hooks stateRef prog txt))
  case result of
    Left (OE.LuaException Exception
e) -> Exception -> IO Result
forall e a. (HasCallStack, Exception e) => e -> IO a
X.throwIO Exception
e
    Left (OE.Failure NoMatch
noMatch) ->
      Either Failure Success -> Result
OR.Result (Either Failure Success -> Result)
-> (Failure -> Either Failure Success) -> Failure -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Failure -> Either Failure Success
forall a b. a -> Either a b
Left (Failure -> Result) -> IO Failure -> IO Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NoMatch -> IO Failure
OE.noMatch NoMatch
noMatch
    Right () -> do
      state <- IORef Progress -> IO Progress
forall a. IORef a -> IO a
IORef.readIORef IORef Progress
stateRef
      pure (OR.Result (Right (OR.progressToSuccess state)))