{-|
Copyright        : (c) Galois, Inc 2015
Maintainer       : Simon Winwood <sjw@galois.com>

A strided interval domain x + [a .. b] * c
-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternGuards #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TypeOperators #-}

-- FIXME: take rounding/number of bits/etc. into consideration
-- FIXME: only really useful for unsigned?
module Data.Macaw.AbsDomain.StridedInterval
       ( StridedInterval(..)
         -- Constructors
       , singleton, mkStridedInterval, fromFoldable
         -- Predicates
       , isSingleton, isTop, member, isSubsetOf
         -- Destructors
       , toList, intervalEnd, size
         -- Domain operations
       , lub, lubSingleton, glb
         -- Operations
       , bvadd, bvadc, bvmul, trunc
         -- Debugging
       ) where


import           Control.Exception (assert)
import qualified Data.Foldable as Fold
import           Data.Parameterized.NatRepr
import           GHC.TypeLits (Nat)
import           Prettyprinter
import           Test.QuickCheck

-- import           Data.Macaw.DebugLogging

-- -----------------------------------------------------------------------------
-- Data type decl and instances

-- This is a canonical (and more compact) representation, basically we
-- turn x + [a .. b] * c into (x + a * c) + [0 .. b - a] * c

data StridedInterval (w :: Nat) =
  StridedInterval { forall (w :: Nat). StridedInterval w -> NatRepr w
typ    :: !(NatRepr w) -- ^ number of bits in type.
                  , forall (w :: Nat). StridedInterval w -> Integer
base   :: !Integer
                  , forall (w :: Nat). StridedInterval w -> Integer
range  :: !Integer -- ^ This is the number of elements in the interval - 1
                  , forall (w :: Nat). StridedInterval w -> Integer
stride :: !Integer
                  }

instance Eq (StridedInterval tp) where
  si1 :: StridedInterval tp
si1@StridedInterval{} == :: StridedInterval tp -> StridedInterval tp -> Bool
== si2 :: StridedInterval tp
si2@StridedInterval{} =
    StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval tp
si1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval tp
si2 Bool -> Bool -> Bool
&& StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval tp
si1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval tp
si2 Bool -> Bool -> Bool
&& StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval tp
si1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval tp
si2

instance Show (StridedInterval tp) where
  show :: StridedInterval tp -> String
show = Doc Any -> String
forall a. Show a => a -> String
show (Doc Any -> String)
-> (StridedInterval tp -> Doc Any) -> StridedInterval tp -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StridedInterval tp -> Doc Any
forall a ann. Pretty a => a -> Doc ann
forall ann. StridedInterval tp -> Doc ann
pretty

intervalEnd :: StridedInterval tp -> Integer
--intervalEnd EmptyInterval = error "intervalEnd"
intervalEnd :: forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval tp
si = StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval tp
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval tp
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval tp
si

size :: StridedInterval tp -> Integer
--size EmptyInterval = 0
size :: forall (w :: Nat). StridedInterval w -> Integer
size StridedInterval tp
si = StridedInterval tp -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval tp
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1

-- -----------------------------------------------------------------------------
-- Constructors

-- | Construct a singleton value
singleton :: NatRepr w -> Integer -> StridedInterval w
singleton :: forall (w :: Nat). NatRepr w -> Integer -> StridedInterval w
singleton NatRepr w
tp Integer
v = StridedInterval { typ :: NatRepr w
typ = NatRepr w
tp
                                 , base :: Integer
base = Integer
v
                                 , range :: Integer
range = Integer
0
                                 , stride :: Integer
stride = Integer
1
                                 }

empty :: NatRepr w -> StridedInterval w
empty :: forall (w :: Nat). NatRepr w -> StridedInterval w
empty NatRepr w
tp =
  StridedInterval { typ :: NatRepr w
typ = NatRepr w
tp
                  , base :: Integer
base = Integer
0
                  , range :: Integer
range = -Integer
1
                  , stride :: Integer
stride = Integer
1
                  }

-- | Make an interval given the start, end, and stride. Note that this
-- will round up if (start - end) is not a multiple of the stride,
-- i.e., @mkStr
mkStridedInterval :: NatRepr w -> Bool
                  -> Integer -> Integer -> Integer
                  -> StridedInterval w
mkStridedInterval :: forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval NatRepr w
tp Bool
roundUp Integer
start Integer
end Integer
s
  | Integer
end Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
start = NatRepr w -> StridedInterval w
forall (w :: Nat). NatRepr w -> StridedInterval w
empty NatRepr w
tp
  | Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0          = NatRepr w -> Integer -> StridedInterval w
forall (w :: Nat). NatRepr w -> Integer -> StridedInterval w
singleton NatRepr w
tp Integer
start
  | Integer
r Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0          = NatRepr w -> Integer -> StridedInterval w
forall (w :: Nat). NatRepr w -> Integer -> StridedInterval w
singleton NatRepr w
tp Integer
start
  | Bool
otherwise       =
      StridedInterval { typ :: NatRepr w
typ = NatRepr w
tp
                      , base :: Integer
base = Integer
start
                      , range :: Integer
range = Integer
r
                      , stride :: Integer
stride = Integer
s }
  where
    r :: Integer
r | Bool
roundUp = ((Integer
end Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
start) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (Integer
s Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
s
      | Bool
otherwise = (Integer
end Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
start) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
s

fromFoldable :: Fold.Foldable t =>
                NatRepr n -> t Integer -> StridedInterval n
fromFoldable :: forall (t :: Type -> Type) (n :: Nat).
Foldable t =>
NatRepr n -> t Integer -> StridedInterval n
fromFoldable NatRepr n
sz t Integer
vs
  | t Integer -> Bool
forall {a}. t a -> Bool
isEmptyV t Integer
vs  = NatRepr n -> StridedInterval n
forall (w :: Nat). NatRepr w -> StridedInterval w
empty NatRepr n
sz
  | Bool
otherwise    = NatRepr n
-> Bool -> Integer -> Integer -> Integer -> StridedInterval n
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval NatRepr n
sz Bool
True Integer
start Integer
end Integer
s
  where
    isEmptyV :: t a -> Bool
isEmptyV = Bool -> Bool
not (Bool -> Bool) -> (t a -> Bool) -> t a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> Bool) -> t a -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
Fold.any (Bool -> a -> Bool
forall a b. a -> b -> a
const Bool
True)
    start :: Integer
start    = t Integer -> Integer
forall a. Ord a => t a -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
Fold.minimum t Integer
vs
    end :: Integer
end      = t Integer -> Integer
forall a. Ord a => t a -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
Fold.maximum t Integer
vs
    -- This is a bit of a hack, relying on the fact that gcd 0 v == v
    s :: Integer
s       = (Integer -> Integer -> Integer) -> Integer -> t Integer -> Integer
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
Fold.foldl' (\Integer
g Integer
v -> Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd Integer
g (Integer
v Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
start)) Integer
0 t Integer
vs


-- -----------------------------------------------------------------------------
-- Predicates

isEmpty :: StridedInterval w -> Bool
isEmpty :: forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
s = StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0

isSingleton :: StridedInterval w -> Maybe Integer
isSingleton :: forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval { base :: forall (w :: Nat). StridedInterval w -> Integer
base = Integer
b, range :: forall (w :: Nat). StridedInterval w -> Integer
range = Integer
0 } = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
b
isSingleton StridedInterval w
_  = Maybe Integer
forall a. Maybe a
Nothing

isTop :: StridedInterval w -> Bool
isTop :: forall (w :: Nat). StridedInterval w -> Bool
isTop StridedInterval w
si = StridedInterval w
si StridedInterval w -> StridedInterval w -> Bool
forall a. Eq a => a -> a -> Bool
== NatRepr w -> StridedInterval w
forall (w :: Nat). NatRepr w -> StridedInterval w
top (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si)

member :: Integer -> StridedInterval w -> Bool
member :: forall (w :: Nat). Integer -> StridedInterval w -> Bool
member Integer
_ StridedInterval w
si | StridedInterval w -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
si = Bool
False
member Integer
n StridedInterval w
si = Bool -> Bool -> Bool
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
              StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
n
              Bool -> Bool -> Bool
&& (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0
              Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
si

-- is the set represented by si1 contained in si2?
isSubsetOf :: StridedInterval w
       -> StridedInterval w
       -> Bool
isSubsetOf :: forall (tp :: Nat).
StridedInterval tp -> StridedInterval tp -> Bool
isSubsetOf StridedInterval w
si1 StridedInterval w
_ | StridedInterval w -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
si1 = Bool
True
isSubsetOf StridedInterval w
si1 StridedInterval w
si2
  | Just Integer
s <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si1 = Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
member Integer
s StridedInterval w
si2
  | Bool
otherwise = Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
member (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1) StridedInterval w
si2
                Bool -> Bool -> Bool
&& Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
member (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si1) StridedInterval w
si2
                Bool -> Bool -> Bool
&& StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2)


-- -----------------------------------------------------------------------------
-- Domain operations

lub :: StridedInterval w
       -> StridedInterval w
       -> StridedInterval w
lub :: forall (w :: Nat).
StridedInterval w -> StridedInterval w -> StridedInterval w
lub StridedInterval w
s StridedInterval w
t | StridedInterval w -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
s = StridedInterval w
t
lub StridedInterval w
s StridedInterval w
t | StridedInterval w -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
t = StridedInterval w
s
-- FIXME: make more precise?
lub StridedInterval w
si1 StridedInterval w
si2
  | Just Integer
s <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si1 = Integer -> StridedInterval w -> StridedInterval w
forall (w :: Nat).
Integer -> StridedInterval w -> StridedInterval w
lubSingleton Integer
s StridedInterval w
si2
  | Just Integer
s <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si2 = Integer -> StridedInterval w -> StridedInterval w
forall (w :: Nat).
Integer -> StridedInterval w -> StridedInterval w
lubSingleton Integer
s StridedInterval w
si1
  | Bool
otherwise =
      NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1) Bool
True Integer
lower Integer
upper
                        (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2))
                             ((Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si2)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
lower))
  where
    lower :: Integer
lower = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si2)
    upper :: Integer
upper = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si2)

-- prop_lub :: StridedInterval (BVType 64)
--             -> StridedInterval (BVType 64)
--             -> Bool
-- prop_lub x y = x `isSubsetOf` (x `lub` y)
--                && y `isSubsetOf` (x `lub` y)

lubSingleton :: Integer
                -> StridedInterval w
                -> StridedInterval w
lubSingleton :: forall (w :: Nat).
Integer -> StridedInterval w -> StridedInterval w
lubSingleton Integer
s StridedInterval w
si
  | Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
member Integer
s StridedInterval w
si  = StridedInterval w
si
  | Just Integer
s' <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si =
      let l :: Integer
l = (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min Integer
s Integer
s')
          u :: Integer
u = (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
s Integer
s')
      in NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si) Bool
True Integer
l Integer
u (Integer
u Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
l)
  | Integer
s Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si  = Integer -> Integer -> Integer -> StridedInterval w
go Integer
s Integer
si_upper (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si)
  | Bool
otherwise    = Integer -> Integer -> Integer -> StridedInterval w
go (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si) (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
s Integer
si_upper) Integer
s
  where
    si_upper :: Integer
si_upper = StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si
    go :: Integer -> Integer -> Integer -> StridedInterval w
go Integer
lower Integer
upper Integer
to_contain =
      NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si) Bool
True Integer
lower Integer
upper
                        (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si) (Integer
to_contain Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
lower))

-- | Greatest lower bound.  @glb si1 si2@ contains only those values
-- which are in @si1@ and @si2@.
glb :: StridedInterval w
       -> StridedInterval w
       -> StridedInterval w
--glb EmptyInterval _ = EmptyInterval
--glb _ EmptyInterval = EmptyInterval
glb :: forall (w :: Nat).
StridedInterval w -> StridedInterval w -> StridedInterval w
glb StridedInterval w
si1 StridedInterval w
si2
  | Just Integer
s' <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si1 =
      if Integer
s' Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
`member` StridedInterval w
si2 then StridedInterval w
si1 else NatRepr w -> StridedInterval w
forall (w :: Nat). NatRepr w -> StridedInterval w
empty (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1)
  | Just Integer
s' <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si2 =
      if Integer
s' Integer -> StridedInterval w -> Bool
forall (w :: Nat). Integer -> StridedInterval w -> Bool
`member` StridedInterval w
si1 then StridedInterval w
si2 else NatRepr w -> StridedInterval w
forall (w :: Nat). NatRepr w -> StridedInterval w
empty (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1)
  | StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si2 =
      NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1) Bool
False (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1) Integer
upper
                        (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
lcm (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2))
   -- lower is the least value that is greater than both bases, less
   -- than both ends, and in both intervals.  That is,
   --
   -- base1 + n * stride1 = base2 + m * stride2
   --
   -- or
   --
   --    n * stride1 - m * stride2 = base2 - base1
   --
   -- where n, m are integers s.t. the above holds, we want also that
   -- the n, m are in range1, range2, resp.
  | Just (Integer
n, Integer
_) <- Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> Maybe (Integer, Integer)
solveLinearDiophantine (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2)
                                          (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1)
                                          (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
si2) =
      NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1) Bool
