{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE StandaloneKindSignatures #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
module AST.Types where

import Data.Int (Int32, Int64)
import Data.GADT.Show
import Data.Kind (Type)
import Data.Some
import Data.Type.Equality

import Data


data Ty
  = TNil
  | TPair Ty Ty
  | TEither Ty Ty
  | TMaybe Ty
  | TArr Nat Ty  -- ^ rank, element type
  | TScal ScalTy
  | TAccum Ty  -- ^ the accumulator contains D2 of this type
  deriving (Show, Eq, Ord)

data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool
  deriving (Show, Eq, Ord)

type STy :: Ty -> Type
data STy t where
  STNil :: STy TNil
  STPair :: STy a -> STy b -> STy (TPair a b)
  STEither :: STy a -> STy b -> STy (TEither a b)
  STMaybe :: STy a -> STy (TMaybe a)
  STArr :: SNat n -> STy t -> STy (TArr n t)
  STScal :: SScalTy t -> STy (TScal t)
  STAccum :: STy t -> STy (TAccum t)
deriving instance Show (STy t)

instance TestEquality STy where
  testEquality STNil STNil = Just Refl
  testEquality STNil _ = Nothing
  testEquality (STPair a b) (STPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STPair{} _ = Nothing
  testEquality (STEither a b) (STEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STEither{} _ = Nothing
  testEquality (STMaybe a) (STMaybe a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STMaybe{} _ = Nothing
  testEquality (STArr a b) (STArr a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl
  testEquality STArr{} _ = Nothing
  testEquality (STScal a) (STScal a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STScal{} _ = Nothing
  testEquality (STAccum a) (STAccum a') | Just Refl <- testEquality a a' = Just Refl
  testEquality STAccum{} _ = Nothing

instance GShow STy where gshowsPrec = defaultGshowsPrec

data SScalTy t where
  STI32 :: SScalTy TI32
  STI64 :: SScalTy TI64
  STF32 :: SScalTy TF32
  STF64 :: SScalTy TF64
  STBool :: SScalTy TBool
deriving instance Show (SScalTy t)

instance TestEquality SScalTy where
  testEquality STI32 STI32 = Just Refl
  testEquality STI64 STI64 = Just Refl
  testEquality STF32 STF32 = Just Refl
  testEquality STF64 STF64 = Just Refl
  testEquality STBool STBool = Just Refl
  testEquality _ _ = Nothing

instance GShow SScalTy where gshowsPrec = defaultGshowsPrec

scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t))
scalRepIsShow STI32 = Dict
scalRepIsShow STI64 = Dict
scalRepIsShow STF32 = Dict
scalRepIsShow STF64 = Dict
scalRepIsShow STBool = Dict

type TIx = TScal TI64

tIx :: STy TIx
tIx = STScal STI64

unSTy :: STy t -> Ty
unSTy = \case
  STNil -> TNil
  STPair a b -> TPair (unSTy a) (unSTy b)
  STEither a b -> TEither (unSTy a) (unSTy b)
  STMaybe t -> TMaybe (unSTy t)
  STArr n t -> TArr (unSNat n) (unSTy t)
  STScal t -> TScal (unSScalTy t)
  STAccum t -> TAccum (unSTy t)

unSEnv :: SList STy env -> [Ty]
unSEnv SNil = []
unSEnv (SCons t l) = unSTy t : unSEnv l

unSScalTy :: SScalTy t -> ScalTy
unSScalTy = \case
  STI32 -> TI32
  STI64 -> TI64
  STF32 -> TF32
  STF64 -> TF64
  STBool -> TBool

reSTy :: Ty -> Some STy
reSTy = \case
  TNil -> Some STNil
  TPair a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STPair a' b'
  TEither a b | Some a' <- reSTy a, Some b' <- reSTy b -> Some $ STEither a' b'
  TMaybe t | Some t' <- reSTy t -> Some $ STMaybe t'
  TArr n t | Some n' <- reSNat n, Some t' <- reSTy t -> Some $ STArr n' t'
  TScal t | Some t' <- reSScalTy t -> Some $ STScal t'
  TAccum t | Some t' <- reSTy t -> Some $ STAccum t'

reSEnv :: [Ty] -> Some (SList STy)
reSEnv [] = Some SNil
reSEnv (t : l) | Some t' <- reSTy t, Some env <- reSEnv l = Some (SCons t' env)

reSScalTy :: ScalTy -> Some SScalTy
reSScalTy = \case
  TI32 -> Some STI32
  TI64 -> Some STI64
  TF32 -> Some STF32
  TF64 -> Some STF64
  TBool -> Some STBool

type family ScalRep t where
  ScalRep TI32 = Int32
  ScalRep TI64 = Int64
  ScalRep TF32 = Float
  ScalRep TF64 = Double
  ScalRep TBool = Bool

type family ScalIsNumeric t where
  ScalIsNumeric TI32 = True
  ScalIsNumeric TI64 = True
  ScalIsNumeric TF32 = True
  ScalIsNumeric TF64 = True
  ScalIsNumeric TBool = False

type family ScalIsFloating t where
  ScalIsFloating TI32 = False
  ScalIsFloating TI64 = False
  ScalIsFloating TF32 = True
  ScalIsFloating TF64 = True
  ScalIsFloating TBool = False

type family ScalIsIntegral t where
  ScalIsIntegral TI32 = True
  ScalIsIntegral TI64 = True
  ScalIsIntegral TF32 = False
  ScalIsIntegral TF64 = False
  ScalIsIntegral TBool = False

-- | Returns true for arrays /and/ accumulators;
hasArrays :: STy t' -> Bool
hasArrays STNil = False
hasArrays (STPair a b) = hasArrays a || hasArrays b
hasArrays (STEither a b) = hasArrays a || hasArrays b
hasArrays (STMaybe t) = hasArrays t
hasArrays STArr{} = True
hasArrays STScal{} = False
hasArrays STAccum{} = True

type family Tup env where
  Tup '[] = TNil
  Tup (t : ts) = TPair (Tup ts) t

mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b))
      -> SList f list -> f (Tup list)
mkTup nil _    SNil = nil
mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e

tTup :: SList STy env -> STy (Tup env)
tTup = mkTup STNil STPair

unTup :: (forall a b. c (TPair a b) -> (c a, c b))
      -> SList f list -> c (Tup list) -> SList c list
unTup _ SNil _ = SNil
unTup unpack (_ `SCons` list) tup =
  let (xs, x) = unpack tup
  in x `SCons` unTup unpack list xs

type family InvTup core env where
  InvTup core '[] = core
  InvTup core (t : ts) = InvTup (TPair core t) ts