diff options
Diffstat (limited to 'src/AST.hs')
| -rw-r--r-- | src/AST.hs | 453 |
1 files changed, 0 insertions, 453 deletions
diff --git a/src/AST.hs b/src/AST.hs deleted file mode 100644 index b8d23b4..0000000 --- a/src/AST.hs +++ /dev/null @@ -1,453 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where - -import Data.Functor.Const -import Data.Kind (Type) - -import Array -import AST.Accum -import AST.Types -import AST.Weaken -import CHAD.Types -import Data - - --- General assumption: head of the list (whatever way it is associated) is the --- inner variable / inner array dimension. In pretty printing, the inner --- variable / inner dimension is printed on the _right_. --- --- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the --- type transformation of CHAD. Indeed, these constructors are created _by_ --- CHAD, and are intended to be eliminated after simplification, so that the --- input program as well as the output program do not contain these --- constructors. --- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where - -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t - - -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - - -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - - -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t - - -- custom derivatives - -- 'b' is the part of the input of the operation that derivatives should - -- be backpropagated to; 'a' is the inactive part. The dual field of - -- ECustom does not allow a derivative to be generated for 'a', and hence - -- none is propagated. - ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t - - -- accumulation effect on monoids - EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t)) - EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil - - -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x (D2 t) -> STy t -> Expr x env (D2 t) - EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t) - EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t) - - -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) - -type Ex = Expr (Const ()) - -ext :: Const () a -ext = Const () - -data Commutative = Commut | Noncommut - deriving (Show) - -type SOp :: Ty -> Ty -> Type -data SOp a t where - OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - ONot :: SOp (TScal TBool) (TScal TBool) - OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right - ORound64 :: SOp (TScal TF64) (TScal TI64) - OToFl64 :: SOp (TScal TI64) (TScal TF64) - ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) -deriving instance Show (SOp a t) - -opt1 :: SOp a t -> STy a -opt1 = \case - OAdd t -> STPair (STScal t) (STScal t) - OMul t -> STPair (STScal t) (STScal t) - ONeg t -> STScal t - OLt t -> STPair (STScal t) (STScal t) - OLe t -> STPair (STScal t) (STScal t) - OEq t -> STPair (STScal t) (STScal t) - ONot -> STScal STBool - OAnd -> STPair (STScal STBool) (STScal STBool) - OOr -> STPair (STScal STBool) (STScal STBool) - OIf -> STScal STBool - ORound64 -> STScal STF64 - OToFl64 -> STScal STI64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STPair (STScal t) (STScal t) - OMod t -> STPair (STScal t) (STScal t) - -opt2 :: SOp a t -> STy t -opt2 = \case - OAdd t -> STScal t - OMul t -> STScal t - ONeg t -> STScal t - OLt _ -> STScal STBool - OLe _ -> STScal STBool - OEq _ -> STScal STBool - ONot -> STScal STBool - OAnd -> STScal STBool - OOr -> STScal STBool - OIf -> STEither STNil STNil - ORound64 -> STScal STI64 - OToFl64 -> STScal STF64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STScal t - OMod t -> STScal t - -typeOf :: Expr x env t -> STy t -typeOf = \case - EVar _ t _ -> t - ELet _ _ e -> typeOf e - - EPair _ a b -> STPair (typeOf a) (typeOf b) - EFst _ e | STPair t _ <- typeOf e -> t - ESnd _ e | STPair _ t <- typeOf e -> t - ENil _ -> STNil - EInl _ t2 e -> STEither (typeOf e) t2 - EInr _ t1 e -> STEither t1 (typeOf e) - ECase _ _ a _ -> typeOf a - ENothing _ t -> STMaybe t - EJust _ e -> STMaybe (typeOf e) - EMaybe _ e _ _ -> typeOf e - - EConstArr _ n t _ -> STArr n (STScal t) - EBuild _ n _ e -> STArr n (typeOf e) - EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t - ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EUnit _ e -> STArr SZ (typeOf e) - EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t - EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - - EConst _ t _ -> STScal t - EIdx0 _ e | STArr _ t <- typeOf e -> t - EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) - EOp _ op _ -> opt2 op - - ECustom _ _ _ _ e _ _ _ _ -> typeOf e - - EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ -> STNil - - EZero _ t -> d2 t - EPlus _ t _ _ -> d2 t - EOneHot _ t _ _ _ -> d2 t - - EError _ t _ -> t - -extOf :: Expr x env t -> x t -extOf = \case - EVar x _ _ -> x - ELet x _ _ -> x - EPair x _ _ -> x - EFst x _ -> x - ESnd x _ -> x - ENil x -> x - EInl x _ _ -> x - EInr x _ _ -> x - ECase x _ _ _ -> x - ENothing x _ -> x - EJust x _ -> x - EMaybe x _ _ _ -> x - EConstArr x _ _ _ -> x - EBuild x _ _ _ -> x - EFold1Inner x _ _ _ _ -> x - ESum1Inner x _ -> x - EUnit x _ -> x - EReplicate1Inner x _ _ -> x - EMaximum1Inner x _ -> x - EMinimum1Inner x _ -> x - EConst x _ _ -> x - EIdx0 x _ -> x - EIdx1 x _ _ -> x - EIdx x _ _ -> x - EShape x _ -> x - EOp x _ _ -> x - ECustom x _ _ _ _ _ _ _ _ -> x - EWith x _ _ _ -> x - EAccum x _ _ _ _ _ -> x - EZero x _ -> x - EPlus x _ _ _ -> x - EOneHot x _ _ _ _ -> x - EError x _ _ -> x - -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = \case - EVar x t i -> EVar (f x) t i - ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body) - EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b) - EFst x e -> EFst (f x) (mapExt f e) - ESnd x e -> ESnd (f x) (mapExt f e) - ENil x -> ENil (f x) - EInl x t e -> EInl (f x) t (mapExt f e) - EInr x t e -> EInr (f x) t (mapExt f e) - ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b) - ENothing x t -> ENothing (f x) t - EJust x e -> EJust (f x) (mapExt f e) - EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e) - EConstArr x n t a -> EConstArr (f x) n t a - EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b) - EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c) - ESum1Inner x e -> ESum1Inner (f x) (mapExt f e) - EUnit x e -> EUnit (f x) (mapExt f e) - EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b) - EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e) - EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e) - EConst x t v -> EConst (f x) t v - EIdx0 x e -> EIdx0 (f x) (mapExt f e) - EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b) - EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es) - EShape x e -> EShape (f x) (mapExt f e) - EOp x op e -> EOp (f x) op (mapExt f e) - ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2) - EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2) - EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3) - EZero x t -> EZero (f x) t - EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b) - EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b) - EError x t s -> EError (f x) t s - -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t -substInline repl = - subst $ \x t -> \case IZ -> repl - IS i -> EVar x t i - -subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t -subst0 repl = - subst $ \_ t -> \case IZ -> repl - IS i -> EVar ext t (IS i) - -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t -subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId - -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) - -> env' :> envOut - -> Expr x env t - -> Expr x envOut t -subst' f w = \case - EVar x t i -> f x t w i - ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) - EPair x a b -> EPair x (subst' f w a) (subst' f w b) - EFst x e -> EFst x (subst' f w e) - ESnd x e -> ESnd x (subst' f w e) - ENil x -> ENil x - EInl x t e -> EInl x t (subst' f w e) - EInr x t e -> EInr x t (subst' f w e) - ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) - ENothing x t -> ENothing x t - EJust x e -> EJust x (subst' f w e) - EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - ESum1Inner x e -> ESum1Inner x (subst' f w e) - EUnit x e -> EUnit x (subst' f w e) - EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) - EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) - EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (subst' f w e) - EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) - EShape x e -> EShape x (subst' f w e) - EOp x op e -> EOp x op (subst' f w e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3) - EZero x t -> EZero x t - EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) - EError x t s -> EError x t s - where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t - sinkF f' x' t w' = \case - IZ -> EVar x' t (w' @> IZ) - IS i -> f' x' t (WPop w') i - -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t -weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) - -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- styKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - -ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = - EBuild ext (SS n) (EPair ext sh size) $ - let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ - in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) - (EFst ext arg) - -eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eidxEq SZ _ _ = EConst ext STBool True -eidxEq (SS SZ) a b = - EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) -eidxEq (SS n) a b - | let ty = tTup (sreplicate (SS n) tIx) - = ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EOp ext OAnd $ EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) - (ESnd ext (EVar ext ty IZ)))) - (eidxEq n (EFst ext (EVar ext ty (IS IZ))) - (EFst ext (EVar ext ty IZ))) - -emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr = - let STArr n t = typeOf arr - in ELet ext arr $ - EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ)) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) f - -ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 = - let STArr n t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ELet ext arr1 $ - ELet ext (weakenExpr WSink arr2) $ - EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $ - ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $ - weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip arr1 arr2 = - let STArr _ t1 = typeOf arr1 - STArr _ t2 = typeOf arr2 - in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2 - -eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a -eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) - --- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). -eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eshapeEmpty SZ _ = EConst ext STBool False -eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) -eshapeEmpty (SS n) e = - ELet ext e $ - EOp ext OAnd (EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) - (EConst ext STI64 0))) - (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) |
