summaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs84
1 files changed, 65 insertions, 19 deletions
diff --git a/src/AST.hs b/src/AST.hs
index b8d23b4..b2f5ce7 100644
--- a/src/AST.hs
+++ b/src/AST.hs
@@ -16,6 +16,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
+{-# LANGUAGE FlexibleInstances #-}
module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
import Data.Functor.Const
@@ -33,11 +34,9 @@ import Data
-- inner variable / inner array dimension. In pretty printing, the inner
-- variable / inner dimension is printed on the _right_.
--
--- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the
--- type transformation of CHAD. Indeed, these constructors are created _by_
--- CHAD, and are intended to be eliminated after simplification, so that the
--- input program as well as the output program do not contain these
--- constructors.
+-- All the monoid operations are unsupposed as the input to CHAD, and are
+-- intended to be eliminated after simplification, so that the input program as
+-- well as the output program do not contain these constructors.
-- TODO: ensure this by a "stage" type parameter.
type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
data Expr x env t where
@@ -56,6 +55,10 @@ data Expr x env t where
ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
+ ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b)
+ ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b)
+ ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b)
+ ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
-- array operations
EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
@@ -88,13 +91,13 @@ data Expr x env t where
-> Expr x env t
-- accumulation effect on monoids
- EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
- EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
+ EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t)
+ EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env (TAccum t) -> Expr x env TNil
-- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
- EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t)
+ EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t
+ EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t
+ EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env a -> Expr x env t
-- partiality
EError :: x a -> STy a -> String -> Expr x env a
@@ -184,6 +187,10 @@ typeOf = \case
ENothing _ t -> STMaybe t
EJust _ e -> STMaybe (typeOf e)
EMaybe _ e _ _ -> typeOf e
+ ELNil _ t1 t2 -> STLEither t1 t2
+ ELInl _ t2 e -> STLEither (typeOf e) t2
+ ELInr _ t1 e -> STLEither t1 (typeOf e)
+ ELCase _ _ a _ _ -> typeOf a
EConstArr _ n t _ -> STArr n (STScal t)
EBuild _ n _ e -> STArr n (typeOf e)
@@ -206,9 +213,9 @@ typeOf = \case
EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
EAccum _ _ _ _ _ _ -> STNil
- EZero _ t -> d2 t
- EPlus _ t _ _ -> d2 t
- EOneHot _ t _ _ _ -> d2 t
+ EZero _ t _ -> fromSMTy t
+ EPlus _ t _ _ -> fromSMTy t
+ EOneHot _ t _ _ _ -> fromSMTy t
EError _ t _ -> t
@@ -226,6 +233,10 @@ extOf = \case
ENothing x _ -> x
EJust x _ -> x
EMaybe x _ _ _ -> x
+ ELNil x _ _ -> x
+ ELInl x _ _ -> x
+ ELInr x _ _ -> x
+ ELCase x _ _ _ _ -> x
EConstArr x _ _ _ -> x
EBuild x _ _ _ -> x
EFold1Inner x _ _ _ _ -> x
@@ -243,7 +254,7 @@ extOf = \case
ECustom x _ _ _ _ _ _ _ _ -> x
EWith x _ _ _ -> x
EAccum x _ _ _ _ _ -> x
- EZero x _ -> x
+ EZero x _ _ -> x
EPlus x _ _ _ -> x
EOneHot x _ _ _ _ -> x
EError x _ _ -> x
@@ -262,6 +273,10 @@ mapExt f = \case
ENothing x t -> ENothing (f x) t
EJust x e -> EJust (f x) (mapExt f e)
EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
+ ELNil x t1 t2 -> ELNil (f x) t1 t2
+ ELInl x t e -> ELInl (f x) t (mapExt f e)
+ ELInr x t e -> ELInr (f x) t (mapExt f e)
+ ELCase x e a b c -> ELCase (f x) (mapExt f e) (mapExt f a) (mapExt f b) (mapExt f c)
EConstArr x n t a -> EConstArr (f x) n t a
EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
@@ -279,7 +294,7 @@ mapExt f = \case
ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
- EZero x t -> EZero (f x) t
+ EZero x t e -> EZero (f x) t (mapExt f e)
EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
EError x t s -> EError (f x) t s
@@ -315,6 +330,10 @@ subst' f w = \case
ENothing x t -> ENothing x t
EJust x e -> EJust x (subst' f w e)
EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
+ ELNil x t1 t2 -> ELNil x t1 t2
+ ELInl x t e -> ELInl x t (subst' f w e)
+ ELInr x t e -> ELInr x t (subst' f w e)
+ ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c)
EConstArr x n t a -> EConstArr x n t a
EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
@@ -332,9 +351,9 @@ subst' f w = \case
ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
- EZero x t -> EZero x t
+ EZero x t e -> EZero x t (subst' f w e)
EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
+ EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
EError x t s -> EError x t s
where
sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
@@ -360,7 +379,16 @@ instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEithe
instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
+instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy
+instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy
+
+class KnownMTy t where knownMTy :: SMTy t
+instance KnownMTy TNil where knownMTy = SMTNil
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy
+instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy
+instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy
+instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy
+instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy
class KnownEnv env where knownEnv :: SList STy env
instance KnownEnv '[] where knownEnv = SNil
@@ -373,7 +401,16 @@ styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
styKnown (STMaybe t) | Dict <- styKnown t = Dict
styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- styKnown t = Dict
+styKnown (STAccum t) | Dict <- smtyKnown t = Dict
+styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
+
+smtyKnown :: SMTy t -> Dict (KnownMTy t)
+smtyKnown SMTNil = Dict
+smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict
+smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict
+smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict
+smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict
sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
sscaltyKnown STI32 = Dict
@@ -451,3 +488,12 @@ eshapeEmpty (SS n) e =
(EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
(EConst ext STI64 0)))
(eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))
+
+ezeroD2 :: STy t -> Ex env (D2 t)
+ezeroD2 t | Refl <- lemZeroInfoD2 t = EZero ext (d2M t) (ENil ext)
+
+-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil
+-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea
+
+-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t)
+-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev