summaryrefslogtreecommitdiff
path: root/src/AST/Accum.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST/Accum.hs')
-rw-r--r--src/AST/Accum.hs90
1 files changed, 73 insertions, 17 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs
index 67c5de7..e84034b 100644
--- a/src/AST/Accum.hs
+++ b/src/AST/Accum.hs
@@ -8,6 +8,7 @@
module AST.Accum where
import AST.Types
+import CHAD.Types
import Data
@@ -26,35 +27,90 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where
SAPHere :: SAcPrj APHere a a
SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b
SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b
- SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b
- SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b
+ SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b
+ SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b
SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b
- -- TODO: This SNat is rather useless, you always have an STy around too
- SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b
+ SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b
-- TODO:
-- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)
deriving instance Show (SAcPrj p a b)
type family AcIdx p t where
AcIdx APHere t = TNil
- AcIdx (APFst p) (TPair a b) = AcIdx p a
- AcIdx (APSnd p) (TPair a b) = AcIdx p b
- AcIdx (APLeft p) (TEither a b) = AcIdx p a
- AcIdx (APRight p) (TEither a b) = AcIdx p b
+ AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b)
+ AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b)
+ AcIdx (APLeft p) (TLEither a b) = AcIdx p a
+ AcIdx (APRight p) (TLEither a b) = AcIdx p b
AcIdx (APJust p) (TMaybe a) = AcIdx p a
AcIdx (APArrIdx p) (TArr n a) =
- -- ((index, array shape), recursive info)
- TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx)))
+ -- ((index, shapes info), recursive info)
+ TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a)))
(AcIdx p a)
-- AcIdx (APArrSlice m) (TArr n a) =
-- -- (index, array shape)
-- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx))
-acPrjTy :: SAcPrj p a b -> STy a -> STy b
+acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b
acPrjTy SAPHere t = t
-acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t
-acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t
-acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t
-acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t
-acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t
-acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t
+acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t
+acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t
+acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t
+acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t
+acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t
+acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t
+
+type family ZeroInfo t where
+ ZeroInfo TNil = TNil
+ ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b)
+ ZeroInfo (TLEither a b) = TNil
+ ZeroInfo (TMaybe a) = TNil
+ ZeroInfo (TArr n t) = TArr n (ZeroInfo t)
+ ZeroInfo (TScal t) = TNil
+
+tZeroInfo :: SMTy t -> STy (ZeroInfo t)
+tZeroInfo SMTNil = STNil
+tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b)
+tZeroInfo (SMTLEither _ _) = STNil
+tZeroInfo (SMTMaybe _) = STNil
+tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)
+tZeroInfo (SMTScal _) = STNil
+
+lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil
+lemZeroInfoD2 STNil = Refl
+lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl
+lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl
+lemZeroInfoD2 (STScal STI32) = Refl
+lemZeroInfoD2 (STScal STI64) = Refl
+lemZeroInfoD2 (STScal STF32) = Refl
+lemZeroInfoD2 (STScal STF64) = Refl
+lemZeroInfoD2 (STScal STBool) = Refl
+lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program"
+lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl
+
+-- -- | Additional info needed for accumulation. This is empty unless there is
+-- -- sparsity in the monoid.
+-- type family AccumInfo t where
+-- AccumInfo TNil = TNil
+-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b)
+-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a)
+-- AccumInfo (TArr n t) = TArr n (AccumInfo t)
+-- AccumInfo (TScal t) = TNil
+
+-- type family PrimalInfo t where
+-- PrimalInfo TNil = TNil
+-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b)
+-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a)
+-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t)
+-- PrimalInfo (TScal t) = TNil
+
+-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t)
+-- tPrimalInfo SMTNil = STNil
+-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b)
+-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a)
+-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t)
+-- tPrimalInfo (SMTScal _) = STNil