{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where import AST.Types import CHAD.Types import Data data AcPrj = APHere | APFst AcPrj | APSnd AcPrj | APLeft AcPrj | APRight AcPrj | APJust AcPrj | APArrIdx AcPrj | APArrSlice Nat -- | @b@ is a small part of @a@, indicated by the projection @p@. data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) type family AcIdx p t where AcIdx APHere t = TNil AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) AcIdx (APLeft p) (TLEither a b) = AcIdx p a AcIdx (APRight p) (TLEither a b) = AcIdx p b AcIdx (APJust p) (TMaybe a) = AcIdx p a AcIdx (APArrIdx p) (TArr n a) = -- ((index, shapes info), recursive info) TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) (AcIdx p a) -- AcIdx (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t type family ZeroInfo t where ZeroInfo TNil = TNil ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) ZeroInfo (TLEither a b) = TNil ZeroInfo (TMaybe a) = TNil ZeroInfo (TArr n t) = TArr n (ZeroInfo t) ZeroInfo (TScal t) = TNil tZeroInfo :: SMTy t -> STy (ZeroInfo t) tZeroInfo SMTNil = STNil tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) tZeroInfo (SMTLEither _ _) = STNil tZeroInfo (SMTMaybe _) = STNil tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) tZeroInfo (SMTScal _) = STNil lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil lemZeroInfoD2 STNil = Refl lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl lemZeroInfoD2 (STScal STI32) = Refl lemZeroInfoD2 (STScal STI64) = Refl lemZeroInfoD2 (STScal STF32) = Refl lemZeroInfoD2 (STScal STF64) = Refl lemZeroInfoD2 (STScal STBool) = Refl lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -- -- | Additional info needed for accumulation. This is empty unless there is -- -- sparsity in the monoid. -- type family AccumInfo t where -- AccumInfo TNil = TNil -- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) -- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) -- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) -- AccumInfo (TArr n t) = TArr n (AccumInfo t) -- AccumInfo (TScal t) = TNil -- type family PrimalInfo t where -- PrimalInfo TNil = TNil -- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) -- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) -- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) -- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) -- PrimalInfo (TScal t) = TNil -- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) -- tPrimalInfo SMTNil = STNil -- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) -- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) -- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) -- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) -- tPrimalInfo (SMTScal _) = STNil