summaryrefslogtreecommitdiff
path: root/src/AST/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Accum.hs')
-rw-r--r--src/AST/Accum.hs60
1 files changed, 49 insertions, 11 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 1101cc0..988a450 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -1,6 +1,8 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE StandaloneDeriving #-}
+{-# LANGUAGE TypeData #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module AST.Accum where
@@ -32,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
-type family AcIdx p t where
- AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
- AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
- AcIdx (APLeft p) (TLEither a b) = AcIdx p a
- AcIdx (APRight p) (TLEither a b) = AcIdx p b
- AcIdx (APJust p) (TMaybe a) = AcIdx p a
- AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, shapes info), recursive info)
+type data AIDense = AID | AIS
+
+data SAIDense d where
+ SAID :: SAIDense AID
+ SAIS :: SAIDense AIS
+deriving instance Show (SAIDense d)
+
+type family AcIdx d p t where
+ AcIdx d APHere t = TNil
+ AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a
+ AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b
+ AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b)
+ AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b)
+ AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a
+ AcIdx d (APRight p) (TLEither a b) = AcIdx d p b
+ AcIdx d (APJust p) (TMaybe a) = AcIdx d p a
+ AcIdx AID (APArrIdx p) (TArr n a) =
+ -- (index, recursive info)
+ TPair (Tup (Replicate n TIx)) (AcIdx AID p a)
+ AcIdx AIS (APArrIdx p) (TArr n a) =
+ -- ((index, shape info), recursive info)
TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
- (AcIdx p a)
- -- AcIdx (APArrSlice m) (TArr n a) =
+ (AcIdx AIS p a)
+ -- AcIdx AID (APArrSlice m) (TArr n a) =
+ -- -- index
+ -- Tup (Replicate m TIx)
+ -- AcIdx AIS (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
+type AcIdxD p t = AcIdx AID p t
+type AcIdxS p t = AcIdx AIS p t
+
acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
@@ -72,6 +92,24 @@ tZeroInfo (SMTMaybe _) = STNil
tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
tZeroInfo (SMTScal _) = STNil
+-- | Info needed to create a zero-valued deep accumulator for a monoid type.
+-- Should be constructable from a D1.
+type family DeepZeroInfo t where
+ DeepZeroInfo TNil = TNil
+ DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b)
+ DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a)
+ DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a)
+ DeepZeroInfo (TScal t) = TNil
+
+tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t)
+tDeepZeroInfo SMTNil = STNil
+tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b)
+tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a)
+tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t)
+tDeepZeroInfo (SMTScal _) = STNil
+
-- -- | Additional info needed for accumulation. This is empty unless there is
-- -- sparsity in the monoid.
-- type family AccumInfo t where