False (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) Integer
upper Integer
s
  | Bool
otherwise = NatRepr w -> StridedInterval w
forall (w :: Nat). NatRepr w -> StridedInterval w
empty (StridedInterval w -> NatRepr w
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval w
si1)
  where
    upper :: Integer
upper = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
intervalEnd StridedInterval w
si2)
    s :: Integer
s     = Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
lcm (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si1) (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si2)

-- solves ax - by = c, (NOTE - sign) for x and y with 0 <= x, y <=
-- a_max, b_max resp.  Assumes a > 0, b > 0, c /= 0.
--
-- In this restricted case, we have
--
-- a * n - b * m = gcd (a, -b) (> 0)
--
-- so we want least t s.t.
--
-- ceiling (max (n * c / - a, m * c / - b)) <= t
-- and
-- t <= floor (min ((a_max * gcd - n * c) / b, b_max * gcd - m * c) / a)

solveLinearDiophantine :: Integer -> Integer -> Integer
                          -> Integer -> Integer
                          -> Maybe (Integer, Integer)
solveLinearDiophantine :: Integer
-> Integer
-> Integer
-> Integer
-> Integer
-> Maybe (Integer, Integer)
solveLinearDiophantine Integer
a Integer
b Integer
c Integer
a_max Integer
b_max
  | Integer
c Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`rem` Integer
g Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 = Maybe (Integer, Integer)
forall a. Maybe a
Nothing
  | Integer
t Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
t_upper = (Integer, Integer) -> Maybe (Integer, Integer)
forall a. a -> Maybe a
Just ( Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
c Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
g) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (Integer
b Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
g) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
t
                        , Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (Integer
c Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
g) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (Integer
a Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`quot` Integer
g) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
t )
  | Bool
otherwise  = Maybe (Integer, Integer)
forall a. Maybe a
Nothing
  where
    (Integer
g, Integer
n, Integer
m) = Integer -> Integer -> (Integer, Integer, Integer)
forall a. Integral a => a -> a -> (a, a, a)
eGCD Integer
a (-Integer
b)

    t :: Integer
t = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
ceil_quot (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
c) (- Integer
a)) (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
ceil_quot (Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
c) (- Integer
b))
    t_upper :: Integer
t_upper = Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
min (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
floor_quot (Integer
a_max Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
g Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
c) Integer
b)
                  (Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
floor_quot (Integer
b_max Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
g Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
c) Integer
a)

-- calculates ceil(x/y)
ceil_quot :: Integral a => a -> a -> a
ceil_quot :: forall a. Integral a => a -> a -> a
ceil_quot a
x a
y = a
x a -> a -> a
forall a. Integral a => a -> a -> a
`quot` a
y a -> a -> a
forall a. Num a => a -> a -> a
+ (if a
x a -> a -> a
forall a. Integral a => a -> a -> a
`rem` a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 then a
0 else a
1)

floor_quot :: Integral a => a -> a -> a
floor_quot :: forall a. Integral a => a -> a -> a
floor_quot a
_ a
0 = String -> a
forall a. (?callStack::CallStack) => String -> a
error String
"floor_quot div by 0"
floor_quot a
x a
y = a
x a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
y

-- prop_sld :: Positive Integer -> Positive Integer
--             -> NonZero Integer -> Positive Integer -> Positive Integer
--             -> Property
-- prop_sld a b c d e = not (isNothing v) ==> p
--   where
--     p = case v of
--          Just (x, y) -> x >= 0 && y >= 0
--                         && x <= getPositive d
--                         && y <= getPositive e
--                         && (getPositive a) * x - (getPositive b) * y == (getNonZero c)
--          _ -> True
--     v = solveLinearDiophantine (getPositive a) (getPositive b) (getNonZero c)
--                                (getPositive d) (getPositive e)


-- | Returns the gcd, and n and m s.t. n * a + m * b = g
-- clagged, fixed, from
--    http://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm
-- this is presumably going to be slower than the gmp version :(
eGCD :: Integral a => a -> a -> (a,a,a)
eGCD :: forall a. Integral a => a -> a -> (a, a, a)
eGCD a
a0 a
b0 = let (a
g, a
m, a
n) = a -> a -> (a, a, a)
forall a. Integral a => a -> a -> (a, a, a)
go a
a0 a
b0
           in if a
g a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 then (-a
g, -a
m, -a
n) else (a
g, a
m, a
n)
  where
    go :: c -> c -> (c, c, c)
go c
a c
0 = (c
a, c
1, c
0)
    go c
a c
b = let (c
g, c
x, c
y) = c -> c -> (c, c, c)
go c
b (c -> (c, c, c)) -> c -> (c, c, c)
forall a b. (a -> b) -> a -> b
$ c -> c -> c
forall a. Integral a => a -> a -> a
rem c
a c
b
             in (c
g, c
y, c
x c -> c -> c
forall a. Num a => a -> a -> a
- (c
a c -> c -> c
forall a. Integral a => a -> a -> a
`quot` c
b) c -> c -> c
forall a. Num a => a -> a -> a
* c
y)

-- prop_eGCD :: Integer -> Integer -> Bool
-- prop_eGCD x y = let (g, a, b) = eGCD x y in x * a + y * b == g

-- -----------------------------------------------------------------------------
-- Operations

-- These operations probably only really make sense for constants or
-- constant ranges.  We can always just make the stride 1, but this
-- loses information.

-- We have x + [0..a] * b + y + [0 .. d] * e
-- = (x + y) + { 0 * b, 0 * e, 1 * b, 1 * e, ..., a * b, d * e }
-- `subsetOf` let m = gcd(b, e)
--            in (x + y) + [0 .. (a * b / m) + (d * e / m) ] * m

top :: NatRepr u -> StridedInterval u
top :: forall (w :: Nat). NatRepr w -> StridedInterval w
top NatRepr u
sz = StridedInterval { typ :: NatRepr u
typ = NatRepr u
sz
                         , base :: Integer
base = Integer
0
                         , range :: Integer
range = NatRepr u -> Integer
forall (w :: Nat). NatRepr w -> Integer
maxUnsigned NatRepr u
sz
                         , stride :: Integer
stride = Integer
1 }

clamp :: NatRepr u -> StridedInterval u -> StridedInterval u
clamp :: forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz StridedInterval u
v = StridedInterval u -> NatRepr u -> StridedInterval u
forall (u :: Nat) (v :: Nat).
StridedInterval u -> NatRepr v -> StridedInterval v
trunc StridedInterval u
v NatRepr u
sz

bvadd :: NatRepr u
      -> StridedInterval u
      -> StridedInterval u
      -> StridedInterval u
--bvadd _ EmptyInterval{} _ = EmptyInterval
--bvadd _ _ EmptyInterval{} = EmptyInterval
bvadd :: forall (u :: Nat).
NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
bvadd NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2
  | Just Integer
s <- StridedInterval u -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval u
si1 =
      NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval u
si2 { base = base si2 + s}
bvadd NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2
  | Just Integer
s <- StridedInterval u -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval u
si2 =
      NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval u
si1 { base = base si1 + s }
bvadd NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2 =
  NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval { typ :: NatRepr u
typ = StridedInterval u -> NatRepr u
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval u
si1
                             , base :: Integer
base = StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si2
                             , range :: Integer
range = Integer
r
                             , stride :: Integer
stride = Integer
m }
  where
    m :: Integer
m = Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2)
    r :: Integer
r | Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = String -> Integer
forall a. (?callStack::CallStack) => String -> a
error String
"bvadd given 0 stride"
      | Bool
otherwise = (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
m)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
m))

bvadc :: NatRepr u
      -> StridedInterval u
      -> StridedInterval u
      -> Maybe Bool
      -> StridedInterval u
bvadc :: forall (u :: Nat).
NatRepr u
-> StridedInterval u
-> StridedInterval u
-> Maybe Bool
-> StridedInterval u
bvadc NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2 (Just Bool
b)
  | Just Integer
s <- StridedInterval u -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval u
si1 =
      NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval u
si2 { base = base si2 + s + (if b then 1 else 0) }
bvadc NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2 (Just Bool
b)
  | Just Integer
s <- StridedInterval u -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval u
si2 =
      NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval u
si1 { base = base si1 + s + (if b then 1 else 0) }
bvadc NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2 Maybe Bool
b =
  NatRepr u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u -> StridedInterval u -> StridedInterval u
clamp NatRepr u
sz (StridedInterval u -> StridedInterval u)
-> StridedInterval u -> StridedInterval u
forall a b. (a -> b) -> a -> b
$ StridedInterval { typ :: NatRepr u
typ = StridedInterval u -> NatRepr u
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval u
si1
                             , base :: Integer
base = StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (if Maybe Bool
b Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True then Integer
1 else Integer
0)
                             , range :: Integer
range = Integer
r
                             , stride :: Integer
stride = Integer
m }
  where
    m :: Integer
m | Maybe Bool
b Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Bool
forall a. Maybe a
Nothing = Integer
1
      | Bool
otherwise = Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2)

    -- The amount to increase the range by given b.
    -- We add 1 when b == Nothing due to the uncertainty
    b_off :: Integer
b_off | Maybe Bool
b Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Bool
forall a. Maybe a
Nothing = Integer
1
          | Bool
otherwise = Integer
0

    r :: Integer
r | Integer
m Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 = String -> Integer
forall a. (?callStack::CallStack) => String -> a
error String
"bvadd given 0 stride"
      | Bool
otherwise = (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
m)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
m) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
b_off)

bvmul :: NatRepr u
      -> StridedInterval u
      -> StridedInterval u
      -> StridedInterval u
--bvmul _ EmptyInterval{} _ = EmptyInterval
--bvmul _ _ EmptyInterval{} = EmptyInterval
-- bvmul sz si1 si2 = top sz -- FIXME: this blows up with the trunc, unfortunately
bvmul :: forall (u :: Nat).
NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
bvmul NatRepr u
sz StridedInterval u
si1 StridedInterval u
si2 =
  NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
bvadd NatRepr u
sz
        (NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
forall (u :: Nat).
NatRepr u
-> StridedInterval u -> StridedInterval u -> StridedInterval u
bvadd NatRepr u
sz
               (Integer -> Integer -> Integer -> StridedInterval u
mk (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si2) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si1) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si2))
               (Integer -> Integer -> Integer -> StridedInterval u
mk Integer
0 (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si2) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval u
si1)))
        (Integer -> Integer -> Integer -> StridedInterval u
mk Integer
0 (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval u
si2) (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si2))
  where
    mk :: Integer -> Integer -> Integer -> StridedInterval u
mk Integer
b Integer
r Integer
s
      | Integer
s Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0    = NatRepr u -> Integer -> StridedInterval u
forall (w :: Nat). NatRepr w -> Integer -> StridedInterval w
singleton (StridedInterval u -> NatRepr u
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval u
si1) Integer
b
      | Bool
otherwise = StridedInterval { typ :: NatRepr u
typ = StridedInterval u -> NatRepr u
forall (w :: Nat). StridedInterval w -> NatRepr w
typ StridedInterval u
si1, base :: Integer
base = Integer
b, range :: Integer
range = Integer
r, stride :: Integer
stride = Integer
s }

-- prop_bvmul ::  StridedInterval (BVType 64)
--             -> StridedInterval (BVType 64)
--             -> Bool
-- prop_bvmul = mk_prop (*) bvmul

-- filterLeq :: NatRepr w -> StridedInterval w -> Integer -> StridedInterval w
-- filterLeq tp@(BVTypeRepr _) si x = glb si (mkStridedInterval tp False 0 x 1)

-- filterGeq :: NatRepr w -> StridedInterval w -> Integer -> StridedInterval w
-- filterGeq tp@(BVTypeRepr _) si x = glb si (mkStridedInterval tp False x u 1)
--   where
--     u = case tp of BVTypeRepr n -> maxUnsigned n

