{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Analysis.Identity (
  identityAnalysis,
) where

import Data.Foldable (toList)
import Data.List (intercalate)

import AST
import AST.Pretty (PrettyX(..))
import CHAD.Types (d1, d2)
import Data
import Util.IdGen


-- | Every array, scalar and accumulator has an ID. Trivial values such as
-- Nothing only have the knowledge that they are indeed Nothing. Compound
-- values know which values they consist of.
data ValId t where
  VINil :: ValId TNil
  VIPair :: ValId a -> ValId b -> ValId (TPair a b)
  VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b)  -- ^ known alternative
  VIEither' :: ValId a -> ValId b -> ValId (TEither a b)  -- ^ unknown alternative, but known values in each case
  VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a)
  VIMaybe' :: ValId a -> ValId (TMaybe a)  -- ^ if it's Just, it contains this value
  VIArr :: Int -> Vec n Int -> ValId (TArr n t)
  VIScal :: Int -> ValId (TScal t)
  VIAccum :: Int -> ValId (TAccum t)

  -- | We don't know what this consists of, but it's a value, and let's just
  -- give it an ID nevertheless.
  VIThing :: STy t -> Int -> ValId t

instance Eq (ValId t) where
  VINil         == VINil           = True
  VINil         == _               = False
  VIPair a b    == VIPair a' b'    = a == a' && b == b'
  VIPair{}      == _               = False
  VIEither a    == VIEither a'     = a == a'
  VIEither{}    == _               = False
  VIEither' a b == VIEither' a' b' = a == a' && b == b'
  VIEither'{}   == _               = False
  VIMaybe a     == VIMaybe a'      = a == a'
  VIMaybe{}     == _               = False
  VIMaybe' a    == VIMaybe' a'     = a == a'
  VIMaybe'{}    == _               = False
  VIArr i is    == VIArr i' is'    = i == i' && is == is'
  VIArr{}       == _               = False
  VIScal i      == VIScal i'       = i == i'
  VIScal{}      == _               = False
  VIAccum i     == VIAccum i'      = i == i'
  VIAccum{}     == _               = False
  VIThing _ i   == VIThing _ i'    = i == i'
  VIThing{}     == _               = False

instance PrettyX ValId where
  prettyX = \case
    VINil -> ""
    VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")"
    VIEither (Left a) -> "(L" ++ prettyX a ++ ")"
    VIEither (Right a) -> "(R" ++ prettyX a ++ ")"
    VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")"
    VIMaybe Nothing -> "N"
    VIMaybe (Just a) -> 'J' : prettyX a
    VIMaybe' a -> 'M' : prettyX a
    VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]"
    VIScal i -> show i
    VIAccum i -> 'C' : show i
    VIThing _ i -> '{' : show i ++ "}"

-- | Symbolic partial evaluation.
identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t
identityAnalysis env term = runIdGen 0 $ do
  env' <- slistMapA genIds env
  snd <$> idana env' term

idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t)
idana env expr = case expr of
  EVar _ t i -> do
    let v = slistIdx env i
    pure (v, EVar v t i)

  ELet _ e1 e2 -> do
    (v1, e1') <- idana env e1
    (v2, e2') <- idana (v1 `SCons` env) e2
    pure (v2, ELet v2 e1' e2')

  EPair _ e1 e2 -> do
    (v1, e1') <- idana env e1
    (v2, e2') <- idana env e2
    pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2')

  EFst _ e -> do
    (v, e') <- idana env e
    v' <- case v of VIPair v1 _ -> pure v1
                    _ -> genIds (typeOf expr)
    pure (v', EFst v' e')

  ESnd _ e -> do
    (v, e') <- idana env e
    v' <- case v of VIPair _ v2 -> pure v2
                    _ -> genIds (typeOf expr)
    pure (v', ESnd v' e')

  ENil _ -> pure (VINil, ENil VINil)

  EInl _ t2 e1 -> do
    (v1, e1') <- idana env e1
    let v = VIEither (Left v1)
    pure (v, EInl v t2 e1')

  EInr _ t1 e2 -> do
    (v2, e2') <- idana env e2
    let v = VIEither (Right v2)
    pure (v, EInr v t1 e2')

  ECase _ e1 e2 e3 -> do
    let STEither t1 t2 = typeOf e1
    (v1, e1') <- idana env e1
    case v1 of
      VIEither (Left v1') -> do
        (v2, e2') <- idana (v1' `SCons` env) e2
        scrap <- genIds t2
        (_, e3') <- idana (scrap `SCons` env) e3
        pure (v2, ECase v2 e1' e2' e3')
      VIEither (Right v1') -> do
        scrap <- genIds t1
        (_, e2') <- idana (scrap `SCons` env) e2
        (v3, e3') <- idana (v1' `SCons` env) e3
        pure (v3, ECase v3 e1' e2' e3')
      VIEither' v1'l v1'r -> do
        (_, e2') <- idana (v1'l `SCons` env) e2
        (_, e3') <- idana (v1'r `SCons` env) e3
        res <- genIds (typeOf expr)
        pure (res, ECase res e1' e2' e3')
      VIThing _ _ -> do
        x2 <- genIds t1
        x3 <- genIds t2
        (v2, e2') <- idana (x2 `SCons` env) e2
        (v3, e3') <- idana (x3 `SCons` env) e3
        res <- unify v2 v3
        pure (res, ECase res e1' e2' e3')

  ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t)

  EJust _ e1 -> do
    (v1, e1') <- idana env e1
    let v = VIMaybe (Just v1)
    pure (v, EJust v e1')

  EMaybe _ e1 e2 e3 -> do
    let STMaybe t1 = typeOf e3
    (v3, e3') <- idana env e3
    case v3 of
      VIMaybe Nothing -> do
        (v1, e1') <- idana env e1
        scrap <- genIds t1
        (_, e2') <- idana (scrap `SCons` env) e2
        pure (v1, EMaybe v1 e1' e2' e3')
      VIMaybe (Just v3j) -> do
        (v2, e2') <- idana (v3j `SCons` env) e2
        (_, e1') <- idana env e1
        pure (v2, EMaybe v2 e1' e2' e3')
      VIMaybe' v3' -> do
        (v2, e2') <- idana (v3' `SCons` env) e2
        (v1, e1') <- idana env e1
        res <- unify v1 v2
        pure (res, EMaybe res e1' e2' e3')
      VIThing _ _ -> do
        (v1, e1') <- idana env e1
        scrap <- genIds t1
        (v2, e2') <- idana (scrap `SCons` env) e2
        res <- unify v1 v2
        pure (res, EMaybe res e1' e2' e3')

  EConstArr _ dim t arr -> do
    x1 <- VIArr <$> genId <*> vecReplicateA dim genId
    pure (x1, EConstArr x1 dim t arr)

  EBuild _ dim e1 e2 -> do
    (shids, e1') <- idana env e1
    x1 <- genIds (tTup (sreplicate dim tIx))
    (_, e2') <- idana (x1 `SCons` env) e2
    res <- VIArr <$> genId <*> shidsToVec dim shids
    pure (res, EBuild res dim e1' e2')

  EFold1Inner _ e1 e2 e3 -> do
    let t1 = typeOf e1
        STArr dim _ = typeOf e3
    x1 <- genIds t1
    x2 <- genIds t1
    (_, e1') <- idana (x1 `SCons` x2 `SCons` env) e1
    (_, e2') <- idana env e2
    (v3, e3') <- idana env e3
    res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v3)
    pure (res, EFold1Inner res e1' e2' e3')

  ESum1Inner _ e1 -> do
    let STArr dim _ = typeOf e1
    (v1, e1') <- idana env e1
    res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
    pure (res, ESum1Inner res e1')

  EUnit _ e1 -> do
    (_, e1') <- idana env e1
    res <- VIArr <$> genId <*> pure VNil
    pure (res, EUnit res e1')

  EReplicate1Inner _ e1 e2 -> do
    let STArr dim _ = typeOf e2
    (v1, e1') <- idana env e1
    (v2, e2') <- idana env e2
    res <- VIArr <$> genId <*> ((valScalId v1 :<) <$> valArrShape dim v2)
    pure (res, EReplicate1Inner res e1' e2')

  EMaximum1Inner _ e1 -> do
    let STArr dim _ = typeOf e1
    (v1, e1') <- idana env e1
    res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
    pure (res, EMaximum1Inner res e1')

  EMinimum1Inner _ e1 -> do
    let STArr dim _ = typeOf e1
    (v1, e1') <- idana env e1
    res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
    pure (res, EMinimum1Inner res e1')

  EConst _ t val -> do
    res <- VIScal <$> genId
    pure (res, EConst res t val)

  EIdx0 _ e1 -> do
    (_, e1') <- idana env e1
    res <- genIds (typeOf expr)
    pure (res, EIdx0 res e1')

  EIdx1 _ e1 e2 -> do
    let STArr dim _ = typeOf e1
    (v1, e1') <- idana env e1
    (_, e2') <- idana env e2
    res <- VIArr <$> genId <*> (vecTail <$> valArrShape dim v1)
    pure (res, EIdx1 res e1' e2')

  EIdx _ e1 e2 -> do
    (_, e1') <- idana env e1
    (_, e2') <- idana env e2
    res <- genIds (typeOf expr)
    pure (res, EIdx res e1' e2')

  EShape _ e1 -> do
    let STArr dim _ = typeOf e1
    (v1, e1') <- idana env e1
    res <- vecToShids dim <$> valArrShape dim v1
    pure (res, EShape res e1')

  EOp _ (op :: SOp a t) e1 -> do
    (_, e1') <- idana env e1
    res <- genIds (typeOf expr)
    pure (res, EOp res op e1')

  ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do
    let t4 = typeOf e1
    x1 <- genIds t2
    x2 <- genIds t1
    (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1
    x3 <- genIds (d1 t2)
    x4 <- genIds (d1 t1)
    (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2
    x5 <- genIds (d2 t4)
    x6 <- genIds t3
    (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3
    (_, e4') <- idana env e4
    (_, e5') <- idana env e5
    res <- genIds t4
    pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5')

  EWith _ e1 e2 -> do
    let t1 = typeOf e1
    (_, e1') <- idana env e1
    x1 <- VIAccum <$> genId
    (v2, e2') <- idana (x1 `SCons` env) e2
    x2 <- genIds t1
    let res = VIPair v2 x2
    pure (res, EWith res e1' e2')

  EAccum _ i e1 e2 e3 -> do
    (_, e1') <- idana env e1
    (_, e2') <- idana env e2
    (_, e3') <- idana env e3
    pure (VINil, EAccum VINil i e1' e2' e3')

  EZero _ t -> do
    res <- genIds (d2 t)
    pure (res, EZero res t)

  EPlus _ t e1 e2 -> do
    (_, e1') <- idana env e1
    (_, e2') <- idana env e2
    res <- genIds (d2 t)
    pure (res, EPlus res t e1' e2')

  EOneHot _ t i e1 e2 -> do
    (_, e1') <- idana env e1
    (_, e2') <- idana env e2
    res <- genIds (d2 t)
    pure (res, EOneHot res t i e1' e2')

  EError _ t s -> do
    res <- genIds t
    pure (res, EError res t s)

-- | This value might be either of the two arguments; we don't know which.
unify :: ValId t -> ValId t -> IdGen (ValId t)
unify VINil VINil = pure VINil
unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d
unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b
unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b
unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b
unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a
unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c
unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c
unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b
unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c
unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d
unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing
unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b
unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a
unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a
unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a
unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b
unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b
unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js
unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j
unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j
unify (VIThing t i) (VIThing _ j) | i == j = pure $ VIThing t i
                                  | otherwise = genIds t
unify (VIThing t _) _ = genIds t
unify _ (VIThing t _) = genIds t

unifyID :: Int -> Int -> IdGen Int
unifyID i j | i == j = pure i
            | otherwise = genId

genIds :: STy t -> IdGen (ValId t)
genIds STNil = pure VINil
genIds (STPair a b) = VIPair <$> genIds a <*> genIds b
genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b
genIds (STMaybe t) = VIMaybe' <$> genIds t
genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId
genIds STScal{} = VIScal <$> genId
genIds STAccum{} = VIAccum <$> genId

shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int)
shidsToVec SZ _ = pure VNil
shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is
shidsToVec (SS n) (VIPair is (VIThing _ i)) = (i :<) <$> shidsToVec n is
shidsToVec n VIThing{} = vecReplicateA n genId

vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx))
vecToShids SZ VNil = VINil
vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i)

valScalId :: ValId (TScal t) -> Int
valScalId (VIScal i) = i
valScalId (VIThing _ i) = i

valArrShape :: SNat n -> ValId (TArr n t) -> IdGen (Vec n Int)
valArrShape _ (VIArr _ v) = pure v
valArrShape n _ = vecReplicateA n genId