diff options
Diffstat (limited to 'src/AST/Accum.hs')
| -rw-r--r-- | src/AST/Accum.hs | 75 | 
1 files changed, 48 insertions, 27 deletions
| diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs index 03369c8..988a450 100644 --- a/src/AST/Accum.hs +++ b/src/AST/Accum.hs @@ -1,14 +1,13 @@  {-# LANGUAGE DataKinds #-}  {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-}  {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeData #-}  {-# LANGUAGE TypeFamilies #-}  {-# LANGUAGE UndecidableInstances #-}  module AST.Accum where  import AST.Types -import CHAD.Types  import Data @@ -35,21 +34,39 @@ data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where    -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t)  deriving instance Show (SAcPrj p a b) -type family AcIdx p t where -  AcIdx APHere t = TNil -  AcIdx (APFst p) (TPair a b) = TPair (AcIdx p a) (ZeroInfo b) -  AcIdx (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx p b) -  AcIdx (APLeft p) (TLEither a b) = AcIdx p a -  AcIdx (APRight p) (TLEither a b) = AcIdx p b -  AcIdx (APJust p) (TMaybe a) = AcIdx p a -  AcIdx (APArrIdx p) (TArr n a) = -    -- ((index, shapes info), recursive info) +type data AIDense = AID | AIS + +data SAIDense d where +  SAID :: SAIDense AID +  SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where +  AcIdx d APHere t = TNil +  AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a +  AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b +  AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) +  AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) +  AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a +  AcIdx d (APRight p) (TLEither a b) = AcIdx d p b +  AcIdx d (APJust p) (TMaybe a) = AcIdx d p a +  AcIdx AID (APArrIdx p) (TArr n a) = +    -- (index, recursive info) +    TPair (Tup (Replicate n TIx)) (AcIdx AID p a) +  AcIdx AIS (APArrIdx p) (TArr n a) = +    -- ((index, shape info), recursive info)      TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) -          (AcIdx p a) -  -- AcIdx (APArrSlice m) (TArr n a) = +          (AcIdx AIS p a) +  -- AcIdx AID (APArrSlice m) (TArr n a) = +  --   -- index +  --   Tup (Replicate m TIx) +  -- AcIdx AIS (APArrSlice m) (TArr n a) =    --   -- (index, array shape)    --   TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t +  acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b  acPrjTy SAPHere t = t  acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t @@ -75,19 +92,23 @@ tZeroInfo (SMTMaybe _) = STNil  tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t)  tZeroInfo (SMTScal _) = STNil -lemZeroInfoD2 :: STy t -> ZeroInfo (D2 t) :~: TNil -lemZeroInfoD2 STNil = Refl -lemZeroInfoD2 (STPair a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STLEither a b) | Refl <- lemZeroInfoD2 a, Refl <- lemZeroInfoD2 b = Refl -lemZeroInfoD2 (STMaybe a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STArr _ a) | Refl <- lemZeroInfoD2 a = Refl -lemZeroInfoD2 (STScal STI32) = Refl -lemZeroInfoD2 (STScal STI64) = Refl -lemZeroInfoD2 (STScal STF32) = Refl -lemZeroInfoD2 (STScal STF64) = Refl -lemZeroInfoD2 (STScal STBool) = Refl -lemZeroInfoD2 (STAccum _) = error "Accumulators disallowed in source program" +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where +  DeepZeroInfo TNil = TNil +  DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) +  DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) +  DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) +  DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) +  DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil  -- -- | Additional info needed for accumulation. This is empty unless there is  -- -- sparsity in the monoid. | 