-- | Returns the least b' s.t. exists i < n. b' = (b + i * q) mod m
-- This is a little tricky.  We want to find { b + i * q | i <- {0 .. n} } mod M
-- Now, if we know the values in {0..q} we can figure out the whole sequence.
-- In particular, we are looking for the new b --- the stride is gcd q M (assuming wrap)
--
-- Let q_w be k * q mod M s.t. 0 <= q_w < q.  Alternately, q_w = q - M mod q
--
-- Assume wlog b0 = b < q.  Take
-- b1 = b0 + q_w
-- b2 = b1 + q_w
-- b3 = ...
--
-- i.e. bs = { b + i * q_w | i <- {0 .. n `div` ceil(m / q)} }
--
-- Now, we are interested in these value mod q, so take
-- bs = { b0 + i * q_w } mod q
--
-- Thus, we get a recursion of (M, q) (q, q_w) (q_w, ...)
--
-- = (M, q) (q, q - M mod q) (q - M mod q, (q - M mod q) - q mod (q - M mod q))
--
-- Base cases:
--  Firstly, when q divides M, we can stop, or n == 0
--  Otherwise, for small M we can search for some i, that is
--  the least 0 <= k < M s.t.
--
--  EX i. (b + i * q) mod M = k
--

-- Assumes b < q (?)
-- currently broken :(
-- leastMod :: Integer -> Integer -> Integer -> Integer -> Integer
-- leastMod _ 0 _ _ = error "leastMod given m = 0"
-- leastMod _ _ 0 _ = error "leastMod given q = 0"
-- leastMod b m q n
--   | b + n * q < m  = b -- no wrap
--   | m `mod` q == 0 = b -- assumes q <= m
--   | otherwise =
--       debug DAbsInt (show ((b, m, q, n), (next_b, m', q', next_n, next_n `div` m_div_q))) $
--       leastMod next_b m' q'
--                 -- FIXME: we sometimes miss a +1 here, we do this to
--                 -- be conservative (overapprox.)
--       (next_n `div` m_div_q)
--   where
--     m_div_q | q == 0 = error "leastMod given q == 0"
--             | r == 0 = error "leastMod given m `div` q == 0"
--             | otherwise = r
--       where r = m `div` q
--     m' = q
--     q' | q == 0 = error "leastMod given q == 0"
--        | otherwise = q - m `mod` q
--     (next_b, next_n)
--       | b < q'    = (b, n)
--       | otherwise = let i = m_div_q + 1
--                     in (((b + i * q) `mod` m) `mod` q, n - i)

-- | Truncate an interval.

-- OPT: this could be made much more efficient I think.
trunc :: StridedInterval u
      -> NatRepr v
      -> StridedInterval v
trunc :: forall (u :: Nat) (v :: Nat).
StridedInterval u -> NatRepr v -> StridedInterval v
trunc StridedInterval u
si NatRepr v
sz
  | StridedInterval u -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isTop StridedInterval u
si              = StridedInterval v
top'
  -- No change/complete wrap case --- happens when we add
  -- (unsigned int) -1, for example.
  | StridedInterval v
si' StridedInterval v -> StridedInterval v -> Bool
forall (tp :: Nat).
StridedInterval tp -> StridedInterval tp -> Bool
`isSubsetOf` StridedInterval v
top' = StridedInterval v
si'
  -- where stride is a power of 2 (well, divides 2 ^ sz), we easily
  -- figure out the new base and just over-approximate by all the values
  | Integer
modulus Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0 =
      let base' :: Integer
base' = (Integer
base_mod_sz
                  Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ (StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* ((Integer
modulus Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
base_mod_sz) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`ceilDiv` StridedInterval u -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval u
si)))
                  Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`mod` Integer
modulus
      in StridedInterval v
si' { base = base', range = (modulus `ceilDiv` stride si) - 1 }
   -- We wrap at least once
  | Bool
otherwise     = StridedInterval v
top'
  where
    modulus :: Integer
modulus = Integer
2 Integer -> Nat -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (NatRepr v -> Nat
forall (n :: Nat). NatRepr n -> Nat
natValue NatRepr v
sz)
    si' :: StridedInterval v
si'  = StridedInterval u
si { typ = typ top'
              , base = toUnsigned sz (base si) }
    top' :: StridedInterval v
top' = NatRepr v -> StridedInterval v
forall (w :: Nat). NatRepr w -> StridedInterval w
top NatRepr v
sz
    base_mod_sz :: Integer
base_mod_sz = StridedInterval v -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval v
si'
    -- positive only
    ceilDiv :: a -> a -> a
ceilDiv a
_ a
0 = String -> a
forall a. (?callStack::CallStack) => String -> a
error String
"SI.trunc given 0 stride."
    ceilDiv a
x a
y = (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y a -> a -> a
forall a. Num a => a -> a -> a
- a
1) a -> a -> a
forall a. Integral a => a -> a -> a
`div` a
y

-- -----------------------------------------------------------------------------
-- Testing

toList :: StridedInterval w -> [Integer]
toList :: forall (w :: Nat). StridedInterval w -> [Integer]
toList si :: StridedInterval w
si@StridedInterval{} = (Integer -> Integer) -> [Integer] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\Integer
v -> StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
v) [Integer
0 .. StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
si]

instance Pretty (StridedInterval w) where
  pretty :: forall ann. StridedInterval w -> Doc ann
pretty StridedInterval w
si | StridedInterval w -> Bool
forall (w :: Nat). StridedInterval w -> Bool
isEmpty StridedInterval w
si = Doc ann
"[]"
  pretty StridedInterval w
si | Just Integer
s <- StridedInterval w -> Maybe Integer
forall (w :: Nat). StridedInterval w -> Maybe Integer
isSingleton StridedInterval w
si = Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets (Integer -> Doc ann
forall ann. Integer -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Integer
s)
  pretty si :: StridedInterval w
si@StridedInterval{} = Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets (Integer -> Doc ann
forall ann. Integer -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si) Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
comma
                                          Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Integer -> Doc ann
forall ann. Integer -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si)
                                          Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann
".."
                                          Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Integer -> Doc ann
forall ann. Integer -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
base StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
range StridedInterval w
si Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* StridedInterval w -> Integer
forall (w :: Nat). StridedInterval w -> Integer
stride StridedInterval w
si))

instance Arbitrary (StridedInterval 64) where
  arbitrary :: Gen (StridedInterval 64)
arbitrary = [(Int, Gen (StridedInterval 64))] -> Gen (StridedInterval 64)
forall a. (?callStack::CallStack) => [(Int, Gen a)] -> Gen a
frequency [ (Int
1, StridedInterval 64 -> Gen (StridedInterval 64)
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (NatRepr 64 -> StridedInterval 64
forall (w :: Nat). NatRepr w -> StridedInterval w
empty NatRepr 64
forall (n :: Nat). KnownNat n => NatRepr n
knownNat))
                        , (Int
9, Gen (StridedInterval 64)
si) ]
    where
      si :: Gen (StridedInterval 64)
si = do Integer
lower <- (Int -> Gen Integer) -> Gen Integer
forall a. (Int -> Gen a) -> Gen a
sized ((Int -> Gen Integer) -> Gen Integer)
-> (Int -> Gen Integer) -> Gen Integer
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Integer, Integer) -> Gen Integer
forall a. Random a => (a, a) -> Gen a
choose (Integer
0, Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
n)
              Integer
upper <- (Int -> Gen Integer) -> Gen Integer
forall a. (Int -> Gen a) -> Gen a
sized ((Int -> Gen Integer) -> Gen Integer)
-> (Int -> Gen Integer) -> Gen Integer
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Integer, Integer) -> Gen Integer
forall a. Random a => (a, a) -> Gen a
choose (Integer
lower, Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
n)
              Integer
s     <- (Int -> Gen Integer) -> Gen Integer
forall a. (Int -> Gen a) -> Gen a
sized ((Int -> Gen Integer) -> Gen Integer)
-> (Int -> Gen Integer) -> Gen Integer
forall a b. (a -> b) -> a -> b
$ \Int
n -> (Integer, Integer) -> Gen Integer
forall a. Random a => (a, a) -> Gen a
choose (Integer
1, Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
n)
              StridedInterval 64 -> Gen (StridedInterval 64)
forall a. a -> Gen a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (StridedInterval 64 -> Gen (StridedInterval 64))
-> StridedInterval 64 -> Gen (StridedInterval 64)
forall a b. (a -> b) -> a -> b
$ NatRepr 64
-> Bool -> Integer -> Integer -> Integer -> StridedInterval 64
forall (w :: Nat).
NatRepr w
-> Bool -> Integer -> Integer -> Integer -> StridedInterval w
mkStridedInterval NatRepr 64
forall (n :: Nat). KnownNat n => NatRepr n
knownNat Bool
True Integer
lower Integer
upper Integer
s