{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE FlexibleContexts #-}
module Interpreter (
  interpret,
  interpret',
  Value,
  NoAccum(..),
  unAccum,
) where

import Data.Int (Int64)
import Data.Proxy

import AST
import Data
import Array
import Interpreter.Accum
import Interpreter.Rep
import Control.Monad (foldM)


interpret :: NoAccum t => Ex '[] t -> Rep t
interpret e = runAcM (go e)
  where
    go :: forall s t. NoAccum t => Ex '[] t -> AcM s (Rep t)
    go e' | Refl <- noAccum (Proxy @s) (Proxy @t) = interpret' SNil e'

newtype Value s t = Value (Rep' s t)

interpret' :: forall env t s. SList (Value s) env -> Ex env t -> AcM s (Rep' s t)
interpret' env = \case
  EVar _ _ i -> case slistIdx env i of Value x -> return x
  ELet _ a b -> do
    x <- interpret' env a
    interpret' (Value x `SCons` env) b
  EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b
  EFst _ e -> fst <$> interpret' env e
  ESnd _ e -> snd <$> interpret' env e
  ENil _ -> return ()
  EInl _ _ e -> Left <$> interpret' env e
  EInr _ _ e -> Right <$> interpret' env e
  ECase _ e a b -> interpret' env e >>= \case
                     Left x -> interpret' (Value x `SCons` env) a
                     Right y -> interpret' (Value y `SCons` env) b
  EConstArr _ _ _ v -> return v
  EBuild1 _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arrayGenerateLinM (ShNil `ShCons` n)
                      (\i -> interpret' (Value (fromIntegral @Int @Int64 i) `SCons` env) b)
  EBuild _ dim a b -> do
    sh <- unTupRepIdx (Proxy @s) ShNil ShCons dim <$> interpret' env a
    arrayGenerateM sh (\idx -> interpret' (Value (tupRepIdx (Proxy @s) ixUncons dim idx) `SCons` env) b)
  EFold1Inner _ a b -> do
    let f = \x y -> interpret' (Value y `SCons` Value x `SCons` env) a
    arr <- interpret' env b
    let sh `ShCons` n = arrayShape arr
    arrayGenerateM sh $ \idx -> foldl1M f [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  ESum1Inner _ e -> do
    arr <- interpret' env e
    let STArr _ (STScal t) = typeOf e
        sh `ShCons` n = arrayShape arr
    numericIsNum t $ arrayGenerateM sh $ \idx -> return $ sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]]
  EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e)
  EReplicate1Inner _ a b -> do
    n <- fromIntegral @Int64 @Int <$> interpret' env a
    arr <- interpret' env b
    let sh = arrayShape arr
    arrayGenerateM (sh `ShCons` n) (\(idx `IxCons` _) -> return (arrayIndex arr idx))
  EConst _ _ v -> return v
  EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e
  EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b)
  EIdx _ n a b -> arrayIndex <$> interpret' env a <*> (unTupRepIdx (Proxy @s) IxNil IxCons n <$> interpret' env b)
  EShape _ e | STArr n _ <- typeOf e -> tupRepIdx (Proxy @s) shUncons n . arrayShape <$> interpret' env e
  EOp _ op e -> interpretOp (Proxy @s) op <$> interpret' env e
  EWith e1 e2 -> do
    initval <- interpret' env e1
    withAccum (typeOf e1) initval $ \accum ->
      interpret' (Value accum `SCons` env) e2
  EAccum i e1 e2 e3 -> do
    idx <- interpret' env e1
    val <- interpret' env e2
    accum <- interpret' env e3
    accumAdd accum i idx val
  EError _ s -> error $ "Interpreter: Program threw error: " ++ s

interpretOp :: Proxy s -> SOp a t -> Rep' s a -> Rep' s t
interpretOp _ op arg = case op of
  OAdd st -> numericIsNum st $ uncurry (+) arg
  OMul st -> numericIsNum st $ uncurry (*) arg
  ONeg st -> numericIsNum st $ negate arg
  OLt st -> numericIsNum st $ uncurry (<) arg
  OLe st -> numericIsNum st $ uncurry (<=) arg
  OEq st -> numericIsNum st $ uncurry (==) arg
  ONot -> not arg
  OIf -> if arg then Left () else Right ()

numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r
numericIsNum STI32 = id
numericIsNum STI64 = id
numericIsNum STF32 = id
numericIsNum STF64 = id

unTupRepIdx :: Proxy s -> f Z -> (forall m. f m -> Int -> f (S m))
            -> SNat n -> Rep' s (Tup (Replicate n TIx)) -> f n
unTupRepIdx _ nil _    SZ _ = nil
unTupRepIdx p nil cons (SS n) (idx, i) = unTupRepIdx p nil cons n idx `cons` fromIntegral @Int64 @Int i

tupRepIdx :: Proxy s -> (forall m. f (S m) -> (f m, Int))
          -> SNat n -> f n -> Rep' s (Tup (Replicate n TIx))
tupRepIdx _ _      SZ _ = ()
tupRepIdx p uncons (SS n) tup =
  let (tup', i) = uncons tup
  in (tupRepIdx p uncons n tup', fromIntegral @Int @Int64 i)

ixUncons :: Index (S n) -> (Index n, Int)
ixUncons (IxCons idx i) = (idx, i)

shUncons :: Shape (S n) -> (Shape n, Int)
shUncons (ShCons idx i) = (idx, i)

class NoAccum t where
  noAccum :: Proxy s -> Proxy t -> Rep' s t :~: Rep t
instance NoAccum TNil where
  noAccum _ _ = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TPair a b) where
  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance (NoAccum a, NoAccum b) => NoAccum (TEither a b) where
  noAccum p _ | Refl <- noAccum p (Proxy @a), Refl <- noAccum p (Proxy @b) = Refl
instance NoAccum t => NoAccum (TArr n t) where
  noAccum p _ | Refl <- noAccum p (Proxy @t) = Refl
instance NoAccum (TScal t) where
  noAccum _ _ = Refl

unAccum :: Proxy s -> STy t -> Maybe (Dict (NoAccum t))
unAccum _ STNil = Just Dict
unAccum p (STPair t1 t2)
  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
  | otherwise = Nothing
unAccum p (STEither t1 t2)
  | Just Dict <- unAccum p t1, Just Dict <- unAccum p t2 = Just Dict
  | otherwise = Nothing
unAccum p (STArr _ t)
  | Just Dict <- unAccum p t = Just Dict
  | otherwise = Nothing
unAccum _ STScal{} = Just Dict
unAccum _ STAccum{} = Nothing

foldl1M :: Monad m => (a -> a -> m a) -> [a] -> m a
foldl1M _ [] = error "foldl1M: empty list"
foldl1M f (tophead : toptail) = foldM f tophead toptail