diff options
Diffstat (limited to 'src/AST/Accum.hs')
| -rw-r--r-- | src/AST/Accum.hs | 116 |
1 files changed, 0 insertions, 116 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs deleted file mode 100644 index 03369c8..0000000 --- a/src/AST/Accum.hs +++ /dev/null @@ -1,116 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.Accum where - -import AST.Types -import CHAD.Types -import Data - - -data AcPrj - = APHere - | APFst AcPrj - | APSnd AcPrj - | APLeft AcPrj - | APRight AcPrj - | APJust AcPrj - | APArrIdx AcPrj - | APArrSlice Nat - --- | @b@ is a small part of @a@, indicated by the projection @p@. -data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where - SAPHere :: SAcPrj APHere a a - SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b - SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b - SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b - -- TODO: - -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) -deriving instance Show (SAcPrj p a b) - -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) - AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) - AcIdx (APLeft p) (TLEither a b) = AcIdx p a - AcIdx (APRight p) (TLEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, shapes info), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = - -- -- (index, array shape) - -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) - -acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b -acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t - -type family ZeroInfo t where - ZeroInfo TNil = TNil - ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) - ZeroInfo (TLEither a b) = TNil - ZeroInfo (TMaybe a) = TNil - ZeroInfo (TArr n t) = TArr n (ZeroInfo t) - ZeroInfo (TScal t) = TNil - -tZeroInfo :: SMTy t -> STy (ZeroInfo t) -tZeroInfo SMTNil = STNil -tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) -tZeroInfo (SMTLEither _ _) = STNil -tZeroInfo (SMTMaybe _) = STNil -tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) -tZeroInfo (SMTScal _) = STNil - -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" - --- -- | Additional info needed for accumulation. This is empty unless there is --- -- sparsity in the monoid. --- type family AccumInfo t where --- AccumInfo TNil = TNil --- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) --- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) --- AccumInfo (TArr n t) = TArr n (AccumInfo t) --- AccumInfo (TScal t) = TNil - --- type family PrimalInfo t where --- PrimalInfo TNil = TNil --- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) --- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) --- PrimalInfo (TScal t) = TNil - --- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) --- tPrimalInfo SMTNil = STNil --- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) --- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) --- tPrimalInfo (SMTScal _) = STNil |
