{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where

import AST.Types
import Data


data AcPrj
  = APHere
  | APFst AcPrj
  | APSnd AcPrj
  | APLeft AcPrj
  | APRight AcPrj
  | APJust AcPrj
  | APArrIdx AcPrj
  | APArrSlice Nat

-- | @b@ is a small part of @a@, indicated by the projection @p@.
data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
  SAPHere :: SAcPrj APHere a a
  SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
  SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
  SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
  SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
  SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
  -- TODO: This SNat is rather useless, you always have an STy around too
  SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
  -- TODO:
  -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)

type family AcIdx p t where
  AcIdx APHere t = TNil
  AcIdx (APFst p) (TPair a b) = AcIdx p a
  AcIdx (APSnd p) (TPair a b) = AcIdx p b
  AcIdx (APLeft p) (TEither a b) = AcIdx p a
  AcIdx (APRight p) (TEither a b) = AcIdx p b
  AcIdx (APJust p) (TMaybe a) = AcIdx p a
  AcIdx (APArrIdx p) (TArr n a) =
    -- ((index, array shape), recursive info)
    TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
          (AcIdx p a)
  -- AcIdx (APArrSlice m) (TArr n a) =
  --   -- (index, array shape)
  --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))

acPrjTy :: SAcPrj p a b -> STy a -> STy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t