aboutsummaryrefslogtreecommitdiff
path: root/src/AST.hs
diff options
context:
space:
mode:
Diffstat (limited to 'src/AST.hs')
-rw-r--r--src/AST.hs453
1 files changed, 0 insertions, 453 deletions
diff --git a/src/AST.hs b/src/AST.hs
deleted file mode 100644
index b8d23b4..0000000
--- a/src/AST.hs
+++ /dev/null
@@ -1,453 +0,0 @@
-{-# LANGUAGE DataKinds #-}
-{-# LANGUAGE DeriveFoldable #-}
-{-# LANGUAGE DeriveFunctor #-}
-{-# LANGUAGE DeriveTraversable #-}
-{-# LANGUAGE EmptyCase #-}
-{-# LANGUAGE FlexibleContexts #-}
-{-# LANGUAGE GADTs #-}
-{-# LANGUAGE LambdaCase #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE QuantifiedConstraints #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# LANGUAGE StandaloneDeriving #-}
-{-# LANGUAGE StandaloneKindSignatures #-}
-{-# LANGUAGE TypeApplications #-}
-{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE TypeOperators #-}
-{-# LANGUAGE UndecidableInstances #-}
-module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where
-
-import Data.Functor.Const
-import Data.Kind (Type)
-
-import Array
-import AST.Accum
-import AST.Types
-import AST.Weaken
-import CHAD.Types
-import Data
-
-
--- General assumption: head of the list (whatever way it is associated) is the
--- inner variable / inner array dimension. In pretty printing, the inner
--- variable / inner dimension is printed on the _right_.
---
--- Note that the 'EZero' and 'EPlus' constructs have typing that depend on the
--- type transformation of CHAD. Indeed, these constructors are created _by_
--- CHAD, and are intended to be eliminated after simplification, so that the
--- input program as well as the output program do not contain these
--- constructors.
--- TODO: ensure this by a "stage" type parameter.
-type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type
-data Expr x env t where
- -- lambda calculus
- EVar :: x t -> STy t -> Idx env t -> Expr x env t
- ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t
-
- -- base types
- EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b)
- EFst :: x a -> Expr x env (TPair a b) -> Expr x env a
- ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b
- ENil :: x TNil -> Expr x env TNil
- EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b)
- EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b)
- ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c
- ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t)
- EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t)
- EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b
-
- -- array operations
- EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t))
- EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t)
- EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (t : t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t)
- ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t)
- EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t)
- EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
- EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t))
-
- -- expression operations
- EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t)
- EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t
- EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t)
- EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t
- EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx))
- EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t
-
- -- custom derivatives
- -- 'b' is the part of the input of the operation that derivatives should
- -- be backpropagated to; 'a' is the inactive part. The dual field of
- -- ECustom does not allow a derivative to be generated for 'a', and hence
- -- none is propagated.
- ECustom :: x t -> STy a -> STy b -> STy tape
- -> Expr x [b, a] t -- ^ regular operation
- -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass
- -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative
- -> Expr x env a -> Expr x env b
- -> Expr x env t
-
- -- accumulation effect on monoids
- EWith :: x (TPair a (D2 t)) -> STy t -> Expr x env (D2 t) -> Expr x (TAccum t : env) a -> Expr x env (TPair a (D2 t))
- EAccum :: x TNil -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (TAccum t) -> Expr x env TNil
-
- -- monoidal operations (to be desugared to regular operations after simplification)
- EZero :: x (D2 t) -> STy t -> Expr x env (D2 t)
- EPlus :: x (D2 t) -> STy t -> Expr x env (D2 t) -> Expr x env (D2 t) -> Expr x env (D2 t)
- EOneHot :: x (D2 t) -> STy t -> SAcPrj p t a -> Expr x env (AcIdx p t) -> Expr x env (D2 a) -> Expr x env (D2 t)
-
- -- partiality
- EError :: x a -> STy a -> String -> Expr x env a
-deriving instance (forall ty. Show (x ty)) => Show (Expr x env t)
-
-type Ex = Expr (Const ())
-
-ext :: Const () a
-ext = Const ()
-
-data Commutative = Commut | Noncommut
- deriving (Show)
-
-type SOp :: Ty -> Ty -> Type
-data SOp a t where
- OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool)
- ONot :: SOp (TScal TBool) (TScal TBool)
- OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
- OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool)
- OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right
- ORound64 :: SOp (TScal TF64) (TScal TI64)
- OToFl64 :: SOp (TScal TI64) (TScal TF64)
- ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a)
- OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
- OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a)
-deriving instance Show (SOp a t)
-
-opt1 :: SOp a t -> STy a
-opt1 = \case
- OAdd t -> STPair (STScal t) (STScal t)
- OMul t -> STPair (STScal t) (STScal t)
- ONeg t -> STScal t
- OLt t -> STPair (STScal t) (STScal t)
- OLe t -> STPair (STScal t) (STScal t)
- OEq t -> STPair (STScal t) (STScal t)
- ONot -> STScal STBool
- OAnd -> STPair (STScal STBool) (STScal STBool)
- OOr -> STPair (STScal STBool) (STScal STBool)
- OIf -> STScal STBool
- ORound64 -> STScal STF64
- OToFl64 -> STScal STI64
- ORecip t -> STScal t
- OExp t -> STScal t
- OLog t -> STScal t
- OIDiv t -> STPair (STScal t) (STScal t)
- OMod t -> STPair (STScal t) (STScal t)
-
-opt2 :: SOp a t -> STy t
-opt2 = \case
- OAdd t -> STScal t
- OMul t -> STScal t
- ONeg t -> STScal t
- OLt _ -> STScal STBool
- OLe _ -> STScal STBool
- OEq _ -> STScal STBool
- ONot -> STScal STBool
- OAnd -> STScal STBool
- OOr -> STScal STBool
- OIf -> STEither STNil STNil
- ORound64 -> STScal STI64
- OToFl64 -> STScal STF64
- ORecip t -> STScal t
- OExp t -> STScal t
- OLog t -> STScal t
- OIDiv t -> STScal t
- OMod t -> STScal t
-
-typeOf :: Expr x env t -> STy t
-typeOf = \case
- EVar _ t _ -> t
- ELet _ _ e -> typeOf e
-
- EPair _ a b -> STPair (typeOf a) (typeOf b)
- EFst _ e | STPair t _ <- typeOf e -> t
- ESnd _ e | STPair _ t <- typeOf e -> t
- ENil _ -> STNil
- EInl _ t2 e -> STEither (typeOf e) t2
- EInr _ t1 e -> STEither t1 (typeOf e)
- ECase _ _ a _ -> typeOf a
- ENothing _ t -> STMaybe t
- EJust _ e -> STMaybe (typeOf e)
- EMaybe _ e _ _ -> typeOf e
-
- EConstArr _ n t _ -> STArr n (STScal t)
- EBuild _ n _ e -> STArr n (typeOf e)
- EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t
- ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EUnit _ e -> STArr SZ (typeOf e)
- EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t
- EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
- EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t
-
- EConst _ t _ -> STScal t
- EIdx0 _ e | STArr _ t <- typeOf e -> t
- EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t
- EIdx _ e _ | STArr _ t <- typeOf e -> t
- EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx)
- EOp _ op _ -> opt2 op
-
- ECustom _ _ _ _ e _ _ _ _ -> typeOf e
-
- EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1)
- EAccum _ _ _ _ _ _ -> STNil
-
- EZero _ t -> d2 t
- EPlus _ t _ _ -> d2 t
- EOneHot _ t _ _ _ -> d2 t
-
- EError _ t _ -> t
-
-extOf :: Expr x env t -> x t
-extOf = \case
- EVar x _ _ -> x
- ELet x _ _ -> x
- EPair x _ _ -> x
- EFst x _ -> x
- ESnd x _ -> x
- ENil x -> x
- EInl x _ _ -> x
- EInr x _ _ -> x
- ECase x _ _ _ -> x
- ENothing x _ -> x
- EJust x _ -> x
- EMaybe x _ _ _ -> x
- EConstArr x _ _ _ -> x
- EBuild x _ _ _ -> x
- EFold1Inner x _ _ _ _ -> x
- ESum1Inner x _ -> x
- EUnit x _ -> x
- EReplicate1Inner x _ _ -> x
- EMaximum1Inner x _ -> x
- EMinimum1Inner x _ -> x
- EConst x _ _ -> x
- EIdx0 x _ -> x
- EIdx1 x _ _ -> x
- EIdx x _ _ -> x
- EShape x _ -> x
- EOp x _ _ -> x
- ECustom x _ _ _ _ _ _ _ _ -> x
- EWith x _ _ _ -> x
- EAccum x _ _ _ _ _ -> x
- EZero x _ -> x
- EPlus x _ _ _ -> x
- EOneHot x _ _ _ _ -> x
- EError x _ _ -> x
-
-mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t
-mapExt f = \case
- EVar x t i -> EVar (f x) t i
- ELet x rhs body -> ELet (f x) (mapExt f rhs) (mapExt f body)
- EPair x a b -> EPair (f x) (mapExt f a) (mapExt f b)
- EFst x e -> EFst (f x) (mapExt f e)
- ESnd x e -> ESnd (f x) (mapExt f e)
- ENil x -> ENil (f x)
- EInl x t e -> EInl (f x) t (mapExt f e)
- EInr x t e -> EInr (f x) t (mapExt f e)
- ECase x e a b -> ECase (f x) (mapExt f e) (mapExt f a) (mapExt f b)
- ENothing x t -> ENothing (f x) t
- EJust x e -> EJust (f x) (mapExt f e)
- EMaybe x a b e -> EMaybe (f x) (mapExt f a) (mapExt f b) (mapExt f e)
- EConstArr x n t a -> EConstArr (f x) n t a
- EBuild x n a b -> EBuild (f x) n (mapExt f a) (mapExt f b)
- EFold1Inner x cm a b c -> EFold1Inner (f x) cm (mapExt f a) (mapExt f b) (mapExt f c)
- ESum1Inner x e -> ESum1Inner (f x) (mapExt f e)
- EUnit x e -> EUnit (f x) (mapExt f e)
- EReplicate1Inner x a b -> EReplicate1Inner (f x) (mapExt f a) (mapExt f b)
- EMaximum1Inner x e -> EMaximum1Inner (f x) (mapExt f e)
- EMinimum1Inner x e -> EMinimum1Inner (f x) (mapExt f e)
- EConst x t v -> EConst (f x) t v
- EIdx0 x e -> EIdx0 (f x) (mapExt f e)
- EIdx1 x a b -> EIdx1 (f x) (mapExt f a) (mapExt f b)
- EIdx x e es -> EIdx (f x) (mapExt f e) (mapExt f es)
- EShape x e -> EShape (f x) (mapExt f e)
- EOp x op e -> EOp (f x) op (mapExt f e)
- ECustom x s t p a b c e1 e2 -> ECustom (f x) s t p (mapExt f a) (mapExt f b) (mapExt f c) (mapExt f e1) (mapExt f e2)
- EWith x t e1 e2 -> EWith (f x) t (mapExt f e1) (mapExt f e2)
- EAccum x t p e1 e2 e3 -> EAccum (f x) t p (mapExt f e1) (mapExt f e2) (mapExt f e3)
- EZero x t -> EZero (f x) t
- EPlus x t a b -> EPlus (f x) t (mapExt f a) (mapExt f b)
- EOneHot x t p a b -> EOneHot (f x) t p (mapExt f a) (mapExt f b)
- EError x t s -> EError (f x) t s
-
-substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t
-substInline repl =
- subst $ \x t -> \case IZ -> repl
- IS i -> EVar x t i
-
-subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t
-subst0 repl =
- subst $ \_ t -> \case IZ -> repl
- IS i -> EVar ext t (IS i)
-
-subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a)
- -> Expr x env t -> Expr x env' t
-subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId
-
-subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a)
- -> env' :> envOut
- -> Expr x env t
- -> Expr x envOut t
-subst' f w = \case
- EVar x t i -> f x t w i
- ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body)
- EPair x a b -> EPair x (subst' f w a) (subst' f w b)
- EFst x e -> EFst x (subst' f w e)
- ESnd x e -> ESnd x (subst' f w e)
- ENil x -> ENil x
- EInl x t e -> EInl x t (subst' f w e)
- EInr x t e -> EInr x t (subst' f w e)
- ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b)
- ENothing x t -> ENothing x t
- EJust x e -> EJust x (subst' f w e)
- EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e)
- EConstArr x n t a -> EConstArr x n t a
- EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b)
- EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c)
- ESum1Inner x e -> ESum1Inner x (subst' f w e)
- EUnit x e -> EUnit x (subst' f w e)
- EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b)
- EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e)
- EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e)
- EConst x t v -> EConst x t v
- EIdx0 x e -> EIdx0 x (subst' f w e)
- EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b)
- EIdx x e es -> EIdx x (subst' f w e) (subst' f w es)
- EShape x e -> EShape x (subst' f w e)
- EOp x op e -> EOp x op (subst' f w e)
- ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2)
- EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2)
- EAccum x t p e1 e2 e3 -> EAccum x t p (subst' f w e1) (subst' f w e2) (subst' f w e3)
- EZero x t -> EZero x t
- EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b)
- EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b)
- EError x t s -> EError x t s
- where
- sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a)
- -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t
- sinkF f' x' t w' = \case
- IZ -> EVar x' t (w' @> IZ)
- IS i -> f' x' t (WPop w') i
-
-weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t
-weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i))
-
-class KnownScalTy t where knownScalTy :: SScalTy t
-instance KnownScalTy TI32 where knownScalTy = STI32
-instance KnownScalTy TI64 where knownScalTy = STI64
-instance KnownScalTy TF32 where knownScalTy = STF32
-instance KnownScalTy TF64 where knownScalTy = STF64
-instance KnownScalTy TBool where knownScalTy = STBool
-
-class KnownTy t where knownTy :: STy t
-instance KnownTy TNil where knownTy = STNil
-instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy
-instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy
-instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy
-instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy
-instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy
-instance KnownTy t => KnownTy (TAccum t) where knownTy = STAccum knownTy
-
-class KnownEnv env where knownEnv :: SList STy env
-instance KnownEnv '[] where knownEnv = SNil
-instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv
-
-styKnown :: STy t -> Dict (KnownTy t)
-styKnown STNil = Dict
-styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict
-styKnown (STMaybe t) | Dict <- styKnown t = Dict
-styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict
-styKnown (STScal t) | Dict <- sscaltyKnown t = Dict
-styKnown (STAccum t) | Dict <- styKnown t = Dict
-
-sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t)
-sscaltyKnown STI32 = Dict
-sscaltyKnown STI64 = Dict
-sscaltyKnown STF32 = Dict
-sscaltyKnown STF64 = Dict
-sscaltyKnown STBool = Dict
-
-envKnown :: SList STy env -> Dict (KnownEnv env)
-envKnown SNil = Dict
-envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict
-
-eTup :: SList (Ex env) list -> Ex env (Tup list)
-eTup = mkTup (ENil ext) (EPair ext)
-
-ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t)
-ebuildUp1 n sh size f =
- EBuild ext (SS n) (EPair ext sh size) $
- let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ
- in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f))
- (EFst ext arg)
-
-eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
-eidxEq SZ _ _ = EConst ext STBool True
-eidxEq (SS SZ) a b =
- EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b))
-eidxEq (SS n) a b
- | let ty = tTup (sreplicate (SS n) tIx)
- = ELet ext a $
- ELet ext (weakenExpr WSink b) $
- EOp ext OAnd $ EPair ext
- (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ)))
- (ESnd ext (EVar ext ty IZ))))
- (eidxEq n (EFst ext (EVar ext ty (IS IZ)))
- (EFst ext (EVar ext ty IZ)))
-
-emap :: Ex (a : env) b -> Ex env (TArr n a) -> Ex env (TArr n b)
-emap f arr =
- let STArr n t = typeOf arr
- in ELet ext arr $
- EBuild ext n (EShape ext (EVar ext (STArr n t) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t) (IS IZ))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- weakenExpr (WCopy (WSink .> WSink)) f
-
-ezipWith :: Ex (b : a : env) c -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c)
-ezipWith f arr1 arr2 =
- let STArr n t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ELet ext arr1 $
- ELet ext (weakenExpr WSink arr2) $
- EBuild ext n (EShape ext (EVar ext (STArr n t1) (IS IZ))) $
- ELet ext (EIdx ext (EVar ext (STArr n t1) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) IZ)) $
- ELet ext (EIdx ext (EVar ext (STArr n t2) (IS (IS IZ)))
- (EVar ext (tTup (sreplicate n tIx)) (IS IZ))) $
- weakenExpr (WCopy (WCopy (WSink .> WSink .> WSink))) f
-
-ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b))
-ezip arr1 arr2 =
- let STArr _ t1 = typeOf arr1
- STArr _ t2 = typeOf arr2
- in ezipWith (EPair ext (EVar ext t1 (IS IZ)) (EVar ext t2 IZ)) arr1 arr2
-
-eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a
-eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c)
-
--- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty).
-eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool)
-eshapeEmpty SZ _ = EConst ext STBool False
-eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0))
-eshapeEmpty (SS n) e =
- ELet ext e $
- EOp ext OAnd (EPair ext
- (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))
- (EConst ext STI64 0)))
- (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ))))