diff options
Diffstat (limited to 'src/AST/Accum.hs')
-rw-r--r-- | src/AST/Accum.hs | 90 |
1 files changed, 73 insertions, 17 deletions
diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 67c5de7..e84034b 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -8,6 +8,7 @@ module AST.Accum where import AST.Types +import CHAD.Types import Data @@ -26,35 +27,90 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where SAPHere :: SAcPrj APHere a a SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TEither t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - -- TODO: This SNat is rather useless, you always have an STy around too - SAPArrIdx :: SAcPrj p a b -> SNat n -> SAcPrj (APArrIdx p) (TArr n a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b -- TODO: -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) deriving instance Show (SAcPrj p a b) type family AcIdx p t where AcIdx APHere t = TNil - AcIdx (APFst p) (TPair a b) = AcIdx p a - AcIdx (APSnd p) (TPair a b) = AcIdx p b - AcIdx (APLeft p) (TEither a b) = AcIdx p a - AcIdx (APRight p) (TEither a b) = AcIdx p b + AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) + AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) + AcIdx (APLeft p) (TLEither a b) = AcIdx p a + AcIdx (APRight p) (TLEither a b) = AcIdx p b AcIdx (APJust p) (TMaybe a) = AcIdx p a AcIdx (APArrIdx p) (TArr n a) = - -- ((index, array shape), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (Tup (Replicate n TIx))) + -- ((index, shapes info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) (AcIdx p a) -- AcIdx (APArrSlice m) (TArr n a) = -- -- (index, array shape) -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) -acPrjTy :: SAcPrj p a b -> STy a -> STy b +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (STPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (STPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (STEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (STEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (STMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj _) (STArr _ t) = acPrjTy prj t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil +lemZeroInfoD2 STNil = Refl +lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl +lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl +lemZeroInfoD2 (STScal STI32) = Refl +lemZeroInfoD2 (STScal STI64) = Refl +lemZeroInfoD2 (STScal STF32) = Refl +lemZeroInfoD2 (STScal STF64) = Refl +lemZeroInfoD2 (STScal STBool) = Refl +lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil |