diff options
Diffstat (limited to 'src/AST/Accum.hs')
-rw-r--r-- | src/AST/Accum.hs | 127 |
1 files changed, 102 insertions, 25 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 67c5de7..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,8 +1,8 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE UndecidableInstances #-} module AST.Accum where @@ -26,35 +26,112 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - -- TODO: This SNat is rather useless, you always have an STy around too - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) -type family AcIdx p t where - AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = AcIdx p a - AcIdx (APSnd p) (TPair a b) = AcIdx p b - AcIdx (APLeft p) (TEither a b) = AcIdx p a - AcIdx (APRight p) (TEither a b) = AcIdx p b - AcIdx (APJust p) (TMaybe a) = AcIdx p a - AcIdx (APArrIdx p) (TArr n a) = - -- ((index, array shape), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) - (AcIdx p a) - -- AcIdx (APArrSlice m) (TArr n a) = +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -acPrjTy :: SAcPrj p a b -> STy a -> STy b +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil |