From 174af2ba568de66e0d890825b8bda930b8e7bb96 Mon Sep 17 00:00:00 2001 From: Tom Smeding Date: Mon, 10 Nov 2025 21:49:45 +0100 Subject: Move module hierarchy under CHAD. --- src/AST.hs | 705 ------------ src/AST/Accum.hs | 137 --- src/AST/Bindings.hs | 84 -- src/AST/Count.hs | 930 ---------------- src/AST/Env.hs | 95 -- src/AST/Pretty.hs | 525 --------- src/AST/Sparse.hs | 287 ----- src/AST/Sparse/Types.hs | 107 -- src/AST/SplitLets.hs | 191 ---- src/AST/Types.hs | 215 ---- src/AST/UnMonoid.hs | 255 ----- src/AST/Weaken.hs | 138 --- src/AST/Weaken/Auto.hs | 192 ---- src/Analysis/Identity.hs | 436 -------- src/Array.hs | 131 --- src/CHAD.hs | 1583 --------------------------- src/CHAD/AST.hs | 705 ++++++++++++ src/CHAD/AST/Accum.hs | 137 +++ src/CHAD/AST/Bindings.hs | 84 ++ src/CHAD/AST/Count.hs | 930 ++++++++++++++++ src/CHAD/AST/Env.hs | 95 ++ src/CHAD/AST/Pretty.hs | 525 +++++++++ src/CHAD/AST/Sparse.hs | 287 +++++ src/CHAD/AST/Sparse/Types.hs | 107 ++ src/CHAD/AST/SplitLets.hs | 191 ++++ src/CHAD/AST/Types.hs | 215 ++++ src/CHAD/AST/UnMonoid.hs | 255 +++++ src/CHAD/AST/Weaken.hs | 138 +++ src/CHAD/AST/Weaken/Auto.hs | 192 ++++ src/CHAD/Accum.hs | 72 -- src/CHAD/Analysis/Identity.hs | 436 ++++++++ src/CHAD/Array.hs | 131 +++ src/CHAD/Compile.hs | 1796 +++++++++++++++++++++++++++++++ src/CHAD/Compile/Exec.hs | 99 ++ src/CHAD/Data.hs | 192 ++++ src/CHAD/Data/VarMap.hs | 119 ++ src/CHAD/Drev.hs | 1583 +++++++++++++++++++++++++++ src/CHAD/Drev/Accum.hs | 72 ++ src/CHAD/Drev/EnvDescr.hs | 96 ++ src/CHAD/Drev/Top.hs | 96 ++ src/CHAD/Drev/Types.hs | 153 +++ src/CHAD/Drev/Types/ToTan.hs | 43 + src/CHAD/EnvDescr.hs | 96 -- src/CHAD/Example.hs | 197 ++++ src/CHAD/Example/GMM.hs | 124 +++ src/CHAD/Example/Types.hs | 11 + src/CHAD/ForwardAD.hs | 270 +++++ src/CHAD/ForwardAD/DualNumbers.hs | 231 ++++ src/CHAD/ForwardAD/DualNumbers/Types.hs | 48 + src/CHAD/Interpreter.hs | 471 ++++++++ src/CHAD/Interpreter/Accum.hs | 366 +++++++ src/CHAD/Interpreter/AccumOld.hs | 366 +++++++ src/CHAD/Interpreter/Rep.hs | 105 ++ src/CHAD/Language.hs | 266 +++++ src/CHAD/Language/AST.hs | 300 ++++++ src/CHAD/Lemmas.hs | 21 + src/CHAD/Simplify.hs | 619 +++++++++++ src/CHAD/Simplify/TH.hs | 80 ++ src/CHAD/Top.hs | 96 -- src/CHAD/Types.hs | 153 --- src/CHAD/Types/ToTan.hs | 43 - src/CHAD/Util/IdGen.hs | 19 + src/Compile.hs | 1796 ------------------------------- src/Compile/Exec.hs | 99 -- src/Data.hs | 192 ---- src/Data/VarMap.hs | 119 -- src/Example.hs | 196 ---- src/Example/GMM.hs | 123 --- src/Example/Types.hs | 11 - src/ForwardAD.hs | 270 ----- src/ForwardAD/DualNumbers.hs | 231 ---- src/ForwardAD/DualNumbers/Types.hs | 48 - src/Interpreter.hs | 471 -------- src/Interpreter/Accum.hs | 366 ------- src/Interpreter/AccumOld.hs | 366 ------- src/Interpreter/Rep.hs | 105 -- src/Language.hs | 267 ----- src/Language/AST.hs | 300 ------ src/Lemmas.hs | 21 - src/Simplify.hs | 619 ----------- src/Simplify/TH.hs | 80 -- src/Util/IdGen.hs | 19 - 82 files changed, 12171 insertions(+), 12170 deletions(-) delete mode 100644 src/AST.hs delete mode 100644 src/AST/Accum.hs delete mode 100644 src/AST/Bindings.hs delete mode 100644 src/AST/Count.hs delete mode 100644 src/AST/Env.hs delete mode 100644 src/AST/Pretty.hs delete mode 100644 src/AST/Sparse.hs delete mode 100644 src/AST/Sparse/Types.hs delete mode 100644 src/AST/SplitLets.hs delete mode 100644 src/AST/Types.hs delete mode 100644 src/AST/UnMonoid.hs delete mode 100644 src/AST/Weaken.hs delete mode 100644 src/AST/Weaken/Auto.hs delete mode 100644 src/Analysis/Identity.hs delete mode 100644 src/Array.hs delete mode 100644 src/CHAD.hs create mode 100644 src/CHAD/AST.hs create mode 100644 src/CHAD/AST/Accum.hs create mode 100644 src/CHAD/AST/Bindings.hs create mode 100644 src/CHAD/AST/Count.hs create mode 100644 src/CHAD/AST/Env.hs create mode 100644 src/CHAD/AST/Pretty.hs create mode 100644 src/CHAD/AST/Sparse.hs create mode 100644 src/CHAD/AST/Sparse/Types.hs create mode 100644 src/CHAD/AST/SplitLets.hs create mode 100644 src/CHAD/AST/Types.hs create mode 100644 src/CHAD/AST/UnMonoid.hs create mode 100644 src/CHAD/AST/Weaken.hs create mode 100644 src/CHAD/AST/Weaken/Auto.hs delete mode 100644 src/CHAD/Accum.hs create mode 100644 src/CHAD/Analysis/Identity.hs create mode 100644 src/CHAD/Array.hs create mode 100644 src/CHAD/Compile.hs create mode 100644 src/CHAD/Compile/Exec.hs create mode 100644 src/CHAD/Data.hs create mode 100644 src/CHAD/Data/VarMap.hs create mode 100644 src/CHAD/Drev.hs create mode 100644 src/CHAD/Drev/Accum.hs create mode 100644 src/CHAD/Drev/EnvDescr.hs create mode 100644 src/CHAD/Drev/Top.hs create mode 100644 src/CHAD/Drev/Types.hs create mode 100644 src/CHAD/Drev/Types/ToTan.hs delete mode 100644 src/CHAD/EnvDescr.hs create mode 100644 src/CHAD/Example.hs create mode 100644 src/CHAD/Example/GMM.hs create mode 100644 src/CHAD/Example/Types.hs create mode 100644 src/CHAD/ForwardAD.hs create mode 100644 src/CHAD/ForwardAD/DualNumbers.hs create mode 100644 src/CHAD/ForwardAD/DualNumbers/Types.hs create mode 100644 src/CHAD/Interpreter.hs create mode 100644 src/CHAD/Interpreter/Accum.hs create mode 100644 src/CHAD/Interpreter/AccumOld.hs create mode 100644 src/CHAD/Interpreter/Rep.hs create mode 100644 src/CHAD/Language.hs create mode 100644 src/CHAD/Language/AST.hs create mode 100644 src/CHAD/Lemmas.hs create mode 100644 src/CHAD/Simplify.hs create mode 100644 src/CHAD/Simplify/TH.hs delete mode 100644 src/CHAD/Top.hs delete mode 100644 src/CHAD/Types.hs delete mode 100644 src/CHAD/Types/ToTan.hs create mode 100644 src/CHAD/Util/IdGen.hs delete mode 100644 src/Compile.hs delete mode 100644 src/Compile/Exec.hs delete mode 100644 src/Data.hs delete mode 100644 src/Data/VarMap.hs delete mode 100644 src/Example.hs delete mode 100644 src/Example/GMM.hs delete mode 100644 src/Example/Types.hs delete mode 100644 src/ForwardAD.hs delete mode 100644 src/ForwardAD/DualNumbers.hs delete mode 100644 src/ForwardAD/DualNumbers/Types.hs delete mode 100644 src/Interpreter.hs delete mode 100644 src/Interpreter/Accum.hs delete mode 100644 src/Interpreter/AccumOld.hs delete mode 100644 src/Interpreter/Rep.hs delete mode 100644 src/Language.hs delete mode 100644 src/Language/AST.hs delete mode 100644 src/Lemmas.hs delete mode 100644 src/Simplify.hs delete mode 100644 src/Simplify/TH.hs delete mode 100644 src/Util/IdGen.hs (limited to 'src') diff --git a/src/AST.hs b/src/AST.hs deleted file mode 100644 index ca6cdd1..0000000 --- a/src/AST.hs +++ /dev/null @@ -1,705 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFoldable #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST (module AST, module AST.Types, module AST.Accum, module AST.Weaken) where - -import Data.Functor.Const -import Data.Functor.Identity -import Data.Int (Int64) -import Data.Kind (Type) - -import Array -import AST.Accum -import AST.Sparse.Types -import AST.Types -import AST.Weaken -import CHAD.Types -import Data - - --- General assumption: head of the list (whatever way it is associated) is the --- inner variable / inner array dimension. In pretty printing, the inner --- variable / inner dimension is printed on the _right_. --- --- All the monoid operations are unsupposed as the input to CHAD, and are --- intended to be eliminated after simplification, so that the input program as --- well as the output program do not contain these constructors. --- TODO: ensure this by a "stage" type parameter. -type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type -data Expr x env t where - -- lambda calculus - EVar :: x t -> STy t -> Idx env t -> Expr x env t - ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t - - -- base types - EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) - EFst :: x a -> Expr x env (TPair a b) -> Expr x env a - ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b - ENil :: x TNil -> Expr x env TNil - EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) - EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) - ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) - EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) - EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b - - -- array operations - EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) - EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) - EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) - -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) - EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) - ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) - EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) - EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) - EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) - EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) - - -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: - -- an implementation is allowed to parallelise this thing and store the b - -- values in some implementation-defined order. - -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. - EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative - -> Expr x (TPair t1 t1 : env) (TPair t1 b) - -> Expr x env t1 - -> Expr x env (TArr (S n) t1) - -> Expr x env (TPair (TArr n t1) -- normal primal fold output - (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) - -- Reverse derivative of EFold1Inner. The contributions to the initial - -- element are not yet added together here; we assume a later fusion system - -- does that for us. - EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative - -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) - -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 - -> Expr x env (TArr n t2) -- incoming cotangent - -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array - - -- expression operations - EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) - EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t - EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) - EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t - EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) - EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t - - -- custom derivatives - -- 'b' is the part of the input of the operation that derivatives should - -- be backpropagated to; 'a' is the inactive part. The dual field of - -- ECustom does not allow a derivative to be generated for 'a', and hence - -- none is propagated. - -- No accumulators are allowed inside a, b and tape. This restriction is - -- currently not used very much, so could be relaxed in the future; be sure - -- to check this requirement whenever it is necessary for soundness! - ECustom :: x t -> STy a -> STy b -> STy tape - -> Expr x [b, a] t -- ^ regular operation - -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative - -> Expr x env a -> Expr x env b - -> Expr x env t - - -- fake halfway checkpointing - ERecompute :: x t -> Expr x env t -> Expr x env t - - -- accumulation effect on monoids - -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it - -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not - -- need to create any zeros. - EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) - -- The 'Sparse' here is eliminated to dense by UnMonoid. - EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil - - -- monoidal operations (to be desugared to regular operations after simplification) - EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t - EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t - EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t - EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t - - -- interface of abstract monoidal types - ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) - ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) - ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) - ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c - - -- partiality - EError :: x a -> STy a -> String -> Expr x env a -deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) - -type Ex = Expr (Const ()) - -ext :: Const () a -ext = Const () - -data Commutative = Commut | Noncommut - deriving (Show) - -type SOp :: Ty -> Ty -> Type -data SOp a t where - OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) - ONot :: SOp (TScal TBool) (TScal TBool) - OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) - OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right - ORound64 :: SOp (TScal TF64) (TScal TI64) - OToFl64 :: SOp (TScal TI64) (TScal TF64) - ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) - OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) - OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) -deriving instance Show (SOp a t) - -opt1 :: SOp a t -> STy a -opt1 = \case - OAdd t -> STPair (STScal t) (STScal t) - OMul t -> STPair (STScal t) (STScal t) - ONeg t -> STScal t - OLt t -> STPair (STScal t) (STScal t) - OLe t -> STPair (STScal t) (STScal t) - OEq t -> STPair (STScal t) (STScal t) - ONot -> STScal STBool - OAnd -> STPair (STScal STBool) (STScal STBool) - OOr -> STPair (STScal STBool) (STScal STBool) - OIf -> STScal STBool - ORound64 -> STScal STF64 - OToFl64 -> STScal STI64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STPair (STScal t) (STScal t) - OMod t -> STPair (STScal t) (STScal t) - -opt2 :: SOp a t -> STy t -opt2 = \case - OAdd t -> STScal t - OMul t -> STScal t - ONeg t -> STScal t - OLt _ -> STScal STBool - OLe _ -> STScal STBool - OEq _ -> STScal STBool - ONot -> STScal STBool - OAnd -> STScal STBool - OOr -> STScal STBool - OIf -> STEither STNil STNil - ORound64 -> STScal STI64 - OToFl64 -> STScal STF64 - ORecip t -> STScal t - OExp t -> STScal t - OLog t -> STScal t - OIDiv t -> STScal t - OMod t -> STScal t - -typeOf :: Expr x env t -> STy t -typeOf = \case - EVar _ t _ -> t - ELet _ _ e -> typeOf e - - EPair _ a b -> STPair (typeOf a) (typeOf b) - EFst _ e | STPair t _ <- typeOf e -> t - ESnd _ e | STPair _ t <- typeOf e -> t - ENil _ -> STNil - EInl _ t2 e -> STEither (typeOf e) t2 - EInr _ t1 e -> STEither t1 (typeOf e) - ECase _ _ a _ -> typeOf a - ENothing _ t -> STMaybe t - EJust _ e -> STMaybe (typeOf e) - EMaybe _ e _ _ -> typeOf e - ELNil _ t1 t2 -> STLEither t1 t2 - ELInl _ t2 e -> STLEither (typeOf e) t2 - ELInr _ t1 e -> STLEither t1 (typeOf e) - ELCase _ _ a _ _ -> typeOf a - - EConstArr _ n t _ -> STArr n (STScal t) - EBuild _ n _ e -> STArr n (typeOf e) - EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) - EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t - ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EUnit _ e -> STArr SZ (typeOf e) - EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t - EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t - EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t - EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) - - EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) - EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) - - EConst _ t _ -> STScal t - EIdx0 _ e | STArr _ t <- typeOf e -> t - EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t - EIdx _ e _ | STArr _ t <- typeOf e -> t - EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) - EOp _ op _ -> opt2 op - - ECustom _ _ _ _ e _ _ _ _ -> typeOf e - ERecompute _ e -> typeOf e - - EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) - EAccum _ _ _ _ _ _ _ -> STNil - - EZero _ t _ -> fromSMTy t - EDeepZero _ t _ -> fromSMTy t - EPlus _ t _ _ -> fromSMTy t - EOneHot _ t _ _ _ -> fromSMTy t - - EError _ t _ -> t - -extOf :: Expr x env t -> x t -extOf = \case - EVar x _ _ -> x - ELet x _ _ -> x - EPair x _ _ -> x - EFst x _ -> x - ESnd x _ -> x - ENil x -> x - EInl x _ _ -> x - EInr x _ _ -> x - ECase x _ _ _ -> x - ENothing x _ -> x - EJust x _ -> x - EMaybe x _ _ _ -> x - ELNil x _ _ -> x - ELInl x _ _ -> x - ELInr x _ _ -> x - ELCase x _ _ _ _ -> x - EConstArr x _ _ _ -> x - EBuild x _ _ _ -> x - EMap x _ _ -> x - EFold1Inner x _ _ _ _ -> x - ESum1Inner x _ -> x - EUnit x _ -> x - EReplicate1Inner x _ _ -> x - EMaximum1Inner x _ -> x - EMinimum1Inner x _ -> x - EReshape x _ _ _ -> x - EZip x _ _ -> x - EFold1InnerD1 x _ _ _ _ -> x - EFold1InnerD2 x _ _ _ _ -> x - EConst x _ _ -> x - EIdx0 x _ -> x - EIdx1 x _ _ -> x - EIdx x _ _ -> x - EShape x _ -> x - EOp x _ _ -> x - ECustom x _ _ _ _ _ _ _ _ -> x - ERecompute x _ -> x - EWith x _ _ _ -> x - EAccum x _ _ _ _ _ _ -> x - EZero x _ _ -> x - EDeepZero x _ _ -> x - EPlus x _ _ _ -> x - EOneHot x _ _ _ _ -> x - EError x _ _ -> x - -mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t -mapExt f = runIdentity . travExt (Identity . f) - -{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} -travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) -travExt f = \case - EVar x t i -> EVar <$> f x <*> pure t <*> pure i - ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body - EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b - EFst x e -> EFst <$> f x <*> travExt f e - ESnd x e -> ESnd <$> f x <*> travExt f e - ENil x -> ENil <$> f x - EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e - EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e - ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b - ENothing x t -> ENothing <$> f x <*> pure t - EJust x e -> EJust <$> f x <*> travExt f e - EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e - ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 - ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e - ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e - ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c - EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a - EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b - EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b - EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e - EUnit x e -> EUnit <$> f x <*> travExt f e - EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b - EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e - EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e - EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b - EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b - EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c - EConst x t v -> EConst <$> f x <*> pure t <*> pure v - EIdx0 x e -> EIdx0 <$> f x <*> travExt f e - EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b - EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es - EShape x e -> EShape <$> f x <*> travExt f e - EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e - ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 - ERecompute x e -> ERecompute <$> f x <*> travExt f e - EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 - EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 - EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e - EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e - EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b - EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b - EError x t s -> EError <$> f x <*> pure t <*> pure s - -substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t -substInline repl = - subst $ \x t -> \case IZ -> repl - IS i -> EVar x t i - -subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t -subst0 repl = - subst $ \_ t -> \case IZ -> repl - IS i -> EVar ext t (IS i) - -subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) - -> Expr x env t -> Expr x env' t -subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId - -subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) - -> env' :> envOut - -> Expr x env t - -> Expr x envOut t -subst' f w = \case - EVar x t i -> f x t w i - ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) - EPair x a b -> EPair x (subst' f w a) (subst' f w b) - EFst x e -> EFst x (subst' f w e) - ESnd x e -> ESnd x (subst' f w e) - ENil x -> ENil x - EInl x t e -> EInl x t (subst' f w e) - EInr x t e -> EInr x t (subst' f w e) - ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) - ENothing x t -> ENothing x t - EJust x e -> EJust x (subst' f w e) - EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) - ELNil x t1 t2 -> ELNil x t1 t2 - ELInl x t e -> ELInl x t (subst' f w e) - ELInr x t e -> ELInr x t (subst' f w e) - ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) - EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) - EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) - ESum1Inner x e -> ESum1Inner x (subst' f w e) - EUnit x e -> EUnit x (subst' f w e) - EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) - EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) - EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) - EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) - EZip x a b -> EZip x (subst' f w a) (subst' f w b) - EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) - EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (subst' f w e) - EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) - EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) - EShape x e -> EShape x (subst' f w e) - EOp x op e -> EOp x op (subst' f w e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) - ERecompute x e -> ERecompute x (subst' f w e) - EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) - EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) - EZero x t e -> EZero x t (subst' f w e) - EDeepZero x t e -> EDeepZero x t (subst' f w e) - EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) - EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) - EError x t s -> EError x t s - where - sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) - -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t - sinkF f' x' t w' = \case - IZ -> EVar x' t (w' @> IZ) - IS i -> f' x' t (WPop w') i - -weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t -weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) - -class KnownScalTy t where knownScalTy :: SScalTy t -instance KnownScalTy TI32 where knownScalTy = STI32 -instance KnownScalTy TI64 where knownScalTy = STI64 -instance KnownScalTy TF32 where knownScalTy = STF32 -instance KnownScalTy TF64 where knownScalTy = STF64 -instance KnownScalTy TBool where knownScalTy = STBool - -class KnownTy t where knownTy :: STy t -instance KnownTy TNil where knownTy = STNil -instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy -instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy -instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy -instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy -instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy -instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy - -class KnownMTy t where knownMTy :: SMTy t -instance KnownMTy TNil where knownMTy = SMTNil -instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy -instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy -instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy -instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy -instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy - -class KnownEnv env where knownEnv :: SList STy env -instance KnownEnv '[] where knownEnv = SNil -instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv - -styKnown :: STy t -> Dict (KnownTy t) -styKnown STNil = Dict -styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict -styKnown (STMaybe t) | Dict <- styKnown t = Dict -styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict -styKnown (STScal t) | Dict <- sscaltyKnown t = Dict -styKnown (STAccum t) | Dict <- smtyKnown t = Dict - -smtyKnown :: SMTy t -> Dict (KnownMTy t) -smtyKnown SMTNil = Dict -smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict -smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict -smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict -smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict - -sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) -sscaltyKnown STI32 = Dict -sscaltyKnown STI64 = Dict -sscaltyKnown STF32 = Dict -sscaltyKnown STF64 = Dict -sscaltyKnown STBool = Dict - -envKnown :: SList STy env -> Dict (KnownEnv env) -envKnown SNil = Dict -envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict - -cheapExpr :: Expr x env t -> Bool -cheapExpr = \case - EVar{} -> True - ENil{} -> True - EConst{} -> True - EFst _ e -> cheapExpr e - ESnd _ e -> cheapExpr e - EUnit _ e -> cheapExpr e - _ -> False - -eTup :: SList (Ex env) list -> Ex env (Tup list) -eTup = mkTup (ENil ext) (EPair ext) - -ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) -ebuildUp1 n sh size f = - EBuild ext (SS n) (EPair ext sh size) $ - let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ - in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) - (EFst ext arg) - -eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eidxEq SZ _ _ = EConst ext STBool True -eidxEq (SS SZ) a b = - EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) -eidxEq (SS n) a b - | let ty = tTup (sreplicate (SS n) tIx) - = ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EOp ext OAnd $ EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) - (ESnd ext (EVar ext ty IZ)))) - (eidxEq n (EFst ext (EVar ext ty (IS IZ))) - (EFst ext (EVar ext ty IZ))) - -emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) -emap f arr - | STArr _ t <- typeOf arr - , Dict <- styKnown t - = EMap ext f arr - -ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) -ezipWith f arr1 arr2 - | STArr _ t1 <- typeOf arr1 - , STArr _ t2 <- typeOf arr2 - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) - IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) - IS (IS i) -> EVar ext t (IS i)) - f) - (EZip ext arr1 arr2) - -ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) -ezip = EZip ext - -eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a -eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) - --- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). -eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) -eshapeEmpty SZ _ = EConst ext STBool False -eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) -eshapeEmpty (SS n) e = - ELet ext e $ - EOp ext OAnd (EPair ext - (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) - (EConst ext STI64 0))) - (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) - -eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) -eshapeConst ShNil = ENil ext -eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) - -eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -eshapeProd SZ _ = EConst ext STI64 1 -eshapeProd (SS SZ) e = ESnd ext e -eshapeProd (SS n) e = - eunPair e $ \_ e1 e2 -> - EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) - -eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) -eflatten e = - let STArr n _ = typeOf e - in elet e $ - EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) - --- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) --- ezeroD2 t ezi = EZero ext (d2M t) ezi - --- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil --- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea - --- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) --- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev - -eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r -eunPair (EPair _ e1 e2) k = k WId e1 e2 -eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) -eunPair e k = - elet e $ - k WSink - (EFst ext (evar IZ)) - (ESnd ext (evar IZ)) - -efst :: Ex env (TPair a b) -> Ex env a -efst (EPair _ e1 _) = e1 -efst e = EFst ext e - -esnd :: Ex env (TPair a b) -> Ex env b -esnd (EPair _ _ e2) = e2 -esnd e = ESnd ext e - -elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b -elet rhs body - | Dict <- styKnown (typeOf rhs) - = if cheapExpr rhs - then substInline rhs body - else ELet ext rhs body - --- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) -use :: Ex env a -> Ex env b -> Ex env b -use a b = elet a $ weakenExpr WSink b - -emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b -emaybe e a b - | STMaybe t <- typeOf e - , Dict <- styKnown t - = EMaybe ext a b e - -ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c -ecase e a b - | STEither t1 t2 <- typeOf e - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = ECase ext e a b - -elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c -elcase e a b c - | STLEither t1 t2 <- typeOf e - , Dict <- styKnown t1 - , Dict <- styKnown t2 - = ELCase ext e a b c - -evar :: KnownTy a => Idx env a -> Ex env a -evar = EVar ext knownTy - -makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) -makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) - where - -- invariant: expression argument is duplicable - go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) - go SMTNil _ = ENil ext - go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) - go SMTLEither{} _ = ENil ext - go SMTMaybe{} _ = ENil ext - go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e - go SMTScal{} _ = ENil ext - -splitSparsePair - :: -- given a sparsity - STy (TPair a b) -> Sparse (TPair a b) t' - -> (forall a' b'. - -- I give you back two sparsities for a and b - Sparse a a' -> Sparse b b' - -- furthermore, I tell you that either your t' is already this (a', b') pair... - -> Either - (t' :~: TPair a' b') - -- or I tell you how to construct a' and b' from t', given an actual t' - (forall r' env. - Idx env t' - -> (forall env'. - (forall c. Ex env' c -> Ex env c) - -> Ex env' a' -> Ex env' b' -> r') - -> r') - -> r) - -> r -splitSparsePair _ SpAbsent k = - k SpAbsent SpAbsent $ Right $ \_ k2 -> - k2 id (ENil ext) (ENil ext) -splitSparsePair _ (SpPair s1 s2) k1 = - k1 s1 s2 $ Left Refl -splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = - let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in - k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> - k2 (elet $ - emaybe (EVar ext (STMaybe (applySparse s t)) i) - (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) - (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) - (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) - -splitSparsePair _ (SpSparse SpAbsent) k = - k SpAbsent SpAbsent $ Right $ \_ k2 -> - k2 id (ENil ext) (ENil ext) --- -- TODO: having to handle sparse-of-sparse at all is ridiculous -splitSparsePair t (SpSparse (SpSparse s)) k = - splitSparsePair t (SpSparse s) $ \s1 s2 eres -> - k s1 s2 $ Right $ \i k2 -> - case eres of - Left refl -> case refl of {} - Right f -> - f IZ $ \wrap e1 e2 -> - k2 (\body -> - elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) - (ENothing ext (applySparse s t)) - (evar IZ)) $ - wrap body) - e1 e2 diff --git a/src/AST/Accum.hs b/src/AST/Accum.hs deleted file mode 100644 index 988a450..0000000 --- a/src/AST/Accum.hs +++ /dev/null @@ -1,137 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeData #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.Accum where - -import AST.Types -import Data - - -data AcPrj - = APHere - | APFst AcPrj - | APSnd AcPrj - | APLeft AcPrj - | APRight AcPrj - | APJust AcPrj - | APArrIdx AcPrj - | APArrSlice Nat - --- | @b@ is a small part of @a@, indicated by the projection @p@. -data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where - SAPHere :: SAcPrj APHere a a - SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b - SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b - SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b - SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b - SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b - SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b - -- TODO: - -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) -deriving instance Show (SAcPrj p a b) - -type data AIDense = AID | AIS - -data SAIDense d where - SAID :: SAIDense AID - SAIS :: SAIDense AIS -deriving instance Show (SAIDense d) - -type family AcIdx d p t where - AcIdx d APHere t = TNil - AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a - AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b - AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) - AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) - AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a - AcIdx d (APRight p) (TLEither a b) = AcIdx d p b - AcIdx d (APJust p) (TMaybe a) = AcIdx d p a - AcIdx AID (APArrIdx p) (TArr n a) = - -- (index, recursive info) - TPair (Tup (Replicate n TIx)) (AcIdx AID p a) - AcIdx AIS (APArrIdx p) (TArr n a) = - -- ((index, shape info), recursive info) - TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) - (AcIdx AIS p a) - -- AcIdx AID (APArrSlice m) (TArr n a) = - -- -- index - -- Tup (Replicate m TIx) - -- AcIdx AIS (APArrSlice m) (TArr n a) = - -- -- (index, array shape) - -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) - -type AcIdxD p t = AcIdx AID p t -type AcIdxS p t = AcIdx AIS p t - -acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b -acPrjTy SAPHere t = t -acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t -acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t -acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t -acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t -acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t -acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t - -type family ZeroInfo t where - ZeroInfo TNil = TNil - ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) - ZeroInfo (TLEither a b) = TNil - ZeroInfo (TMaybe a) = TNil - ZeroInfo (TArr n t) = TArr n (ZeroInfo t) - ZeroInfo (TScal t) = TNil - -tZeroInfo :: SMTy t -> STy (ZeroInfo t) -tZeroInfo SMTNil = STNil -tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) -tZeroInfo (SMTLEither _ _) = STNil -tZeroInfo (SMTMaybe _) = STNil -tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) -tZeroInfo (SMTScal _) = STNil - --- | Info needed to create a zero-valued deep accumulator for a monoid type. --- Should be constructable from a D1. -type family DeepZeroInfo t where - DeepZeroInfo TNil = TNil - DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) - DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) - DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) - DeepZeroInfo (TScal t) = TNil - -tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) -tDeepZeroInfo SMTNil = STNil -tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) -tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) -tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) -tDeepZeroInfo (SMTScal _) = STNil - --- -- | Additional info needed for accumulation. This is empty unless there is --- -- sparsity in the monoid. --- type family AccumInfo t where --- AccumInfo TNil = TNil --- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) --- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) --- AccumInfo (TArr n t) = TArr n (AccumInfo t) --- AccumInfo (TScal t) = TNil - --- type family PrimalInfo t where --- PrimalInfo TNil = TNil --- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) --- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) --- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) --- PrimalInfo (TScal t) = TNil - --- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) --- tPrimalInfo SMTNil = STNil --- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) --- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) --- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) --- tPrimalInfo (SMTScal _) = STNil diff --git a/src/AST/Bindings.hs b/src/AST/Bindings.hs deleted file mode 100644 index 463586a..0000000 --- a/src/AST/Bindings.hs +++ /dev/null @@ -1,84 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module AST.Bindings where - -import AST -import AST.Env -import Data -import Lemmas - - --- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. -data Bindings f env binds where - BTop :: Bindings f env '[] - BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) -deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') -infixl `BPush` - -bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) -bpush b e = b `BPush` (typeOf e, e) -infixl `bpush` - -mapBindings :: (forall env' t'. f env' t' -> g env' t') - -> Bindings f env binds -> Bindings g env binds -mapBindings _ BTop = BTop -mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) - -weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) - -> env1 :> env2 - -> Bindings f env1 binds - -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) -weakenBindings _ w BTop = (BTop, w) -weakenBindings wf w (BPush b (t, x)) = - let (b', w') = weakenBindings wf w b - in (BPush b' (t, wf w' x), WCopy w') - -weakenBindingsE :: env1 :> env2 - -> Bindings (Expr x) env1 binds - -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) -weakenBindingsE = weakenBindings weakenExpr - -weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' -weakenOver SNil w = w -weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) - -sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env' -sinkWithBindings BTop = WId -sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b - -bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) -bconcat b1 BTop = b1 -bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) - | Refl <- lemAppendAssoc @binds2C @binds1 @env - = BPush (bconcat b1 b2) (t, x) - -bindingsBinds :: Bindings f env binds -> SList STy binds -bindingsBinds BTop = SNil -bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) - -letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t -letBinds BTop = id -letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs - -collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' -collectBindings = \env -> fst . go env WId - where - go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) - go _ _ SETop = (BTop, WId) - go (ty `SCons` env) w (SEYesR sub) = - let (bs, w') = go env (WPop w) sub - in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') - go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/AST/Count.hs b/src/AST/Count.hs deleted file mode 100644 index ac8634e..0000000 --- a/src/AST/Count.hs +++ /dev/null @@ -1,930 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE PatternSynonyms #-} -module AST.Count where - -import Data.Functor.Product -import Data.Some -import Data.Type.Equality -import GHC.Generics (Generic, Generically(..)) - -import Array -import AST -import AST.Env -import Data - - --- | The monoid operation combines assuming that /both/ branches are taken. -class Monoid a => Occurrence a where - -- | One of the two branches is taken - (<||>) :: a -> a -> a - -- | This code is executed many times - scaleMany :: a -> a - - -data Count = Zero | One | Many - deriving (Show, Eq, Ord) - -instance Semigroup Count where - Zero <> n = n - n <> Zero = n - _ <> _ = Many -instance Monoid Count where - mempty = Zero -instance Occurrence Count where - (<||>) = max - scaleMany Zero = Zero - scaleMany _ = Many - -data Occ = Occ { _occLexical :: Count - , _occRuntime :: Count } - deriving (Eq, Generic) - deriving (Semigroup, Monoid) via Generically Occ - -instance Show Occ where - showsPrec d (Occ l r) = showParen (d > 10) $ - showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r - -instance Occurrence Occ where - Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) - scaleMany (Occ l c) = Occ l (scaleMany c) - - -data Substruc t t' where - -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below - SsFull :: Substruc t t - SsNone :: Substruc t TNil - SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') - SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') - SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') - SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') - SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements - SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') - -pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' -pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) - where SsPair' = SsPair -{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} - -pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' -pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) - where SsArr' = SsArr -{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} - -instance Semigroup (Some (Substruc t)) where - Some SsFull <> _ = Some SsFull - _ <> Some SsFull = Some SsFull - Some SsNone <> s = s - s <> Some SsNone = s - Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) - Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) - Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) - Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) - Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) - Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) -instance Monoid (Some (Substruc t)) where - mempty = Some SsNone - -instance TestEquality (Substruc t) where - testEquality SsFull s = isFull s - testEquality s SsFull = sym <$> isFull s - testEquality SsNone SsNone = Just Refl - testEquality SsNone _ = Nothing - testEquality _ SsNone = Nothing - testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing - testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing - -isFull :: Substruc t t' -> Maybe (t :~: t') -isFull SsFull = Just Refl -isFull SsNone = Nothing -- TODO: nil? -isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing -isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing -isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing - -applySubstruc :: Substruc t t' -> STy t -> STy t' -applySubstruc SsFull t = t -applySubstruc SsNone _ = STNil -applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) -applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) -applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) -applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) - -applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' -applySubstrucM SsFull t = t -applySubstrucM SsNone _ = SMTNil -applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) -applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) -applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) -applySubstrucM _ t = case t of {} - -data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) - | a ~ b => ExMapId - -fromExMap :: ExMap a b -> Ex env a -> Ex env b -fromExMap (ExMap f) = f -fromExMap ExMapId = id - -simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' -simplifySubstruc STNil SsNone = SsFull - -simplifySubstruc _ SsFull = SsFull -simplifySubstruc _ SsNone = SsNone -simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) -simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) -simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) -simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) - --- simplifySubstruc' :: Substruc t t' --- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r --- simplifySubstruc' SsFull k = k SsFull ExMapId --- simplifySubstruc' SsNone k = k SsNone ExMapId --- simplifySubstruc' (SsPair s1 s2) k = --- simplifySubstruc' s1 $ \s1' f1 -> --- simplifySubstruc' s2 $ \s2' f2 -> --- case (s1', s2') of --- (SsFull, SsFull) -> --- k SsFull (case (f1, f2) of --- (ExMapId, ExMapId) -> ExMapId --- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> --- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) --- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) --- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) --- simplifySubstruc' _ _ = _ - --- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) --- ssUnpair SsFull = (SsFull, SsFull) --- ssUnpair SsNone = (SsNone, SsNone) --- ssUnpair (SsPair a b) = (a, b) - --- ssUnleft :: Substruc (TEither a b) -> Substruc a --- ssUnleft SsFull = SsFull --- ssUnleft SsNone = SsNone --- ssUnleft (SsEither a _) = a - --- ssUnright :: Substruc (TEither a b) -> Substruc b --- ssUnright SsFull = SsFull --- ssUnright SsNone = SsNone --- ssUnright (SsEither _ b) = b - --- ssUnlleft :: Substruc (TLEither a b) -> Substruc a --- ssUnlleft SsFull = SsFull --- ssUnlleft SsNone = SsNone --- ssUnlleft (SsLEither a _) = a - --- ssUnlright :: Substruc (TLEither a b) -> Substruc b --- ssUnlright SsFull = SsFull --- ssUnlright SsNone = SsNone --- ssUnlright (SsLEither _ b) = b - --- ssUnjust :: Substruc (TMaybe a) -> Substruc a --- ssUnjust SsFull = SsFull --- ssUnjust SsNone = SsNone --- ssUnjust (SsMaybe a) = a - --- ssUnarr :: Substruc (TArr n a) -> Substruc a --- ssUnarr SsFull = SsFull --- ssUnarr SsNone = SsNone --- ssUnarr (SsArr a) = a - --- ssUnaccum :: Substruc (TAccum a) -> Substruc a --- ssUnaccum SsFull = SsFull --- ssUnaccum SsNone = SsNone --- ssUnaccum (SsAccum a) = a - - -type family MapEmpty env where - MapEmpty '[] = '[] - MapEmpty (t : env) = TNil : MapEmpty env - -data OccEnv a env env' where - OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! - OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') - -instance Semigroup a => Semigroup (Some (OccEnv a env)) where - Some OccEnd <> e = e - e <> Some OccEnd = e - Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) - -instance Semigroup a => Monoid (Some (OccEnv a env)) where - mempty = Some OccEnd - -instance Occurrence a => Occurrence (Some (OccEnv a env)) where - Some OccEnd <||> e = e - e <||> Some OccEnd = e - Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) - - scaleMany (Some OccEnd) = Some OccEnd - scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) - -onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) -onehotOccEnv IZ v s = Some (OccPush OccEnd v s) -onehotOccEnv (IS i) v s - | Some env' <- onehotOccEnv i v s - = Some (OccPush env' mempty SsNone) - -occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') -occEnvPop (OccPush e _ s) = (e, s) -occEnvPop OccEnd = (OccEnd, SsNone) - -occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r -occEnvPop' (OccPush e _ s) k = k e s -occEnvPop' OccEnd k = k OccEnd SsNone - -occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) -occEnvPopSome (Some (OccPush e _ _)) = Some e -occEnvPopSome (Some OccEnd) = Some OccEnd - -occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) -occEnvPrj OccEnd _ = mempty -occEnvPrj (OccPush _ o s) IZ = (o, Some s) -occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i - -occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) -occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) -occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) -occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) -occEnvPrjS (OccPush e _ _) (IS i) - | Some (Pair s' i') <- occEnvPrjS e i - = Some (Pair s' (IS i')) - -projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small -projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of - _ | Just Refl <- testEquality topsbig topssmall -> ex - - (SsFull, SsFull) -> ex - (SsNone, SsNone) -> ex - (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" - (_, SsNone) -> - case typeOf ex of - STNil -> ex - _ -> use ex $ ENil ext - - (SsPair s1 s2, SsPair s1' s2') -> - eunPair ex $ \_ e1 e2 -> - EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) - (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex - (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex - - (SsEither s1 s2, SsEither s1' s2') - | STEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in ecase ex - (EInl ext (typeOf e2) e1) - (EInr ext (typeOf e1) e2) - (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex - (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex - - (SsLEither s1 s2, SsLEither s1' s2') - | STLEither t1 t2 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) - in elcase ex - (ELNil ext (typeOf e1) (typeOf e2)) - (ELInl ext (typeOf e2) e1) - (ELInr ext (typeOf e1) e2) - (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex - (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex - - (SsMaybe s1, SsMaybe s1') - | STMaybe t1 <- typeOf ex -> - let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) - in emaybe ex - (ENothing ext (typeOf e1)) - (EJust ext e1) - (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex - (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex - - (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex - (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex - (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex - - (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" - (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex - (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex - - --- | A boolean for each entry in the environment, with the ability to uniformly --- mask the top part above a certain index. -data EnvMask env where - EMRest :: Bool -> EnvMask env - EMPush :: EnvMask env -> Bool -> EnvMask (t : env) - -envMaskPrj :: EnvMask env -> Idx env t -> Bool -envMaskPrj (EMRest b) _ = b -envMaskPrj (_ `EMPush` b) IZ = b -envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i - -occCount :: Idx env a -> Expr x env t -> Occ -occCount idx ex - | Some env <- occCountAll ex - = fst (occEnvPrj env idx) - -occCountAll :: Expr x env t -> Some (OccEnv Occ env) -occCountAll ex = occCountX SsFull ex $ \env _ -> Some env - -pruneExpr :: SList f env -> Expr x env t -> Ex env t -pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) - where - fullOccEnv :: SList f env -> OccEnv () env env - fullOccEnv SNil = OccEnd - fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull - --- In one traversal, count occurrences of variables and determine what parts of --- expressions are actually used. These two results are computed independently: --- even if (almost) nothing of a particular term is actually used, variable --- references in that term still count as usual. --- --- In @occCountX s t k@: --- * s: how much of the result of this term is required --- * t: the term to analyse --- * k: is passed the actual environment usage of this expression, including --- occurrence counts. The callback reconstructs a new expression in an --- updated "response" environment. The response must be at least as large as --- the computed usages. -occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t - -> (forall env'. OccEnv Occ env env' - -- response OccEnv must be at least as large as the OccEnv returned above - -> (forall env''. OccEnv () env env'' -> Ex env'' t') - -> r) - -> r -occCountX initialS topexpr k = case topexpr of - EVar _ t i -> - withSome (onehotOccEnv i (Occ One One) s) $ \env -> - k env $ \env' -> - withSome (occEnvPrjS env' i) $ \(Pair s' i') -> - projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') - ELet _ rhs body -> - occCountX s body $ \envB mkbody -> - occEnvPop' envB $ \envB' s1 -> - occCountX s1 rhs $ \envR mkrhs -> - withSome (Some envB' <> Some envR) $ \env -> - k env $ \env' -> - ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) - EPair _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsPair' s1 s2 -> - occCountX s1 a $ \env1 mka -> - occCountX s2 b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPair ext (mka env') (mkb env') - EFst _ e -> - occCountX (SsPair s SsNone) e $ \env1 mke -> - k env1 $ \env' -> - EFst ext (mke env') - ESnd _ e -> - occCountX (SsPair SsNone s) e $ \env1 mke -> - k env1 $ \env' -> - ESnd ext (mke env') - ENil _ -> - case s of - SsFull -> k OccEnd (\_ -> ENil ext) - SsNone -> k OccEnd (\_ -> ENil ext) - EInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - EInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - EInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - EInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsEither SsFull SsFull) topexpr k - ECase _ e a b -> - occCountX s a $ \env1' mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env1' $ \env1 s1 -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) - ENothing _ t -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EJust _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsMaybe s' -> - occCountX s' e $ \env1 mke -> - k env1 $ \env' -> - EJust ext (mke env') - SsFull -> occCountX (SsMaybe SsFull) topexpr k - EMaybe _ a b e -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s2 -> - occCountX (SsMaybe s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> - k env $ \env' -> - EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') - ELNil _ t1 t2 -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInl _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s1 e $ \env1 mke -> - k env1 $ \env' -> - ELInl ext (applySubstruc s2 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELInr _ t e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - SsLEither s1 s2 -> - occCountX s2 e $ \env1 mke -> - k env1 $ \env' -> - ELInr ext (applySubstruc s1 t) (mke env') - SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k - ELCase _ e a b c -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2' mkb -> - occCountX s c $ \env3' mkc -> - occEnvPop' env2' $ \env2 s1 -> - occEnvPop' env3' $ \env3 s2 -> - occCountX (SsLEither s1 s2) e $ \env0 mke -> - withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> - k env $ \env' -> - ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) - - EConstArr _ n t x -> - case s of - SsNone -> k OccEnd (\_ -> ENil ext) - SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) - SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) - - EBuild _ n a b -> - case s of - SsNone -> - occCountX SsFull a $ \env1 mka -> - occCountX SsNone b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - use (EBuild ext n (mka env') $ - use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ - ENil ext) $ - ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX s' b $ \env2'' mkb -> - occEnvPop' env2'' $ \env2' s2 -> - withSome (Some env1 <> scaleMany (Some env2')) $ \env -> - k env $ \env' -> - EBuild ext n (mka env') $ - elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) - - EMap _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ - ENil ext - SsArr' s' -> - occCountX s' a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1 -> - occCountX (SsArr s1) b $ \env2 mkb -> - withSome (scaleMany (Some env1') <> Some env2) $ \env -> - k env $ \env' -> - EMap ext (mka (OccPush env' () s1)) (mkb env') - - EFold1Inner _ commut a b c -> - occCountX SsFull a $ \env1'' mka -> - occEnvPop' env1'' $ \env1' s1' -> - let s1 = case s1' of - SsNone -> Some SsNone - SsPair' s1'a s1'b -> Some s1'a <> Some s1'b - s0 = case s of - SsNone -> Some SsNone - SsArr' s' -> Some s' in - withSome (s1 <> s0) $ \sElt -> - occCountX sElt b $ \env2 mkb -> - occCountX (SsArr sElt) c $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsArr sElt) s $ - EFold1Inner ext commut - (projectSmallerSubstruc SsFull sElt $ - mka (OccPush env' () (SsPair sElt sElt))) - (mkb env') (mkc env') - - ESum1Inner _ e -> handleReduction (ESum1Inner ext) e - - EUnit _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' s' -> - occCountX s' e $ \env mke -> - k env $ \env' -> - EUnit ext (mke env') - - EReplicate1Inner _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX SsFull a $ \env1 mka -> - occCountX (SsArr s') b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReplicate1Inner ext (mka env') (mkb env') - - EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e - EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e - - EReshape _ n esh e -> - case s of - SsNone -> - occCountX SsNone esh $ \env1 mkesh -> - occCountX SsNone e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkesh env') $ use (mke env') $ ENil ext - SsArr' s' -> - occCountX SsFull esh $ \env1 mkesh -> - occCountX (SsArr s') e $ \env2 mke -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EReshape ext n (mkesh env') (mke env') - - EZip _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ mka env' - SsArr' (SsPair' SsNone s2) -> - occCountX SsNone a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ - emap (EPair ext (ENil ext) (evar IZ)) (mkb env') - SsArr' (SsPair' s1 SsNone) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mkb env') $ - emap (EPair ext (evar IZ) (ENil ext)) (mka env') - SsArr' (SsPair' s1 s2) -> - occCountX (SsArr s1) a $ \env1 mka -> - occCountX (SsArr s2) b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EZip ext (mka env') (mkb env') - - EFold1InnerD1 _ cm e1 e2 e3 -> - case s of - -- If nothing is necessary, we can execute a fold and then proceed to ignore it - SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex - -- If we don't need the stores, still a fold suffices - SsPair' sP SsNone -> - let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) - (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) - in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) - -- If for whatever reason the additional stores themselves are - -- unnecessary but the shape of the array is, then oblige - SsPair' sP (SsArr' SsNone) -> - let STArr sn _ = typeOf e3 - foldex = - elet (mapExt (\_ -> ext) e3) $ - EPair ext - (EShape ext (evar IZ)) - (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) - (mapExt (\_ -> ext) (weakenExpr WSink e2)) - (evar IZ)) - in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> - k env1 $ \env' -> - eunPair (mkfoldex env') $ \_ eshape earr -> - EPair ext earr (EBuild ext sn eshape (ENil ext)) - -- If at least some of the additional stores are required, we need to keep this a mapAccum - SsPair' _ (SsArr' sB) -> - -- TODO: propagate usage of primals - occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> - occEnvPop' env1_1' $ \env1' _ -> - occCountX SsFull e2 $ \env2 mkb -> - occCountX SsFull e3 $ \env3 mkc -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ - EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) - (mkb env') (mkc env') - - EFold1InnerD2 _ cm ef ebog ed -> - -- TODO: propagate usage of duals - occCountX SsFull ef $ \env1_2' mkef -> - occEnvPop' env1_2' $ \env1_1' _ -> - occEnvPop' env1_1' $ \env1' sB -> - occCountX (SsArr sB) ebog $ \env2 mkebog -> - occCountX SsFull ed $ \env3 mked -> - withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ - EFold1InnerD2 ext cm - (mkef (OccPush (OccPush env' () sB) () SsFull)) - (mkebog env') (mked env') - - EConst _ t x -> - k OccEnd $ \_ -> - case s of - SsNone -> ENil ext - SsFull -> EConst ext t x - - EIdx0 _ e -> - occCountX (SsArr s) e $ \env1 mke -> - k env1 $ \env' -> - EIdx0 ext (mke env') - - EIdx1 _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - SsArr' s' -> - occCountX (SsArr s') a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx1 ext (mka env') (mkb env') - - EIdx _ a b -> - case s of - SsNone -> - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - _ -> - occCountX (SsArr s) a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EIdx ext (mka env') (mkb env') - - EShape _ e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX (SsArr SsNone) e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EShape ext (mke env') - - EOp _ op e -> - case s of - SsNone -> - occCountX SsNone e $ \env1 mke -> - k env1 $ \env' -> - use (mke env') $ ENil ext - _ -> - occCountX SsFull e $ \env1 mke -> - k env1 $ \env' -> - projectSmallerSubstruc SsFull s $ EOp ext op (mke env') - - ECustom _ t1 t2 t3 e1 e2 e3 a b - | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> - error "Accumulators not allowed in input/output/tape of an ECustom" - | otherwise -> - case s of - SsNone -> - -- Allowed to ignore e1/e2/e3 here because no accumulators are - -- communicated, and hence no relevant effects exist - occCountX SsNone a $ \env1 mka -> - occCountX SsNone b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (mka env') $ use (mkb env') $ ENil ext - s' -> -- Let's be pessimistic for safety - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s' $ - ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') - - ERecompute _ e -> - occCountX s e $ \env1 mke -> - k env1 $ \env' -> - ERecompute ext (mke env') - - EWith _ t a b -> - case s of - SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with - occCountX SsNone b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - withSome (case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone) $ \s1' -> - occCountX s1' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ - ENil ext - SsPair sB sA -> - occCountX sB b $ \env2' mkb -> - occEnvPop' env2' $ \env2 s1 -> - let s1' = case s1 of - SsFull -> Some SsFull - SsAccum s' -> Some s' - SsNone -> Some SsNone in - withSome (Some sA <> s1') $ \sA' -> - occCountX sA' a $ \env1 mka -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ - EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) - SsFull -> occCountX (SsPair SsFull SsFull) topexpr k - - EAccum _ t p a sp b e -> - -- TODO: do better! - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> - occCountX SsFull e $ \env3 mke -> - withSome (Some env1 <> Some env2) $ \env12 -> - withSome (Some env12 <> Some env3) $ \env -> - k env $ \env' -> - case s of {SsFull -> id; SsNone -> id} $ - EAccum ext t p (mka env') sp (mkb env') (mke env') - - EZero _ t e -> - occCountX (subZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EZero ext (applySubstrucM s t) (mke env') - where - subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) - subZeroInfo SsFull = SsFull - subZeroInfo SsNone = SsNone - subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) - subZeroInfo SsEither{} = error "Either is not a monoid" - subZeroInfo SsLEither{} = SsNone - subZeroInfo SsMaybe{} = SsNone - subZeroInfo (SsArr s') = SsArr (subZeroInfo s') - subZeroInfo SsAccum{} = error "Accum is not a monoid" - - EDeepZero _ t e -> - occCountX (subDeepZeroInfo s) e $ \env1 mke -> - k env1 $ \env' -> - EDeepZero ext (applySubstrucM s t) (mke env') - where - subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) - subDeepZeroInfo SsFull = SsFull - subDeepZeroInfo SsNone = SsNone - subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo SsEither{} = error "Either is not a monoid" - subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) - subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') - subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') - subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" - - EPlus _ t a b -> - occCountX s a $ \env1 mka -> - occCountX s b $ \env2 mkb -> - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - EPlus ext (applySubstrucM s t) (mka env') (mkb env') - - EOneHot _ t p a b -> - occCountX SsFull a $ \env1 mka -> - occCountX SsFull b $ \env2 mkb -> -- TODO: do better - withSome (Some env1 <> Some env2) $ \env -> - k env $ \env' -> - projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') - - EError _ t msg -> - k OccEnd $ \_ -> EError ext (applySubstruc s t) msg - where - s = simplifySubstruc (typeOf topexpr) initialS - - handleReduction :: t ~ TArr n (TScal t2) - => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) - -> Expr x env (TArr (S n) (TScal t2)) - -> r - handleReduction reduce e - | STArr (SS n) _ <- typeOf e = - case s of - SsNone -> - occCountX SsNone e $ \env mke -> - k env $ \env' -> - use (mke env') $ ENil ext - SsArr' SsNone -> - occCountX (SsArr SsNone) e $ \env mke -> - k env $ \env' -> - elet (mke env') $ - EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) - SsArr' SsFull -> - occCountX (SsArr SsFull) e $ \env mke -> - k env $ \env' -> - reduce (mke env') - - -deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r -deleteUnused SNil (Some OccEnd) k = k SETop -deleteUnused (_ `SCons` env) (Some OccEnd) k = - deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) -deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = - deleteUnused env (Some occenv) $ \sub -> - case count of Zero -> k (SENo sub) - _ -> k (SEYesR sub) - -unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t -unsafeWeakenWithSubenv = \sub -> - subst (\x t i -> case sinkViaSubenv i sub of - Just i' -> EVar x t i' - Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") - where - sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) - sinkViaSubenv IZ (SEYesR _) = Just IZ - sinkViaSubenv IZ (SENo _) = Nothing - sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub - sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/AST/Env.hs b/src/AST/Env.hs deleted file mode 100644 index 85faba3..0000000 --- a/src/AST/Env.hs +++ /dev/null @@ -1,95 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module AST.Env where - -import Data.Type.Equality - -import AST.Sparse -import AST.Weaken -import CHAD.Types -import Data - - --- | @env'@ is a subset of @env@: each element of @env@ is either included in --- @env'@ ('SEYes') or not included in @env'@ ('SENo'). -data Subenv' s env env' where - SETop :: Subenv' s '[] '[] - SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') - SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' -deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') - -type Subenv = Subenv' (:~:) -type SubenvS = Subenv' Sparse - -pattern SEYesR :: forall tenv tenv'. () - => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') - => Subenv env env' -> Subenv tenv tenv' -pattern SEYesR s = SEYes Refl s - -{-# COMPLETE SETop, SEYesR, SENo #-} - -subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' -subList SNil SETop = SNil -subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) -subList (SCons _ xs) (SENo sub) = subList xs sub - -subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env -subenvAll SNil = SETop -subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) - -subenvNone :: SList f env -> Subenv' s env '[] -subenvNone SNil = SETop -subenvNone (SCons _ env) = SENo (subenvNone env) - -subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] -subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) -subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) -subenvOnehot SNil i _ = case i of {} - -subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 -subenvCompose SETop SETop = SETop -subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) -subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) -subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) - -subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') -subenvConcat sub1 SETop = sub1 -subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) -subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) - --- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 --- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r --- subenvSplit SNil sub k = k SETop sub --- subenvSplit (SCons _ list) (SENo sub) k = --- subenvSplit list sub $ \sub1 sub2 -> --- k (SENo sub1) sub2 --- subenvSplit (SCons _ list) (SEYes s sub) k = --- subenvSplit list sub $ \sub1 sub2 -> --- k (SEYes s sub1) sub2 - -sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 -sinkWithSubenv SETop = WId -sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub -sinkWithSubenv (SENo sub) = sinkWithSubenv sub - -wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env -wUndoSubenv SETop = WId -wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) -wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub - -subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' -subenvMap _ SNil SETop = SETop -subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) -subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) - -subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') -subenvD2E SETop = SETop -subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) -subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/AST/Pretty.hs b/src/AST/Pretty.hs deleted file mode 100644 index bbcfd9e..0000000 --- a/src/AST/Pretty.hs +++ /dev/null @@ -1,525 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where - -import Control.Monad (ap) -import Data.List (intersperse, intercalate) -import Data.Functor.Const -import qualified Data.Functor.Product as Product -import Data.String (fromString) -import Prettyprinter -import Prettyprinter.Render.String - -import qualified Data.Text.Lazy as TL -import qualified Prettyprinter.Render.Terminal as PT -import System.Console.ANSI (hSupportsANSI) -import System.IO (stdout) -import System.IO.Unsafe (unsafePerformIO) - -import AST -import AST.Count -import AST.Sparse.Types -import CHAD.Types -import Data - - -class PrettyX x where - prettyX :: x t -> String - - prettyXsuffix :: x t -> String - prettyXsuffix x = "<" ++ prettyX x ++ ">" - -instance PrettyX (Const ()) where - prettyX _ = "" - prettyXsuffix _ = "" - - -type SVal = SList (Const String) - -newtype M a = M { runM :: Int -> (a, Int) } - deriving (Functor) -instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } -instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } - -genId :: M Int -genId = M (\i -> (i, i + 1)) - -nameBaseForType :: STy t -> String -nameBaseForType STNil = "nil" -nameBaseForType (STPair{}) = "p" -nameBaseForType (STEither{}) = "e" -nameBaseForType (STMaybe{}) = "m" -nameBaseForType (STScal STI32) = "n" -nameBaseForType (STScal STI64) = "n" -nameBaseForType (STArr{}) = "a" -nameBaseForType (STAccum{}) = "ac" -nameBaseForType _ = "x" - -genName' :: String -> M String -genName' prefix = (prefix ++) . show <$> genId - -genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn' prefix ty idx ex - | occCount idx ex == mempty = case ty of STNil -> return "()" - _ -> return "_" - | otherwise = genName' prefix - --- TODO: let this return a type-tagged thing so that name environments are more typed than Const -genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String -genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t - -pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () -pprintExpr = putStrLn . ppExpr knownEnv - -ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String -ppExpr senv e = render $ fst . flip runM 1 $ do - val <- mkVal senv - e' <- ppExpr' 0 val e - let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." - return $ group $ flatAlt - (hang 2 $ - ppString lam - <> hardline <> e') - (ppString lam <+> e') - where - mkVal :: SList f env -> M (SVal env) - mkVal SNil = return SNil - mkVal (SCons _ v) = do - val <- mkVal v - name <- genName' "arg" - return (Const name `SCons` val) - -ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc -ppExpr' d val expr = case expr of - EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr - - e@ELet{} -> ppExprLet d val e - - EPair _ a b -> do - a' <- ppExpr' 0 val a - b' <- ppExpr' 0 val b - return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) - (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) - - EFst _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' - - ESnd _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' - - ENil _ -> return $ ppString "()" - - EInl _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' - - EInr _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' - - ECase _ e a b -> do - e' <- ppExpr' 0 val e - let STEither t1 t2 = typeOf e - name1 <- genNameIfUsedIn t1 IZ a - a' <- ppExpr' 0 (Const name1 `SCons` val) a - name2 <- genNameIfUsedIn t2 IZ b - b' <- ppExpr' 0 (Const name2 `SCons` val) b - return $ ppParen (d > 0) $ - hang 2 $ - annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") - <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' - <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' - - ENothing _ _ -> return $ ppString "Nothing" - - EJust _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' - - EMaybe _ a b e -> do - let STMaybe t = typeOf e - e' <- ppExpr' 0 val e - a' <- ppExpr' 0 val a - name <- genNameIfUsedIn t IZ b - b' <- ppExpr' 0 (Const name `SCons` val) b - return $ ppParen (d > 0) $ - align $ - group (flatAlt - (annotate AKey (ppString "case") <> ppX expr <+> e' - <> hardline <> annotate AKey (ppString "of")) - (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) - <> hardline - <> indent 2 - (ppString "Nothing" <+> ppString "->" <+> a' - <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') - - ELNil _ _ _ -> return (ppString "LNil") - - ELInl _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' - - ELInr _ _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' - - ELCase _ e a b c -> do - e' <- ppExpr' 0 val e - let STLEither t1 t2 = typeOf e - a' <- ppExpr' 11 val a - name1 <- genNameIfUsedIn t1 IZ b - b' <- ppExpr' 0 (Const name1 `SCons` val) b - name2 <- genNameIfUsedIn t2 IZ c - c' <- ppExpr' 0 (Const name2 `SCons` val) c - return $ ppParen (d > 0) $ - hang 2 $ - annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") - <> hardline <> ppString "LNil" <+> ppString "->" <+> a' - <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' - <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' - - EConstArr _ _ ty v - | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr - - EBuild _ n a b -> do - a' <- ppExpr' 11 val a - name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b - e' <- ppExpr' 0 (Const name `SCons` val) b - let primName = ppString ("build" ++ intSubscript (fromSNat n)) - return $ ppParen (d > 0) $ - group $ flatAlt - (hang 2 $ - annotate AHighlight primName <> ppX expr <+> a' - <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" - <> hardline <> e') - (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) - - EMap _ a b -> do - let STArr _ t1 = typeOf b - name <- genNameIfUsedIn t1 IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - return $ ppParen (d > 0) $ - ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] - - EFold1Inner _ cm a b c -> do - name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - c' <- ppExpr' 11 val c - let opname = "fold1i" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] - - ESum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' - - EUnit _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' - - EReplicate1Inner _ a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] - - EMaximum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' - - EMinimum1Inner _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' - - EReshape _ n esh e -> do - esh' <- ppExpr' 11 val esh - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] - - EZip _ e1 e2 -> do - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] - - EFold1InnerD1 _ cm a b c -> do - name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a - a' <- ppExpr' 0 (Const name `SCons` val) a - b' <- ppExpr' 11 val b - c' <- ppExpr' 11 val c - let opname = "fold1iD1" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] - - EFold1InnerD2 _ cm ef ebog ed -> do - let STArr _ tB = typeOf ebog - STArr _ t2 = typeOf ed - namef1 <- genNameIfUsedIn tB (IS IZ) ef - namef2 <- genNameIfUsedIn t2 IZ ef - ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef - ebog' <- ppExpr' 11 val ebog - ed' <- ppExpr' 11 val ed - let opname = "fold1iD2" ++ ppCommut cm - return $ ppParen (d > 10) $ - ppApp (annotate AHighlight (ppString opname) <> ppX expr) - [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] - - EConst _ ty v - | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr - - EIdx0 _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' - - EIdx1 _ a b -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 9 val b - return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' - - EIdx _ a b -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 10 val b - return $ ppParen (d > 8) $ - a' <+> ppString "!" <> ppX expr <+> b' - - EShape _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' - - EOp _ op (EPair _ a b) - | (Infix, ops) <- operator op -> do - a' <- ppExpr' 9 val a - b' <- ppExpr' 9 val b - return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' - - EOp _ op e -> do - e' <- ppExpr' 11 val e - let ops = case operator op of - (Infix, s) -> "(" ++ s ++ ")" - (Prefix, s) -> s - return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' - - ECustom _ t1 t2 t3 a b c e1 e2 -> do - en1 <- genNameIfUsedIn t1 (IS IZ) a - en2 <- genNameIfUsedIn t2 IZ a - pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b - pn2 <- genNameIfUsedIn (d1 t2) IZ b - dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c - dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c - a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a - b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b - c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - return $ ppParen (d > 10) $ - ppApp (ppString "custom" <> ppX expr) - [ppLam [ppString en1, ppString en2] a' - ,ppLam [ppString pn1, ppString pn2] b' - ,ppLam [ppString dn1, ppString dn2] c' - ,e1' - ,e2'] - - ERecompute _ e -> do - e' <- ppExpr' 11 val e - return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] - - EWith _ t e1 e2 -> do - e1' <- ppExpr' 11 val e1 - name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 - e2' <- ppExpr' 0 (Const name `SCons` val) e2 - return $ ppParen (d > 0) $ - group $ flatAlt - (hang 2 $ - annotate AWith (ppString "with") <> ppX expr <+> e1' - <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" - <> hardline <> e2') - (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) - - EAccum _ t prj e1 sp e2 e3 -> do - e1' <- ppExpr' 11 val e1 - e2' <- ppExpr' 11 val e2 - e3' <- ppExpr' 11 val e3 - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) - [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] - - EZero _ t e1 -> do - e1' <- ppExpr' 11 val e1 - return $ ppParen (d > 0) $ - annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - - EDeepZero _ t e1 -> do - e1' <- ppExpr' 11 val e1 - return $ ppParen (d > 0) $ - annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' - - EPlus _ t a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] - - EOneHot _ t prj a b -> do - a' <- ppExpr' 11 val a - b' <- ppExpr' 11 val b - return $ ppParen (d > 10) $ - ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] - - EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) - -ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc -ppExprLet d val etop = do - let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) - collect val' (ELet _ rhs body) = do - let occ = occCount IZ body - name <- genNameIfUsedIn (typeOf rhs) IZ body - rhs' <- ppExpr' 0 val' rhs - (binds, core) <- collect (Const name `SCons` val') body - return ((name, occ, rhs') : binds, core) - collect val' e = ([],) <$> ppExpr' 0 val' e - - (binds, core) <- collect val etop - - return $ ppParen (d > 0) $ - align $ - annotate AKey (ppString "let") - <+> align (mconcat $ intersperse hardline $ - map (\(name, _occ, rhs) -> - ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) - binds) - <> hardline <> annotate AKey (ppString "in") <+> core - -ppApp :: ADoc -> [ADoc] -> ADoc -ppApp fun args = group $ fun <+> align (sep args) - -ppLam :: [ADoc] -> ADoc -> ADoc -ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) - <> softline <> body <> ppString ")") - -ppAcPrj :: SMTy a -> SAcPrj p a b -> String -ppAcPrj _ SAPHere = "." -ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" -ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" -ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" -ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" -ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj -ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) - -ppSparse :: SMTy a -> Sparse a b -> String -ppSparse t sp | Just Refl <- isDense t sp = "D" -ppSparse _ SpAbsent = "A" -ppSparse t (SpSparse s) = "S" ++ ppSparse t s -ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" -ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" -ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s -ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s -ppSparse (SMTScal _) SpScal = "." - -ppCommut :: Commutative -> String -ppCommut Commut = "(C)" -ppCommut Noncommut = "" - -ppX :: PrettyX x => Expr x env t -> ADoc -ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) - -data Fixity = Prefix | Infix - deriving (Show) - -operator :: SOp a t -> (Fixity, String) -operator OAdd{} = (Infix, "+") -operator OMul{} = (Infix, "*") -operator ONeg{} = (Prefix, "negate") -operator OLt{} = (Infix, "<") -operator OLe{} = (Infix, "<=") -operator OEq{} = (Infix, "==") -operator ONot = (Prefix, "not") -operator OAnd = (Infix, "&&") -operator OOr = (Infix, "||") -operator OIf = (Prefix, "ifB") -operator ORound64 = (Prefix, "round") -operator OToFl64 = (Prefix, "toFl64") -operator ORecip{} = (Prefix, "recip") -operator OExp{} = (Prefix, "exp") -operator OLog{} = (Prefix, "log") -operator OIDiv{} = (Infix, "`div`") -operator OMod{} = (Infix, "`mod`") - -ppSTy :: Int -> STy t -> String -ppSTy d ty = render $ ppSTy' d ty - -ppSTy' :: Int -> STy t -> Doc q -ppSTy' _ STNil = ppString "1" -ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b -ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b -ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b -ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t -ppSTy' d (STArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t -ppSTy' _ (STScal sty) = ppString $ case sty of - STI32 -> "i32" - STI64 -> "i64" - STF32 -> "f32" - STF64 -> "f64" - STBool -> "bool" -ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t - -ppSMTy :: Int -> SMTy t -> String -ppSMTy d ty = render $ ppSMTy' d ty - -ppSMTy' :: Int -> SMTy t -> Doc q -ppSMTy' _ SMTNil = ppString "1" -ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b -ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b -ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t -ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ - ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t -ppSMTy' _ (SMTScal sty) = ppString $ case sty of - STI32 -> "i32" - STI64 -> "i64" - STF32 -> "f32" - STF64 -> "f64" - -ppString :: String -> Doc x -ppString = fromString - -ppParen :: Bool -> Doc x -> Doc x -ppParen True = parens -ppParen False = id - -intSubscript :: Int -> String -intSubscript = \case 0 -> "₀" - n | n < 0 -> '₋' : go (-n) "" - | otherwise -> go n "" - where go 0 suff = suff - go n suff = let (q, r) = n `quotRem` 10 - in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) - -data Annot = AKey | AWith | AHighlight | AMonoid | AExt - deriving (Show) - -annotToANSI :: Annot -> PT.AnsiStyle -annotToANSI AKey = PT.bold -annotToANSI AWith = PT.color PT.Red <> PT.underlined -annotToANSI AHighlight = PT.color PT.Blue -annotToANSI AMonoid = PT.color PT.Green -annotToANSI AExt = PT.colorDull PT.White - -type ADoc = Doc Annot - -render :: Doc Annot -> String -render = - (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI - else renderString) - . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } - where - stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/AST/Sparse.hs b/src/AST/Sparse.hs deleted file mode 100644 index 2a29799..0000000 --- a/src/AST/Sparse.hs +++ /dev/null @@ -1,287 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE RankNTypes #-} - -{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} -module AST.Sparse (module AST.Sparse, module AST.Sparse.Types) where - -import Data.Type.Equality - -import AST -import AST.Sparse.Types -import Data (SBool(..)) - - -sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' -sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext -sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 -sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh -sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = - eunPair e1 $ \w1 e1a e1b -> - eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> - EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) - (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) -sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = - elet e2 $ - elcase (weakenExpr WSink e1) - (evar IZ) - (elcase (evar (IS IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) - (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) - (elcase (evar (IS IZ)) - (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) - (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") - (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = - elet e2 $ - emaybe (weakenExpr WSink e1) - (evar IZ) - (emaybe (evar (IS IZ)) - (EJust ext (evar IZ)) - (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) -sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 -sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 - - -cheapZero :: SMTy t -> Maybe (forall env. Ex env t) -cheapZero SMTNil = Just (ENil ext) -cheapZero (SMTPair t1 t2) - | Just e1 <- cheapZero t1 - , Just e2 <- cheapZero t2 - = Just (EPair ext e1 e2) - | otherwise - = Nothing -cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) -cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) -cheapZero SMTArr{} = Nothing -cheapZero (SMTScal t) = case t of - STI32 -> Just (EConst ext t 0) - STI64 -> Just (EConst ext t 0) - STF32 -> Just (EConst ext t 0.0) - STF64 -> Just (EConst ext t 0.0) - - -data Injection sp a b where - -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that - -- 'sparsePlusS' can provide injections even if the caller doesn't require - -- them. This simplifies the sparsePlusS code. - Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b - Noinj :: Injection False a b - -withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' -withInj (Inj f) k = Inj (k f) -withInj Noinj _ = Noinj - -withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 - -> ((forall e. Ex e a1 -> Ex e b1) - -> (forall e. Ex e a2 -> Ex e b2) - -> (forall e'. Ex e' a' -> Ex e' b')) - -> Injection sp a' b' -withInj2 (Inj f) (Inj g) k = Inj (k f g) -withInj2 Noinj _ _ = Noinj -withInj2 _ Noinj _ = Noinj - --- | This function produces quadratically-sized code in the presence of nested --- dynamic sparsity. TODO can this be improved? -sparsePlusS - :: SBool inj1 -> SBool inj2 - -> SMTy t -> Sparse t t1 -> Sparse t t2 - -> (forall t3. Sparse t t3 - -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) - -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) - -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) - -> r) - -> r --- nil override (but don't destroy effects!) -sparsePlusS _ _ SMTNil _ _ k = - k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) - --- simplifications -sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = - sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> - k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) -sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = - sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> - k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) - -sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = - let ta = applySparse sp1 (fromSMTy t) in - sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = - let tb = applySparse sp2 (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = - let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in - sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - minj2 - (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = - let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) - -sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = - let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in - sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> - k sp3 - (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) - minj2 - (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) -sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = - let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in - sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> - k sp3 - minj1 - (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) - (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) -sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k -sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k - --- TODO: sparse of Just is just Maybe - --- dense plus -sparsePlusS _ _ t sp1 sp2 k - | Just Refl <- isDense t sp1 - , Just Refl <- isDense t sp2 - = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) - --- handle absents -sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) -sparsePlusS ST _ t SpAbsent sp2 k - | Just zero2 <- cheapZero (applySparse sp2 t) = - k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) - | otherwise = - k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) - -sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) -sparsePlusS _ ST t sp1 SpAbsent k - | Just zero1 <- cheapZero (applySparse sp1 t) = - k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) - | otherwise = - k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) - --- double sparse yields sparse -sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- single sparse can yield non-sparse if the other argument is always present -sparsePlusS SF _ t (SpSparse sp1) sp2 k = - sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> - k sp3 Noinj (Inj inj2) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (inj2 (evar IZ)) - (plus (evar IZ) (evar (IS IZ)))) -sparsePlusS ST _ t (SpSparse sp1) sp2 k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpSparse sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> EJust ext (inj2 b)) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (EJust ext (inj2 (evar IZ))) - (EJust ext (plus (evar IZ) (evar (IS IZ))))) -sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = - sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> - k sp3 inj2 inj1 (flip plus) - --- products -sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = - sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> - sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> - k (SpPair sp3a sp3b) - (withInj2 minj13a minj13b $ \inj13a inj13b -> - \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) - (withInj2 minj23a minj23b $ \inj23a inj23b -> - \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) - (\x1 x2 -> - eunPair x1 $ \w1 x1a x1b -> - eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> - EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) - --- coproducts -sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = - sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> - sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> - let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) - inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) - inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) - in - k (SpLEither sp3a sp3b) - (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) - (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) - (\x1 x2 -> - elet x2 $ - elcase (weakenExpr WSink x1) - (elcase (evar IZ) - nil - (inl (inj23a (evar IZ))) - (inr (inj23b (evar IZ)))) - (elcase (evar (IS IZ)) - (inl (inj13a (evar IZ))) - (inl (plusa (evar (IS IZ)) (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) - (elcase (evar (IS IZ)) - (inr (inj13b (evar IZ))) - (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") - (inr (plusb (evar (IS IZ)) (evar IZ))))) - --- maybe -sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = - sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> - k (SpMaybe sp3) - (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) - (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) - (\a b -> - elet b $ - emaybe (weakenExpr WSink a) - (emaybe (evar IZ) - (ENothing ext (applySparse sp3 (fromSMTy t))) - (EJust ext (inj2 (evar IZ)))) - (emaybe (evar (IS IZ)) - (EJust ext (inj1 (evar IZ))) - (EJust ext (plus (evar (IS IZ)) (evar IZ))))) - --- dense array cotangents simply recurse -sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> - k (SpArr sp3) - (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) - (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) - (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) - --- scalars -sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/AST/Sparse/Types.hs b/src/AST/Sparse/Types.hs deleted file mode 100644 index 10cac4e..0000000 --- a/src/AST/Sparse/Types.hs +++ /dev/null @@ -1,107 +0,0 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module AST.Sparse.Types where - -import AST.Types - -import Data.Kind (Type, Constraint) -import Data.Type.Equality - - -data Sparse t t' where - SpSparse :: Sparse t t' -> Sparse t (TMaybe t') - SpAbsent :: Sparse t TNil - - SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') - SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') - SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') - SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') - SpScal :: Sparse (TScal t) (TScal t) -deriving instance Show (Sparse t t') - -class ApplySparse f where - applySparse :: Sparse t t' -> f t -> f t' - -instance ApplySparse STy where - applySparse (SpSparse s) t = STMaybe (applySparse s t) - applySparse SpAbsent _ = STNil - applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) - applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) - applySparse SpScal t = t - -instance ApplySparse SMTy where - applySparse (SpSparse s) t = SMTMaybe (applySparse s t) - applySparse SpAbsent _ = SMTNil - applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) - applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) - applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) - applySparse SpScal t = t - - -class IsSubType s where - type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint - subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' - subtTrans :: s a b -> s b c -> s a c - subtFull :: IsSubTypeSubject s f => f t -> s t t - -instance IsSubType (:~:) where - type IsSubTypeSubject (:~:) f = () - subtApply = gcastWith - subtTrans = trans - subtFull _ = Refl - -instance IsSubType Sparse where - type IsSubTypeSubject Sparse f = f ~ SMTy - subtApply = applySparse - - subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) - subtTrans _ SpAbsent = SpAbsent - subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) - subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) - subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) - subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) - subtTrans SpScal SpScal = SpScal - - subtFull = spDense - -spDense :: SMTy t -> Sparse t t -spDense SMTNil = SpAbsent -spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) -spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) -spDense (SMTMaybe t) = SpMaybe (spDense t) -spDense (SMTArr _ t) = SpArr (spDense t) -spDense (SMTScal _) = SpScal - -isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') -isDense SMTNil SpAbsent = Just Refl -isDense _ SpSparse{} = Nothing -isDense _ SpAbsent = Nothing -isDense (SMTPair t1 t2) (SpPair s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTLEither t1 t2) (SpLEither s1 s2) - | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl - | otherwise = Nothing -isDense (SMTMaybe t) (SpMaybe s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTArr _ t) (SpArr s) - | Just Refl <- isDense t s = Just Refl - | otherwise = Nothing -isDense (SMTScal _) SpScal = Just Refl - -isAbsent :: Sparse t t' -> Bool -isAbsent (SpSparse s) = isAbsent s -isAbsent SpAbsent = True -isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 -isAbsent (SpMaybe s) = isAbsent s -isAbsent (SpArr s) = isAbsent s -isAbsent SpScal = False diff --git a/src/AST/SplitLets.hs b/src/AST/SplitLets.hs deleted file mode 100644 index 267dd87..0000000 --- a/src/AST/SplitLets.hs +++ /dev/null @@ -1,191 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module AST.SplitLets (splitLets) where - -import Data.Type.Equality - -import AST -import AST.Bindings -import Lemmas - - -splitLets :: Ex env t -> Ex env t -splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) - -splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t -splitLets' = \sub -> \case - EVar _ t i -> sub t i WId - ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) - ECase x e a b -> - let STEither t1 t2 = typeOf e - in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) - EMaybe x a b e -> - let STMaybe t1 = typeOf e - in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) - ELCase x e a b c -> - let STLEither t1 t2 = typeOf e - in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) - EFold1Inner x cm a b c -> - let STArr _ t1 = typeOf c - in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD1 x cm a b c -> - let STArr _ t1 = typeOf c - in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) - EFold1InnerD2 x cm a b c -> - let STArr _ tB = typeOf b - STArr _ t2 = typeOf c - in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) - - EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) - EFst x e -> EFst x (splitLets' sub e) - ESnd x e -> ESnd x (splitLets' sub e) - ENil x -> ENil x - EInl x t e -> EInl x t (splitLets' sub e) - EInr x t e -> EInr x t (splitLets' sub e) - ENothing x t -> ENothing x t - EJust x e -> EJust x (splitLets' sub e) - ELNil x t1 t2 -> ELNil x t1 t2 - ELInl x t e -> ELInl x t (splitLets' sub e) - ELInr x t e -> ELInr x t (splitLets' sub e) - EConstArr x n t a -> EConstArr x n t a - EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) - EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) - ESum1Inner x e -> ESum1Inner x (splitLets' sub e) - EUnit x e -> EUnit x (splitLets' sub e) - EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) - EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) - EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) - EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) - EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) - EConst x t v -> EConst x t v - EIdx0 x e -> EIdx0 x (splitLets' sub e) - EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) - EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) - EShape x e -> EShape x (splitLets' sub e) - EOp x op e -> EOp x op (splitLets' sub e) - ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) - ERecompute x e -> ERecompute x (splitLets' sub e) - EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) - EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) - EZero x t ezi -> EZero x t (splitLets' sub ezi) - EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) - EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) - EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) - EError x t s -> EError x t s - where - sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t - sinkF _ t IZ w = EVar ext t (w @> IZ) - sinkF f t (IS i) w = f t i (w .> WSink) - - split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t - split1 sub (tbind :: STy bind) body = - let (ptrs, bs) = split tbind - in letBinds bs $ - splitLets' (\cases _ IZ w -> subPointers ptrs w - t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) - body - - split2 :: forall bind1 bind2 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t - split2 sub tbind1 tbind2 body = - let (ptrs1', bs1') = split @env' tbind1 - bs1 = fst (weakenBindingsE WSink bs1') - (ptrs2, bs2) = split @(bind1 : env') tbind2 - in letBinds bs1 $ - letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ - splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) - _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) - t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) - body - - -- TODO: abstract this to splitN lol wtf - _split4 :: forall bind1 bind2 bind3 bind4 env' env t. - (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) - -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t - _split4 sub tbind1 tbind2 tbind3 tbind4 body = - let (ptrs1, bs1') = split @env' tbind1 - (ptrs2, bs2') = split @(bind1 : env') tbind2 - (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 - (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 - bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') - bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') - bs3 = fst (weakenBindingsE WSink bs3') - b1 = bindingsBinds bs1 - b2 = bindingsBinds bs2 - b3 = bindingsBinds bs3 - b4 = bindingsBinds bs4 - in letBinds bs1 $ - letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ - letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ - letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ - splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) - _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) - _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) - _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) - t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) - body - -type family Split t where - Split (TPair a b) = SplitRec (TPair a b) - Split _ = '[] - -type family SplitRec t where - SplitRec TNil = '[] - SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) - SplitRec t = '[t] - -data Pointers env t where - Point :: STy t -> Idx env t -> Pointers env t - PNil :: Pointers env TNil - PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) - PWeak :: env' :> env -> Pointers env' t -> Pointers env t - -subPointers :: Pointers env t -> env :> env' -> Ex env' t -subPointers (Point t i) w = EVar ext t (w @> i) -subPointers PNil _ = ENil ext -subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) -subPointers (PWeak w' p) w = subPointers p (w .> w') - -split :: forall env t. STy t - -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) -split typ = case typ of - STPair{} -> splitRec (EVar ext typ IZ) typ - STNil -> other - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) - other = (Point typ IZ, BTop) - -splitRec :: forall env t. Ex env t -> STy t - -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) -splitRec rhs typ = case typ of - STNil -> (PNil, BTop) - STPair (a :: STy a) (b :: STy b) - | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> - let (p1, bs1) = splitRec (EFst ext rhs) a - (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b - in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) - STEither{} -> other - STLEither{} -> other - STMaybe{} -> other - STArr{} -> other - STScal{} -> other - STAccum{} -> other - where - other :: (Pointers (t : env) t, Bindings Ex env '[t]) - other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/AST/Types.hs b/src/AST/Types.hs deleted file mode 100644 index 4ddcb50..0000000 --- a/src/AST/Types.hs +++ /dev/null @@ -1,215 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeData #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module AST.Types where - -import Data.Int (Int32, Int64) -import Data.GADT.Compare -import Data.GADT.Show -import Data.Kind (Type) -import Data.Type.Equality - -import Data - - -type data Ty - = TNil - | TPair Ty Ty - | TEither Ty Ty - | TLEither Ty Ty - | TMaybe Ty - | TArr Nat Ty -- ^ rank, element type - | TScal ScalTy - | TAccum Ty -- ^ contained type must be a monoid type - -type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool - -type STy :: Ty -> Type -data STy t where - STNil :: STy TNil - STPair :: STy a -> STy b -> STy (TPair a b) - STEither :: STy a -> STy b -> STy (TEither a b) - STLEither :: STy a -> STy b -> STy (TLEither a b) - STMaybe :: STy a -> STy (TMaybe a) - STArr :: SNat n -> STy t -> STy (TArr n t) - STScal :: SScalTy t -> STy (TScal t) - STAccum :: SMTy t -> STy (TAccum t) -deriving instance Show (STy t) - -instance GCompare STy where - gcompare = \cases - STNil STNil -> GEQ - STNil _ -> GLT ; _ STNil -> GGT - (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STPair{} _ -> GLT ; _ STPair{} -> GGT - (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STEither{} _ -> GLT ; _ STEither{} -> GGT - (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - STLEither{} _ -> GLT ; _ STLEither{} -> GGT - (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') - STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT - (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') - STArr{} _ -> GLT ; _ STArr{} -> GGT - (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') - STScal{} _ -> GLT ; _ STScal{} -> GGT - (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') - -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT - -instance TestEquality STy where testEquality = geq -instance GEq STy where geq = defaultGeq -instance GShow STy where gshowsPrec = defaultGshowsPrec - --- | Monoid types -type SMTy :: Ty -> Type -data SMTy t where - SMTNil :: SMTy TNil - SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) - SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) - SMTMaybe :: SMTy a -> SMTy (TMaybe a) - SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) - SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) -deriving instance Show (SMTy t) - -instance GCompare SMTy where - gcompare = \cases - SMTNil SMTNil -> GEQ - SMTNil _ -> GLT ; _ SMTNil -> GGT - (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT - (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') - SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT - (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') - SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT - (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') - SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT - (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') - -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT - -instance TestEquality SMTy where testEquality = geq -instance GEq SMTy where geq = defaultGeq -instance GShow SMTy where gshowsPrec = defaultGshowsPrec - -fromSMTy :: SMTy t -> STy t -fromSMTy = \case - SMTNil -> STNil - SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) - SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) - SMTMaybe t -> STMaybe (fromSMTy t) - SMTArr n t -> STArr n (fromSMTy t) - SMTScal sty -> STScal sty - -data SScalTy t where - STI32 :: SScalTy TI32 - STI64 :: SScalTy TI64 - STF32 :: SScalTy TF32 - STF64 :: SScalTy TF64 - STBool :: SScalTy TBool -deriving instance Show (SScalTy t) - -instance GCompare SScalTy where - gcompare = \cases - STI32 STI32 -> GEQ - STI32 _ -> GLT ; _ STI32 -> GGT - STI64 STI64 -> GEQ - STI64 _ -> GLT ; _ STI64 -> GGT - STF32 STF32 -> GEQ - STF32 _ -> GLT ; _ STF32 -> GGT - STF64 STF64 -> GEQ - STF64 _ -> GLT ; _ STF64 -> GGT - STBool STBool -> GEQ - -- STBool _ -> GLT ; _ STBool -> GGT - -instance TestEquality SScalTy where testEquality = geq -instance GEq SScalTy where geq = defaultGeq -instance GShow SScalTy where gshowsPrec = defaultGshowsPrec - -scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) -scalRepIsShow STI32 = Dict -scalRepIsShow STI64 = Dict -scalRepIsShow STF32 = Dict -scalRepIsShow STF64 = Dict -scalRepIsShow STBool = Dict - -type TIx = TScal TI64 - -tIx :: STy TIx -tIx = STScal STI64 - -type family ScalRep t where - ScalRep TI32 = Int32 - ScalRep TI64 = Int64 - ScalRep TF32 = Float - ScalRep TF64 = Double - ScalRep TBool = Bool - -type family ScalIsNumeric t where - ScalIsNumeric TI32 = True - ScalIsNumeric TI64 = True - ScalIsNumeric TF32 = True - ScalIsNumeric TF64 = True - ScalIsNumeric TBool = False - -type family ScalIsFloating t where - ScalIsFloating TI32 = False - ScalIsFloating TI64 = False - ScalIsFloating TF32 = True - ScalIsFloating TF64 = True - ScalIsFloating TBool = False - -type family ScalIsIntegral t where - ScalIsIntegral TI32 = True - ScalIsIntegral TI64 = True - ScalIsIntegral TF32 = False - ScalIsIntegral TF64 = False - ScalIsIntegral TBool = False - --- | Returns true for arrays /and/ accumulators. -typeHasArrays :: STy t' -> Bool -typeHasArrays STNil = False -typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b -typeHasArrays (STMaybe t) = typeHasArrays t -typeHasArrays STArr{} = True -typeHasArrays STScal{} = False -typeHasArrays STAccum{} = True - -typeHasAccums :: STy t' -> Bool -typeHasAccums STNil = False -typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b -typeHasAccums (STMaybe t) = typeHasAccums t -typeHasAccums STArr{} = False -typeHasAccums STScal{} = False -typeHasAccums STAccum{} = True - -type family Tup env where - Tup '[] = TNil - Tup (t : ts) = TPair (Tup ts) t - -mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) - -> SList f list -> f (Tup list) -mkTup nil _ SNil = nil -mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e - -tTup :: SList STy env -> STy (Tup env) -tTup = mkTup STNil STPair - -unTup :: (forall a b. c (TPair a b) -> (c a, c b)) - -> SList f list -> c (Tup list) -> SList c list -unTup _ SNil _ = SNil -unTup unpack (_ `SCons` list) tup = - let (xs, x) = unpack tup - in x `SCons` unTup unpack list xs - -type family InvTup core env where - InvTup core '[] = core - InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/AST/UnMonoid.hs b/src/AST/UnMonoid.hs deleted file mode 100644 index 1712ba5..0000000 --- a/src/AST/UnMonoid.hs +++ /dev/null @@ -1,255 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE TypeOperators #-} -module AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where - -import AST -import AST.Sparse.Types -import Data - - --- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by --- expanding them into their concrete implementations. Also ensure that --- 'EAccum' has a dense sparsity. -unMonoid :: Ex env t -> Ex env t -unMonoid = \case - EZero _ t e -> zero t e - EDeepZero _ t e -> deepZero t e - EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) - EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) - - EVar _ t i -> EVar ext t i - ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) - EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) - EFst _ e -> EFst ext (unMonoid e) - ESnd _ e -> ESnd ext (unMonoid e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext t (unMonoid e) - EInr _ t e -> EInr ext t (unMonoid e) - ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) - ENothing _ t -> ENothing ext t - EJust _ e -> EJust ext (unMonoid e) - EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) - ELNil _ t1 t2 -> ELNil ext t1 t2 - ELInl _ t e -> ELInl ext t (unMonoid e) - ELInr _ t e -> ELInr ext t (unMonoid e) - ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) - EConstArr _ n t x -> EConstArr ext n t x - EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) - EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) - ESum1Inner _ e -> ESum1Inner ext (unMonoid e) - EUnit _ e -> EUnit ext (unMonoid e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) - EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) - EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) - EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) - EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) - EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) - EConst _ t x -> EConst ext t x - EIdx0 _ e -> EIdx0 ext (unMonoid e) - EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) - EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) - EShape _ e -> EShape ext (unMonoid e) - EOp _ op e -> EOp ext op (unMonoid e) - ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) - ERecompute _ e -> ERecompute ext (unMonoid e) - EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) - EAccum _ t p eidx sp eval eacc -> - accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> - acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> - EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) - EError _ t s -> EError ext t s - -zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t --- don't destroy the effects! -zero SMTNil e = ELet ext e $ ENil ext -zero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) -zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) -zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e -zero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t -deepZero SMTNil e = elet e $ ENil ext -deepZero (SMTPair t1 t2) e = - ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) - (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) -deepZero (SMTLEither t1 t2) e = - elcase e - (ELNil ext (fromSMTy t1) (fromSMTy t2)) - (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) - (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) -deepZero (SMTMaybe t) e = - emaybe e - (ENothing ext (fromSMTy t)) - (EJust ext (deepZero t (evar IZ))) -deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e -deepZero (SMTScal t) _ = case t of - STI32 -> EConst ext STI32 0 - STI64 -> EConst ext STI64 0 - STF32 -> EConst ext STF32 0.0 - STF64 -> EConst ext STF64 0.0 - -plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t --- don't destroy the effects! -plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext -plus (SMTPair t1 t2) a b = - let t = STPair (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) - (EFst ext (EVar ext t IZ))) - (plus t2 (ESnd ext (EVar ext t (IS IZ))) - (ESnd ext (EVar ext t IZ))) -plus (SMTLEither t1 t2) a b = - let t = STLEither (fromSMTy t1) (fromSMTy t2) - in ELet ext a $ - ELet ext (weakenExpr WSink b) $ - ELCase ext (EVar ext t (IS IZ)) - (EVar ext t IZ) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) - (EError ext t "plus l+r")) - (ELCase ext (EVar ext t (IS IZ)) - (EVar ext t (IS (IS IZ))) - (EError ext t "plus r+l") - (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) -plus (SMTMaybe t) a b = - ELet ext b $ - EMaybe ext - (EVar ext (STMaybe (fromSMTy t)) IZ) - (EJust ext - (EMaybe ext - (EVar ext (fromSMTy t) IZ) - (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) - (weakenExpr WSink a) -plus (SMTArr _ t) a b = - ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) - a b -plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) - -onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t -onehot typ topprj idx arg = case (typ, topprj) of - (_, SAPHere) -> - ELet ext arg $ - EVar ext (fromSMTy typ) IZ - - (SMTPair t1 t2, SAPFst prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t1 in - EPair ext (EVar ext toh IZ) - (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) - - (SMTPair t1 t2, SAPSnd prj) -> - ELet ext idx $ - let tidx = typeOf idx in - ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ - let toh = fromSMTy t2 in - EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) - (EVar ext toh IZ) - - (SMTLEither t1 t2, SAPLeft prj) -> - ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) - (SMTLEither t1 t2, SAPRight prj) -> - ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) - - (SMTMaybe t1, SAPJust prj) -> - EJust ext (onehot t1 prj idx arg) - - (SMTArr n t1, SAPArrIdx prj) -> - let tidx = tTup (sreplicate n tIx) - in ELet ext idx $ - EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ - eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) - (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) - (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ - zero t1 (EVar ext (tZeroInfo t1) IZ)) - -accumulateSparse - :: SMTy t -> Sparse t t' -> Ex env t' - -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) - -> Ex env TNil -accumulateSparse topty topsp arg accum = case (topty, topsp) of - (_, s) | Just Refl <- isDense topty s -> - accum WId SAPHere (ENil ext) arg - (SMTScal _, SpScal) -> - accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh - (_, SpSparse s) -> - emaybe arg - (ENil ext) - (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) - (_, SpAbsent) -> - ENil ext - (SMTPair t1 t2, SpPair s1 s2) -> - eunPair arg $ \w1 e1 e2 -> - elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ - accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) - (SMTLEither t1 t2, SpLEither s1 s2) -> - elcase arg - (ENil ext) - (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) - (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) - (SMTMaybe t, SpMaybe s) -> - emaybe arg - (ENil ext) - (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) - (SMTArr n t, SpArr s) -> - let tn = tTup (sreplicate n tIx) in - elet arg $ - elet (EBuild ext n (EShape ext (evar IZ)) $ - accumulateSparse t s - (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) - (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ - ENil ext - -acPrjCompose - :: SAIDense dense - -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) - -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) - -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r -acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 -acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPFst p') idx' -acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = - acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPSnd p') idx' -acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) -acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') -acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPLeft p') idx' -acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPRight p') idx' -acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = - acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> - k (SAPJust p') idx' -acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') -acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k - | Dict <- styKnown (typeOf idx1) = - acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> - k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/AST/Weaken.hs b/src/AST/Weaken.hs deleted file mode 100644 index f0820b8..0000000 --- a/src/AST/Weaken.hs +++ /dev/null @@ -1,138 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} - -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} - --- The reason why this is a separate module with "little" in it: -{-# LANGUAGE AllowAmbiguousTypes #-} - -module AST.Weaken (module AST.Weaken, Append) where - -import Data.Bifunctor (first) -import Data.Functor.Const -import Data.GADT.Compare -import Data.Kind (Type) - -import Data -import Lemmas - - -type Idx :: [k] -> k -> Type -data Idx env t where - IZ :: Idx (t : env) t - IS :: Idx env t -> Idx (a : env) t -deriving instance Show (Idx env t) - -instance GEq (Idx env) where - geq IZ IZ = Just Refl - geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl - geq _ _ = Nothing - -splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) -splitIdx SNil i = Right i -splitIdx (SCons _ _) IZ = Left IZ -splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) - -slistIdx :: SList f list -> Idx list t -> f t -slistIdx (SCons x _) IZ = x -slistIdx (SCons _ list) (IS i) = slistIdx list i -slistIdx SNil i = case i of {} - -idx2int :: Idx env t -> Int -idx2int IZ = 0 -idx2int (IS n) = 1 + idx2int n - -data env :> env' where - WId :: env :> env - WSink :: forall t env. env :> (t : env) - WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') - WPop :: (t : env) :> env' -> env :> env' - WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 - WClosed :: '[] :> env - WIdx :: Idx env t -> (t : env) :> env - WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' - -> Append pre (t : env) :> t : Append pre env' - WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs - -> Append as (Append bs env) :> Append bs (Append as env) - WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs - -> as :> bs -> env1 :> env2 - -> Append as env1 :> Append bs env2 -deriving instance Show (env :> env') -infix 4 :> - -infixr 2 @> -(@>) :: env :> env' -> Idx env t -> Idx env' t -WId @> i = i -WSink @> i = IS i -WCopy _ @> IZ = IZ -WCopy w @> IS i = IS (w @> i) -WPop w @> i = w @> IS i -WThen w1 w2 @> i = w2 @> w1 @> i -WClosed @> i = case i of {} -WIdx j @> IZ = j -WIdx _ @> IS i = i -WPick SNil w @> i = WCopy w @> i -WPick (_ `SCons` _) _ @> IZ = IS IZ -WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i -WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i = - case splitIdx @(Append bs env) as i of - Left i' -> indexSinks bs (indexRaiseAbove @env as i') - Right i' -> case splitIdx @env bs i' of - Left j -> indexRaiseAbove @(Append as env) bs j - Right j -> indexSinks bs (indexSinks as j) -WStack @env1 @env2 as bs wlo whi @> i = - case splitIdx @env1 as i of - Left i' -> indexRaiseAbove @env2 bs (wlo @> i') - Right i' -> indexSinks bs (whi @> i') - -indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t -indexSinks SNil j = j -indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j) - -indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t -indexRaiseAbove = flip go - where - go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t - go IZ (_ `SCons` _) = IZ - go (IS i) (_ `SCons` as) = IS (go i as) - -infixr 3 .> -(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 -(.>) = flip WThen - -class KnownListSpine list where knownListSpine :: SList (Const ()) list -instance KnownListSpine '[] where knownListSpine = SNil -instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine - -wSinks' :: forall list env. KnownListSpine list => env :> Append list env -wSinks' = wSinks (knownListSpine :: SList (Const ()) list) - -wSinks :: forall env bs f. SList f bs -> env :> Append bs env -wSinks SNil = WId -wSinks (SCons _ spine) = WSink .> wSinks spine - -wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env' -wSinksAnd SNil w = w -wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w - -wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 -wCopies bs w = - let bs' = slistMap (\_ -> Const ()) bs - in WStack bs' bs' WId w - -wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env -wRaiseAbove SNil _ = WClosed -wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) - -wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2 -wPops SNil w = w -wPops (_ `SCons` bs) w = wPops bs (WPop w) diff --git a/src/AST/Weaken/Auto.hs b/src/AST/Weaken/Auto.hs deleted file mode 100644 index 7370df1..0000000 --- a/src/AST/Weaken/Auto.hs +++ /dev/null @@ -1,192 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE FunctionalDependencies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeAbstractions #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - -{-# LANGUAGE AllowAmbiguousTypes #-} - -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS_GHC -Wno-partial-type-signatures #-} -module AST.Weaken.Auto ( - autoWeak, - (&.), auto, auto1, - Layout(..), -) where - -import Data.Functor.Const -import Data.Kind (Constraint) -import GHC.OverloadedLabels -import GHC.TypeLits -import Unsafe.Coerce (unsafeCoerce) - -import AST.Weaken -import Data -import Lemmas - - -type family Lookup name list where - Lookup name ('(name, x) : _) = x - Lookup name (_ : list) = Lookup name list - Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.") - - --- | The @withPre@ type parameter indicates whether there can be 'LPreW' --- occurrences within this layout. 'names' is the list of names that this --- layout /produces/. That is: for LPreW, it contains the target name. The --- 'names' list of a source layout must be a subset of the names list of the --- target layout (which cannot contain LPreW); this is checked with SubLayout. -data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where - LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) - -- | Pre-weaken with a weakening - LPreW :: forall name1 name2 segments. - SegmentName name1 -> SegmentName name2 - -> Lookup name1 segments :> Lookup name2 segments - -> Layout True segments '[name2] (Lookup name1 segments) - (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) -infixr :++: - -instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where - fromLabel = LSeg (symbolSing @name) - -newtype SegmentName name = SegmentName (SSymbol name) - deriving (Show) - -instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where - fromLabel = SegmentName symbolSing - - -type family SubLayout names1 names2 where - SubLayout '[] _ = () :: Constraint - SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 -type family SubLayout' n ok names1 names2 where - SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") - SubLayout' _ True names1 names2 = SubLayout names1 names2 -type family Contains n names where - Contains _ '[] = False - Contains n (n : _) = True - Contains n (_ : names) = Contains n names - - -data SSegments (segments :: [(Symbol, [t])]) where - SSegNil :: SSegments '[] - SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) - -instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where - fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil - -auto :: KnownListSpine list => SList (Const ()) list -auto = knownListSpine - -auto1 :: SList (Const ()) '[t] -auto1 = Const () `SCons` SNil - -infixr &. -(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) -(&.) = ssegmentsAppend - where - ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) - ssegmentsAppend SSegNil l2 = l2 - ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) - - --- | If the found segment is a TopSeg, returns Nothing. -segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) -segmentLookup = \segs name -> case go segs name of - Just ts -> ts - Nothing -> error $ "Segment not found: " ++ fromSSymbol name - where - go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) - go SSegNil _ = Nothing - go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = - case sameSymbol n name of - Just Refl -> - case go sseg name of - Nothing -> Just ts - Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name - Nothing -> - case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of - Refl -> go sseg name - -data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where - LinEnd :: LinLayout withPre segments '[] - LinApp :: SSymbol name -> LinLayout withPre segments env - -> LinLayout withPre segments (Append (Lookup name segments) env) - LinAppPreW :: SSymbol name1 -> SSymbol name2 - -> Lookup name1 segments :> Lookup name2 segments - -> LinLayout True segments env - -> LinLayout True segments (Append (Lookup name1 segments) env) - -linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2) -linLayoutAppend LinEnd lin = lin -linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) - | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 - = LinApp name (linLayoutAppend lin1 lin2) -linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) - | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 - = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) - -lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env -lineariseLayout (LSeg name :: Layout _ _ _ seg) - | Refl <- lemAppendNil @seg - = LinApp name LinEnd -lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 -lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) - | Refl <- lemAppendNil @seg - = LinAppPreW name1 name2 w LinEnd - -preWeaken :: SSegments segments -> LinLayout True segments env - -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r -preWeaken _ LinEnd k = k WId LinEnd -preWeaken segs (LinApp name lin) k = - preWeaken segs lin $ \w lin' -> - k (wCopies (segmentLookup segs name) w) (LinApp name lin') -preWeaken segs (LinAppPreW name1 name2 weak lin) k = - preWeaken segs lin $ \w lin' -> - k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin') - -pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env - -> r -- Name was not found in source - -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r) - -> r -pullDown segs name@SSymbol linlayout kNotFound k = - case linlayout of - LinEnd -> kNotFound - LinApp n'@SSymbol lin - | Just Refl <- sameSymbol name n' -> k lin WId - | otherwise -> - pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w -> - k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) - .> wCopies (segmentLookup segs n') w) - -sortLinLayouts :: SSegments segments - -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 -sortLinLayouts _ LinEnd LinEnd = WId -sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) - | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) - | otherwise = - pullDown segs name2 lin1 - (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) - (\tail1' w -> - -- We've pulled down name2 in lin1 so that it's at the head; the - -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') - -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and - -- wCopies the name2 on top of that. - wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) -sortLinLayouts _ LinEnd LinApp{} = WClosed -sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" - -autoWeak :: SubLayout names1 names2 - => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 -autoWeak segs ly1 ly2 = - preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> - sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/Analysis/Identity.hs b/src/Analysis/Identity.hs deleted file mode 100644 index 7b896a3..0000000 --- a/src/Analysis/Identity.hs +++ /dev/null @@ -1,436 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -module Analysis.Identity ( - identityAnalysis, - identityAnalysis', - ValId(..), - validSplitEither, -) where - -import Data.Foldable (toList) -import Data.List (intercalate) - -import AST -import AST.Pretty (PrettyX(..)) -import CHAD.Types (d1, d2) -import Data -import Util.IdGen - - --- | Every array, scalar and accumulator has an ID. Trivial values such as --- Nothing only have the knowledge that they are indeed Nothing. Compound --- values know which values they consist of. -data ValId t where - VINil :: ValId TNil - VIPair :: ValId a -> ValId b -> ValId (TPair a b) - VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative - VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case - VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) - VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) - VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value - VIArr :: Int -> Vec n Int -> ValId (TArr n t) - VIScal :: Int -> ValId (TScal t) - VIAccum :: Int -> ValId (TAccum t) -deriving instance Show (ValId t) - -instance PrettyX ValId where - prettyX = \case - VINil -> "" - VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" - VIEither (Left a) -> "(L" ++ prettyX a ++ ")" - VIEither (Right a) -> "(R" ++ prettyX a ++ ")" - VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" - VIMaybe Nothing -> "N" - VIMaybe (Just a) -> 'J' : prettyX a - VIMaybe' a -> 'M' : prettyX a - VILEither (VIMaybe Nothing) -> "lN" - VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" - VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" - VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" - VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" - VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" - VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" - VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" - VIScal i -> show i - VIAccum i -> 'C' : show i - -validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b)) -validSplitEither (VIEither (Left v)) = (Just v, Nothing) -validSplitEither (VIEither (Right v)) = (Nothing, Just v) -validSplitEither (VIEither' v1 v2) = (Just v1, Just v2) - --- | Symbolic partial evaluation. -identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t -identityAnalysis env term = runIdGen 0 $ do - env' <- slistMapA genIds env - snd <$> idana env' term - -identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t -identityAnalysis' env term = snd (runIdGen 0 (idana env term)) - -idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) -idana env expr = case expr of - EVar _ t i -> do - let v = slistIdx env i - pure (v, EVar v t i) - - ELet _ e1 e2 -> do - (v1, e1') <- idana env e1 - (v2, e2') <- idana (v1 `SCons` env) e2 - pure (v2, ELet v2 e1' e2') - - EPair _ e1 e2 -> do - (v1, e1') <- idana env e1 - (v2, e2') <- idana env e2 - pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2') - - EFst _ e -> do - (v, e') <- idana env e - let VIPair v1 _ = v - pure (v1, EFst v1 e') - - ESnd _ e -> do - (v, e') <- idana env e - let VIPair _ v2 = v - pure (v2, ESnd v2 e') - - ENil _ -> pure (VINil, ENil VINil) - - EInl _ t2 e1 -> do - (v1, e1') <- idana env e1 - let v = VIEither (Left v1) - pure (v, EInl v t2 e1') - - EInr _ t1 e2 -> do - (v2, e2') <- idana env e2 - let v = VIEither (Right v2) - pure (v, EInr v t1 e2') - - ECase _ e1 e2 e3 -> do - let STEither t1 t2 = typeOf e1 - (v1, e1') <- idana env e1 - case v1 of - VIEither (Left v1') -> do - (v2, e2') <- idana (v1' `SCons` env) e2 - scrap <- genIds t2 - (_, e3') <- idana (scrap `SCons` env) e3 - pure (v2, ECase v2 e1' e2' e3') - VIEither (Right v1') -> do - scrap <- genIds t1 - (_, e2') <- idana (scrap `SCons` env) e2 - (v3, e3') <- idana (v1' `SCons` env) e3 - pure (v3, ECase v3 e1' e2' e3') - VIEither' v1'l v1'r -> do - (v2, e2') <- idana (v1'l `SCons` env) e2 - (v3, e3') <- idana (v1'r `SCons` env) e3 - res <- unify v2 v3 - pure (res, ECase res e1' e2' e3') - - ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) - - EJust _ e1 -> do - (v1, e1') <- idana env e1 - let v = VIMaybe (Just v1) - pure (v, EJust v e1') - - EMaybe _ e1 e2 e3 -> do - let STMaybe t1 = typeOf e3 - (v3, e3') <- idana env e3 - case v3 of - VIMaybe Nothing -> do - (v1, e1') <- idana env e1 - scrap <- genIds t1 - (_, e2') <- idana (scrap `SCons` env) e2 - pure (v1, EMaybe v1 e1' e2' e3') - VIMaybe (Just v3j) -> do - (v2, e2') <- idana (v3j `SCons` env) e2 - (_, e1') <- idana env e1 - pure (v2, EMaybe v2 e1' e2' e3') - VIMaybe' v3' -> do - (v2, e2') <- idana (v3' `SCons` env) e2 - (v1, e1') <- idana env e1 - res <- unify v1 v2 - pure (res, EMaybe res e1' e2' e3') - - ELNil _ t1 t2 -> do - let v = VILEither (VIMaybe Nothing) - pure (v, ELNil v t1 t2) - - ELInl _ t2 e1 -> do - (v1, e1') <- idana env e1 - let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) - pure (v, ELInl v t2 e1') - - ELInr _ t1 e2 -> do - (v2, e2') <- idana env e2 - let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) - pure (v, ELInr v t1 e2') - - ELCase _ e1 e2 e3 e4 -> do - let STLEither t1 t2 = typeOf e1 - (v1L, e1') <- idana env e1 - let VILEither v1 = v1L - let go mv1'l mv1'r f = do - v1'l <- maybe (genIds t1) pure mv1'l - v1'r <- maybe (genIds t2) pure mv1'r - (v2, e2') <- idana env e2 - (v3, e3') <- idana (v1'l `SCons` env) e3 - (v4, e4') <- idana (v1'r `SCons` env) e4 - res <- f v2 v3 v4 - pure (res, ELCase res e1' e2' e3' e4') - case v1 of - VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) - VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) - VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) - VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) - VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) - VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) - VIMaybe' (VIEither' v1'l v1'r) -> - go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) - - EConstArr _ dim t arr -> do - x1 <- VIArr <$> genId <*> vecReplicateA dim genId - pure (x1, EConstArr x1 dim t arr) - - EBuild _ dim e1 e2 -> do - (shids, e1') <- idana env e1 - x1 <- genIds (tTup (sreplicate dim tIx)) - (_, e2') <- idana (x1 `SCons` env) e2 - res <- VIArr <$> genId <*> shidsToVec dim shids - pure (res, EBuild res dim e1' e2') - - EMap _ e1 e2 -> do - let STArr _ t = typeOf e2 - x1 <- genIds t - (_, e1') <- idana (x1 `SCons` env) e1 - (v2, e2') <- idana env e2 - let VIArr _ sh = v2 - res <- VIArr <$> genId <*> pure sh - pure (res, EMap res e1' e2') - - EFold1Inner _ cm e1 e2 e3 -> do - let t1 = typeOf e1 - x1 <- genIds (STPair t1 t1) - (_, e1') <- idana (x1 `SCons` env) e1 - (_, e2') <- idana env e2 - (v3, e3') <- idana env e3 - let VIArr _ (_ :< sh) = v3 - res <- VIArr <$> genId <*> pure sh - pure (res, EFold1Inner res cm e1' e2' e3') - - ESum1Inner _ e1 -> do - (v1, e1') <- idana env e1 - let VIArr _ (_ :< sh) = v1 - res <- VIArr <$> genId <*> pure sh - pure (res, ESum1Inner res e1') - - EUnit _ e1 -> do - (_, e1') <- idana env e1 - res <- VIArr <$> genId <*> pure VNil - pure (res, EUnit res e1') - - EReplicate1Inner _ e1 e2 -> do - (v1, e1') <- idana env e1 - let VIScal v1' = v1 - (v2, e2') <- idana env e2 - let VIArr _ sh = v2 - res <- VIArr <$> genId <*> pure (v1' :< sh) - pure (res, EReplicate1Inner res e1' e2') - - EMaximum1Inner _ e1 -> do - (v1, e1') <- idana env e1 - let VIArr _ (_ :< sh) = v1 - res <- VIArr <$> genId <*> pure sh - pure (res, EMaximum1Inner res e1') - - EMinimum1Inner _ e1 -> do - (v1, e1') <- idana env e1 - let VIArr _ (_ :< sh) = v1 - res <- VIArr <$> genId <*> pure sh - pure (res, EMinimum1Inner res e1') - - EReshape _ dim e1 e2 -> do - (v1, e1') <- idana env e1 - (_, e2') <- idana env e2 - res <- VIArr <$> genId <*> shidsToVec dim v1 - pure (res, EReshape res dim e1' e2') - - EZip _ e1 e2 -> do - (v1, e1') <- idana env e1 - (_, e2') <- idana env e2 - let VIArr _ sh = v1 - res <- VIArr <$> genId <*> pure sh - pure (res, EZip res e1' e2') - - EFold1InnerD1 _ cm e1 e2 e3 -> do - let t1 = typeOf e2 - x1 <- genIds (STPair t1 t1) - (_, e1') <- idana (x1 `SCons` env) e1 - (_, e2') <- idana env e2 - (v3, e3') <- idana env e3 - let VIArr _ sh'@(_ :< sh) = v3 - res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') - pure (res, EFold1InnerD1 res cm e1' e2' e3') - - EFold1InnerD2 _ cm ef ebog ed -> do - let STArr _ tB = typeOf ebog - STArr _ t2 = typeOf ed - xf1 <- genIds t2 - xf2 <- genIds tB - (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef - (v2, e2') <- idana env ebog - (_, e3') <- idana env ed - let VIArr _ sh@(_ :< sh') = v2 - res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) - pure (res, EFold1InnerD2 res cm e1' e2' e3') - - EConst _ t val -> do - res <- VIScal <$> genId - pure (res, EConst res t val) - - EIdx0 _ e1 -> do - (_, e1') <- idana env e1 - res <- genIds (typeOf expr) - pure (res, EIdx0 res e1') - - EIdx1 _ e1 e2 -> do - (v1, e1') <- idana env e1 - let VIArr _ sh = v1 - (_, e2') <- idana env e2 - res <- VIArr <$> genId <*> pure (vecInit sh) - pure (res, EIdx1 res e1' e2') - - EIdx _ e1 e2 -> do - (_, e1') <- idana env e1 - (_, e2') <- idana env e2 - res <- genIds (typeOf expr) - pure (res, EIdx res e1' e2') - - EShape _ e1 -> do - let STArr dim _ = typeOf e1 - (v1, e1') <- idana env e1 - let VIArr _ sh = v1 - res = vecToShids dim sh - pure (res, EShape res e1') - - EOp _ (op :: SOp a t) e1 -> do - (_, e1') <- idana env e1 - res <- genIds (typeOf expr) - pure (res, EOp res op e1') - - ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do - let t4 = typeOf e1 - x1 <- genIds t2 - x2 <- genIds t1 - (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 - x3 <- genIds (d1 t2) - x4 <- genIds (d1 t1) - (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 - x5 <- genIds (d2 t4) - x6 <- genIds t3 - (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 - (_, e4') <- idana env e4 - (_, e5') <- idana env e5 - res <- genIds t4 - pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') - - ERecompute _ e -> do - (v, e') <- idana env e - pure (v, ERecompute v e') - - EWith _ t e1 e2 -> do - let t1 = typeOf e1 - (_, e1') <- idana env e1 - x1 <- VIAccum <$> genId - (v2, e2') <- idana (x1 `SCons` env) e2 - x2 <- genIds t1 - let res = VIPair v2 x2 - pure (res, EWith res t e1' e2') - - EAccum _ t prj e1 sp e2 e3 -> do - (_, e1') <- idana env e1 - (_, e2') <- idana env e2 - (_, e3') <- idana env e3 - pure (VINil, EAccum VINil t prj e1' sp e2' e3') - - EZero _ t e1 -> do - -- Approximate the result of EZero to be independent from the zero info - -- expression; not quite true for shape variables - (_, e1') <- idana env e1 - res <- genIds (fromSMTy t) - pure (res, EZero res t e1') - - EDeepZero _ t e1 -> do - -- Approximate the result of EDeepZero to be independent from the zero info - -- expression; not quite true for shape variables - (_, e1') <- idana env e1 - res <- genIds (fromSMTy t) - pure (res, EDeepZero res t e1') - - EPlus _ t e1 e2 -> do - (_, e1') <- idana env e1 - (_, e2') <- idana env e2 - res <- genIds (fromSMTy t) - pure (res, EPlus res t e1' e2') - - EOneHot _ t i e1 e2 -> do - (_, e1') <- idana env e1 - (_, e2') <- idana env e2 - res <- genIds (fromSMTy t) - pure (res, EOneHot res t i e1' e2') - - EError _ t s -> do - res <- genIds t - pure (res, EError res t s) - --- | This value might be either of the two arguments; we don't know which. -unify :: ValId t -> ValId t -> IdGen (ValId t) -unify VINil VINil = pure VINil -unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d -unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b -unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b -unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b -unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a -unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c -unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c -unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b -unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c -unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d -unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing -unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b -unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a -unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a -unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a -unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b -unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a -unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b -unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b -unify (VILEither a) (VILEither b) = VILEither <$> unify a b -unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js -unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j -unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j - -unifyID :: Int -> Int -> IdGen Int -unifyID i j | i == j = pure i - | otherwise = genId - -genIds :: STy t -> IdGen (ValId t) -genIds STNil = pure VINil -genIds (STPair a b) = VIPair <$> genIds a <*> genIds b -genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b -genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) -genIds (STMaybe t) = VIMaybe' <$> genIds t -genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId -genIds STScal{} = VIScal <$> genId -genIds STAccum{} = VIAccum <$> genId - -shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) -shidsToVec SZ _ = pure VNil -shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is - -vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx)) -vecToShids SZ VNil = VINil -vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i) diff --git a/src/Array.hs b/src/Array.hs deleted file mode 100644 index 6ceb9fe..0000000 --- a/src/Array.hs +++ /dev/null @@ -1,131 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TupleSections #-} -module Array where - -import Control.DeepSeq -import Control.Monad.Trans.State.Strict -import Data.Foldable (traverse_) -import Data.Vector (Vector) -import qualified Data.Vector as V -import GHC.Generics (Generic) - -import Data - - -data Shape n where - ShNil :: Shape Z - ShCons :: Shape n -> Int -> Shape (S n) -deriving instance Show (Shape n) -deriving instance Eq (Shape n) - -instance NFData (Shape n) where - rnf ShNil = () - rnf (sh `ShCons` n) = rnf n `seq` rnf sh - -data Index n where - IxNil :: Index Z - IxCons :: Index n -> Int -> Index (S n) -deriving instance Show (Index n) -deriving instance Eq (Index n) - -instance NFData (Index n) where - rnf IxNil = () - rnf (sh `IxCons` n) = rnf n `seq` rnf sh - -shapeSize :: Shape n -> Int -shapeSize ShNil = 1 -shapeSize (ShCons sh n) = shapeSize sh * n - -shapeRank :: Shape n -> SNat n -shapeRank ShNil = SZ -shapeRank (sh `ShCons` _) = SS (shapeRank sh) - -fromLinearIndex :: Shape n -> Int -> Index n -fromLinearIndex ShNil 0 = IxNil -fromLinearIndex ShNil _ = error "Index out of range" -fromLinearIndex (sh `ShCons` n) i = - let (q, r) = i `quotRem` n - in fromLinearIndex sh q `IxCons` r - -toLinearIndex :: Shape n -> Index n -> Int -toLinearIndex ShNil IxNil = 0 -toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i - -emptyShape :: SNat n -> Shape n -emptyShape SZ = ShNil -emptyShape (SS m) = emptyShape m `ShCons` 0 - -enumShape :: Shape n -> [Index n] -enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] - -shapeToList :: Shape n -> [Int] -shapeToList = go [] - where - go :: [Int] -> Shape n -> [Int] - go suff ShNil = suff - go suff (sh `ShCons` n) = go (n:suff) sh - - --- | TODO: this Vector is a boxed vector, which is horrendously inefficient. -data Array (n :: Nat) t = Array (Shape n) (Vector t) - deriving (Show, Functor, Foldable, Traversable, Generic) -instance NFData t => NFData (Array n t) - -arrayShape :: Array n t -> Shape n -arrayShape (Array sh _) = sh - -arraySize :: Array n t -> Int -arraySize (Array sh _) = shapeSize sh - -emptyArray :: SNat n -> Array n t -emptyArray n = Array (emptyShape n) V.empty - -arrayFromList :: Shape n -> [t] -> Array n t -arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) - -arrayToList :: Array n t -> [t] -arrayToList (Array _ v) = V.toList v - -arrayReshape :: Shape n -> Array m t -> Array n t -arrayReshape sh (Array sh' v) - | shapeSize sh == shapeSize sh' = Array sh v - | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" - -arrayUnit :: t -> Array Z t -arrayUnit x = Array ShNil (V.singleton x) - -arrayIndex :: Array n t -> Index n -> t -arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) - -arrayIndexLinear :: Array n t -> Int -> t -arrayIndexLinear (Array _ v) i = v V.! i - -arrayIndex1 :: Array (S n) t -> Int -> Array n t -arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) - -arrayGenerate :: Shape n -> (Index n -> t) -> Array n t -arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) - -arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t -arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) - -arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) -arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) - -arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) -arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f - -arrayMap :: (a -> b) -> Array n a -> Array n b -arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) - -arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) -arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) - --- | The Int is the linear index of the value. -traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () -traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 diff --git a/src/CHAD.hs b/src/CHAD.hs deleted file mode 100644 index 298d964..0000000 --- a/src/CHAD.hs +++ /dev/null @@ -1,1583 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE ImpredicativeTypes #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeData #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module CHAD ( - drev, - freezeRet, - CHADConfig(..), - defaultConfig, - Storage(..), - Descr(..), - Select, -) where - -import Data.Functor.Const -import Data.Some -import Data.Type.Equality (type (==), testEquality) - -import Analysis.Identity (ValId(..), validSplitEither) -import AST -import AST.Bindings -import AST.Count -import AST.Env -import AST.Sparse -import AST.Weaken.Auto -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap -import Data.VarMap (VarMap) -import Lemmas - - ------------------------------- TAPES AND BINDINGS ------------------------------ - -type family Tape binds where - Tape '[] = TNil - Tape (t : ts) = TPair t (Tape ts) - -tapeTy :: SList STy binds -> STy (Tape binds) -tapeTy SNil = STNil -tapeTy (SCons t ts) = STPair t (tapeTy ts) - -bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds - -> binds :> env2 -> Ex env2 (Tape tapebinds) -bindingsCollectTape SNil SETop _ = ENil ext -bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = - EPair ext (EVar ext t (w @> IZ)) - (bindingsCollectTape binds sub (w .> WSink)) -bindingsCollectTape (_ `SCons` binds) (SENo sub) w = - bindingsCollectTape binds sub (w .> WSink) - --- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds --- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) --- bindingsCollectTape' binds sub w --- | Refl <- lemAppendNil @binds --- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) - --- In order from large to small: i.e. in reverse order from what we want, --- because in a Bindings, the head of the list is the bottom-most entry. -type family TapeUnfoldings binds where - TapeUnfoldings '[] = '[] - TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts - -type family Reverse l where - Reverse '[] = '[] - Reverse (t : ts) = Append (Reverse ts) '[t] - --- An expression that is always 'snd' -data UnfExpr env t where - UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t - -fromUnfExpr :: UnfExpr env t -> Ex env t -fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) - --- - A bunch of 'snd' expressions taking us from knowing that there's a --- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix --- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in --- the environment. --- - In the extended environment, another bunch of let bindings (these are --- 'fst' expressions, but no need to know that statically) that project the --- fsts out of what we introduced above, one for each type in 'ts'. -data Reconstructor env ts = - Reconstructor - (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) - (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) - -ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) -ssnoc SNil a = SCons a SNil -ssnoc (SCons t ts) a = SCons t (ssnoc ts a) - -sreverse :: SList f ts -> SList f (Reverse ts) -sreverse SNil = SNil -sreverse (SCons t ts) = ssnoc (sreverse ts) t - -stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) -stapeUnfoldings SNil = SNil -stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) - --- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. -shiftUnfolder - :: STy t - -> SList STy ts - -> Bindings UnfExpr (Tape ts : env) list - -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) -shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) -shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = - -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order - -- to expand an 'Append' in the types so that things simplify just enough. - -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list - -- of bindings produced by 'b'. We want to conclude from this that - -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know - -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after - -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. - BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t - BPush{} -> UnfExSnd itemTy t) - -growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) -growRecon t ts (Reconstructor unfbs bs) - | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) - , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) - , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env - = Reconstructor - (shiftUnfolder t ts unfbs) - -- Add a 'fst' at the bottom of the builder stack. - -- First we have to weaken most of 'bs' to skip one more binding in the - -- unfolder stack above it. - (BPush (fst (weakenBindingsE - (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) - (WSink :: env :> (Tape (t : ts) : env))) bs)) - (t - ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ - wSinks @(Tape (t : ts) : env) - (sappend ts - (sappend (sappend (sreverse (stapeUnfoldings ts)) - (SCons (tapeTy ts) SNil)) - SNil)) - @> IZ)) - -buildReconstructor :: SList STy ts -> Reconstructor env ts -buildReconstructor SNil = Reconstructor BTop BTop -buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) - --- STRATEGY FOR reconstructBindings --- --- binds = [] --- e : () --- --- binds = [c] --- e : (c, ()) --- x0 = snd x1 : () --- y1 = fst e : c --- --- binds = [b, c] --- e : (b, (c, ())) --- x1 = snd e : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- --- binds = [a, b, c] --- e : (a, (b, (c, ()))) --- x2 = snd e : (b, (c, ())) --- x1 = snd x2 : (c, ()) --- x0 = snd x1 : () --- y1 = fst x1 : c --- y2 = fst x2 : b --- y3 = fst x3 : a - --- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all --- the things in the list 'binds', we want to create a let stack that extracts --- all values from that tuple and in effect "restores" the environment --- described by 'binds'. The idea is that elsewhere, we took a slice of the --- environment and saved it all in a tuple to be restored later. We --- incidentally also add a bunch of additional bindings, namely 'Reverse --- (TapeUnfoldings binds)', so the calling code just has to skip those in --- whatever it wants to do. -reconstructBindings :: SList STy binds - -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) - ,SList STy (Reverse (TapeUnfoldings binds))) -reconstructBindings binds = - (\tape -> let Reconstructor unf build = buildReconstructor binds - in fst $ weakenBindingsE (WIdx tape) - (bconcat (mapBindings fromUnfExpr unf) build) - ,sreverse (stapeUnfoldings binds)) - - ----------------------------------- DERIVATIVES --------------------------------- - -d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) -d1op (OAdd t) e = EOp ext (OAdd t) e -d1op (OMul t) e = EOp ext (OMul t) e -d1op (ONeg t) e = EOp ext (ONeg t) e -d1op (OLt t) e = EOp ext (OLt t) e -d1op (OLe t) e = EOp ext (OLe t) e -d1op (OEq t) e = EOp ext (OEq t) e -d1op ONot e = EOp ext ONot e -d1op OAnd e = EOp ext OAnd e -d1op OOr e = EOp ext OOr e -d1op OIf e = EOp ext OIf e -d1op ORound64 e = EOp ext ORound64 e -d1op OToFl64 e = EOp ext OToFl64 e -d1op (ORecip t) e = EOp ext (ORecip t) e -d1op (OExp t) e = EOp ext (OExp t) e -d1op (OLog t) e = EOp ext (OLog t) e -d1op (OIDiv t) e = EOp ext (OIDiv t) e -d1op (OMod t) e = EOp ext (OMod t) e - --- | Both primal and dual must be duplicable expressions -data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) - | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) - -d2op :: SOp a t -> D2Op a t -d2op op = case op of - OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d - OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> - EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) - (EOp ext (OMul t) (EPair ext (EFst ext e) d)) - ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d - OLt t -> Linear $ \_ -> pairZero t - OLe t -> Linear $ \_ -> pairZero t - OEq t -> Linear $ \_ -> pairZero t - ONot -> Linear $ \_ -> ENil ext - OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - OIf -> Linear $ \_ -> ENil ext - ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) - OToFl64 -> Linear $ \_ -> ENil ext - ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) - OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) - OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) - OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - where - pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) - pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) - (EZero ext (d2M (STScal t)) (ENil ext)) - where - ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r - ziNil STI32 k = k - ziNil STI64 k = k - ziNil STF32 k = k - ziNil STF64 k = k - ziNil STBool k = k - - d2opUnArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TScal a) t) - -> D2Op (TScal a) t - d2opUnArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> ENil ext - STI64 -> Linear $ \_ -> ENil ext - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> ENil ext - - d2opBinArrangeInt :: SScalTy a - -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) - -> D2Op (TPair (TScal a) (TScal a)) t - d2opBinArrangeInt ty float = case ty of - STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - STF32 -> float - STF64 -> float - STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) - - floatingD2 :: ScalIsFloating a ~ True - => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r - floatingD2 STF32 k = k - floatingD2 STF64 k = k - - integralD2 :: ScalIsIntegral a ~ True - => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r - integralD2 STI32 k = k - integralD2 STI64 k = k - -desD1E :: Descr env sto -> SList STy (D1E env) -desD1E = d1e . descrList - --- d1W :: env :> env' -> D1E env :> D1E env' --- d1W WId = WId --- d1W WSink = WSink --- d1W (WCopy w) = WCopy (d1W w) --- d1W (WPop w) = WPop (d1W w) --- d1W (WThen u w) = WThen (d1W u) (d1W w) - -conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) -conv1Idx IZ = IZ -conv1Idx (IS i) = IS (conv1Idx i) - -data Idx2 env sto t - = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) - | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) - | Idx2Di (Idx (Select env sto "discr") t) - -conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t -conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ -conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ -conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ -conv2Idx (DPush des (_, _, SAccum)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SMerge)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me (IS j) - Idx2Di j -> Idx2Di j -conv2Idx (DPush des (_, _, SDiscr)) (IS i) = - case conv2Idx des i of Idx2Ac j -> Idx2Ac j - Idx2Me j -> Idx2Me j - Idx2Di j -> Idx2Di (IS j) -conv2Idx DTop i = case i of {} - -opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) -opt2UnSparse = go . opt2 - where - go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) - go (STScal STI32) SpAbsent = \_ -> ENil ext - go (STScal STI64) SpAbsent = \_ -> ENil ext - go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) - go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) - go (STScal STBool) SpAbsent = \_ -> ENil ext - go (STScal STF32) SpScal = id - go (STScal STF64) SpScal = id - go STNil _ = \_ -> ENil ext - go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) - go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" - - ------------------------------------ SPARSITY ----------------------------------- - -expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) -expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e -expandSparse t (SpSparse sp) epr e = - EMaybe ext - (EZero ext (d2M t) (d2zeroInfo t epr)) - (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) - e -expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) -expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = - eunPair epr $ \w1 epr1 epr2 -> - eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> - EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) - (expandSparse t2 s2 (weakenExpr w2 epr2) e2) -expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = - ELCase ext e - (EZero ext (d2M (STEither t1 t2)) (ENil ext)) - (ECase ext (weakenExpr WSink epr) - (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) - (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) - (ECase ext (weakenExpr WSink epr) - (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") - (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) -expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = - ELCase ext e - (EZero ext (d2M (STEither t1 t2)) (ENil ext)) - (ELCase ext (weakenExpr WSink epr) - (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") - (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) - (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) - (ELCase ext (weakenExpr WSink epr) - (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") - (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") - (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) -expandSparse (STMaybe t) (SpMaybe s) epr e = - EMaybe ext - (ENothing ext (d2 t)) - (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr - in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) - e -expandSparse (STArr _ t) (SpArr s) epr e = - ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e -expandSparse (STScal STF32) SpScal _ e = e -expandSparse (STScal STF64) SpScal _ e = e -expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" - -subenvPlus :: SBool req1 -> SBool req2 - -> SList SMTy env - -> SubenvS env env1 -> SubenvS env env2 - -> (forall env3. SubenvS env env3 - -> Injection req1 (Tup env1) (Tup env3) - -> Injection req2 (Tup env2) (Tup env3) - -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) - -> r) - -> r --- don't destroy effects! -subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) - -subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = - subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> - k (SENo sub3) s31 s32 pl - -subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = - subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> - k (SEYes sp1 sub3) - (withInj minj13 $ \inj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) e1b) - Noinj - (\e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ))) -subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k - | Just zero1 <- cheapZero (applySparse sp1 t) = - subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> - k (SEYes sp1 sub3) - (withInj minj13 $ \inj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) e1b) - (Inj $ \e2 -> EPair ext (inj23 e2) zero1) - (\e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (ESnd ext (EVar ext (typeOf e1) IZ))) - | otherwise = - subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> - k (SEYes (SpSparse sp1) sub3) - (withInj minj13 $ \inj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) (EJust ext e1b)) - (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) - (\e1 e2 -> - ELet ext e1 $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) - (weakenExpr WSink e2)) - (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) - -subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = - subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> - k sub3 minj13 minj23 (flip pl) - -subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = - subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> - sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> - k (SEYes sp3 sub3) - (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> - \e1 -> eunPair e1 $ \_ e1a e1b -> - EPair ext (inj13 e1a) (tinj13 e1b)) - (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> - \e2 -> eunPair e2 $ \_ e2a e2b -> - EPair ext (inj23 e2a) (tinj23 e2b)) - (\e1 e2 -> - ELet ext e1 $ - ELet ext (weakenExpr WSink e2) $ - EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) - (EFst ext (EVar ext (typeOf e2) IZ))) - (plus - (ESnd ext (EVar ext (typeOf e1) (IS IZ))) - (ESnd ext (EVar ext (typeOf e2) IZ)))) - -expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs - -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) -expandSubenvZeros _ SNil SETop _ = ENil ext -expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = - eunPair e $ \w1 e1 e2 -> - EPair ext - (expandSubenvZeros (w1 .> WPop w) ts sub e1) - (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) -expandSubenvZeros w (SCons t ts) (SENo sub) e = - EPair ext - (expandSubenvZeros (WPop w) ts sub e) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - - ---------------------------------- ACCUMULATORS --------------------------------- - -fromArrayValId :: Maybe (ValId t) -> Maybe Int -fromArrayValId (Just (VIArr i _)) = Just i -fromArrayValId _ = Nothing - -accumPromote :: forall dt env sto proxy r. - proxy dt - -> Descr env sto - -> (forall stoRepl envPro. - (Select env stoRepl "merge" ~ '[]) - => Descr env stoRepl - -- ^ A revised environment description that switches - -- arrays (used in the OccEnv) that are currently on - -- "merge" storage, to "accum" storage. - -> SList STy envPro - -- ^ New entries on top of the original dual environment, - -- that house the accumulators for the promoted arrays in - -- the original environment. - -> Subenv (Select env sto "merge") envPro - -- ^ The promoted entries were merge entries in the - -- original environment. - -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) - -- ^ All entries that were accumulators are still - -- accumulators. - -> VarMap Int (D2AcE (Select env stoRepl "accum")) - -- ^ Accumulator map for _only_ the the newly allocated - -- accumulators. - -> (forall shbinds. - SList STy shbinds - -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) - :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) - -- ^ A weakening that converts a computation in the - -- revised environment to one in the original environment - -- extended with some accumulators. - -> r) - -> r -accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) -accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of - -- Accumulators are left as-is - SAccum -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - envpro - prosub - (SEYesR accrevsub) - (VarMap.sink1 accumMap) - (\shbinds -> - autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) - (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) - (#pro :++: #d :++: #shb :++: #acc :++: #tl) - .> WCopy (wf shbinds) - .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) - (#d :++: #shb :++: #acc :++: #tl) - (#acc :++: (#d :++: #shb :++: #tl))) - - SMerge -> case t of - -- Discrete values are left as-is - _ | isDiscrete t -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - (SENo prosub) - accrevsub - accumMap' - wf - - -- Values with "merge" storage are promoted to an accumulator in envPro - _ -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SAccum)) - (t `SCons` envpro) - (SEYesR prosub) - (SENo accrevsub) - (let accumMap' = VarMap.sink1 accumMap - in case fromArrayValId vid of - Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' - Nothing -> accumMap') - (\(shbinds :: SList _ shbinds) -> - let shbindsC = slistMap (\_ -> Const ()) shbinds - in - -- wf: - -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WCopy wf: - -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - -- WPICK: ^ THESE TWO || - -- goal: | ARE EQUAL || - -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) - WCopy (wf shbinds) - .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) - (WId @(D2AcE (Select env1 stoRepl "accum")))) - - -- Discrete values are left as-is, nothing to do - SDiscr -> - accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> - k (storepl `DPush` (t, vid, SDiscr)) - envpro - prosub - accrevsub - accumMap - wf - where - isDiscrete :: STy t' -> Bool - isDiscrete = \case - STNil -> True - STPair a b -> isDiscrete a && isDiscrete b - STEither a b -> isDiscrete a && isDiscrete b - STLEither a b -> isDiscrete a && isDiscrete b - STMaybe a -> isDiscrete a - STArr _ a -> isDiscrete a - STScal st -> case st of - STI32 -> True - STI64 -> True - STF32 -> False - STF64 -> False - STBool -> True - STAccum{} -> False - - ----------------------------- RETURN TRIPLE FROM CHAD --------------------------- - -data Ret env0 sto sd t = - forall shbinds tapebinds contribs. - Ret (Bindings Ex (D1E env0) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (Ex (Append shbinds (D1E env0)) (D1 t)) - (SubenvS (D2E (Select env0 sto "merge")) contribs) - (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) -deriving instance Show (Ret env0 sto sd t) - -type data TyTyPair = MkTyTyPair Ty Ty - -data SingleRet env0 sto (pair :: TyTyPair) = - forall shbinds tapebinds. - SingleRet - (Bindings Ex (D1E env0) shbinds) -- shared binds - (Subenv shbinds tapebinds) - (RetPair env0 sto (D1E env0) shbinds tapebinds pair) - --- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds --- -> Subenv shbinds tapebinds --- -> Ex (Append shbinds (D1E env0)) (D1 t) --- -> SubenvS (D2E (Select env0 sto "merge")) contribs --- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) --- -> SingleRet env0 sto (MkTyTyPair sd t) --- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) --- {-# COMPLETE Ret1 #-} - -data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where - RetPair :: forall sd t contribs -- existentials - env0 sto env shbinds tapebinds. -- universals - Ex (Append shbinds env) (D1 t) - -> SubenvS (D2E (Select env0 sto "merge")) contribs - -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) - -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) -deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) - -data Rets env0 sto env list = - forall shbinds tapebinds. - Rets (Bindings Ex env shbinds) - (Subenv shbinds tapebinds) - (SList (RetPair env0 sto env shbinds tapebinds) list) -deriving instance Show (Rets env0 sto env list) - -toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) -toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) - -weakenRetPair :: SList STy shbinds -> env :> env' - -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair -weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 - -weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list -weakenRets w (Rets binds tapesub list) = - let (binds', _) = weakenBindingsE w binds - in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) - -rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. - Descr env0 sto - -> SList f b1 -> SList f b2 - -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 - -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair - -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair -rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) - | Refl <- lemAppendAssoc @b2 @b1 @env = - RetPair e1 sub - (weakenExpr (autoWeak - (#d (auto1 @sd) - &. #t2 (subList b2 subtape2) - &. #t1 (subList b1 subtape1) - &. #tl (d2ace (select SAccum descr))) - (#d :++: (#t2 :++: #tl)) - (#d :++: ((#t2 :++: #t1) :++: #tl))) - e2) - -retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list -retConcat _ SNil = Rets BTop SETop SNil -retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) - | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs - <- weakenRets (sinkWithBindings e0) (retConcat descr list) - , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) - , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) - = Rets (bconcat e0 binds) - (subenvConcat subtape subtape2) - (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) - sub - (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) - (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) - subtape subtape2) - pairs)) - -freezeRet :: Descr env sto - -> Ret env sto (D2 t) t - -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) -freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = - let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 - e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 - tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) - library = #d (auto1 @(D2 t)) - &. #tape (subList (bindingsBinds e0) subtape) - &. #shbinds (bindingsBinds e0) - &. #d2ace (d2ace (select SAccum descr)) - &. #tl (desD1E descr) - &. #contribs (SCons tContribs SNil) - in letBinds e0' $ - EPair ext - (weakenExpr wInsertD2Ac e1) - (ELet ext (weakenExpr (autoWeak library - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) - (#shbinds :++: #d :++: #d2ace :++: #tl)) - e2') $ - expandSubenvZeros - (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) - .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) - (select SMerge descr) sub (EVar ext tContribs IZ)) - - ----------------------------- THE CHAD TRANSFORMATION --------------------------- - -drev :: forall env sto sd t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Sparse (D2 t) sd - -> Expr ValId env t -> Ret env sto sd t -drev des _ sd | isAbsent sd = - \e -> - Ret BTop - SETop - (drevPrimal des e) - (subenvNone (d2e (select SMerge des))) - (ENil ext) -drev _ _ SpAbsent = error "Absent should be isAbsent" - -drev des accumMap (SpSparse sd) = - \e -> - case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> - subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> - Ret e0 - subtape - e1 - sub' - (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) - (inj2 (ENil ext)) - (inj1 (weakenExpr (WCopy WSink) e2))) - } - -drev des accumMap sd = \case - EVar _ t i -> - case conv2Idx des i of - Idx2Ac accI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (d2e (select SMerge des))) - (let ty = applySparse sd (d2M t) - in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) - - Idx2Me tupI -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvOnehot (d2e (select SMerge des)) tupI sd) - (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) - - Idx2Di _ -> - Ret BTop - SETop - (EVar ext (d1 t) (conv1Idx i)) - (subenvNone (d2e (select SMerge des))) - (ENil ext) - - ELet _ (rhs :: Expr _ _ a) body - | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge - , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body - , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs - , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 - , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds - , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) - , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) - -> - subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> - let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in - Ret (bconcat (rhs0 `bpush` rhs1) body0') - (subenvConcat subtapeRHS subtapeBody) - (weakenExpr wbody0' body1) - subBoth - (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) - &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) - &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #tl) - (#d :++: (#body :++: #rhs) :++: #tl)) - body2) $ - ELet ext - (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ - plus_RHS_Body - (EVar ext (contribTupTy des subRHS) IZ) - (EFst ext (EVar ext bodyResType (IS IZ)))) - - EPair _ a b - | SpPair sd1 sd2 <- sd - , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) - <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil - , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> - subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> - Ret binds - subtape - (EPair ext a1 b1) - subBoth - (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) - (weakenExpr (WCopy WSink) a2)) $ - ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) - (weakenExpr (WCopy (WSink .> WSink)) b2)) $ - plus_A_B - (EVar ext (contribTupTy des subA) (IS IZ)) - (EVar ext (contribTupTy des subB) IZ)) - - EFst _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e - , STPair t1 _ <- typeOf e -> - Ret e0 - subtape - (EFst ext e1) - sub - (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ - weakenExpr (WCopy WSink) e2) - - ESnd _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e - , STPair _ t2 <- typeOf e -> - Ret e0 - subtape - (ESnd ext e1) - sub - (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ - weakenExpr (WCopy WSink) e2) - - -- Don't need to handle ENil, because its cotangent is always absent! - -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) - - EInl _ t2 e - | SpLEither sd1 sd2 <- sd - , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> - subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> - Ret e0 - subtape - (EInl ext (d1 t2) e1) - sub' - (ELCase ext - (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) - (inj2 $ ENil ext) - (inj1 $ weakenExpr (WCopy WSink) e2) - (EError ext (contribTupTy des sub') "inl<-dinr")) - - EInr _ t1 e - | SpLEither sd1 sd2 <- sd - , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> - subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> - Ret e0 - subtape - (EInr ext (d1 t1) e1) - sub' - (ELCase ext - (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) - (inj2 $ ENil ext) - (EError ext (contribTupTy des sub') "inr<-dinl") - (inj1 $ weakenExpr (WCopy WSink) e2)) - - ECase _ e (a :: Expr _ _ t) b - | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e - , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge - , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge - , let (bindids1, bindids2) = validSplitEither (extOf e) - , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 - <- drevScoped des accumMap t1 storage1 bindids1 sd a - , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 - <- drevScoped des accumMap t2 storage2 bindids2 sd b - , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e - , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) - , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA - , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB - , let tapeA = tapeTy subtapeListA - , let tapeB = tapeTy subtapeListB - , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) - (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA - , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) - (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB - , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) - , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 - , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 - , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) - , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) - , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) - , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) - , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) - , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) - -> - subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> - subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> - Ret (e0 `bpush` ECase ext e1 - (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) - (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) - (SEYesR subtapeE) - (EFst ext (EVar ext tPrimal IZ)) - subOut - (elet - (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListA - in letBinds (rebinds IZ) $ - ELet ext - (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ - elet - (weakenExpr (autoWeak (#d (auto1 @sd) - &. #ta0 subtapeListA - &. #prea0 prerebinds - &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #ta0 :++: #tl) - (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) - a2) $ - EPair ext (sAB_A $ EFst ext (evar IZ)) - (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) - (let (rebinds, prerebinds) = reconstructBindings subtapeListB - in letBinds (rebinds IZ) $ - ELet ext - (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ - elet - (weakenExpr (autoWeak (#d (auto1 @sd) - &. #tb0 subtapeListB - &. #preb0 prerebinds - &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) - &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) - &. #tl (d2ace (select SAccum des))) - (#d :++: #tb0 :++: #tl) - (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) - b2) $ - EPair ext (sAB_B $ EFst ext (evar IZ)) - (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ - plus_AB_E - (EFst ext (evar IZ)) - (ELet ext (ESnd ext (evar IZ)) $ - weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) - - EConst _ t val -> - Ret BTop - SETop - (EConst ext t val) - (subenvNone (d2e (select SMerge des))) - (ENil ext) - - EOp _ op e - | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> - case d2op op of - Linear d2opfun -> - Ret e0 - subtape - (d1op op e1) - sub - (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) - (weakenExpr (WCopy WSink) e2)) - Nonlinear d2opfun -> - Ret (e0 `bpush` e1) - (SEYesR subtape) - (d1op op $ EVar ext (d1 (typeOf e)) IZ) - sub - (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) - (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) - (weakenExpr (WCopy (wSinks' @[_,_])) e2)) - - ECustom _ _ tb _ srce pr du a b - -- allowed to ignore a2 because 'a' is the part of the input that is inactive - | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> - case isDense (d2M (typeOf srce)) sd of - Just Refl -> - Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) - `bpush` weakenExpr WSink b1 - `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) - `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) - (SEYesR (SENo (SENo (SENo bsubtape)))) - (EFst ext (EVar ext (typeOf pr) (IS IZ))) - bsub - (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink)) b2) - - Nothing -> - Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) - `bpush` weakenExpr WSink b1 - `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) - (SEYesR (SENo (SENo bsubtape))) - (EFst ext (EVar ext (typeOf pr) IZ)) - bsub - (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape - ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent - (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) - (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ - ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) - - ERecompute _ e -> - deleteUnused (descrList des) (occCountAll e) $ \usedSub -> - let smallE = unsafeWeakenWithSubenv usedSub e in - subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> - case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> - let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in - Ret (collectBindings (desD1E des) subD1eUsed) - (subenvAll (desD1E usedDes)) - (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) - (subenvCompose subMergeUsed' sub) - (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ - weakenExpr - (autoWeak (#d (auto1 @sd) - &. #shbinds (bindingsBinds e0) - &. #tape (subList (bindingsBinds e0) subtape) - &. #d1env (desD1E usedDes) - &. #tl' (d2ace (select SAccum usedDes)) - &. #tl (d2ace (select SAccum des))) - (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) - (#shbinds :++: #d :++: #d1env :++: #tl)) - e2) - } - - EError _ t s -> - Ret BTop - SETop - (EError ext (d1 t) s) - (subenvNone (d2e (select SMerge des))) - (ENil ext) - - EConstArr _ n t val -> - Ret BTop - SETop - (EConstArr ext n t val) - (subenvNone (d2e (select SMerge des))) - (ENil ext) - - EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) - | SpArr @_ @sdElt sdElt <- sd - , let eltty = typeOf ef - , shty :: STy shty <- tTup (sreplicate ndim tIx) - , Refl <- indexTupD1Id ndim -> - drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> - let library = #ix (shty `SCons` SNil) - &. #e0 (bindingsBinds e0) - &. #propr (d1e provars) - &. #d1env (desD1E des) - &. #d (auto1 @sdElt) - &. #tape (auto1 @e_tape) - &. #pro (d2ace provars) - &. #d2acEnv (d2ace (select SAccum des)) - &. #darr (auto1 @(TArr ndim sdElt)) - &. #tapearr (auto1 @(TArr ndim e_tape)) in - Ret (proPrimalBinds - `bpush` weakenExpr (wSinks (d1e provars)) - (EBuild ext ndim - (drevPrimal des she) - (letBinds e0 $ - EPair ext e1 e1tape)) - `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) - (SEYesR (SENo (subenvAll (d1e provars)))) - (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) - (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) - (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in - ESnd ext $ - wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ - EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ - -- the cotangent for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) - (EVar ext shty IZ)) $ - -- the tape for this element - ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) - (EVar ext shty (IS IZ))) $ - weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) - (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) - e2) - - EMap _ ef (earr :: Expr _ _ (TArr n a)) - | SpArr sdElt <- sd - , let STArr ndim t1 = typeOf earr - t2 = typeOf ef -> - drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> - case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> - let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds - ttape = typeOf ef1tape - library = #d1env (desD1E des) - &. #a0 (bindingsBinds ea0) - &. #atapebinds (subList (bindingsBinds ea0) easubtape) - &. #propr (d1e provars) - &. #x (d1 t1 `SCons` SNil) - &. #parr (STArr ndim (d1 t1) `SCons` SNil) - &. #tapearr (STArr ndim ttape `SCons` SNil) - &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) - &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) - &. #tape (ttape `SCons` SNil) - &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) - &. #d2acEnv (d2ace (select SAccum des)) - &. #pro (d2ace provars) - in - subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> - Ret (bconcat ea0 proPrimalBinds' - `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 - `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) - (letBinds ef0 $ - EPair ext ef1 ef1tape)) - (EVar ext (STArr ndim (d1 t1)) IZ) - `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) - (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) - (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) - subfa - (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in - elet - (wrapAccum (autoWeak library #propr layout) $ - emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ - elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ - weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) - (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) - ef2) - (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) - (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ - plus_f_a - (ESnd ext (evar IZ)) - (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) - (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) - ea2))) - } - - EFold1Inner _ commut origef ex₀ earr - | SpArr @_ @sdElt sdElt <- sd - , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr - , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) - <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> - drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> - let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in - let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) - bogTy = STArr (SS ndim) bogEltTy - primalTy = STPair (STArr ndim (d1 eltty)) bogTy - library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) - &. #parr (auto1 @(TArr (S n) (D1 elt))) - &. #px₀ (auto1 @(D1 elt)) - &. #px (auto1 @(D1 elt)) - &. #pzi (auto1 @(ZeroInfo (D2 elt))) - &. #primal (primalTy `SCons` SNil) - &. #darr (auto1 @(TArr n sdElt)) - &. #d (auto1 @(D2 elt)) - &. #x₀abinds (bindingsBinds bindsx₀a) - &. #fbinds (bindingsBinds ef0) - &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) - &. #ftape (auto1 @ef_tape) - &. #bogelt (bogEltTy `SCons` SNil) - &. #propr (d1e provars) - &. #d1env (desD1E des) - &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace provars) - &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) - wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in - subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> - subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> - Ret (bconcat bindsx₀a proPrimalBinds' - `bpush` weakenExpr wOverPrimalBindings ex₀1 - `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) - `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 - `bpush` EFold1InnerD1 ext commut - (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in - weakenExpr (autoWeak library (#xy :++: #d1env) layout) - (letBinds ef0 $ - EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) - ef1 - (EPair ext - (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) - ef1tape))) - (EVar ext (d1 eltty) (IS (IS IZ))) - (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) - (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) - (EFst ext (EVar ext primalTy IZ)) - subx₀af - (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in - elet - (wrapAccum (autoWeak library #propr layout1) $ - let layout2 = #d2acPro :++: layout1 in - EFold1InnerD2 ext commut - (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ - let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in - expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ - weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) - (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) - (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) - (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) - (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ - plus_x₀a_f - (plus_x₀_a - (elet (EIdx0 ext - (EFold1Inner ext Commut - (let t = STPair (d2 eltty) (d2 eltty) - in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) - (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) - (eflatten (EFst ext (EFst ext (evar IZ)))))) $ - weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) - ex₀2) - (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ - subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) - (ESnd ext (evar IZ))) - - EUnit _ e - | SpArr sdElt <- sd - , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> - Ret e0 - subtape - (EUnit ext e1) - sub - (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ - weakenExpr (WCopy WSink) e2) - - EReplicate1Inner _ en e - -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. - | SpArr sdElt <- sd - , let STArr ndim eltty = typeOf e -> - -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. - sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> - case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> - Ret binds - subtape - (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) - sub - (ELet ext (EFold1Inner ext Commut - (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) - in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) - (inj2 (ENil ext)) - (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ - weakenExpr (WCopy WSink) e2) - } - - EIdx0 _ e - | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e - , STArr _ t <- typeOf e -> - Ret e0 - subtape - (EIdx0 ext e1) - sub - (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ - weakenExpr (WCopy WSink) e2) - - EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" - {- - EIdx1 _ e ei - -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. - | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) - <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil - , STArr (SS n) eltty <- typeOf e -> - Ret (binds `bpush` e1 - `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) - (SEYesR (SENo subtape)) - (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) - (weakenExpr (WSink .> WSink) ei1)) - sub - (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - -} - - EIdx _ e ei - -- We're allowed to differentiate ei as primal because its output is discrete. - | STArr n eltty <- typeOf e - , Refl <- indexTupD1Id n - , let tIxN = tTup (sreplicate n tIx) -> - sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> - case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> - Ret (binds `bpush` e1 - `bpush` EShape ext (EVar ext (typeOf e1) IZ) - `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) - (SEYesR (SEYesR (SENo subtape))) - (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) - (EVar ext (tTup (sreplicate n tIx)) IZ)) - sub - (ELet ext - (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) - (SAPArrIdx SAPHere) - (EPair ext - (EPair ext (EVar ext tIxN (IS IZ)) - (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ - makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) - (ENil ext)) - (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - } - - EShape _ e - -- Allowed to differentiate e as primal because the output of EShape is - -- discrete, hence we'd be passing a zero cotangent to e anyway. - | STArr n _ <- typeOf e - , Refl <- indexTupD1Id n -> - Ret BTop - SETop - (EShape ext (drevPrimal des e)) - (subenvNone (d2eM (select SMerge des))) - (ENil ext) - - ESum1Inner _ e - | SpArr sd' <- sd - , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e - , STArr (SS n) t <- typeOf e -> - Ret (e0 `bpush` e1 - `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) - (SEYesR (SENo subtape)) - (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) - sub - (ELet ext (EReplicate1Inner ext - (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) - (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - - EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e - EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e - - EReshape _ n esh e - | SpArr sd' <- sd - , STArr orign t <- typeOf e - , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e - , Refl <- indexTupD1Id n -> - Ret (e0 `bpush` e1 - `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) - (SEYesR (SENo subtape)) - (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) - (EVar ext (STArr orign (d1 t)) (IS IZ))) - sub - (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) - (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e2) - - EZip _ a b - | SpArr sd' <- sd - , STArr n t1 <- typeOf a - , STArr _ t2 <- typeOf b -> - splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> - case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` - toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of - { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> - subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> - Ret binds - subtape - (EZip ext a1 b1) - subBoth - (case pairSplitE of - Left Refl -> - let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in - plus_A_B - (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) - (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) - Right f -> f IZ $ \wrapPair pick1 pick2 -> - elet (emap (wrapPair (EPair ext pick1 pick2)) - (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ - plus_A_B - (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) - (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) - } - - ENothing{} -> err_unsupported "ENothing" - EJust{} -> err_unsupported "EJust" - EMaybe{} -> err_unsupported "EMaybe" - ELNil{} -> err_unsupported "ELNil" - ELInl{} -> err_unsupported "ELInl" - ELInr{} -> err_unsupported "ELInr" - ELCase{} -> err_unsupported "ELCase" - - EWith{} -> err_accum - EZero{} -> err_monoid - EDeepZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" - EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" - - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_unsupported s = error $ "CHAD: unsupported " ++ s - err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" - - contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) - contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) - -deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) - => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) - -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> Sparse (D2s t) sd - -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) -deriv_extremum extremum des accumMap sd e - | at@(STArr (SS n) t@(STScal st)) <- typeOf e - , let at' = STArr n t - , let tIxN = tTup (sreplicate (SS n) tIx) = - sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> - case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> - Ret (e0 `bpush` e1 - `bpush` extremum (EVar ext at IZ)) - (SEYesR (SEYesR subtape)) - (EVar ext at' IZ) - sub - (ELet ext - (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ - eif (EOp ext (OEq st) (EPair ext - (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) - (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) - (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) - (inj2 (ENil ext))) $ - weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) - } - -data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) - -data RetScoped env0 sto a s sd t = - forall shbinds tapebinds contribs sa. - RetScoped - (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds - (Subenv (Append shbinds '[D1 a]) tapebinds) - (Ex (Append shbinds (D1E (a : env0))) (D1 t)) - (SubenvS (D2E (Select env0 sto "merge")) contribs) - -- ^ merge contributions to the _enclosing_ merge environment - (Sparse (D2 a) sa) - -- ^ contribution to the argument - (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) - (If (s == "discr") (Tup contribs) - (TPair (Tup contribs) sa))) - -- ^ the merge contributions, plus the cotangent to the argument - -- (if there is any) -deriving instance Show (RetScoped env0 sto a s sd t) - -drevScoped :: forall a s env sto sd t. - (?config :: CHADConfig) - => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) - -> STy a -> Storage s -> Maybe (ValId a) - -> Sparse (D2 t) sd - -> Expr ValId (a : env) t - -> RetScoped env sto a s sd t -drevScoped des accumMap argty argsto argids sd expr = case argsto of - SMerge - | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr - , Refl <- lemAppendNil @tapebinds -> - case sub of - SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 - SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) - - SAccum - | chcSmartWith ?config - , Just (VIArr i _) <- argids - , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap - , Just Refl <- testEquality foundTy (STAccum (d2M argty)) - , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr - , Refl <- lemAppendNil @tapebinds -> - -- Our contribution to the binding's cotangent _here_ is zero (absent), - -- because we're contributing to an earlier binding of the same value - -- instead. - RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ - let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in - ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ - weakenExpr (autoWeak (#d (auto1 @sd) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des))) - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: #body :++: #tl)) - (EPair ext e2 (ENil ext)) - - | let accumMap' = case argids of - Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) - _ -> VarMap.sink1 accumMap - , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> - let library = #d (auto1 @sd) - &. #p (auto1 @(D1 a)) - &. #body (subList (bindingsBinds e0) subtape) - &. #ac (auto1 @(TAccum (D2 a))) - &. #tl (d2ace (select SAccum des)) - in - RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ - let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in - EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ - weakenExpr (autoWeak library - (#d :++: #body :++: #ac :++: #tl) - (#ac :++: #d :++: (#body :++: #p) :++: #tl)) - e2 - - SDiscr - | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr - , Refl <- lemAppendNil @tapebinds -> - RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 - -drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) - => Descr env sto - -> VarMap Int (D2AcE (Select env sto "accum")) - -> (STy a, Storage s) - -> Sparse (D2 t) dt - -> Expr ValId (a : env) t - -> (forall provars shbinds tape d2a'. - SList STy provars - -> Subenv (D2E (Select env sto "merge")) (D2E provars) - -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) - -> Bindings Ex (D1 a : D1E env) shbinds - -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) - -> Ex (Append shbinds (D1 a : D1E env)) tape - -> Sparse (D2 a) d2a' - -> (forall env' b. - D1E provars :> env' - -> Ex (Append (D2AcE provars) env') b - -> Ex ( env') (TPair b (Tup (D2E provars)))) - -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' - -> r) - -> r -drevLambda des accumMap (argty, argsto) sd origef k = - let t = typeOf origef in - deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> - let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in - subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> - accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> - let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in - let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in - let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in - case prf1 prodes argty argsto of { Refl -> - case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> - let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in - extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> - let library = #fbinds (bindingsBinds ef0) - &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) - &. #ftape (auto1 @(Tape e_tape)) - &. #arg (d1 argty `SCons` SNil) - &. #d (applySparse sd (d2 t) `SCons` SNil) - &. #d1env (desD1E des) - &. #d1env' (desD1E usedDes) - &. #propr (d1e envPro) - &. #d2acUsed (d2ace (select SAccum usedDes)) - &. #d2acEnv (d2ace (select SAccum des)) - &. #d2acPro (d2ace envPro) - &. #efPrerebinds efPrerebinds in - k envPro - (subenvD2E (subenvCompose subMergeUsed proSub)) - mergePrimalBindings - (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) - (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) - (#fbinds :++: #arg :++: #d1env)) - ef1) - (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) - argSp - (\wpro1 body -> - uninvertTup (d2e envPro) (typeOf body) $ - makeAccumulators wpro1 envPro $ - body) - (letBinds (efRebinds IZ) $ - weakenExpr - (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) - ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) - .> wPro (subList (bindingsBinds ef0) subtapeEf)) - (getSparseArg ef2)) - }} - where - extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) - => proxy env sto -> proxy2 a -> Storage s - -- if s == "merge", this simplifies to SubenvS '[D2 a] t' - -- if s == "discr", this simplifies to SubenvS '[] t' - -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' - -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r - extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id - extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) - extractContrib _ _ SDiscr SETop k' = k' SpAbsent id - - prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s - -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" - prf1 _ _ SMerge = Refl - prf1 _ _ SDiscr = Refl - --- TODO: proper primal-only transform that doesn't depend on D1 = Id -drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) -drevPrimal des e - | Refl <- d1Identity (typeOf e) - , Refl <- d1eIdentity (descrList des) - = mapExt (const ext) e diff --git a/src/CHAD/AST.hs b/src/CHAD/AST.hs new file mode 100644 index 0000000..aa6aa96 --- /dev/null +++ b/src/CHAD/AST.hs @@ -0,0 +1,705 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFoldable #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST (module CHAD.AST, module CHAD.AST.Types, module CHAD.AST.Accum, module CHAD.AST.Weaken) where + +import Data.Functor.Const +import Data.Functor.Identity +import Data.Int (Int64) +import Data.Kind (Type) + +import CHAD.Array +import CHAD.AST.Accum +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- General assumption: head of the list (whatever way it is associated) is the +-- inner variable / inner array dimension. In pretty printing, the inner +-- variable / inner dimension is printed on the _right_. +-- +-- All the monoid operations are unsupposed as the input to CHAD, and are +-- intended to be eliminated after simplification, so that the input program as +-- well as the output program do not contain these constructors. +-- TODO: ensure this by a "stage" type parameter. +type Expr :: (Ty -> Type) -> [Ty] -> Ty -> Type +data Expr x env t where + -- lambda calculus + EVar :: x t -> STy t -> Idx env t -> Expr x env t + ELet :: x t -> Expr x env a -> Expr x (a : env) t -> Expr x env t + + -- base types + EPair :: x (TPair a b) -> Expr x env a -> Expr x env b -> Expr x env (TPair a b) + EFst :: x a -> Expr x env (TPair a b) -> Expr x env a + ESnd :: x b -> Expr x env (TPair a b) -> Expr x env b + ENil :: x TNil -> Expr x env TNil + EInl :: x (TEither a b) -> STy b -> Expr x env a -> Expr x env (TEither a b) + EInr :: x (TEither a b) -> STy a -> Expr x env b -> Expr x env (TEither a b) + ECase :: x c -> Expr x env (TEither a b) -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + ENothing :: x (TMaybe t) -> STy t -> Expr x env (TMaybe t) + EJust :: x (TMaybe t) -> Expr x env t -> Expr x env (TMaybe t) + EMaybe :: x b -> Expr x env b -> Expr x (t : env) b -> Expr x env (TMaybe t) -> Expr x env b + + -- array operations + EConstArr :: Show (ScalRep t) => x (TArr n (TScal t)) -> SNat n -> SScalTy t -> Array n (ScalRep t) -> Expr x env (TArr n (TScal t)) + EBuild :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x (Tup (Replicate n TIx) : env) t -> Expr x env (TArr n t) + EMap :: x (TArr n t) -> Expr x (a : env) t -> Expr x env (TArr n a) -> Expr x env (TArr n t) + -- bottommost t in 't : t : env' is the rightmost argument (environments grow to the right) + EFold1Inner :: x (TArr n t) -> Commutative -> Expr x (TPair t t : env) t -> Expr x env t -> Expr x env (TArr (S n) t) -> Expr x env (TArr n t) + ESum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EUnit :: x (TArr Z t) -> Expr x env t -> Expr x env (TArr Z t) + EReplicate1Inner :: x (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) -> Expr x env (TArr (S n) t) + EMaximum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EMinimum1Inner :: ScalIsNumeric t ~ True => x (TArr n (TScal t)) -> Expr x env (TArr (S n) (TScal t)) -> Expr x env (TArr n (TScal t)) + EReshape :: x (TArr n t) -> SNat n -> Expr x env (Tup (Replicate n TIx)) -> Expr x env (TArr m t) -> Expr x env (TArr n t) + EZip :: x (TArr n (TPair a b)) -> Expr x env (TArr n a) -> Expr x env (TArr n b) -> Expr x env (TArr n (TPair a b)) + + -- Primal of EFold1Inner. Looks like a mapAccumL, but differs semantically: + -- an implementation is allowed to parallelise this thing and store the b + -- values in some implementation-defined order. + -- TODO: For a parallel implementation some data will probably need to be stored about the reduction order in addition to simply the array of bs. + EFold1InnerD1 :: x (TPair (TArr n t1) (TArr (S n) b)) -> Commutative + -> Expr x (TPair t1 t1 : env) (TPair t1 b) + -> Expr x env t1 + -> Expr x env (TArr (S n) t1) + -> Expr x env (TPair (TArr n t1) -- normal primal fold output + (TArr (S n) b)) -- additional stores; usually: (prescanl, the tape stores) + -- Reverse derivative of EFold1Inner. The contributions to the initial + -- element are not yet added together here; we assume a later fusion system + -- does that for us. + EFold1InnerD2 :: x (TPair (TArr n t2) (TArr (S n) t2)) -> Commutative + -> Expr x (t2 : b : env) (TPair t2 t2) -- reverse derivative of function (should contribute to free variables via accumulation) + -> Expr x env (TArr (S n) b) -- stores from EFold1InnerD1 + -> Expr x env (TArr n t2) -- incoming cotangent + -> Expr x env (TPair (TArr n t2) (TArr (S n) t2)) -- outgoing cotangents to x0 (not summed) and input array + + -- expression operations + EConst :: Show (ScalRep t) => x (TScal t) -> SScalTy t -> ScalRep t -> Expr x env (TScal t) + EIdx0 :: x t -> Expr x env (TArr Z t) -> Expr x env t + EIdx1 :: x (TArr n t) -> Expr x env (TArr (S n) t) -> Expr x env TIx -> Expr x env (TArr n t) + EIdx :: x t -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) -> Expr x env t + EShape :: x (Tup (Replicate n TIx)) -> Expr x env (TArr n t) -> Expr x env (Tup (Replicate n TIx)) + EOp :: x t -> SOp a t -> Expr x env a -> Expr x env t + + -- custom derivatives + -- 'b' is the part of the input of the operation that derivatives should + -- be backpropagated to; 'a' is the inactive part. The dual field of + -- ECustom does not allow a derivative to be generated for 'a', and hence + -- none is propagated. + -- No accumulators are allowed inside a, b and tape. This restriction is + -- currently not used very much, so could be relaxed in the future; be sure + -- to check this requirement whenever it is necessary for soundness! + ECustom :: x t -> STy a -> STy b -> STy tape + -> Expr x [b, a] t -- ^ regular operation + -> Expr x [D1 b, D1 a] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Expr x [D2 t, tape] (D2 b) -- ^ CHAD reverse derivative + -> Expr x env a -> Expr x env b + -> Expr x env t + + -- fake halfway checkpointing + ERecompute :: x t -> Expr x env t -> Expr x env t + + -- accumulation effect on monoids + -- | The initialiser for an accumulator __MUST__ be deep! If it is zero, it + -- must be EDeepZero, not just EZero. This is to ensure that EAccum does not + -- need to create any zeros. + EWith :: x (TPair a t) -> SMTy t -> Expr x env t -> Expr x (TAccum t : env) a -> Expr x env (TPair a t) + -- The 'Sparse' here is eliminated to dense by UnMonoid. + EAccum :: x TNil -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxD p t) -> Sparse a b -> Expr x env b -> Expr x env (TAccum t) -> Expr x env TNil + + -- monoidal operations (to be desugared to regular operations after simplification) + EZero :: x t -> SMTy t -> Expr x env (ZeroInfo t) -> Expr x env t + EDeepZero :: x t -> SMTy t -> Expr x env (DeepZeroInfo t) -> Expr x env t + EPlus :: x t -> SMTy t -> Expr x env t -> Expr x env t -> Expr x env t + EOneHot :: x t -> SMTy t -> SAcPrj p t a -> Expr x env (AcIdxS p t) -> Expr x env a -> Expr x env t + + -- interface of abstract monoidal types + ELNil :: x (TLEither a b) -> STy a -> STy b -> Expr x env (TLEither a b) + ELInl :: x (TLEither a b) -> STy b -> Expr x env a -> Expr x env (TLEither a b) + ELInr :: x (TLEither a b) -> STy a -> Expr x env b -> Expr x env (TLEither a b) + ELCase :: x c -> Expr x env (TLEither a b) -> Expr x env c -> Expr x (a : env) c -> Expr x (b : env) c -> Expr x env c + + -- partiality + EError :: x a -> STy a -> String -> Expr x env a +deriving instance (forall ty. Show (x ty)) => Show (Expr x env t) + +type Ex = Expr (Const ()) + +ext :: Const () a +ext = Const () + +data Commutative = Commut | Noncommut + deriving (Show) + +type SOp :: Ty -> Ty -> Type +data SOp a t where + OAdd :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMul :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + ONeg :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLt :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OLe :: ScalIsNumeric a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + OEq :: SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal TBool) + ONot :: SOp (TScal TBool) (TScal TBool) + OAnd :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OOr :: SOp (TPair (TScal TBool) (TScal TBool)) (TScal TBool) + OIf :: SOp (TScal TBool) (TEither TNil TNil) -- True is Left, False is Right + ORound64 :: SOp (TScal TF64) (TScal TI64) + OToFl64 :: SOp (TScal TI64) (TScal TF64) + ORecip :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OExp :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OLog :: ScalIsFloating a ~ True => SScalTy a -> SOp (TScal a) (TScal a) + OIDiv :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) + OMod :: ScalIsIntegral a ~ True => SScalTy a -> SOp (TPair (TScal a) (TScal a)) (TScal a) +deriving instance Show (SOp a t) + +opt1 :: SOp a t -> STy a +opt1 = \case + OAdd t -> STPair (STScal t) (STScal t) + OMul t -> STPair (STScal t) (STScal t) + ONeg t -> STScal t + OLt t -> STPair (STScal t) (STScal t) + OLe t -> STPair (STScal t) (STScal t) + OEq t -> STPair (STScal t) (STScal t) + ONot -> STScal STBool + OAnd -> STPair (STScal STBool) (STScal STBool) + OOr -> STPair (STScal STBool) (STScal STBool) + OIf -> STScal STBool + ORound64 -> STScal STF64 + OToFl64 -> STScal STI64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STPair (STScal t) (STScal t) + OMod t -> STPair (STScal t) (STScal t) + +opt2 :: SOp a t -> STy t +opt2 = \case + OAdd t -> STScal t + OMul t -> STScal t + ONeg t -> STScal t + OLt _ -> STScal STBool + OLe _ -> STScal STBool + OEq _ -> STScal STBool + ONot -> STScal STBool + OAnd -> STScal STBool + OOr -> STScal STBool + OIf -> STEither STNil STNil + ORound64 -> STScal STI64 + OToFl64 -> STScal STF64 + ORecip t -> STScal t + OExp t -> STScal t + OLog t -> STScal t + OIDiv t -> STScal t + OMod t -> STScal t + +typeOf :: Expr x env t -> STy t +typeOf = \case + EVar _ t _ -> t + ELet _ _ e -> typeOf e + + EPair _ a b -> STPair (typeOf a) (typeOf b) + EFst _ e | STPair t _ <- typeOf e -> t + ESnd _ e | STPair _ t <- typeOf e -> t + ENil _ -> STNil + EInl _ t2 e -> STEither (typeOf e) t2 + EInr _ t1 e -> STEither t1 (typeOf e) + ECase _ _ a _ -> typeOf a + ENothing _ t -> STMaybe t + EJust _ e -> STMaybe (typeOf e) + EMaybe _ e _ _ -> typeOf e + ELNil _ t1 t2 -> STLEither t1 t2 + ELInl _ t2 e -> STLEither (typeOf e) t2 + ELInr _ t1 e -> STLEither t1 (typeOf e) + ELCase _ _ a _ _ -> typeOf a + + EConstArr _ n t _ -> STArr n (STScal t) + EBuild _ n _ e -> STArr n (typeOf e) + EMap _ a b | STArr n _ <- typeOf b -> STArr n (typeOf a) + EFold1Inner _ _ _ _ e | STArr (SS n) t <- typeOf e -> STArr n t + ESum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EUnit _ e -> STArr SZ (typeOf e) + EReplicate1Inner _ _ e | STArr n t <- typeOf e -> STArr (SS n) t + EMaximum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EMinimum1Inner _ e | STArr (SS n) t <- typeOf e -> STArr n t + EReshape _ n _ e | STArr _ t <- typeOf e -> STArr n t + EZip _ a b | STArr n t1 <- typeOf a, STArr _ t2 <- typeOf b -> STArr n (STPair t1 t2) + + EFold1InnerD1 _ _ e1 _ e3 | STPair t1 tb <- typeOf e1, STArr (SS n) _ <- typeOf e3 -> STPair (STArr n t1) (STArr (SS n) tb) + EFold1InnerD2 _ _ _ _ e3 | STArr n t2 <- typeOf e3 -> STPair (STArr n t2) (STArr (SS n) t2) + + EConst _ t _ -> STScal t + EIdx0 _ e | STArr _ t <- typeOf e -> t + EIdx1 _ e _ | STArr (SS n) t <- typeOf e -> STArr n t + EIdx _ e _ | STArr _ t <- typeOf e -> t + EShape _ e | STArr n _ <- typeOf e -> tTup (sreplicate n tIx) + EOp _ op _ -> opt2 op + + ECustom _ _ _ _ e _ _ _ _ -> typeOf e + ERecompute _ e -> typeOf e + + EWith _ _ e1 e2 -> STPair (typeOf e2) (typeOf e1) + EAccum _ _ _ _ _ _ _ -> STNil + + EZero _ t _ -> fromSMTy t + EDeepZero _ t _ -> fromSMTy t + EPlus _ t _ _ -> fromSMTy t + EOneHot _ t _ _ _ -> fromSMTy t + + EError _ t _ -> t + +extOf :: Expr x env t -> x t +extOf = \case + EVar x _ _ -> x + ELet x _ _ -> x + EPair x _ _ -> x + EFst x _ -> x + ESnd x _ -> x + ENil x -> x + EInl x _ _ -> x + EInr x _ _ -> x + ECase x _ _ _ -> x + ENothing x _ -> x + EJust x _ -> x + EMaybe x _ _ _ -> x + ELNil x _ _ -> x + ELInl x _ _ -> x + ELInr x _ _ -> x + ELCase x _ _ _ _ -> x + EConstArr x _ _ _ -> x + EBuild x _ _ _ -> x + EMap x _ _ -> x + EFold1Inner x _ _ _ _ -> x + ESum1Inner x _ -> x + EUnit x _ -> x + EReplicate1Inner x _ _ -> x + EMaximum1Inner x _ -> x + EMinimum1Inner x _ -> x + EReshape x _ _ _ -> x + EZip x _ _ -> x + EFold1InnerD1 x _ _ _ _ -> x + EFold1InnerD2 x _ _ _ _ -> x + EConst x _ _ -> x + EIdx0 x _ -> x + EIdx1 x _ _ -> x + EIdx x _ _ -> x + EShape x _ -> x + EOp x _ _ -> x + ECustom x _ _ _ _ _ _ _ _ -> x + ERecompute x _ -> x + EWith x _ _ _ -> x + EAccum x _ _ _ _ _ _ -> x + EZero x _ _ -> x + EDeepZero x _ _ -> x + EPlus x _ _ _ -> x + EOneHot x _ _ _ _ -> x + EError x _ _ -> x + +mapExt :: (forall a. x a -> x' a) -> Expr x env t -> Expr x' env t +mapExt f = runIdentity . travExt (Identity . f) + +{-# SPECIALIZE travExt :: (forall a. x a -> Identity (x' a)) -> Expr x env t -> Identity (Expr x' env t) #-} +travExt :: Applicative f => (forall a. x a -> f (x' a)) -> Expr x env t -> f (Expr x' env t) +travExt f = \case + EVar x t i -> EVar <$> f x <*> pure t <*> pure i + ELet x rhs body -> ELet <$> f x <*> travExt f rhs <*> travExt f body + EPair x a b -> EPair <$> f x <*> travExt f a <*> travExt f b + EFst x e -> EFst <$> f x <*> travExt f e + ESnd x e -> ESnd <$> f x <*> travExt f e + ENil x -> ENil <$> f x + EInl x t e -> EInl <$> f x <*> pure t <*> travExt f e + EInr x t e -> EInr <$> f x <*> pure t <*> travExt f e + ECase x e a b -> ECase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b + ENothing x t -> ENothing <$> f x <*> pure t + EJust x e -> EJust <$> f x <*> travExt f e + EMaybe x a b e -> EMaybe <$> f x <*> travExt f a <*> travExt f b <*> travExt f e + ELNil x t1 t2 -> ELNil <$> f x <*> pure t1 <*> pure t2 + ELInl x t e -> ELInl <$> f x <*> pure t <*> travExt f e + ELInr x t e -> ELInr <$> f x <*> pure t <*> travExt f e + ELCase x e a b c -> ELCase <$> f x <*> travExt f e <*> travExt f a <*> travExt f b <*> travExt f c + EConstArr x n t a -> EConstArr <$> f x <*> pure n <*> pure t <*> pure a + EBuild x n a b -> EBuild <$> f x <*> pure n <*> travExt f a <*> travExt f b + EMap x a b -> EMap <$> f x <*> travExt f a <*> travExt f b + EFold1Inner x cm a b c -> EFold1Inner <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + ESum1Inner x e -> ESum1Inner <$> f x <*> travExt f e + EUnit x e -> EUnit <$> f x <*> travExt f e + EReplicate1Inner x a b -> EReplicate1Inner <$> f x <*> travExt f a <*> travExt f b + EMaximum1Inner x e -> EMaximum1Inner <$> f x <*> travExt f e + EMinimum1Inner x e -> EMinimum1Inner <$> f x <*> travExt f e + EZip x a b -> EZip <$> f x <*> travExt f a <*> travExt f b + EReshape x n a b -> EReshape <$> f x <*> pure n <*> travExt f a <*> travExt f b + EFold1InnerD1 x cm a b c -> EFold1InnerD1 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EFold1InnerD2 x cm a b c -> EFold1InnerD2 <$> f x <*> pure cm <*> travExt f a <*> travExt f b <*> travExt f c + EConst x t v -> EConst <$> f x <*> pure t <*> pure v + EIdx0 x e -> EIdx0 <$> f x <*> travExt f e + EIdx1 x a b -> EIdx1 <$> f x <*> travExt f a <*> travExt f b + EIdx x e es -> EIdx <$> f x <*> travExt f e <*> travExt f es + EShape x e -> EShape <$> f x <*> travExt f e + EOp x op e -> EOp <$> f x <*> pure op <*> travExt f e + ECustom x s t p a b c e1 e2 -> ECustom <$> f x <*> pure s <*> pure t <*> pure p <*> travExt f a <*> travExt f b <*> travExt f c <*> travExt f e1 <*> travExt f e2 + ERecompute x e -> ERecompute <$> f x <*> travExt f e + EWith x t e1 e2 -> EWith <$> f x <*> pure t <*> travExt f e1 <*> travExt f e2 + EAccum x t p e1 sp e2 e3 -> EAccum <$> f x <*> pure t <*> pure p <*> travExt f e1 <*> pure sp <*> travExt f e2 <*> travExt f e3 + EZero x t e -> EZero <$> f x <*> pure t <*> travExt f e + EDeepZero x t e -> EDeepZero <$> f x <*> pure t <*> travExt f e + EPlus x t a b -> EPlus <$> f x <*> pure t <*> travExt f a <*> travExt f b + EOneHot x t p a b -> EOneHot <$> f x <*> pure t <*> pure p <*> travExt f a <*> travExt f b + EError x t s -> EError <$> f x <*> pure t <*> pure s + +substInline :: Expr x env a -> Expr x (a : env) t -> Expr x env t +substInline repl = + subst $ \x t -> \case IZ -> repl + IS i -> EVar x t i + +subst0 :: Ex (b : env) a -> Ex (a : env) t -> Ex (b : env) t +subst0 repl = + subst $ \_ t -> \case IZ -> repl + IS i -> EVar ext t (IS i) + +subst :: (forall a. x a -> STy a -> Idx env a -> Expr x env' a) + -> Expr x env t -> Expr x env' t +subst f = subst' (\x t w i -> weakenExpr w (f x t i)) WId + +subst' :: (forall a env2. x a -> STy a -> env' :> env2 -> Idx env a -> Expr x env2 a) + -> env' :> envOut + -> Expr x env t + -> Expr x envOut t +subst' f w = \case + EVar x t i -> f x t w i + ELet x rhs body -> ELet x (subst' f w rhs) (subst' (sinkF f) (WCopy w) body) + EPair x a b -> EPair x (subst' f w a) (subst' f w b) + EFst x e -> EFst x (subst' f w e) + ESnd x e -> ESnd x (subst' f w e) + ENil x -> ENil x + EInl x t e -> EInl x t (subst' f w e) + EInr x t e -> EInr x t (subst' f w e) + ECase x e a b -> ECase x (subst' f w e) (subst' (sinkF f) (WCopy w) a) (subst' (sinkF f) (WCopy w) b) + ENothing x t -> ENothing x t + EJust x e -> EJust x (subst' f w e) + EMaybe x a b e -> EMaybe x (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' f w e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (subst' f w e) + ELInr x t e -> ELInr x t (subst' f w e) + ELCase x e a b c -> ELCase x (subst' f w e) (subst' f w a) (subst' (sinkF f) (WCopy w) b) (subst' (sinkF f) (WCopy w) c) + EConstArr x n t a -> EConstArr x n t a + EBuild x n a b -> EBuild x n (subst' f w a) (subst' (sinkF f) (WCopy w) b) + EMap x a b -> EMap x (subst' (sinkF f) (WCopy w) a) (subst' f w b) + EFold1Inner x cm a b c -> EFold1Inner x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + ESum1Inner x e -> ESum1Inner x (subst' f w e) + EUnit x e -> EUnit x (subst' f w e) + EReplicate1Inner x a b -> EReplicate1Inner x (subst' f w a) (subst' f w b) + EMaximum1Inner x e -> EMaximum1Inner x (subst' f w e) + EMinimum1Inner x e -> EMinimum1Inner x (subst' f w e) + EReshape x n a b -> EReshape x n (subst' f w a) (subst' f w b) + EZip x a b -> EZip x (subst' f w a) (subst' f w b) + EFold1InnerD1 x cm a b c -> EFold1InnerD1 x cm (subst' (sinkF f) (WCopy w) a) (subst' f w b) (subst' f w c) + EFold1InnerD2 x cm a b c -> EFold1InnerD2 x cm (subst' (sinkF (sinkF f)) (WCopy (WCopy w)) a) (subst' f w b) (subst' f w c) + EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (subst' f w e) + EIdx1 x a b -> EIdx1 x (subst' f w a) (subst' f w b) + EIdx x e es -> EIdx x (subst' f w e) (subst' f w es) + EShape x e -> EShape x (subst' f w e) + EOp x op e -> EOp x op (subst' f w e) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (subst' f w e1) (subst' f w e2) + ERecompute x e -> ERecompute x (subst' f w e) + EWith x t e1 e2 -> EWith x t (subst' f w e1) (subst' (sinkF f) (WCopy w) e2) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (subst' f w e1) sp (subst' f w e2) (subst' f w e3) + EZero x t e -> EZero x t (subst' f w e) + EDeepZero x t e -> EDeepZero x t (subst' f w e) + EPlus x t a b -> EPlus x t (subst' f w a) (subst' f w b) + EOneHot x t p a b -> EOneHot x t p (subst' f w a) (subst' f w b) + EError x t s -> EError x t s + where + sinkF :: (forall a. x a -> STy a -> (env' :> env2) -> Idx env a -> Expr x env2 a) + -> x t -> STy t -> ((b : env') :> env2) -> Idx (b : env) t -> Expr x env2 t + sinkF f' x' t w' = \case + IZ -> EVar x' t (w' @> IZ) + IS i -> f' x' t (WPop w') i + +weakenExpr :: env :> env' -> Expr x env t -> Expr x env' t +weakenExpr = subst' (\x t w' i -> EVar x t (w' @> i)) + +class KnownScalTy t where knownScalTy :: SScalTy t +instance KnownScalTy TI32 where knownScalTy = STI32 +instance KnownScalTy TI64 where knownScalTy = STI64 +instance KnownScalTy TF32 where knownScalTy = STF32 +instance KnownScalTy TF64 where knownScalTy = STF64 +instance KnownScalTy TBool where knownScalTy = STBool + +class KnownTy t where knownTy :: STy t +instance KnownTy TNil where knownTy = STNil +instance (KnownTy s, KnownTy t) => KnownTy (TPair s t) where knownTy = STPair knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TEither s t) where knownTy = STEither knownTy knownTy +instance (KnownTy s, KnownTy t) => KnownTy (TLEither s t) where knownTy = STLEither knownTy knownTy +instance KnownTy t => KnownTy (TMaybe t) where knownTy = STMaybe knownTy +instance (KnownNat n, KnownTy t) => KnownTy (TArr n t) where knownTy = STArr knownNat knownTy +instance KnownScalTy t => KnownTy (TScal t) where knownTy = STScal knownScalTy +instance KnownMTy t => KnownTy (TAccum t) where knownTy = STAccum knownMTy + +class KnownMTy t where knownMTy :: SMTy t +instance KnownMTy TNil where knownMTy = SMTNil +instance (KnownMTy s, KnownMTy t) => KnownMTy (TPair s t) where knownMTy = SMTPair knownMTy knownMTy +instance KnownMTy t => KnownMTy (TMaybe t) where knownMTy = SMTMaybe knownMTy +instance (KnownMTy s, KnownMTy t) => KnownMTy (TLEither s t) where knownMTy = SMTLEither knownMTy knownMTy +instance (KnownNat n, KnownMTy t) => KnownMTy (TArr n t) where knownMTy = SMTArr knownNat knownMTy +instance (KnownScalTy t, ScalIsNumeric t ~ True) => KnownMTy (TScal t) where knownMTy = SMTScal knownScalTy + +class KnownEnv env where knownEnv :: SList STy env +instance KnownEnv '[] where knownEnv = SNil +instance (KnownTy t, KnownEnv env) => KnownEnv (t : env) where knownEnv = SCons knownTy knownEnv + +styKnown :: STy t -> Dict (KnownTy t) +styKnown STNil = Dict +styKnown (STPair a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STLEither a b) | Dict <- styKnown a, Dict <- styKnown b = Dict +styKnown (STMaybe t) | Dict <- styKnown t = Dict +styKnown (STArr n t) | Dict <- snatKnown n, Dict <- styKnown t = Dict +styKnown (STScal t) | Dict <- sscaltyKnown t = Dict +styKnown (STAccum t) | Dict <- smtyKnown t = Dict + +smtyKnown :: SMTy t -> Dict (KnownMTy t) +smtyKnown SMTNil = Dict +smtyKnown (SMTPair a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTLEither a b) | Dict <- smtyKnown a, Dict <- smtyKnown b = Dict +smtyKnown (SMTMaybe t) | Dict <- smtyKnown t = Dict +smtyKnown (SMTArr n t) | Dict <- snatKnown n, Dict <- smtyKnown t = Dict +smtyKnown (SMTScal t) | Dict <- sscaltyKnown t = Dict + +sscaltyKnown :: SScalTy t -> Dict (KnownScalTy t) +sscaltyKnown STI32 = Dict +sscaltyKnown STI64 = Dict +sscaltyKnown STF32 = Dict +sscaltyKnown STF64 = Dict +sscaltyKnown STBool = Dict + +envKnown :: SList STy env -> Dict (KnownEnv env) +envKnown SNil = Dict +envKnown (t `SCons` env) | Dict <- styKnown t, Dict <- envKnown env = Dict + +cheapExpr :: Expr x env t -> Bool +cheapExpr = \case + EVar{} -> True + ENil{} -> True + EConst{} -> True + EFst _ e -> cheapExpr e + ESnd _ e -> cheapExpr e + EUnit _ e -> cheapExpr e + _ -> False + +eTup :: SList (Ex env) list -> Ex env (Tup list) +eTup = mkTup (ENil ext) (EPair ext) + +ebuildUp1 :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx -> Ex (TIx : env) (TArr n t) -> Ex env (TArr (S n) t) +ebuildUp1 n sh size f = + EBuild ext (SS n) (EPair ext sh size) $ + let arg = EVar ext (tTup (sreplicate (SS n) tIx)) IZ + in EIdx ext (ELet ext (ESnd ext arg) (weakenExpr (WCopy WSink) f)) + (EFst ext arg) + +eidxEq :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eidxEq SZ _ _ = EConst ext STBool True +eidxEq (SS SZ) a b = + EOp ext (OEq STI64) (EPair ext (ESnd ext a) (ESnd ext b)) +eidxEq (SS n) a b + | let ty = tTup (sreplicate (SS n) tIx) + = ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EOp ext OAnd $ EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext ty (IS IZ))) + (ESnd ext (EVar ext ty IZ)))) + (eidxEq n (EFst ext (EVar ext ty (IS IZ))) + (EFst ext (EVar ext ty IZ))) + +emap :: (KnownTy a => Ex (a : env) b) -> Ex env (TArr n a) -> Ex env (TArr n b) +emap f arr + | STArr _ t <- typeOf arr + , Dict <- styKnown t + = EMap ext f arr + +ezipWith :: ((KnownTy a, KnownTy b) => Ex (b : a : env) c) -> Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n c) +ezipWith f arr1 arr2 + | STArr _ t1 <- typeOf arr1 + , STArr _ t2 <- typeOf arr2 + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = EMap ext (subst (\_ t -> \case IZ -> ESnd ext (EVar ext (STPair t1 t2) IZ) + IS IZ -> EFst ext (EVar ext (STPair t1 t2) IZ) + IS (IS i) -> EVar ext t (IS i)) + f) + (EZip ext arr1 arr2) + +ezip :: Ex env (TArr n a) -> Ex env (TArr n b) -> Ex env (TArr n (TPair a b)) +ezip = EZip ext + +eif :: Ex env (TScal TBool) -> Ex env a -> Ex env a -> Ex env a +eif a b c = ECase ext (EOp ext OIf a) (weakenExpr WSink b) (weakenExpr WSink c) + +-- | Returns whether the shape is all-zero, but returns False for the zero-dimensional shape (because it is _not_ empty). +eshapeEmpty :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env (TScal TBool) +eshapeEmpty SZ _ = EConst ext STBool False +eshapeEmpty (SS SZ) e = EOp ext (OEq STI64) (EPair ext (ESnd ext e) (EConst ext STI64 0)) +eshapeEmpty (SS n) e = + ELet ext e $ + EOp ext OAnd (EPair ext + (EOp ext (OEq STI64) (EPair ext (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)) + (EConst ext STI64 0))) + (eshapeEmpty n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) IZ)))) + +eshapeConst :: Shape n -> Ex env (Tup (Replicate n TIx)) +eshapeConst ShNil = ENil ext +eshapeConst (sh `ShCons` n) = EPair ext (eshapeConst sh) (EConst ext STI64 (fromIntegral @Int @Int64 n)) + +eshapeProd :: SNat n -> Ex env (Tup (Replicate n TIx)) -> Ex env TIx +eshapeProd SZ _ = EConst ext STI64 1 +eshapeProd (SS SZ) e = ESnd ext e +eshapeProd (SS n) e = + eunPair e $ \_ e1 e2 -> + EOp ext (OMul STI64) (EPair ext (eshapeProd n e1) e2) + +eflatten :: Ex env (TArr n t) -> Ex env (TArr N1 t) +eflatten e = + let STArr n _ = typeOf e + in elet e $ + EReshape ext (SS SZ) (EPair ext (ENil ext) (eshapeProd n (EShape ext (evar IZ)))) (evar IZ) + +-- ezeroD2 :: STy t -> Ex env (ZeroInfo (D2 t)) -> Ex env (D2 t) +-- ezeroD2 t ezi = EZero ext (d2M t) ezi + +-- eaccumD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (TAccum (D2 t)) -> Ex env TNil +-- eaccumD2 t p ei ev ea | Refl <- lemZeroInfoD2 t = EAccum ext (d2M t) (ENil ext) p ei ev ea + +-- eonehotD2 :: STy t -> SAcPrj p (D2 t) a -> Ex env (AcIdx p (D2 t)) -> Ex env a -> Ex env (D2 t) +-- eonehotD2 t p ei ev | Refl <- lemZeroInfoD2 t = EOneHot ext (d2M t) (ENil ext) p ei ev + +eunPair :: Ex env (TPair a b) -> (forall env'. env :> env' -> Ex env' a -> Ex env' b -> Ex env' r) -> Ex env r +eunPair (EPair _ e1 e2) k = k WId e1 e2 +eunPair e k | cheapExpr e = k WId (EFst ext e) (ESnd ext e) +eunPair e k = + elet e $ + k WSink + (EFst ext (evar IZ)) + (ESnd ext (evar IZ)) + +efst :: Ex env (TPair a b) -> Ex env a +efst (EPair _ e1 _) = e1 +efst e = EFst ext e + +esnd :: Ex env (TPair a b) -> Ex env b +esnd (EPair _ _ e2) = e2 +esnd e = ESnd ext e + +elet :: Ex env a -> (KnownTy a => Ex (a : env) b) -> Ex env b +elet rhs body + | Dict <- styKnown (typeOf rhs) + = if cheapExpr rhs + then substInline rhs body + else ELet ext rhs body + +-- | Let-bind it but don't use the value (just ensure the expression's effects don't get lost) +use :: Ex env a -> Ex env b -> Ex env b +use a b = elet a $ weakenExpr WSink b + +emaybe :: Ex env (TMaybe a) -> Ex env b -> (KnownTy a => Ex (a : env) b) -> Ex env b +emaybe e a b + | STMaybe t <- typeOf e + , Dict <- styKnown t + = EMaybe ext a b e + +ecase :: Ex env (TEither a b) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +ecase e a b + | STEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ECase ext e a b + +elcase :: Ex env (TLEither a b) -> ((KnownTy a, KnownTy b) => Ex env c) -> ((KnownTy a, KnownTy b) => Ex (a : env) c) -> ((KnownTy a, KnownTy b) => Ex (b : env) c) -> Ex env c +elcase e a b c + | STLEither t1 t2 <- typeOf e + , Dict <- styKnown t1 + , Dict <- styKnown t2 + = ELCase ext e a b c + +evar :: KnownTy a => Idx env a -> Ex env a +evar = EVar ext knownTy + +makeZeroInfo :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) +makeZeroInfo = \ty reference -> ELet ext reference $ go ty (EVar ext (fromSMTy ty) IZ) + where + -- invariant: expression argument is duplicable + go :: SMTy t -> Ex env t -> Ex env (ZeroInfo t) + go SMTNil _ = ENil ext + go (SMTPair t1 t2) e = EPair ext (go t1 (EFst ext e)) (go t2 (ESnd ext e)) + go SMTLEither{} _ = ENil ext + go SMTMaybe{} _ = ENil ext + go (SMTArr _ t) e = emap (go t (EVar ext (fromSMTy t) IZ)) e + go SMTScal{} _ = ENil ext + +splitSparsePair + :: -- given a sparsity + STy (TPair a b) -> Sparse (TPair a b) t' + -> (forall a' b'. + -- I give you back two sparsities for a and b + Sparse a a' -> Sparse b b' + -- furthermore, I tell you that either your t' is already this (a', b') pair... + -> Either + (t' :~: TPair a' b') + -- or I tell you how to construct a' and b' from t', given an actual t' + (forall r' env. + Idx env t' + -> (forall env'. + (forall c. Ex env' c -> Ex env c) + -> Ex env' a' -> Ex env' b' -> r') + -> r') + -> r) + -> r +splitSparsePair _ SpAbsent k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +splitSparsePair _ (SpPair s1 s2) k1 = + k1 s1 s2 $ Left Refl +splitSparsePair t@(STPair t1 t2) (SpSparse s@(SpPair s1 s2)) k = + let t' = STPair (STMaybe (applySparse s1 t1)) (STMaybe (applySparse s2 t2)) in + k (SpSparse s1) (SpSparse s2) $ Right $ \i k2 -> + k2 (elet $ + emaybe (EVar ext (STMaybe (applySparse s t)) i) + (EPair ext (ENothing ext (applySparse s1 t1)) (ENothing ext (applySparse s2 t2))) + (EPair ext (EJust ext (EFst ext (evar IZ))) (EJust ext (ESnd ext (evar IZ))))) + (EFst ext (EVar ext t' IZ)) (ESnd ext (EVar ext t' IZ)) + +splitSparsePair _ (SpSparse SpAbsent) k = + k SpAbsent SpAbsent $ Right $ \_ k2 -> + k2 id (ENil ext) (ENil ext) +-- -- TODO: having to handle sparse-of-sparse at all is ridiculous +splitSparsePair t (SpSparse (SpSparse s)) k = + splitSparsePair t (SpSparse s) $ \s1 s2 eres -> + k s1 s2 $ Right $ \i k2 -> + case eres of + Left refl -> case refl of {} + Right f -> + f IZ $ \wrap e1 e2 -> + k2 (\body -> + elet (emaybe (EVar ext (STMaybe (STMaybe (applySparse s t))) i) + (ENothing ext (applySparse s t)) + (evar IZ)) $ + wrap body) + e1 e2 diff --git a/src/CHAD/AST/Accum.hs b/src/CHAD/AST/Accum.hs new file mode 100644 index 0000000..ea74a95 --- /dev/null +++ b/src/CHAD/AST/Accum.hs @@ -0,0 +1,137 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.Accum where + +import CHAD.AST.Types +import CHAD.Data + + +data AcPrj + = APHere + | APFst AcPrj + | APSnd AcPrj + | APLeft AcPrj + | APRight AcPrj + | APJust AcPrj + | APArrIdx AcPrj + | APArrSlice Nat + +-- | @b@ is a small part of @a@, indicated by the projection @p@. +data SAcPrj (p :: AcPrj) (a :: Ty) (b :: Ty) where + SAPHere :: SAcPrj APHere a a + SAPFst :: SAcPrj p a b -> SAcPrj (APFst p) (TPair a t) b + SAPSnd :: SAcPrj p a b -> SAcPrj (APSnd p) (TPair t a) b + SAPLeft :: SAcPrj p a b -> SAcPrj (APLeft p) (TLEither a t) b + SAPRight :: SAcPrj p a b -> SAcPrj (APRight p) (TLEither t a) b + SAPJust :: SAcPrj p a b -> SAcPrj (APJust p) (TMaybe a) b + SAPArrIdx :: SAcPrj p a b -> SAcPrj (APArrIdx p) (TArr n a) b + -- TODO: + -- SAPArrSlice :: SNat m -> SAcPrj (APArrSlice m) (TArr n t) (TArr (n - m) t) +deriving instance Show (SAcPrj p a b) + +type data AIDense = AID | AIS + +data SAIDense d where + SAID :: SAIDense AID + SAIS :: SAIDense AIS +deriving instance Show (SAIDense d) + +type family AcIdx d p t where + AcIdx d APHere t = TNil + AcIdx AID (APFst p) (TPair a b) = AcIdx AID p a + AcIdx AID (APSnd p) (TPair a b) = AcIdx AID p b + AcIdx AIS (APFst p) (TPair a b) = TPair (AcIdx AIS p a) (ZeroInfo b) + AcIdx AIS (APSnd p) (TPair a b) = TPair (ZeroInfo a) (AcIdx AIS p b) + AcIdx d (APLeft p) (TLEither a b) = AcIdx d p a + AcIdx d (APRight p) (TLEither a b) = AcIdx d p b + AcIdx d (APJust p) (TMaybe a) = AcIdx d p a + AcIdx AID (APArrIdx p) (TArr n a) = + -- (index, recursive info) + TPair (Tup (Replicate n TIx)) (AcIdx AID p a) + AcIdx AIS (APArrIdx p) (TArr n a) = + -- ((index, shape info), recursive info) + TPair (TPair (Tup (Replicate n TIx)) (ZeroInfo (TArr n a))) + (AcIdx AIS p a) + -- AcIdx AID (APArrSlice m) (TArr n a) = + -- -- index + -- Tup (Replicate m TIx) + -- AcIdx AIS (APArrSlice m) (TArr n a) = + -- -- (index, array shape) + -- TPair (Tup (Replicate m TIx)) (Tup (Replicate n TIx)) + +type AcIdxD p t = AcIdx AID p t +type AcIdxS p t = AcIdx AIS p t + +acPrjTy :: SAcPrj p a b -> SMTy a -> SMTy b +acPrjTy SAPHere t = t +acPrjTy (SAPFst prj) (SMTPair t _) = acPrjTy prj t +acPrjTy (SAPSnd prj) (SMTPair _ t) = acPrjTy prj t +acPrjTy (SAPLeft prj) (SMTLEither t _) = acPrjTy prj t +acPrjTy (SAPRight prj) (SMTLEither _ t) = acPrjTy prj t +acPrjTy (SAPJust prj) (SMTMaybe t) = acPrjTy prj t +acPrjTy (SAPArrIdx prj) (SMTArr _ t) = acPrjTy prj t + +type family ZeroInfo t where + ZeroInfo TNil = TNil + ZeroInfo (TPair a b) = TPair (ZeroInfo a) (ZeroInfo b) + ZeroInfo (TLEither a b) = TNil + ZeroInfo (TMaybe a) = TNil + ZeroInfo (TArr n t) = TArr n (ZeroInfo t) + ZeroInfo (TScal t) = TNil + +tZeroInfo :: SMTy t -> STy (ZeroInfo t) +tZeroInfo SMTNil = STNil +tZeroInfo (SMTPair a b) = STPair (tZeroInfo a) (tZeroInfo b) +tZeroInfo (SMTLEither _ _) = STNil +tZeroInfo (SMTMaybe _) = STNil +tZeroInfo (SMTArr n t) = STArr n (tZeroInfo t) +tZeroInfo (SMTScal _) = STNil + +-- | Info needed to create a zero-valued deep accumulator for a monoid type. +-- Should be constructable from a D1. +type family DeepZeroInfo t where + DeepZeroInfo TNil = TNil + DeepZeroInfo (TPair a b) = TPair (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TLEither a b) = TLEither (DeepZeroInfo a) (DeepZeroInfo b) + DeepZeroInfo (TMaybe a) = TMaybe (DeepZeroInfo a) + DeepZeroInfo (TArr n a) = TArr n (DeepZeroInfo a) + DeepZeroInfo (TScal t) = TNil + +tDeepZeroInfo :: SMTy t -> STy (DeepZeroInfo t) +tDeepZeroInfo SMTNil = STNil +tDeepZeroInfo (SMTPair a b) = STPair (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTLEither a b) = STLEither (tDeepZeroInfo a) (tDeepZeroInfo b) +tDeepZeroInfo (SMTMaybe a) = STMaybe (tDeepZeroInfo a) +tDeepZeroInfo (SMTArr n t) = STArr n (tDeepZeroInfo t) +tDeepZeroInfo (SMTScal _) = STNil + +-- -- | Additional info needed for accumulation. This is empty unless there is +-- -- sparsity in the monoid. +-- type family AccumInfo t where +-- AccumInfo TNil = TNil +-- AccumInfo (TPair a b) = TPair (AccumInfo a) (AccumInfo b) +-- AccumInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- AccumInfo (TMaybe a) = TMaybe (AccumInfo a) +-- AccumInfo (TArr n t) = TArr n (AccumInfo t) +-- AccumInfo (TScal t) = TNil + +-- type family PrimalInfo t where +-- PrimalInfo TNil = TNil +-- PrimalInfo (TPair a b) = TPair (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TLEither a b) = TLEither (PrimalInfo a) (PrimalInfo b) +-- PrimalInfo (TMaybe a) = TMaybe (PrimalInfo a) +-- PrimalInfo (TArr n t) = TArr n (PrimalInfo t) +-- PrimalInfo (TScal t) = TNil + +-- tPrimalInfo :: SMTy t -> STy (PrimalInfo t) +-- tPrimalInfo SMTNil = STNil +-- tPrimalInfo (SMTPair a b) = STPair (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTLEither a b) = STLEither (tPrimalInfo a) (tPrimalInfo b) +-- tPrimalInfo (SMTMaybe a) = STMaybe (tPrimalInfo a) +-- tPrimalInfo (SMTArr n t) = STArr n (tPrimalInfo t) +-- tPrimalInfo (SMTScal _) = STNil diff --git a/src/CHAD/AST/Bindings.hs b/src/CHAD/AST/Bindings.hs new file mode 100644 index 0000000..c1a1e77 --- /dev/null +++ b/src/CHAD/AST/Bindings.hs @@ -0,0 +1,84 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.AST.Bindings where + +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data +import CHAD.Lemmas + + +-- binding lists: a let stack without a body. The stack lives in 'env' and defines 'binds'. +data Bindings f env binds where + BTop :: Bindings f env '[] + BPush :: Bindings f env binds -> (STy t, f (Append binds env) t) -> Bindings f env (t : binds) +deriving instance (forall e t. Show (f e t)) => Show (Bindings f env env') +infixl `BPush` + +bpush :: Bindings (Expr x) env binds -> Expr x (Append binds env) t -> Bindings (Expr x) env (t : binds) +bpush b e = b `BPush` (typeOf e, e) +infixl `bpush` + +mapBindings :: (forall env' t'. f env' t' -> g env' t') + -> Bindings f env binds -> Bindings g env binds +mapBindings _ BTop = BTop +mapBindings f (BPush b (t, e)) = BPush (mapBindings f b) (t, f e) + +weakenBindings :: (forall e1 e2 t. e1 :> e2 -> f e1 t -> f e2 t) + -> env1 :> env2 + -> Bindings f env1 binds + -> (Bindings f env2 binds, Append binds env1 :> Append binds env2) +weakenBindings _ w BTop = (BTop, w) +weakenBindings wf w (BPush b (t, x)) = + let (b', w') = weakenBindings wf w b + in (BPush b' (t, wf w' x), WCopy w') + +weakenBindingsE :: env1 :> env2 + -> Bindings (Expr x) env1 binds + -> (Bindings (Expr x) env2 binds, Append binds env1 :> Append binds env2) +weakenBindingsE = weakenBindings weakenExpr + +weakenOver :: SList STy ts -> env :> env' -> Append ts env :> Append ts env' +weakenOver SNil w = w +weakenOver (SCons _ ts) w = WCopy (weakenOver ts w) + +sinkWithBindings :: forall env' env binds f. Bindings f env binds -> env' :> Append binds env' +sinkWithBindings BTop = WId +sinkWithBindings (BPush b _) = WSink .> sinkWithBindings b + +bconcat :: forall f env binds1 binds2. Bindings f env binds1 -> Bindings f (Append binds1 env) binds2 -> Bindings f env (Append binds2 binds1) +bconcat b1 BTop = b1 +bconcat b1 (BPush (b2 :: Bindings _ (Append binds1 env) binds2C) (t, x)) + | Refl <- lemAppendAssoc @binds2C @binds1 @env + = BPush (bconcat b1 b2) (t, x) + +bindingsBinds :: Bindings f env binds -> SList STy binds +bindingsBinds BTop = SNil +bindingsBinds (BPush binds (t, _)) = SCons t (bindingsBinds binds) + +letBinds :: Bindings Ex env binds -> Ex (Append binds env) t -> Ex env t +letBinds BTop = id +letBinds (BPush b (_, rhs)) = letBinds b . ELet ext rhs + +collectBindings :: SList STy env -> Subenv env env' -> Bindings Ex env env' +collectBindings = \env -> fst . go env WId + where + go :: SList STy env -> env :> env0 -> Subenv env env' -> (Bindings Ex env0 env', env0 :> Append env' env0) + go _ _ SETop = (BTop, WId) + go (ty `SCons` env) w (SEYesR sub) = + let (bs, w') = go env (WPop w) sub + in (BPush bs (ty, EVar ext ty (w' .> w @> IZ)), WSink .> w') + go (_ `SCons` env) w (SENo sub) = go env (WPop w) sub diff --git a/src/CHAD/AST/Count.hs b/src/CHAD/AST/Count.hs new file mode 100644 index 0000000..133093a --- /dev/null +++ b/src/CHAD/AST/Count.hs @@ -0,0 +1,930 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE DerivingVia #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} +module CHAD.AST.Count where + +import Data.Functor.Product +import Data.Some +import Data.Type.Equality +import GHC.Generics (Generic, Generically(..)) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Env +import CHAD.Data + + +-- | The monoid operation combines assuming that /both/ branches are taken. +class Monoid a => Occurrence a where + -- | One of the two branches is taken + (<||>) :: a -> a -> a + -- | This code is executed many times + scaleMany :: a -> a + + +data Count = Zero | One | Many + deriving (Show, Eq, Ord) + +instance Semigroup Count where + Zero <> n = n + n <> Zero = n + _ <> _ = Many +instance Monoid Count where + mempty = Zero +instance Occurrence Count where + (<||>) = max + scaleMany Zero = Zero + scaleMany _ = Many + +data Occ = Occ { _occLexical :: Count + , _occRuntime :: Count } + deriving (Eq, Generic) + deriving (Semigroup, Monoid) via Generically Occ + +instance Show Occ where + showsPrec d (Occ l r) = showParen (d > 10) $ + showString "Occ " . showsPrec 11 l . showString " " . showsPrec 11 r + +instance Occurrence Occ where + Occ l1 r1 <||> Occ l2 r2 = Occ (l1 <> l2) (r1 <||> r2) + scaleMany (Occ l c) = Occ l (scaleMany c) + + +data Substruc t t' where + -- If you add constructors here, do not forget to update the COMPLETE pragmas of any pattern synonyms below + SsFull :: Substruc t t + SsNone :: Substruc t TNil + SsPair :: Substruc a a' -> Substruc b b' -> Substruc (TPair a b) (TPair a' b') + SsEither :: Substruc a a' -> Substruc b b' -> Substruc (TEither a b) (TEither a' b') + SsLEither :: Substruc a a' -> Substruc b b' -> Substruc (TLEither a b) (TLEither a' b') + SsMaybe :: Substruc a a' -> Substruc (TMaybe a) (TMaybe a') + SsArr :: Substruc a a' -> Substruc (TArr n a) (TArr n a') -- ^ union of usages of all array elements + SsAccum :: Substruc a a' -> Substruc (TAccum a) (TAccum a') + +pattern SsPair' :: forall a b t'. forall a' b'. t' ~ TPair a' b' => Substruc a a' -> Substruc b b' -> Substruc (TPair a b) t' +pattern SsPair' s1 s2 <- ((\case { SsFull -> SsPair SsFull SsFull ; s -> s }) -> SsPair s1 s2) + where SsPair' = SsPair +{-# COMPLETE SsNone, SsPair', SsEither, SsLEither, SsMaybe, SsArr, SsAccum #-} + +pattern SsArr' :: forall n a t'. forall a'. t' ~ TArr n a' => Substruc a a' -> Substruc (TArr n a) t' +pattern SsArr' s <- ((\case { SsFull -> SsArr SsFull ; s -> s }) -> SsArr s) + where SsArr' = SsArr +{-# COMPLETE SsNone, SsPair, SsEither, SsLEither, SsMaybe, SsArr', SsAccum #-} + +instance Semigroup (Some (Substruc t)) where + Some SsFull <> _ = Some SsFull + _ <> Some SsFull = Some SsFull + Some SsNone <> s = s + s <> Some SsNone = s + Some (SsPair a b) <> Some (SsPair a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsPair a2 b2) + Some (SsEither a b) <> Some (SsEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsEither a2 b2) + Some (SsLEither a b) <> Some (SsLEither a' b') = withSome (Some a <> Some a') $ \a2 -> withSome (Some b <> Some b') $ \b2 -> Some (SsLEither a2 b2) + Some (SsMaybe a) <> Some (SsMaybe a') = withSome (Some a <> Some a') $ \a2 -> Some (SsMaybe a2) + Some (SsArr a) <> Some (SsArr a') = withSome (Some a <> Some a') $ \a2 -> Some (SsArr a2) + Some (SsAccum a) <> Some (SsAccum a') = withSome (Some a <> Some a') $ \a2 -> Some (SsAccum a2) +instance Monoid (Some (Substruc t)) where + mempty = Some SsNone + +instance TestEquality (Substruc t) where + testEquality SsFull s = isFull s + testEquality s SsFull = sym <$> isFull s + testEquality SsNone SsNone = Just Refl + testEquality SsNone _ = Nothing + testEquality _ SsNone = Nothing + testEquality (SsPair a b) (SsPair a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsEither a b) (SsEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsLEither a b) (SsLEither a' b') | Just Refl <- testEquality a a', Just Refl <- testEquality b b' = Just Refl | otherwise = Nothing + testEquality (SsMaybe s) (SsMaybe s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsArr s) (SsArr s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + testEquality (SsAccum s) (SsAccum s') | Just Refl <- testEquality s s' = Just Refl | otherwise = Nothing + +isFull :: Substruc t t' -> Maybe (t :~: t') +isFull SsFull = Just Refl +isFull SsNone = Nothing -- TODO: nil? +isFull (SsPair a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsLEither a b) | Just Refl <- isFull a, Just Refl <- isFull b = Just Refl | otherwise = Nothing +isFull (SsMaybe s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsArr s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing +isFull (SsAccum s) | Just Refl <- isFull s = Just Refl | otherwise = Nothing + +applySubstruc :: Substruc t t' -> STy t -> STy t' +applySubstruc SsFull t = t +applySubstruc SsNone _ = STNil +applySubstruc (SsPair s1 s2) (STPair a b) = STPair (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsEither s1 s2) (STEither a b) = STEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsLEither s1 s2) (STLEither a b) = STLEither (applySubstruc s1 a) (applySubstruc s2 b) +applySubstruc (SsMaybe s) (STMaybe t) = STMaybe (applySubstruc s t) +applySubstruc (SsArr s) (STArr n t) = STArr n (applySubstruc s t) +applySubstruc (SsAccum s) (STAccum t) = STAccum (applySubstrucM s t) + +applySubstrucM :: Substruc t t' -> SMTy t -> SMTy t' +applySubstrucM SsFull t = t +applySubstrucM SsNone _ = SMTNil +applySubstrucM (SsPair s1 s2) (SMTPair a b) = SMTPair (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsLEither s1 s2) (SMTLEither a b) = SMTLEither (applySubstrucM s1 a) (applySubstrucM s2 b) +applySubstrucM (SsMaybe s) (SMTMaybe t) = SMTMaybe (applySubstrucM s t) +applySubstrucM (SsArr s) (SMTArr n t) = SMTArr n (applySubstrucM s t) +applySubstrucM _ t = case t of {} + +data ExMap a b = ExMap (forall env. Ex env a -> Ex env b) + | a ~ b => ExMapId + +fromExMap :: ExMap a b -> Ex env a -> Ex env b +fromExMap (ExMap f) = f +fromExMap ExMapId = id + +simplifySubstruc :: STy t -> Substruc t t' -> Substruc t t' +simplifySubstruc STNil SsNone = SsFull + +simplifySubstruc _ SsFull = SsFull +simplifySubstruc _ SsNone = SsNone +simplifySubstruc (STPair t1 t2) (SsPair s1 s2) = SsPair (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STEither t1 t2) (SsEither s1 s2) = SsEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STLEither t1 t2) (SsLEither s1 s2) = SsLEither (simplifySubstruc t1 s1) (simplifySubstruc t2 s2) +simplifySubstruc (STMaybe t) (SsMaybe s) = SsMaybe (simplifySubstruc t s) +simplifySubstruc (STArr _ t) (SsArr s) = SsArr (simplifySubstruc t s) +simplifySubstruc (STAccum t) (SsAccum s) = SsAccum (simplifySubstruc (fromSMTy t) s) + +-- simplifySubstruc' :: Substruc t t' +-- -> (forall t'2. Substruc t t'2 -> ExMap t'2 t' -> r) -> r +-- simplifySubstruc' SsFull k = k SsFull ExMapId +-- simplifySubstruc' SsNone k = k SsNone ExMapId +-- simplifySubstruc' (SsPair s1 s2) k = +-- simplifySubstruc' s1 $ \s1' f1 -> +-- simplifySubstruc' s2 $ \s2' f2 -> +-- case (s1', s2') of +-- (SsFull, SsFull) -> +-- k SsFull (case (f1, f2) of +-- (ExMapId, ExMapId) -> ExMapId +-- _ -> ExMap (\e -> eunPair e $ \_ e1 e2 -> +-- EPair ext (fromExMap f1 e1) (fromExMap f2 e2))) +-- (SsNone, SsNone) -> k SsNone (ExMap (\_ -> EPair ext (fromExMap f1 (ENil ext)) (fromExMap f2 (ENil ext)))) +-- _ -> k (SsPair s1' s2') (ExMap (\e -> elet e $ EPair ext (fromExMap f1 (EFst ext (evar IZ))) (fromExMap f2 (ESnd ext (evar IZ))))) +-- simplifySubstruc' _ _ = _ + +-- ssUnpair :: Substruc (TPair a b) -> (Substruc a, Substruc b) +-- ssUnpair SsFull = (SsFull, SsFull) +-- ssUnpair SsNone = (SsNone, SsNone) +-- ssUnpair (SsPair a b) = (a, b) + +-- ssUnleft :: Substruc (TEither a b) -> Substruc a +-- ssUnleft SsFull = SsFull +-- ssUnleft SsNone = SsNone +-- ssUnleft (SsEither a _) = a + +-- ssUnright :: Substruc (TEither a b) -> Substruc b +-- ssUnright SsFull = SsFull +-- ssUnright SsNone = SsNone +-- ssUnright (SsEither _ b) = b + +-- ssUnlleft :: Substruc (TLEither a b) -> Substruc a +-- ssUnlleft SsFull = SsFull +-- ssUnlleft SsNone = SsNone +-- ssUnlleft (SsLEither a _) = a + +-- ssUnlright :: Substruc (TLEither a b) -> Substruc b +-- ssUnlright SsFull = SsFull +-- ssUnlright SsNone = SsNone +-- ssUnlright (SsLEither _ b) = b + +-- ssUnjust :: Substruc (TMaybe a) -> Substruc a +-- ssUnjust SsFull = SsFull +-- ssUnjust SsNone = SsNone +-- ssUnjust (SsMaybe a) = a + +-- ssUnarr :: Substruc (TArr n a) -> Substruc a +-- ssUnarr SsFull = SsFull +-- ssUnarr SsNone = SsNone +-- ssUnarr (SsArr a) = a + +-- ssUnaccum :: Substruc (TAccum a) -> Substruc a +-- ssUnaccum SsFull = SsFull +-- ssUnaccum SsNone = SsNone +-- ssUnaccum (SsAccum a) = a + + +type family MapEmpty env where + MapEmpty '[] = '[] + MapEmpty (t : env) = TNil : MapEmpty env + +data OccEnv a env env' where + OccEnd :: OccEnv a env (MapEmpty env) -- not necessarily top! + OccPush :: OccEnv a env env' -> a -> Substruc t t' -> OccEnv a (t : env) (t' : env') + +instance Semigroup a => Semigroup (Some (OccEnv a env)) where + Some OccEnd <> e = e + e <> Some OccEnd = e + Some (OccPush e o s) <> Some (OccPush e' o' s') = withSome (Some e <> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <> o') s2) + +instance Semigroup a => Monoid (Some (OccEnv a env)) where + mempty = Some OccEnd + +instance Occurrence a => Occurrence (Some (OccEnv a env)) where + Some OccEnd <||> e = e + e <||> Some OccEnd = e + Some (OccPush e o s) <||> Some (OccPush e' o' s') = withSome (Some e <||> Some e') $ \e2 -> withSome (Some s <> Some s') $ \s2 -> Some (OccPush e2 (o <||> o') s2) + + scaleMany (Some OccEnd) = Some OccEnd + scaleMany (Some (OccPush e o s)) = withSome (scaleMany (Some e)) $ \e2 -> Some (OccPush e2 (scaleMany o) s) + +onehotOccEnv :: Monoid a => Idx env t -> a -> Substruc t t' -> Some (OccEnv a env) +onehotOccEnv IZ v s = Some (OccPush OccEnd v s) +onehotOccEnv (IS i) v s + | Some env' <- onehotOccEnv i v s + = Some (OccPush env' mempty SsNone) + +occEnvPop :: OccEnv a (t : env) (t' : env') -> (OccEnv a env env', Substruc t t') +occEnvPop (OccPush e _ s) = (e, s) +occEnvPop OccEnd = (OccEnd, SsNone) + +occEnvPop' :: OccEnv a (t : env) env' -> (forall t' env''. env' ~ t' : env'' => OccEnv a env env'' -> Substruc t t' -> r) -> r +occEnvPop' (OccPush e _ s) k = k e s +occEnvPop' OccEnd k = k OccEnd SsNone + +occEnvPopSome :: Some (OccEnv a (t : env)) -> Some (OccEnv a env) +occEnvPopSome (Some (OccPush e _ _)) = Some e +occEnvPopSome (Some OccEnd) = Some OccEnd + +occEnvPrj :: Monoid a => OccEnv a env env' -> Idx env t -> (a, Some (Substruc t)) +occEnvPrj OccEnd _ = mempty +occEnvPrj (OccPush _ o s) IZ = (o, Some s) +occEnvPrj (OccPush e _ _) (IS i) = occEnvPrj e i + +occEnvPrjS :: OccEnv a env env' -> Idx env t -> Some (Product (Substruc t) (Idx env')) +occEnvPrjS OccEnd IZ = Some (Pair SsNone IZ) +occEnvPrjS OccEnd (IS i) | Some (Pair s i') <- occEnvPrjS OccEnd i = Some (Pair s (IS i')) +occEnvPrjS (OccPush _ _ s) IZ = Some (Pair s IZ) +occEnvPrjS (OccPush e _ _) (IS i) + | Some (Pair s' i') <- occEnvPrjS e i + = Some (Pair s' (IS i')) + +projectSmallerSubstruc :: Substruc t t'big -> Substruc t t'small -> Ex env t'big -> Ex env t'small +projectSmallerSubstruc topsbig topssmall ex = case (topsbig, topssmall) of + _ | Just Refl <- testEquality topsbig topssmall -> ex + + (SsFull, SsFull) -> ex + (SsNone, SsNone) -> ex + (SsNone, _) -> error "projectSmallerSubstruc: smaller substructure not smaller" + (_, SsNone) -> + case typeOf ex of + STNil -> ex + _ -> use ex $ ENil ext + + (SsPair s1 s2, SsPair s1' s2') -> + eunPair ex $ \_ e1 e2 -> + EPair ext (projectSmallerSubstruc s1 s1' e1) (projectSmallerSubstruc s2 s2' e2) + (s@SsPair{}, SsFull) -> projectSmallerSubstruc s (SsPair SsFull SsFull) ex + (SsFull, s@SsPair{}) -> projectSmallerSubstruc (SsPair SsFull SsFull) s ex + + (SsEither s1 s2, SsEither s1' s2') + | STEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in ecase ex + (EInl ext (typeOf e2) e1) + (EInr ext (typeOf e1) e2) + (s@SsEither{}, SsFull) -> projectSmallerSubstruc s (SsEither SsFull SsFull) ex + (SsFull, s@SsEither{}) -> projectSmallerSubstruc (SsEither SsFull SsFull) s ex + + (SsLEither s1 s2, SsLEither s1' s2') + | STLEither t1 t2 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + e2 = projectSmallerSubstruc s2 s2' (EVar ext t2 IZ) + in elcase ex + (ELNil ext (typeOf e1) (typeOf e2)) + (ELInl ext (typeOf e2) e1) + (ELInr ext (typeOf e1) e2) + (s@SsLEither{}, SsFull) -> projectSmallerSubstruc s (SsLEither SsFull SsFull) ex + (SsFull, s@SsLEither{}) -> projectSmallerSubstruc (SsLEither SsFull SsFull) s ex + + (SsMaybe s1, SsMaybe s1') + | STMaybe t1 <- typeOf ex -> + let e1 = projectSmallerSubstruc s1 s1' (EVar ext t1 IZ) + in emaybe ex + (ENothing ext (typeOf e1)) + (EJust ext e1) + (s@SsMaybe{}, SsFull) -> projectSmallerSubstruc s (SsMaybe SsFull) ex + (SsFull, s@SsMaybe{}) -> projectSmallerSubstruc (SsMaybe SsFull) s ex + + (SsArr s1, SsArr s2) -> emap (projectSmallerSubstruc s1 s2 (evar IZ)) ex + (s@SsArr{}, SsFull) -> projectSmallerSubstruc s (SsArr SsFull) ex + (SsFull, s@SsArr{}) -> projectSmallerSubstruc (SsArr SsFull) s ex + + (SsAccum _, SsAccum _) -> error "TODO smaller ssaccum" + (s@SsAccum{}, SsFull) -> projectSmallerSubstruc s (SsAccum SsFull) ex + (SsFull, s@SsAccum{}) -> projectSmallerSubstruc (SsAccum SsFull) s ex + + +-- | A boolean for each entry in the environment, with the ability to uniformly +-- mask the top part above a certain index. +data EnvMask env where + EMRest :: Bool -> EnvMask env + EMPush :: EnvMask env -> Bool -> EnvMask (t : env) + +envMaskPrj :: EnvMask env -> Idx env t -> Bool +envMaskPrj (EMRest b) _ = b +envMaskPrj (_ `EMPush` b) IZ = b +envMaskPrj (env `EMPush` _) (IS i) = envMaskPrj env i + +occCount :: Idx env a -> Expr x env t -> Occ +occCount idx ex + | Some env <- occCountAll ex + = fst (occEnvPrj env idx) + +occCountAll :: Expr x env t -> Some (OccEnv Occ env) +occCountAll ex = occCountX SsFull ex $ \env _ -> Some env + +pruneExpr :: SList f env -> Expr x env t -> Ex env t +pruneExpr env ex = occCountX SsFull ex $ \_ mkex -> mkex (fullOccEnv env) + where + fullOccEnv :: SList f env -> OccEnv () env env + fullOccEnv SNil = OccEnd + fullOccEnv (_ `SCons` e) = OccPush (fullOccEnv e) () SsFull + +-- In one traversal, count occurrences of variables and determine what parts of +-- expressions are actually used. These two results are computed independently: +-- even if (almost) nothing of a particular term is actually used, variable +-- references in that term still count as usual. +-- +-- In @occCountX s t k@: +-- * s: how much of the result of this term is required +-- * t: the term to analyse +-- * k: is passed the actual environment usage of this expression, including +-- occurrence counts. The callback reconstructs a new expression in an +-- updated "response" environment. The response must be at least as large as +-- the computed usages. +occCountX :: forall env t t' x r. Substruc t t' -> Expr x env t + -> (forall env'. OccEnv Occ env env' + -- response OccEnv must be at least as large as the OccEnv returned above + -> (forall env''. OccEnv () env env'' -> Ex env'' t') + -> r) + -> r +occCountX initialS topexpr k = case topexpr of + EVar _ t i -> + withSome (onehotOccEnv i (Occ One One) s) $ \env -> + k env $ \env' -> + withSome (occEnvPrjS env' i) $ \(Pair s' i') -> + projectSmallerSubstruc s' s (EVar ext (applySubstruc s' t) i') + ELet _ rhs body -> + occCountX s body $ \envB mkbody -> + occEnvPop' envB $ \envB' s1 -> + occCountX s1 rhs $ \envR mkrhs -> + withSome (Some envB' <> Some envR) $ \env -> + k env $ \env' -> + ELet ext (mkrhs env') (mkbody (OccPush env' () s1)) + EPair _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsPair' s1 s2 -> + occCountX s1 a $ \env1 mka -> + occCountX s2 b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPair ext (mka env') (mkb env') + EFst _ e -> + occCountX (SsPair s SsNone) e $ \env1 mke -> + k env1 $ \env' -> + EFst ext (mke env') + ESnd _ e -> + occCountX (SsPair SsNone s) e $ \env1 mke -> + k env1 $ \env' -> + ESnd ext (mke env') + ENil _ -> + case s of + SsFull -> k OccEnd (\_ -> ENil ext) + SsNone -> k OccEnd (\_ -> ENil ext) + EInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + EInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + EInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + EInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsEither SsFull SsFull) topexpr k + ECase _ e a b -> + occCountX s a $ \env1' mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env1' $ \env1 s1 -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + ECase ext (mke env') (mka (OccPush env' () s1)) (mkb (OccPush env' () s2)) + ENothing _ t -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsMaybe s' -> k OccEnd (\_ -> ENothing ext (applySubstruc s' t)) + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EJust _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsMaybe s' -> + occCountX s' e $ \env1 mke -> + k env1 $ \env' -> + EJust ext (mke env') + SsFull -> occCountX (SsMaybe SsFull) topexpr k + EMaybe _ a b e -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s2 -> + occCountX (SsMaybe s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2)) $ \env -> + k env $ \env' -> + EMaybe ext (mka env') (mkb (OccPush env' () s2)) (mke env') + ELNil _ t1 t2 -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsLEither s1 s2 -> k OccEnd (\_ -> ELNil ext (applySubstruc s1 t1) (applySubstruc s2 t2)) + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInl _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s1 e $ \env1 mke -> + k env1 $ \env' -> + ELInl ext (applySubstruc s2 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELInr _ t e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + SsLEither s1 s2 -> + occCountX s2 e $ \env1 mke -> + k env1 $ \env' -> + ELInr ext (applySubstruc s1 t) (mke env') + SsFull -> occCountX (SsLEither SsFull SsFull) topexpr k + ELCase _ e a b c -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2' mkb -> + occCountX s c $ \env3' mkc -> + occEnvPop' env2' $ \env2 s1 -> + occEnvPop' env3' $ \env3 s2 -> + occCountX (SsLEither s1 s2) e $ \env0 mke -> + withSome (Some env0 <> (Some env1 <||> Some env2 <||> Some env3)) $ \env -> + k env $ \env' -> + ELCase ext (mke env') (mka env') (mkb (OccPush env' () s1)) (mkc (OccPush env' () s2)) + + EConstArr _ n t x -> + case s of + SsNone -> k OccEnd (\_ -> ENil ext) + SsArr' SsNone -> k OccEnd (\_ -> EBuild ext n (eshapeConst (arrayShape x)) (ENil ext)) + SsArr' SsFull -> k OccEnd (\_ -> EConstArr ext n t x) + + EBuild _ n a b -> + case s of + SsNone -> + occCountX SsFull a $ \env1 mka -> + occCountX SsNone b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + use (EBuild ext n (mka env') $ + use (elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2))) $ + ENil ext) $ + ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX s' b $ \env2'' mkb -> + occEnvPop' env2'' $ \env2' s2 -> + withSome (Some env1 <> scaleMany (Some env2')) $ \env -> + k env $ \env' -> + EBuild ext n (mka env') $ + elet (projectSmallerSubstruc SsFull s2 (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy WSink) (mkb (OccPush env' () s2)) + + EMap _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + use (EMap ext (mka (OccPush env' () s1)) (mkb env')) $ + ENil ext + SsArr' s' -> + occCountX s' a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1 -> + occCountX (SsArr s1) b $ \env2 mkb -> + withSome (scaleMany (Some env1') <> Some env2) $ \env -> + k env $ \env' -> + EMap ext (mka (OccPush env' () s1)) (mkb env') + + EFold1Inner _ commut a b c -> + occCountX SsFull a $ \env1'' mka -> + occEnvPop' env1'' $ \env1' s1' -> + let s1 = case s1' of + SsNone -> Some SsNone + SsPair' s1'a s1'b -> Some s1'a <> Some s1'b + s0 = case s of + SsNone -> Some SsNone + SsArr' s' -> Some s' in + withSome (s1 <> s0) $ \sElt -> + occCountX sElt b $ \env2 mkb -> + occCountX (SsArr sElt) c $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsArr sElt) s $ + EFold1Inner ext commut + (projectSmallerSubstruc SsFull sElt $ + mka (OccPush env' () (SsPair sElt sElt))) + (mkb env') (mkc env') + + ESum1Inner _ e -> handleReduction (ESum1Inner ext) e + + EUnit _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' s' -> + occCountX s' e $ \env mke -> + k env $ \env' -> + EUnit ext (mke env') + + EReplicate1Inner _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX SsFull a $ \env1 mka -> + occCountX (SsArr s') b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReplicate1Inner ext (mka env') (mkb env') + + EMaximum1Inner _ e -> handleReduction (EMaximum1Inner ext) e + EMinimum1Inner _ e -> handleReduction (EMinimum1Inner ext) e + + EReshape _ n esh e -> + case s of + SsNone -> + occCountX SsNone esh $ \env1 mkesh -> + occCountX SsNone e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkesh env') $ use (mke env') $ ENil ext + SsArr' s' -> + occCountX SsFull esh $ \env1 mkesh -> + occCountX (SsArr s') e $ \env2 mke -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EReshape ext n (mkesh env') (mke env') + + EZip _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ mka env' + SsArr' (SsPair' SsNone s2) -> + occCountX SsNone a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ + emap (EPair ext (ENil ext) (evar IZ)) (mkb env') + SsArr' (SsPair' s1 SsNone) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mkb env') $ + emap (EPair ext (evar IZ) (ENil ext)) (mka env') + SsArr' (SsPair' s1 s2) -> + occCountX (SsArr s1) a $ \env1 mka -> + occCountX (SsArr s2) b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EZip ext (mka env') (mkb env') + + EFold1InnerD1 _ cm e1 e2 e3 -> + case s of + -- If nothing is necessary, we can execute a fold and then proceed to ignore it + SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX SsNone foldex $ \env1 mkfoldex -> k env1 mkfoldex + -- If we don't need the stores, still a fold suffices + SsPair' sP SsNone -> + let foldex = EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) e1)) + (mapExt (\_ -> ext) e2) (mapExt (\_ -> ext) e3) + in occCountX sP foldex $ \env1 mkfoldex -> k env1 $ \env' -> EPair ext (mkfoldex env') (ENil ext) + -- If for whatever reason the additional stores themselves are + -- unnecessary but the shape of the array is, then oblige + SsPair' sP (SsArr' SsNone) -> + let STArr sn _ = typeOf e3 + foldex = + elet (mapExt (\_ -> ext) e3) $ + EPair ext + (EShape ext (evar IZ)) + (EFold1Inner ext cm (EFst ext (mapExt (\_ -> ext) (weakenExpr (WCopy WSink) e1))) + (mapExt (\_ -> ext) (weakenExpr WSink e2)) + (evar IZ)) + in occCountX (SsPair SsFull sP) foldex $ \env1 mkfoldex -> + k env1 $ \env' -> + eunPair (mkfoldex env') $ \_ eshape earr -> + EPair ext earr (EBuild ext sn eshape (ENil ext)) + -- If at least some of the additional stores are required, we need to keep this a mapAccum + SsPair' _ (SsArr' sB) -> + -- TODO: propagate usage of primals + occCountX (SsPair SsFull sB) e1 $ \env1_1' mka -> + occEnvPop' env1_1' $ \env1' _ -> + occCountX SsFull e2 $ \env2 mkb -> + occCountX SsFull e3 $ \env3 mkc -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair SsFull (SsArr sB)) s $ + EFold1InnerD1 ext cm (mka (OccPush env' () SsFull)) + (mkb env') (mkc env') + + EFold1InnerD2 _ cm ef ebog ed -> + -- TODO: propagate usage of duals + occCountX SsFull ef $ \env1_2' mkef -> + occEnvPop' env1_2' $ \env1_1' _ -> + occEnvPop' env1_1' $ \env1' sB -> + occCountX (SsArr sB) ebog $ \env2 mkebog -> + occCountX SsFull ed $ \env3 mked -> + withSome (scaleMany (Some env1') <> Some env2 <> Some env3) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ + EFold1InnerD2 ext cm + (mkef (OccPush (OccPush env' () sB) () SsFull)) + (mkebog env') (mked env') + + EConst _ t x -> + k OccEnd $ \_ -> + case s of + SsNone -> ENil ext + SsFull -> EConst ext t x + + EIdx0 _ e -> + occCountX (SsArr s) e $ \env1 mke -> + k env1 $ \env' -> + EIdx0 ext (mke env') + + EIdx1 _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + SsArr' s' -> + occCountX (SsArr s') a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx1 ext (mka env') (mkb env') + + EIdx _ a b -> + case s of + SsNone -> + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + _ -> + occCountX (SsArr s) a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EIdx ext (mka env') (mkb env') + + EShape _ e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX (SsArr SsNone) e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EShape ext (mke env') + + EOp _ op e -> + case s of + SsNone -> + occCountX SsNone e $ \env1 mke -> + k env1 $ \env' -> + use (mke env') $ ENil ext + _ -> + occCountX SsFull e $ \env1 mke -> + k env1 $ \env' -> + projectSmallerSubstruc SsFull s $ EOp ext op (mke env') + + ECustom _ t1 t2 t3 e1 e2 e3 a b + | typeHasAccums t1 || typeHasAccums t2 || typeHasAccums t3 -> + error "Accumulators not allowed in input/output/tape of an ECustom" + | otherwise -> + case s of + SsNone -> + -- Allowed to ignore e1/e2/e3 here because no accumulators are + -- communicated, and hence no relevant effects exist + occCountX SsNone a $ \env1 mka -> + occCountX SsNone b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (mka env') $ use (mkb env') $ ENil ext + s' -> -- Let's be pessimistic for safety + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s' $ + ECustom ext t1 t2 t3 (mapExt (const ext) e1) (mapExt (const ext) e2) (mapExt (const ext) e3) (mka env') (mkb env') + + ERecompute _ e -> + occCountX s e $ \env1 mke -> + k env1 $ \env' -> + ERecompute ext (mke env') + + EWith _ t a b -> + case s of + SsNone -> -- TODO: simplifier should remove accumulations to an unused with, and then remove the with + occCountX SsNone b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + withSome (case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone) $ \s1' -> + occCountX s1' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + use (EWith ext (applySubstrucM s1' t) (mka env') (mkb (OccPush env' () (SsAccum s1')))) $ + ENil ext + SsPair sB sA -> + occCountX sB b $ \env2' mkb -> + occEnvPop' env2' $ \env2 s1 -> + let s1' = case s1 of + SsFull -> Some SsFull + SsAccum s' -> Some s' + SsNone -> Some SsNone in + withSome (Some sA <> s1') $ \sA' -> + occCountX sA' a $ \env1 mka -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc (SsPair sB sA') (SsPair sB sA) $ + EWith ext (applySubstrucM sA' t) (mka env') (mkb (OccPush env' () (SsAccum sA'))) + SsFull -> occCountX (SsPair SsFull SsFull) topexpr k + + EAccum _ t p a sp b e -> + -- TODO: do better! + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> + occCountX SsFull e $ \env3 mke -> + withSome (Some env1 <> Some env2) $ \env12 -> + withSome (Some env12 <> Some env3) $ \env -> + k env $ \env' -> + case s of {SsFull -> id; SsNone -> id} $ + EAccum ext t p (mka env') sp (mkb env') (mke env') + + EZero _ t e -> + occCountX (subZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EZero ext (applySubstrucM s t) (mke env') + where + subZeroInfo :: Substruc t1 t2 -> Substruc (ZeroInfo t1) (ZeroInfo t2) + subZeroInfo SsFull = SsFull + subZeroInfo SsNone = SsNone + subZeroInfo (SsPair s1 s2) = SsPair (subZeroInfo s1) (subZeroInfo s2) + subZeroInfo SsEither{} = error "Either is not a monoid" + subZeroInfo SsLEither{} = SsNone + subZeroInfo SsMaybe{} = SsNone + subZeroInfo (SsArr s') = SsArr (subZeroInfo s') + subZeroInfo SsAccum{} = error "Accum is not a monoid" + + EDeepZero _ t e -> + occCountX (subDeepZeroInfo s) e $ \env1 mke -> + k env1 $ \env' -> + EDeepZero ext (applySubstrucM s t) (mke env') + where + subDeepZeroInfo :: Substruc t1 t2 -> Substruc (DeepZeroInfo t1) (DeepZeroInfo t2) + subDeepZeroInfo SsFull = SsFull + subDeepZeroInfo SsNone = SsNone + subDeepZeroInfo (SsPair s1 s2) = SsPair (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo SsEither{} = error "Either is not a monoid" + subDeepZeroInfo (SsLEither s1 s2) = SsLEither (subDeepZeroInfo s1) (subDeepZeroInfo s2) + subDeepZeroInfo (SsMaybe s') = SsMaybe (subDeepZeroInfo s') + subDeepZeroInfo (SsArr s') = SsArr (subDeepZeroInfo s') + subDeepZeroInfo SsAccum{} = error "Accum is not a monoid" + + EPlus _ t a b -> + occCountX s a $ \env1 mka -> + occCountX s b $ \env2 mkb -> + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + EPlus ext (applySubstrucM s t) (mka env') (mkb env') + + EOneHot _ t p a b -> + occCountX SsFull a $ \env1 mka -> + occCountX SsFull b $ \env2 mkb -> -- TODO: do better + withSome (Some env1 <> Some env2) $ \env -> + k env $ \env' -> + projectSmallerSubstruc SsFull s $ EOneHot ext t p (mka env') (mkb env') + + EError _ t msg -> + k OccEnd $ \_ -> EError ext (applySubstruc s t) msg + where + s = simplifySubstruc (typeOf topexpr) initialS + + handleReduction :: t ~ TArr n (TScal t2) + => (forall env2. Ex env2 (TArr (S n) (TScal t2)) -> Ex env2 (TArr n (TScal t2))) + -> Expr x env (TArr (S n) (TScal t2)) + -> r + handleReduction reduce e + | STArr (SS n) _ <- typeOf e = + case s of + SsNone -> + occCountX SsNone e $ \env mke -> + k env $ \env' -> + use (mke env') $ ENil ext + SsArr' SsNone -> + occCountX (SsArr SsNone) e $ \env mke -> + k env $ \env' -> + elet (mke env') $ + EBuild ext n (EFst ext (EShape ext (evar IZ))) (ENil ext) + SsArr' SsFull -> + occCountX (SsArr SsFull) e $ \env mke -> + k env $ \env' -> + reduce (mke env') + + +deleteUnused :: SList f env -> Some (OccEnv Occ env) -> (forall env'. Subenv env env' -> r) -> r +deleteUnused SNil (Some OccEnd) k = k SETop +deleteUnused (_ `SCons` env) (Some OccEnd) k = + deleteUnused env (Some OccEnd) $ \sub -> k (SENo sub) +deleteUnused (_ `SCons` env) (Some (OccPush occenv (Occ _ count) _)) k = + deleteUnused env (Some occenv) $ \sub -> + case count of Zero -> k (SENo sub) + _ -> k (SEYesR sub) + +unsafeWeakenWithSubenv :: Subenv env env' -> Expr x env t -> Expr x env' t +unsafeWeakenWithSubenv = \sub -> + subst (\x t i -> case sinkViaSubenv i sub of + Just i' -> EVar x t i' + Nothing -> error "unsafeWeakenWithSubenv: Index occurred that was subenv'd away") + where + sinkViaSubenv :: Idx env t -> Subenv env env' -> Maybe (Idx env' t) + sinkViaSubenv IZ (SEYesR _) = Just IZ + sinkViaSubenv IZ (SENo _) = Nothing + sinkViaSubenv (IS i) (SEYesR sub) = IS <$> sinkViaSubenv i sub + sinkViaSubenv (IS i) (SENo sub) = sinkViaSubenv i sub diff --git a/src/CHAD/AST/Env.hs b/src/CHAD/AST/Env.hs new file mode 100644 index 0000000..8e6b745 --- /dev/null +++ b/src/CHAD/AST/Env.hs @@ -0,0 +1,95 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Env where + +import Data.Type.Equality + +import CHAD.AST.Sparse +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +-- | @env'@ is a subset of @env@: each element of @env@ is either included in +-- @env'@ ('SEYes') or not included in @env'@ ('SENo'). +data Subenv' s env env' where + SETop :: Subenv' s '[] '[] + SEYes :: forall t t' env env' s. s t t' -> Subenv' s env env' -> Subenv' s (t : env) (t' : env') + SENo :: forall t env env' s. Subenv' s env env' -> Subenv' s (t : env) env' +deriving instance (forall t t'. Show (s t t')) => Show (Subenv' s env env') + +type Subenv = Subenv' (:~:) +type SubenvS = Subenv' Sparse + +pattern SEYesR :: forall tenv tenv'. () + => forall t env env'. (tenv ~ t : env, tenv' ~ t : env') + => Subenv env env' -> Subenv tenv tenv' +pattern SEYesR s = SEYes Refl s + +{-# COMPLETE SETop, SEYesR, SENo #-} + +subList :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env' -> SList f env' +subList SNil SETop = SNil +subList (SCons x xs) (SEYes s sub) = SCons (subtApply s x) (subList xs sub) +subList (SCons _ xs) (SENo sub) = subList xs sub + +subenvAll :: (IsSubType s, IsSubTypeSubject s f) => SList f env -> Subenv' s env env +subenvAll SNil = SETop +subenvAll (SCons t env) = SEYes (subtFull t) (subenvAll env) + +subenvNone :: SList f env -> Subenv' s env '[] +subenvNone SNil = SETop +subenvNone (SCons _ env) = SENo (subenvNone env) + +subenvOnehot :: SList f env -> Idx env t -> s t t' -> Subenv' s env '[t'] +subenvOnehot (SCons _ env) IZ sp = SEYes sp (subenvNone env) +subenvOnehot (SCons _ env) (IS i) sp = SENo (subenvOnehot env i sp) +subenvOnehot SNil i _ = case i of {} + +subenvCompose :: IsSubType s => Subenv' s env1 env2 -> Subenv' s env2 env3 -> Subenv' s env1 env3 +subenvCompose SETop SETop = SETop +subenvCompose (SEYes s1 sub1) (SEYes s2 sub2) = SEYes (subtTrans s1 s2) (subenvCompose sub1 sub2) +subenvCompose (SEYes _ sub1) (SENo sub2) = SENo (subenvCompose sub1 sub2) +subenvCompose (SENo sub1) sub2 = SENo (subenvCompose sub1 sub2) + +subenvConcat :: Subenv' s env1 env1' -> Subenv' s env2 env2' -> Subenv' s (Append env2 env1) (Append env2' env1') +subenvConcat sub1 SETop = sub1 +subenvConcat sub1 (SEYes s sub2) = SEYes s (subenvConcat sub1 sub2) +subenvConcat sub1 (SENo sub2) = SENo (subenvConcat sub1 sub2) + +-- subenvSplit :: SList f env1a -> Subenv' s (Append env1a env1b) env2 +-- -> (forall env2a env2b. Subenv' s env1a env2a -> Subenv' s env1b env2b -> r) -> r +-- subenvSplit SNil sub k = k SETop sub +-- subenvSplit (SCons _ list) (SENo sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SENo sub1) sub2 +-- subenvSplit (SCons _ list) (SEYes s sub) k = +-- subenvSplit list sub $ \sub1 sub2 -> +-- k (SEYes s sub1) sub2 + +sinkWithSubenv :: Subenv' s env env' -> env0 :> Append env' env0 +sinkWithSubenv SETop = WId +sinkWithSubenv (SEYes _ sub) = WSink .> sinkWithSubenv sub +sinkWithSubenv (SENo sub) = sinkWithSubenv sub + +wUndoSubenv :: Subenv' (:~:) env env' -> env' :> env +wUndoSubenv SETop = WId +wUndoSubenv (SEYes Refl sub) = WCopy (wUndoSubenv sub) +wUndoSubenv (SENo sub) = WSink .> wUndoSubenv sub + +subenvMap :: (forall a a'. f a -> s a a' -> s' a a') -> SList f env -> Subenv' s env env' -> Subenv' s' env env' +subenvMap _ SNil SETop = SETop +subenvMap f (t `SCons` l) (SEYes s sub) = SEYes (f t s) (subenvMap f l sub) +subenvMap f (_ `SCons` l) (SENo sub) = SENo (subenvMap f l sub) + +subenvD2E :: Subenv env env' -> Subenv (D2E env) (D2E env') +subenvD2E SETop = SETop +subenvD2E (SEYesR sub) = SEYesR (subenvD2E sub) +subenvD2E (SENo sub) = SENo (subenvD2E sub) diff --git a/src/CHAD/AST/Pretty.hs b/src/CHAD/AST/Pretty.hs new file mode 100644 index 0000000..3f6a3af --- /dev/null +++ b/src/CHAD/AST/Pretty.hs @@ -0,0 +1,525 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Pretty (pprintExpr, ppExpr, ppSTy, ppSMTy, PrettyX(..)) where + +import Control.Monad (ap) +import Data.List (intersperse, intercalate) +import Data.Functor.Const +import qualified Data.Functor.Product as Product +import Data.String (fromString) +import Prettyprinter +import Prettyprinter.Render.String + +import qualified Data.Text.Lazy as TL +import qualified Prettyprinter.Render.Terminal as PT +import System.Console.ANSI (hSupportsANSI) +import System.IO (stdout) +import System.IO.Unsafe (unsafePerformIO) + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types + + +class PrettyX x where + prettyX :: x t -> String + + prettyXsuffix :: x t -> String + prettyXsuffix x = "<" ++ prettyX x ++ ">" + +instance PrettyX (Const ()) where + prettyX _ = "" + prettyXsuffix _ = "" + + +type SVal = SList (Const String) + +newtype M a = M { runM :: Int -> (a, Int) } + deriving (Functor) +instance Applicative M where { pure x = M (\i -> (x, i)) ; (<*>) = ap } +instance Monad M where { M f >>= g = M (\i -> let (x, j) = f i in runM (g x) j) } + +genId :: M Int +genId = M (\i -> (i, i + 1)) + +nameBaseForType :: STy t -> String +nameBaseForType STNil = "nil" +nameBaseForType (STPair{}) = "p" +nameBaseForType (STEither{}) = "e" +nameBaseForType (STMaybe{}) = "m" +nameBaseForType (STScal STI32) = "n" +nameBaseForType (STScal STI64) = "n" +nameBaseForType (STArr{}) = "a" +nameBaseForType (STAccum{}) = "ac" +nameBaseForType _ = "x" + +genName' :: String -> M String +genName' prefix = (prefix ++) . show <$> genId + +genNameIfUsedIn' :: String -> STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn' prefix ty idx ex + | occCount idx ex == mempty = case ty of STNil -> return "()" + _ -> return "_" + | otherwise = genName' prefix + +-- TODO: let this return a type-tagged thing so that name environments are more typed than Const +genNameIfUsedIn :: STy a -> Idx env a -> Expr x env t -> M String +genNameIfUsedIn = \t -> genNameIfUsedIn' (nameBaseForType t) t + +pprintExpr :: (KnownEnv env, PrettyX x) => Expr x env t -> IO () +pprintExpr = putStrLn . ppExpr knownEnv + +ppExpr :: PrettyX x => SList STy env -> Expr x env t -> String +ppExpr senv e = render $ fst . flip runM 1 $ do + val <- mkVal senv + e' <- ppExpr' 0 val e + let lam = "λ" ++ intercalate " " (reverse (unSList (\(Product.Pair (Const name) ty) -> "(" ++ name ++ " : " ++ ppSTy 0 ty ++ ")") (slistZip val senv))) ++ "." + return $ group $ flatAlt + (hang 2 $ + ppString lam + <> hardline <> e') + (ppString lam <+> e') + where + mkVal :: SList f env -> M (SVal env) + mkVal SNil = return SNil + mkVal (SCons _ v) = do + val <- mkVal v + name <- genName' "arg" + return (Const name `SCons` val) + +ppExpr' :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExpr' d val expr = case expr of + EVar _ _ i -> return $ ppString (getConst (slistIdx val i)) <> ppX expr + + e@ELet{} -> ppExprLet d val e + + EPair _ a b -> do + a' <- ppExpr' 0 val a + b' <- ppExpr' 0 val b + return $ group $ flatAlt (align $ ppString "(" <> a' <> hardline <> ppString "," <> b' <> ppString ")" <> ppX expr) + (ppString "(" <> a' <> ppString "," <+> b' <> ppString ")" <> ppX expr) + + EFst _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "fst" <> ppX expr <+> e' + + ESnd _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "snd" <> ppX expr <+> e' + + ENil _ -> return $ ppString "()" + + EInl _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "Inl" <> ppX expr <+> e' + + EInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "Inr" <> ppX expr <+> e' + + ECase _ e a b -> do + e' <- ppExpr' 0 val e + let STEither t1 t2 = typeOf e + name1 <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name1 `SCons` val) a + name2 <- genNameIfUsedIn t2 IZ b + b' <- ppExpr' 0 (Const name2 `SCons` val) b + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "Inl" <+> ppString name1 <+> ppString "->" <+> a' + <> hardline <> ppString "Inr" <+> ppString name2 <+> ppString "->" <+> b' + + ENothing _ _ -> return $ ppString "Nothing" + + EJust _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "Just" <> ppX expr <+> e' + + EMaybe _ a b e -> do + let STMaybe t = typeOf e + e' <- ppExpr' 0 val e + a' <- ppExpr' 0 val a + name <- genNameIfUsedIn t IZ b + b' <- ppExpr' 0 (Const name `SCons` val) b + return $ ppParen (d > 0) $ + align $ + group (flatAlt + (annotate AKey (ppString "case") <> ppX expr <+> e' + <> hardline <> annotate AKey (ppString "of")) + (annotate AKey (ppString "case") <> ppX expr <+> e' <+> annotate AKey (ppString "of"))) + <> hardline + <> indent 2 + (ppString "Nothing" <+> ppString "->" <+> a' + <> hardline <> ppString "Just" <+> ppString name <+> ppString "->" <+> b') + + ELNil _ _ _ -> return (ppString "LNil") + + ELInl _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInl" <> ppX expr <+> e' + + ELInr _ _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "LInr" <> ppX expr <+> e' + + ELCase _ e a b c -> do + e' <- ppExpr' 0 val e + let STLEither t1 t2 = typeOf e + a' <- ppExpr' 11 val a + name1 <- genNameIfUsedIn t1 IZ b + b' <- ppExpr' 0 (Const name1 `SCons` val) b + name2 <- genNameIfUsedIn t2 IZ c + c' <- ppExpr' 0 (Const name2 `SCons` val) c + return $ ppParen (d > 0) $ + hang 2 $ + annotate AKey (ppString "lcase") <> ppX expr <+> e' <+> annotate AKey (ppString "of") + <> hardline <> ppString "LNil" <+> ppString "->" <+> a' + <> hardline <> ppString "LInl" <+> ppString name1 <+> ppString "->" <+> b' + <> hardline <> ppString "LInr" <+> ppString name2 <+> ppString "->" <+> c' + + EConstArr _ _ ty v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr + + EBuild _ n a b -> do + a' <- ppExpr' 11 val a + name <- genNameIfUsedIn' "i" (tTup (sreplicate n tIx)) IZ b + e' <- ppExpr' 0 (Const name `SCons` val) b + let primName = ppString ("build" ++ intSubscript (fromSNat n)) + return $ ppParen (d > 0) $ + group $ flatAlt + (hang 2 $ + annotate AHighlight primName <> ppX expr <+> a' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e') + (ppApp (annotate AHighlight primName <> ppX expr) [a', ppLam [ppString name] e']) + + EMap _ a b -> do + let STArr _ t1 = typeOf b + name <- genNameIfUsedIn t1 IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + return $ ppParen (d > 0) $ + ppApp (annotate AHighlight (ppString "map") <> ppX expr) [ppLam [ppString name] a', b'] + + EFold1Inner _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf a) (typeOf a)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1i" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + ESum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "sum1i" <> ppX expr <+> e' + + EUnit _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "unit" <> ppX expr <+> e' + + EReplicate1Inner _ a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ ppParen (d > 10) $ ppApp (ppString "replicate1i" <> ppX expr) [a', b'] + + EMaximum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "maximum1i" <> ppX expr <+> e' + + EMinimum1Inner _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "minimum1i" <> ppX expr <+> e' + + EReshape _ n esh e -> do + esh' <- ppExpr' 11 val esh + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString ("reshape" ++ intSubscript (fromSNat n)) <> ppX expr) [esh', e'] + + EZip _ e1 e2 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ ppApp (ppString "zip" <> ppX expr) [e1', e2'] + + EFold1InnerD1 _ cm a b c -> do + name <- genNameIfUsedIn (STPair (typeOf b) (typeOf b)) IZ a + a' <- ppExpr' 0 (Const name `SCons` val) a + b' <- ppExpr' 11 val b + c' <- ppExpr' 11 val c + let opname = "fold1iD1" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) [ppLam [ppString name] a', b', c'] + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + namef1 <- genNameIfUsedIn tB (IS IZ) ef + namef2 <- genNameIfUsedIn t2 IZ ef + ef' <- ppExpr' 0 (Const namef2 `SCons` Const namef1 `SCons` val) ef + ebog' <- ppExpr' 11 val ebog + ed' <- ppExpr' 11 val ed + let opname = "fold1iD2" ++ ppCommut cm + return $ ppParen (d > 10) $ + ppApp (annotate AHighlight (ppString opname) <> ppX expr) + [ppLam [ppString namef1, ppString namef2] ef', ebog', ed'] + + EConst _ ty v + | Dict <- scalRepIsShow ty -> return $ ppString (showsPrec d v "") <> ppX expr + + EIdx0 _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "idx0" <> ppX expr <+> e' + + EIdx1 _ a b -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 9 val b + return $ ppParen (d > 8) $ a' <+> ppString ".!" <> ppX expr <+> b' + + EIdx _ a b -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 10 val b + return $ ppParen (d > 8) $ + a' <+> ppString "!" <> ppX expr <+> b' + + EShape _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppString "shape" <> ppX expr <+> e' + + EOp _ op (EPair _ a b) + | (Infix, ops) <- operator op -> do + a' <- ppExpr' 9 val a + b' <- ppExpr' 9 val b + return $ ppParen (d > 8) $ a' <+> ppString ops <> ppX expr <+> b' + + EOp _ op e -> do + e' <- ppExpr' 11 val e + let ops = case operator op of + (Infix, s) -> "(" ++ s ++ ")" + (Prefix, s) -> s + return $ ppParen (d > 10) $ ppString ops <> ppX expr <+> e' + + ECustom _ t1 t2 t3 a b c e1 e2 -> do + en1 <- genNameIfUsedIn t1 (IS IZ) a + en2 <- genNameIfUsedIn t2 IZ a + pn1 <- genNameIfUsedIn (d1 t1) (IS IZ) b + pn2 <- genNameIfUsedIn (d1 t2) IZ b + dn1 <- genNameIfUsedIn' "tape" t3 (IS IZ) c + dn2 <- genNameIfUsedIn' "d" (d2 (typeOf a)) IZ c + a' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) a + b' <- ppExpr' 11 (Const pn2 `SCons` Const pn1 `SCons` SNil) b + c' <- ppExpr' 11 (Const dn2 `SCons` Const dn1 `SCons` SNil) c + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + return $ ppParen (d > 10) $ + ppApp (ppString "custom" <> ppX expr) + [ppLam [ppString en1, ppString en2] a' + ,ppLam [ppString pn1, ppString pn2] b' + ,ppLam [ppString dn1, ppString dn2] c' + ,e1' + ,e2'] + + ERecompute _ e -> do + e' <- ppExpr' 11 val e + return $ ppParen (d > 10) $ ppApp (ppString "recompute" <> ppX expr) [e'] + + EWith _ t e1 e2 -> do + e1' <- ppExpr' 11 val e1 + name <- genNameIfUsedIn' "ac" (STAccum t) IZ e2 + e2' <- ppExpr' 0 (Const name `SCons` val) e2 + return $ ppParen (d > 0) $ + group $ flatAlt + (hang 2 $ + annotate AWith (ppString "with") <> ppX expr <+> e1' + <+> ppString "$" <+> ppString "\\" <> ppString name <+> ppString "->" + <> hardline <> e2') + (ppApp (annotate AWith (ppString "with") <> ppX expr) [e1', ppLam [ppString name] e2']) + + EAccum _ t prj e1 sp e2 e3 -> do + e1' <- ppExpr' 11 val e1 + e2' <- ppExpr' 11 val e2 + e3' <- ppExpr' 11 val e3 + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "accum") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (applySparse sp (acPrjTy prj t))) + [ppString (ppAcPrj t prj), ppString (ppSparse (acPrjTy prj t) sp), e1', e2', e3'] + + EZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "zero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + + EDeepZero _ t e1 -> do + e1' <- ppExpr' 11 val e1 + return $ ppParen (d > 0) $ + annotate AMonoid (ppString "deepzero") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t <+> e1' + + EPlus _ t a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "plus") <> ppX expr <+> ppString "@" <> ppSMTy' 11 t) [a', b'] + + EOneHot _ t prj a b -> do + a' <- ppExpr' 11 val a + b' <- ppExpr' 11 val b + return $ ppParen (d > 10) $ + ppApp (annotate AMonoid (ppString "onehot") <> ppX expr <+> ppString "@" <> ppSMTy' 11 (acPrjTy prj t)) [ppString (ppAcPrj t prj), a', b'] + + EError _ _ s -> return $ ppParen (d > 10) $ ppString "error" <> ppX expr <+> ppString (show s) + +ppExprLet :: PrettyX x => Int -> SVal env -> Expr x env t -> M ADoc +ppExprLet d val etop = do + let collect :: PrettyX x => SVal env -> Expr x env t -> M ([(String, Occ, ADoc)], ADoc) + collect val' (ELet _ rhs body) = do + let occ = occCount IZ body + name <- genNameIfUsedIn (typeOf rhs) IZ body + rhs' <- ppExpr' 0 val' rhs + (binds, core) <- collect (Const name `SCons` val') body + return ((name, occ, rhs') : binds, core) + collect val' e = ([],) <$> ppExpr' 0 val' e + + (binds, core) <- collect val etop + + return $ ppParen (d > 0) $ + align $ + annotate AKey (ppString "let") + <+> align (mconcat $ intersperse hardline $ + map (\(name, _occ, rhs) -> + ppString (name ++ {- " (" ++ show _occ ++ ")" ++ -} " = ") <> rhs) + binds) + <> hardline <> annotate AKey (ppString "in") <+> core + +ppApp :: ADoc -> [ADoc] -> ADoc +ppApp fun args = group $ fun <+> align (sep args) + +ppLam :: [ADoc] -> ADoc -> ADoc +ppLam args body = ppString "(" <> hang 2 (ppString "\\" <> sep (args ++ [ppString "->"]) + <> softline <> body <> ppString ")") + +ppAcPrj :: SMTy a -> SAcPrj p a b -> String +ppAcPrj _ SAPHere = "." +ppAcPrj (SMTPair t _) (SAPFst prj) = "(" ++ ppAcPrj t prj ++ ",)" +ppAcPrj (SMTPair _ t) (SAPSnd prj) = "(," ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTLEither t _) (SAPLeft prj) = "(" ++ ppAcPrj t prj ++ "|)" +ppAcPrj (SMTLEither _ t) (SAPRight prj) = "(|" ++ ppAcPrj t prj ++ ")" +ppAcPrj (SMTMaybe t) (SAPJust prj) = "J" ++ ppAcPrj t prj +ppAcPrj (SMTArr n t) (SAPArrIdx prj) = "[" ++ ppAcPrj t prj ++ "]" ++ intSubscript (fromSNat n) + +ppSparse :: SMTy a -> Sparse a b -> String +ppSparse t sp | Just Refl <- isDense t sp = "D" +ppSparse _ SpAbsent = "A" +ppSparse t (SpSparse s) = "S" ++ ppSparse t s +ppSparse (SMTPair t1 t2) (SpPair s1 s2) = "(" ++ ppSparse t1 s1 ++ "," ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTLEither t1 t2) (SpLEither s1 s2) = "(" ++ ppSparse t1 s1 ++ "|" ++ ppSparse t2 s2 ++ ")" +ppSparse (SMTMaybe t) (SpMaybe s) = "M" ++ ppSparse t s +ppSparse (SMTArr _ t) (SpArr s) = "A" ++ ppSparse t s +ppSparse (SMTScal _) SpScal = "." + +ppCommut :: Commutative -> String +ppCommut Commut = "(C)" +ppCommut Noncommut = "" + +ppX :: PrettyX x => Expr x env t -> ADoc +ppX expr = annotate AExt $ ppString $ prettyXsuffix (extOf expr) + +data Fixity = Prefix | Infix + deriving (Show) + +operator :: SOp a t -> (Fixity, String) +operator OAdd{} = (Infix, "+") +operator OMul{} = (Infix, "*") +operator ONeg{} = (Prefix, "negate") +operator OLt{} = (Infix, "<") +operator OLe{} = (Infix, "<=") +operator OEq{} = (Infix, "==") +operator ONot = (Prefix, "not") +operator OAnd = (Infix, "&&") +operator OOr = (Infix, "||") +operator OIf = (Prefix, "ifB") +operator ORound64 = (Prefix, "round") +operator OToFl64 = (Prefix, "toFl64") +operator ORecip{} = (Prefix, "recip") +operator OExp{} = (Prefix, "exp") +operator OLog{} = (Prefix, "log") +operator OIDiv{} = (Infix, "`div`") +operator OMod{} = (Infix, "`mod`") + +ppSTy :: Int -> STy t -> String +ppSTy d ty = render $ ppSTy' d ty + +ppSTy' :: Int -> STy t -> Doc q +ppSTy' _ STNil = ppString "1" +ppSTy' d (STPair a b) = ppParen (d > 7) $ ppSTy' 8 a <> ppString " * " <> ppSTy' 8 b +ppSTy' d (STEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " + " <> ppSTy' 7 b +ppSTy' d (STLEither a b) = ppParen (d > 6) $ ppSTy' 7 a <> ppString " ⊕ " <> ppSTy' 7 b +ppSTy' d (STMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSTy' 11 t +ppSTy' d (STArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSTy' 11 t +ppSTy' _ (STScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + STBool -> "bool" +ppSTy' d (STAccum t) = ppParen (d > 10) $ ppString "Accum " <> ppSMTy' 11 t + +ppSMTy :: Int -> SMTy t -> String +ppSMTy d ty = render $ ppSMTy' d ty + +ppSMTy' :: Int -> SMTy t -> Doc q +ppSMTy' _ SMTNil = ppString "1" +ppSMTy' d (SMTPair a b) = ppParen (d > 7) $ ppSMTy' 8 a <> ppString " * " <> ppSMTy' 8 b +ppSMTy' d (SMTLEither a b) = ppParen (d > 6) $ ppSMTy' 7 a <> ppString " ⊕ " <> ppSMTy' 7 b +ppSMTy' d (SMTMaybe t) = ppParen (d > 10) $ ppString "Maybe " <> ppSMTy' 11 t +ppSMTy' d (SMTArr n t) = ppParen (d > 10) $ + ppString "Arr " <> ppString (show (fromSNat n)) <> ppString " " <> ppSMTy' 11 t +ppSMTy' _ (SMTScal sty) = ppString $ case sty of + STI32 -> "i32" + STI64 -> "i64" + STF32 -> "f32" + STF64 -> "f64" + +ppString :: String -> Doc x +ppString = fromString + +ppParen :: Bool -> Doc x -> Doc x +ppParen True = parens +ppParen False = id + +intSubscript :: Int -> String +intSubscript = \case 0 -> "₀" + n | n < 0 -> '₋' : go (-n) "" + | otherwise -> go n "" + where go 0 suff = suff + go n suff = let (q, r) = n `quotRem` 10 + in go q ("₀₁₂₃₄₅₆₇₈₉" !! r : suff) + +data Annot = AKey | AWith | AHighlight | AMonoid | AExt + deriving (Show) + +annotToANSI :: Annot -> PT.AnsiStyle +annotToANSI AKey = PT.bold +annotToANSI AWith = PT.color PT.Red <> PT.underlined +annotToANSI AHighlight = PT.color PT.Blue +annotToANSI AMonoid = PT.color PT.Green +annotToANSI AExt = PT.colorDull PT.White + +type ADoc = Doc Annot + +render :: Doc Annot -> String +render = + (if stdoutTTY then TL.unpack . PT.renderLazy . reAnnotateS annotToANSI + else renderString) + . layoutPretty LayoutOptions { layoutPageWidth = AvailablePerLine 120 1.0 } + where + stdoutTTY = unsafePerformIO $ hSupportsANSI stdout diff --git a/src/CHAD/AST/Sparse.hs b/src/CHAD/AST/Sparse.hs new file mode 100644 index 0000000..9156160 --- /dev/null +++ b/src/CHAD/AST/Sparse.hs @@ -0,0 +1,287 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE RankNTypes #-} + +{-# OPTIONS_GHC -fmax-pmcheck-models=80 #-} +module CHAD.AST.Sparse (module CHAD.AST.Sparse, module CHAD.AST.Sparse.Types) where + +import Data.Type.Equality + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data (SBool(..)) + + +sparsePlus :: SMTy t -> Sparse t t' -> Ex env t' -> Ex env t' -> Ex env t' +sparsePlus _ SpAbsent e1 e2 = use e1 $ use e2 $ ENil ext +sparsePlus t sp e1 e2 | Just Refl <- isDense t sp = EPlus ext t e1 e2 +sparsePlus t (SpSparse sp) e1 e2 = sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 -- heh +sparsePlus (SMTPair t1 t2) (SpPair sp1 sp2) e1 e2 = + eunPair e1 $ \w1 e1a e1b -> + eunPair (weakenExpr w1 e2) $ \w2 e2a e2b -> + EPair ext (sparsePlus t1 sp1 (weakenExpr w2 e1a) e2a) + (sparsePlus t2 sp2 (weakenExpr w2 e1b) e2b) +sparsePlus (SMTLEither t1 t2) (SpLEither sp1 sp2) e1 e2 = + elet e2 $ + elcase (weakenExpr WSink e1) + (evar IZ) + (elcase (evar (IS IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (evar IZ)) + (ELInl ext (applySparse sp2 (fromSMTy t2)) (sparsePlus t1 sp1 (evar (IS IZ)) (evar IZ))) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus ll+lr")) + (elcase (evar (IS IZ)) + (ELInr ext (applySparse sp1 (fromSMTy t1)) (evar IZ)) + (EError ext (fromSMTy (applySparse (SpLEither sp1 sp2) (SMTLEither t1 t2))) "splus lr+ll") + (ELInr ext (applySparse sp1 (fromSMTy t1)) (sparsePlus t2 sp2 (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTMaybe t) (SpMaybe sp) e1 e2 = + elet e2 $ + emaybe (weakenExpr WSink e1) + (evar IZ) + (emaybe (evar (IS IZ)) + (EJust ext (evar IZ)) + (EJust ext (sparsePlus t sp (evar (IS IZ)) (evar IZ)))) +sparsePlus (SMTArr _ t) (SpArr sp) e1 e2 = ezipWith (sparsePlus t sp (evar (IS IZ)) (evar IZ)) e1 e2 +sparsePlus t@SMTScal{} SpScal e1 e2 = EPlus ext t e1 e2 + + +cheapZero :: SMTy t -> Maybe (forall env. Ex env t) +cheapZero SMTNil = Just (ENil ext) +cheapZero (SMTPair t1 t2) + | Just e1 <- cheapZero t1 + , Just e2 <- cheapZero t2 + = Just (EPair ext e1 e2) + | otherwise + = Nothing +cheapZero (SMTLEither t1 t2) = Just (ELNil ext (fromSMTy t1) (fromSMTy t2)) +cheapZero (SMTMaybe t) = Just (ENothing ext (fromSMTy t)) +cheapZero SMTArr{} = Nothing +cheapZero (SMTScal t) = case t of + STI32 -> Just (EConst ext t 0) + STI64 -> Just (EConst ext t 0) + STF32 -> Just (EConst ext t 0.0) + STF64 -> Just (EConst ext t 0.0) + + +data Injection sp a b where + -- | 'Inj' is purposefully also allowed when @sp@ is @False@ so that + -- 'sparsePlusS' can provide injections even if the caller doesn't require + -- them. This simplifies the sparsePlusS code. + Inj :: (forall e. Ex e a -> Ex e b) -> Injection sp a b + Noinj :: Injection False a b + +withInj :: Injection sp a b -> ((forall e. Ex e a -> Ex e b) -> (forall e'. Ex e' a' -> Ex e' b')) -> Injection sp a' b' +withInj (Inj f) k = Inj (k f) +withInj Noinj _ = Noinj + +withInj2 :: Injection sp a1 b1 -> Injection sp a2 b2 + -> ((forall e. Ex e a1 -> Ex e b1) + -> (forall e. Ex e a2 -> Ex e b2) + -> (forall e'. Ex e' a' -> Ex e' b')) + -> Injection sp a' b' +withInj2 (Inj f) (Inj g) k = Inj (k f g) +withInj2 Noinj _ _ = Noinj +withInj2 _ Noinj _ = Noinj + +-- | This function produces quadratically-sized code in the presence of nested +-- dynamic sparsity. TODO can this be improved? +sparsePlusS + :: SBool inj1 -> SBool inj2 + -> SMTy t -> Sparse t t1 -> Sparse t t2 + -> (forall t3. Sparse t t3 + -> Injection inj1 t1 t3 -- only available if first injection is requested (second argument may be absent) + -> Injection inj2 t2 t3 -- only available if second injection is requested (first argument may be absent) + -> (forall e. Ex e t1 -> Ex e t2 -> Ex e t3) + -> r) + -> r +-- nil override (but don't destroy effects!) +sparsePlusS _ _ SMTNil _ _ k = + k SpAbsent (Inj $ \a -> use a $ ENil ext) (Inj $ \b -> use b $ ENil ext) (\a b -> use a $ use b $ ENil ext) + +-- simplifications +sparsePlusS req1 req2 t (SpSparse SpAbsent) sp2 k = + sparsePlusS req1 req2 t SpAbsent sp2 $ \sp3 minj1 minj2 plus -> + k sp3 (withInj minj1 $ \inj1 -> \a -> use a $ inj1 (ENil ext)) minj2 (\a b -> use a $ plus (ENil ext) b) +sparsePlusS req1 req2 t sp1 (SpSparse SpAbsent) k = + sparsePlusS req1 req2 t sp1 SpAbsent $ \sp3 minj1 minj2 plus -> + k sp3 minj1 (withInj minj2 $ \inj2 -> \b -> use b $ inj2 (ENil ext)) (\a b -> use b $ plus a (ENil ext)) + +sparsePlusS req1 req2 t (SpSparse (SpSparse sp1)) sp2 k = + let ta = applySparse sp1 (fromSMTy t) in + sparsePlusS req1 req2 t (SpSparse sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpSparse sp2)) k = + let tb = applySparse sp2 (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpSparse sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpLEither sp1a sp1b)) sp2 k = + let STLEither ta tb = applySparse (SpLEither sp1a sp1b) (fromSMTy t) in + sparsePlusS req1 req2 t (SpLEither sp1a sp1b) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + minj2 + (\a b -> plus (emaybe a (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpLEither sp2a sp2b)) k = + let STLEither ta tb = applySparse (SpLEither sp2a sp2b) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpLEither sp2a sp2b) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + (\a b -> plus a (emaybe b (ELNil ext ta tb) (EVar ext (STLEither ta tb) IZ))) + +sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k = + let STMaybe ta = applySparse (SpMaybe sp1) (fromSMTy t) in + sparsePlusS req1 req2 t (SpMaybe sp1) sp2 $ \sp3 minj1 minj2 plus -> + k sp3 + (withInj minj1 $ \inj1 -> \a -> inj1 (emaybe a (ENothing ext ta) (evar IZ))) + minj2 + (\a b -> plus (emaybe a (ENothing ext ta) (EVar ext (STMaybe ta) IZ)) b) +sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k = + let STMaybe tb = applySparse (SpMaybe sp2) (fromSMTy t) in + sparsePlusS req1 req2 t sp1 (SpMaybe sp2) $ \sp3 minj1 minj2 plus -> + k sp3 + minj1 + (withInj minj2 $ \inj2 -> \b -> inj2 (emaybe b (ENothing ext tb) (evar IZ))) + (\a b -> plus a (emaybe b (ENothing ext tb) (EVar ext (STMaybe tb) IZ))) +sparsePlusS req1 req2 t (SpMaybe (SpSparse sp1)) sp2 k = sparsePlusS req1 req2 t (SpSparse (SpMaybe sp1)) sp2 k +sparsePlusS req1 req2 t sp1 (SpMaybe (SpSparse sp2)) k = sparsePlusS req1 req2 t sp1 (SpSparse (SpMaybe sp2)) k + +-- TODO: sparse of Just is just Maybe + +-- dense plus +sparsePlusS _ _ t sp1 sp2 k + | Just Refl <- isDense t sp1 + , Just Refl <- isDense t sp2 + = k (spDense t) (Inj id) (Inj id) (\a b -> EPlus ext t a b) + +-- handle absents +sparsePlusS SF _ _ SpAbsent sp2 k = k sp2 Noinj (Inj id) (\a b -> use a $ b) +sparsePlusS ST _ t SpAbsent sp2 k + | Just zero2 <- cheapZero (applySparse sp2 t) = + k sp2 (Inj $ \a -> use a $ zero2) (Inj id) (\a b -> use a $ b) + | otherwise = + k (SpSparse sp2) (Inj $ \a -> use a $ ENothing ext (applySparse sp2 (fromSMTy t))) (Inj $ EJust ext) (\a b -> use a $ EJust ext b) + +sparsePlusS _ SF _ sp1 SpAbsent k = k sp1 (Inj id) Noinj (\a b -> use b $ a) +sparsePlusS _ ST t sp1 SpAbsent k + | Just zero1 <- cheapZero (applySparse sp1 t) = + k sp1 (Inj id) (Inj $ \b -> use b $ zero1) (\a b -> use b $ a) + | otherwise = + k (SpSparse sp1) (Inj $ EJust ext) (Inj $ \b -> use b $ ENothing ext (applySparse sp1 (fromSMTy t))) (\a b -> use b $ EJust ext a) + +-- double sparse yields sparse +sparsePlusS _ _ t (SpSparse sp1) (SpSparse sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- single sparse can yield non-sparse if the other argument is always present +sparsePlusS SF _ t (SpSparse sp1) sp2 k = + sparsePlusS SF ST t sp1 sp2 $ \sp3 _ (Inj inj2) plus -> + k sp3 Noinj (Inj inj2) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (inj2 (evar IZ)) + (plus (evar IZ) (evar (IS IZ)))) +sparsePlusS ST _ t (SpSparse sp1) sp2 k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpSparse sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> EJust ext (inj2 b)) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (EJust ext (inj2 (evar IZ))) + (EJust ext (plus (evar IZ) (evar (IS IZ))))) +sparsePlusS req1 req2 t sp1 (SpSparse sp2) k = + sparsePlusS req2 req1 t (SpSparse sp2) sp1 $ \sp3 inj1 inj2 plus -> + k sp3 inj2 inj1 (flip plus) + +-- products +sparsePlusS req1 req2 (SMTPair ta tb) (SpPair sp1a sp1b) (SpPair sp2a sp2b) k = + sparsePlusS req1 req2 ta sp1a sp2a $ \sp3a minj13a minj23a plusa -> + sparsePlusS req1 req2 tb sp1b sp2b $ \sp3b minj13b minj23b plusb -> + k (SpPair sp3a sp3b) + (withInj2 minj13a minj13b $ \inj13a inj13b -> + \x1 -> eunPair x1 $ \_ x1a x1b -> EPair ext (inj13a x1a) (inj13b x1b)) + (withInj2 minj23a minj23b $ \inj23a inj23b -> + \x2 -> eunPair x2 $ \_ x2a x2b -> EPair ext (inj23a x2a) (inj23b x2b)) + (\x1 x2 -> + eunPair x1 $ \w1 x1a x1b -> + eunPair (weakenExpr w1 x2) $ \w2 x2a x2b -> + EPair ext (plusa (weakenExpr w2 x1a) x2a) (plusb (weakenExpr w2 x1b) x2b)) + +-- coproducts +sparsePlusS _ _ (SMTLEither ta tb) (SpLEither sp1a sp1b) (SpLEither sp2a sp2b) k = + sparsePlusS ST ST ta sp1a sp2a $ \(sp3a :: Sparse _t3 t3a) (Inj inj13a) (Inj inj23a) plusa -> + sparsePlusS ST ST tb sp1b sp2b $ \(sp3b :: Sparse _t3' t3b) (Inj inj13b) (Inj inj23b) plusb -> + let nil :: Ex e (TLEither t3a t3b) ; nil = ELNil ext (applySparse sp3a (fromSMTy ta)) (applySparse sp3b (fromSMTy tb)) + inl :: Ex e t3a -> Ex e (TLEither t3a t3b) ; inl = ELInl ext (applySparse sp3b (fromSMTy tb)) + inr :: Ex e t3b -> Ex e (TLEither t3a t3b) ; inr = ELInr ext (applySparse sp3a (fromSMTy ta)) + in + k (SpLEither sp3a sp3b) + (Inj $ \x1 -> elcase x1 nil (inl (inj13a (evar IZ))) (inr (inj13b (evar IZ)))) + (Inj $ \x2 -> elcase x2 nil (inl (inj23a (evar IZ))) (inr (inj23b (evar IZ)))) + (\x1 x2 -> + elet x2 $ + elcase (weakenExpr WSink x1) + (elcase (evar IZ) + nil + (inl (inj23a (evar IZ))) + (inr (inj23b (evar IZ)))) + (elcase (evar (IS IZ)) + (inl (inj13a (evar IZ))) + (inl (plusa (evar (IS IZ)) (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS ll+lr")) + (elcase (evar (IS IZ)) + (inr (inj13b (evar IZ))) + (EError ext (applySparse (SpLEither sp3a sp3b) (fromSMTy (SMTLEither ta tb))) "plusS lr+ll") + (inr (plusb (evar (IS IZ)) (evar IZ))))) + +-- maybe +sparsePlusS _ _ (SMTMaybe t) (SpMaybe sp1) (SpMaybe sp2) k = + sparsePlusS ST ST t sp1 sp2 $ \sp3 (Inj inj1) (Inj inj2) plus -> + k (SpMaybe sp3) + (Inj $ \a -> emaybe a (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj1 (evar IZ)))) + (Inj $ \b -> emaybe b (ENothing ext (applySparse sp3 (fromSMTy t))) (EJust ext (inj2 (evar IZ)))) + (\a b -> + elet b $ + emaybe (weakenExpr WSink a) + (emaybe (evar IZ) + (ENothing ext (applySparse sp3 (fromSMTy t))) + (EJust ext (inj2 (evar IZ)))) + (emaybe (evar (IS IZ)) + (EJust ext (inj1 (evar IZ))) + (EJust ext (plus (evar (IS IZ)) (evar IZ))))) + +-- dense array cotangents simply recurse +sparsePlusS req1 req2 (SMTArr _ t) (SpArr sp1) (SpArr sp2) k = + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 minj1 minj2 plus -> + k (SpArr sp3) + (withInj minj1 $ \inj1 -> emap (inj1 (EVar ext (applySparse sp1 (fromSMTy t)) IZ))) + (withInj minj2 $ \inj2 -> emap (inj2 (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + (ezipWith (plus (EVar ext (applySparse sp1 (fromSMTy t)) (IS IZ)) + (EVar ext (applySparse sp2 (fromSMTy t)) IZ))) + +-- scalars +sparsePlusS _ _ (SMTScal t) SpScal SpScal k = k SpScal (Inj id) (Inj id) (EPlus ext (SMTScal t)) diff --git a/src/CHAD/AST/Sparse/Types.hs b/src/CHAD/AST/Sparse/Types.hs new file mode 100644 index 0000000..8f41ba4 --- /dev/null +++ b/src/CHAD/AST/Sparse/Types.hs @@ -0,0 +1,107 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Sparse.Types where + +import Data.Kind (Type, Constraint) +import Data.Type.Equality + +import CHAD.AST.Types + + +data Sparse t t' where + SpSparse :: Sparse t t' -> Sparse t (TMaybe t') + SpAbsent :: Sparse t TNil + + SpPair :: Sparse a a' -> Sparse b b' -> Sparse (TPair a b) (TPair a' b') + SpLEither :: Sparse a a' -> Sparse b b' -> Sparse (TLEither a b) (TLEither a' b') + SpMaybe :: Sparse t t' -> Sparse (TMaybe t) (TMaybe t') + SpArr :: Sparse t t' -> Sparse (TArr n t) (TArr n t') + SpScal :: Sparse (TScal t) (TScal t) +deriving instance Show (Sparse t t') + +class ApplySparse f where + applySparse :: Sparse t t' -> f t -> f t' + +instance ApplySparse STy where + applySparse (SpSparse s) t = STMaybe (applySparse s t) + applySparse SpAbsent _ = STNil + applySparse (SpPair s1 s2) (STPair t1 t2) = STPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (STLEither t1 t2) = STLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (STMaybe t) = STMaybe (applySparse s t) + applySparse (SpArr s) (STArr n t) = STArr n (applySparse s t) + applySparse SpScal t = t + +instance ApplySparse SMTy where + applySparse (SpSparse s) t = SMTMaybe (applySparse s t) + applySparse SpAbsent _ = SMTNil + applySparse (SpPair s1 s2) (SMTPair t1 t2) = SMTPair (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpLEither s1 s2) (SMTLEither t1 t2) = SMTLEither (applySparse s1 t1) (applySparse s2 t2) + applySparse (SpMaybe s) (SMTMaybe t) = SMTMaybe (applySparse s t) + applySparse (SpArr s) (SMTArr n t) = SMTArr n (applySparse s t) + applySparse SpScal t = t + + +class IsSubType s where + type IsSubTypeSubject (s :: k -> k -> Type) (f :: k -> Type) :: Constraint + subtApply :: IsSubTypeSubject s f => s t t' -> f t -> f t' + subtTrans :: s a b -> s b c -> s a c + subtFull :: IsSubTypeSubject s f => f t -> s t t + +instance IsSubType (:~:) where + type IsSubTypeSubject (:~:) f = () + subtApply = gcastWith + subtTrans = trans + subtFull _ = Refl + +instance IsSubType Sparse where + type IsSubTypeSubject Sparse f = f ~ SMTy + subtApply = applySparse + + subtTrans s1 (SpSparse s2) = SpSparse (subtTrans s1 s2) + subtTrans _ SpAbsent = SpAbsent + subtTrans (SpPair s1a s1b) (SpPair s2a s2b) = SpPair (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpLEither s1a s1b) (SpLEither s2a s2b) = SpLEither (subtTrans s1a s2a) (subtTrans s1b s2b) + subtTrans (SpSparse s1) (SpMaybe s2) = SpSparse (subtTrans s1 s2) + subtTrans (SpMaybe s1) (SpMaybe s2) = SpMaybe (subtTrans s1 s2) + subtTrans (SpArr s1) (SpArr s2) = SpArr (subtTrans s1 s2) + subtTrans SpScal SpScal = SpScal + + subtFull = spDense + +spDense :: SMTy t -> Sparse t t +spDense SMTNil = SpAbsent +spDense (SMTPair t1 t2) = SpPair (spDense t1) (spDense t2) +spDense (SMTLEither t1 t2) = SpLEither (spDense t1) (spDense t2) +spDense (SMTMaybe t) = SpMaybe (spDense t) +spDense (SMTArr _ t) = SpArr (spDense t) +spDense (SMTScal _) = SpScal + +isDense :: SMTy t -> Sparse t t' -> Maybe (t :~: t') +isDense SMTNil SpAbsent = Just Refl +isDense _ SpSparse{} = Nothing +isDense _ SpAbsent = Nothing +isDense (SMTPair t1 t2) (SpPair s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTLEither t1 t2) (SpLEither s1 s2) + | Just Refl <- isDense t1 s1, Just Refl <- isDense t2 s2 = Just Refl + | otherwise = Nothing +isDense (SMTMaybe t) (SpMaybe s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTArr _ t) (SpArr s) + | Just Refl <- isDense t s = Just Refl + | otherwise = Nothing +isDense (SMTScal _) SpScal = Just Refl + +isAbsent :: Sparse t t' -> Bool +isAbsent (SpSparse s) = isAbsent s +isAbsent SpAbsent = True +isAbsent (SpPair s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpLEither s1 s2) = isAbsent s1 && isAbsent s2 +isAbsent (SpMaybe s) = isAbsent s +isAbsent (SpArr s) = isAbsent s +isAbsent SpScal = False diff --git a/src/CHAD/AST/SplitLets.hs b/src/CHAD/AST/SplitLets.hs new file mode 100644 index 0000000..34267e4 --- /dev/null +++ b/src/CHAD/AST/SplitLets.hs @@ -0,0 +1,191 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.AST.SplitLets (splitLets) where + +import Data.Type.Equality + +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.Lemmas + + +splitLets :: Ex env t -> Ex env t +splitLets = splitLets' (\t i w -> EVar ext t (w @> i)) + +splitLets' :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) -> Ex env t -> Ex env' t +splitLets' = \sub -> \case + EVar _ t i -> sub t i WId + ELet _ rhs body -> ELet ext (splitLets' sub rhs) (split1 sub (typeOf rhs) body) + ECase x e a b -> + let STEither t1 t2 = typeOf e + in ECase x (splitLets' sub e) (split1 sub t1 a) (split1 sub t2 b) + EMaybe x a b e -> + let STMaybe t1 = typeOf e + in EMaybe x (splitLets' sub a) (split1 sub t1 b) (splitLets' sub e) + ELCase x e a b c -> + let STLEither t1 t2 = typeOf e + in ELCase x (splitLets' sub e) (splitLets' sub a) (split1 sub t1 b) (split1 sub t2 c) + EFold1Inner x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1Inner x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD1 x cm a b c -> + let STArr _ t1 = typeOf c + in EFold1InnerD1 x cm (split1 sub (STPair t1 t1) a) (splitLets' sub b) (splitLets' sub c) + EFold1InnerD2 x cm a b c -> + let STArr _ tB = typeOf b + STArr _ t2 = typeOf c + in EFold1InnerD2 x cm (split2 sub tB t2 a) (splitLets' sub b) (splitLets' sub c) + + EPair x a b -> EPair x (splitLets' sub a) (splitLets' sub b) + EFst x e -> EFst x (splitLets' sub e) + ESnd x e -> ESnd x (splitLets' sub e) + ENil x -> ENil x + EInl x t e -> EInl x t (splitLets' sub e) + EInr x t e -> EInr x t (splitLets' sub e) + ENothing x t -> ENothing x t + EJust x e -> EJust x (splitLets' sub e) + ELNil x t1 t2 -> ELNil x t1 t2 + ELInl x t e -> ELInl x t (splitLets' sub e) + ELInr x t e -> ELInr x t (splitLets' sub e) + EConstArr x n t a -> EConstArr x n t a + EBuild x n a b -> EBuild x n (splitLets' sub a) (splitLets' (sinkF sub) b) + EMap x a b -> EMap x (splitLets' (sinkF sub) a) (splitLets' sub b) + ESum1Inner x e -> ESum1Inner x (splitLets' sub e) + EUnit x e -> EUnit x (splitLets' sub e) + EReplicate1Inner x a b -> EReplicate1Inner x (splitLets' sub a) (splitLets' sub b) + EMaximum1Inner x e -> EMaximum1Inner x (splitLets' sub e) + EMinimum1Inner x e -> EMinimum1Inner x (splitLets' sub e) + EReshape x n a b -> EReshape x n (splitLets' sub a) (splitLets' sub b) + EZip x a b -> EZip x (splitLets' sub a) (splitLets' sub b) + EConst x t v -> EConst x t v + EIdx0 x e -> EIdx0 x (splitLets' sub e) + EIdx1 x a b -> EIdx1 x (splitLets' sub a) (splitLets' sub b) + EIdx x e es -> EIdx x (splitLets' sub e) (splitLets' sub es) + EShape x e -> EShape x (splitLets' sub e) + EOp x op e -> EOp x op (splitLets' sub e) + ECustom x s t p a b c e1 e2 -> ECustom x s t p a b c (splitLets' sub e1) (splitLets' sub e2) + ERecompute x e -> ERecompute x (splitLets' sub e) + EWith x t e1 e2 -> EWith x t (splitLets' sub e1) (splitLets' (sinkF sub) e2) + EAccum x t p e1 sp e2 e3 -> EAccum x t p (splitLets' sub e1) sp (splitLets' sub e2) (splitLets' sub e3) + EZero x t ezi -> EZero x t (splitLets' sub ezi) + EDeepZero x t ezi -> EDeepZero x t (splitLets' sub ezi) + EPlus x t a b -> EPlus x t (splitLets' sub a) (splitLets' sub b) + EOneHot x t p a b -> EOneHot x t p (splitLets' sub a) (splitLets' sub b) + EError x t s -> EError x t s + where + sinkF :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy t -> Idx (b : env) t -> (b : env') :> env3 -> Ex env3 t + sinkF _ t IZ w = EVar ext t (w @> IZ) + sinkF f t (IS i) w = f t i (w .> WSink) + + split1 :: (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind -> Ex (bind : env) t -> Ex (bind : env') t + split1 sub (tbind :: STy bind) body = + let (ptrs, bs) = split tbind + in letBinds bs $ + splitLets' (\cases _ IZ w -> subPointers ptrs w + t (IS i) w -> sub t i (WPop @bind (wPops (bindingsBinds bs) w))) + body + + split2 :: forall bind1 bind2 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> Ex (bind2 : bind1 : env) t -> Ex (bind2 : bind1 : env') t + split2 sub tbind1 tbind2 body = + let (ptrs1', bs1') = split @env' tbind1 + bs1 = fst (weakenBindingsE WSink bs1') + (ptrs2, bs2) = split @(bind1 : env') tbind2 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE (sinkWithBindings @(bind2 : bind1 : env') bs1) bs2)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs2 (w .> wCopies (bindingsBinds bs2) (wSinks @(bind2 : bind1 : env') (bindingsBinds bs1))) + _ (IS IZ) w -> subPointers ptrs1' (w .> wSinks (bindingsBinds bs2) .> wCopies (bindingsBinds bs1) (WSink @bind2 @(bind1 : env'))) + t (IS (IS i)) w -> sub t i (WPop @bind1 (WPop @bind2 (wPops (bindingsBinds bs1) (wPops (bindingsBinds bs2) w))))) + body + + -- TODO: abstract this to splitN lol wtf + _split4 :: forall bind1 bind2 bind3 bind4 env' env t. + (forall a env2. STy a -> Idx env a -> env' :> env2 -> Ex env2 a) + -> STy bind1 -> STy bind2 -> STy bind3 -> STy bind4 -> Ex (bind4 : bind3 : bind2 : bind1 : env) t -> Ex (bind4 : bind3 : bind2 : bind1 : env') t + _split4 sub tbind1 tbind2 tbind3 tbind4 body = + let (ptrs1, bs1') = split @env' tbind1 + (ptrs2, bs2') = split @(bind1 : env') tbind2 + (ptrs3, bs3') = split @(bind2 : bind1 : env') tbind3 + (ptrs4, bs4) = split @(bind3 : bind2 : bind1 : env') tbind4 + bs1 = fst (weakenBindingsE (WSink .> WSink .> WSink) bs1') + bs2 = fst (weakenBindingsE (WSink .> WSink) bs2') + bs3 = fst (weakenBindingsE WSink bs3') + b1 = bindingsBinds bs1 + b2 = bindingsBinds bs2 + b3 = bindingsBinds bs3 + b4 = bindingsBinds bs4 + in letBinds bs1 $ + letBinds (fst (weakenBindingsE ( sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs2)) $ + letBinds (fst (weakenBindingsE ( sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs3)) $ + letBinds (fst (weakenBindingsE (sinkWithBindings bs3 .> sinkWithBindings bs2 .> sinkWithBindings @(bind4 : bind3 : bind2 : bind1 : env') bs1) bs4)) $ + splitLets' (\cases _ IZ w -> subPointers ptrs4 (w .> wCopies b4 (wSinks b3 .> wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1)) + _ (IS IZ) w -> subPointers ptrs3 (w .> wSinks b4 .> wCopies b3 (wSinks b2 .> wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink)) + _ (IS (IS IZ)) w -> subPointers ptrs2 (w .> wSinks b4 .> wSinks b3 .> wCopies b2 (wSinks @(bind4 : bind3 : bind2 : bind1 : env') b1 .> WSink .> WSink)) + _ (IS (IS (IS IZ))) w -> subPointers ptrs1 (w .> wSinks b4 .> wSinks b3 .> wSinks b2 .> wCopies b1 (WSink @bind4 .> WSink @bind3 .> WSink @bind2 @(bind1 : env'))) + t (IS (IS (IS (IS i)))) w -> sub t i (WPop @bind1 (WPop @bind2 (WPop @bind3 (WPop @bind4 (wPops b1 (wPops b2 (wPops b3 (wPops b4 w))))))))) + body + +type family Split t where + Split (TPair a b) = SplitRec (TPair a b) + Split _ = '[] + +type family SplitRec t where + SplitRec TNil = '[] + SplitRec (TPair a b) = Append (SplitRec b) (SplitRec a) + SplitRec t = '[t] + +data Pointers env t where + Point :: STy t -> Idx env t -> Pointers env t + PNil :: Pointers env TNil + PPair :: Pointers env a -> Pointers env b -> Pointers env (TPair a b) + PWeak :: env' :> env -> Pointers env' t -> Pointers env t + +subPointers :: Pointers env t -> env :> env' -> Ex env' t +subPointers (Point t i) w = EVar ext t (w @> i) +subPointers PNil _ = ENil ext +subPointers (PPair a b) w = EPair ext (subPointers a w) (subPointers b w) +subPointers (PWeak w' p) w = subPointers p (w .> w') + +split :: forall env t. STy t + -> (Pointers (Append (Split t) (t : env)) t, Bindings Ex (t : env) (Split t)) +split typ = case typ of + STPair{} -> splitRec (EVar ext typ IZ) typ + STNil -> other + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + where + other :: (Pointers (t : env) t, Bindings Ex (t : env) '[]) + other = (Point typ IZ, BTop) + +splitRec :: forall env t. Ex env t -> STy t + -> (Pointers (Append (SplitRec t) env) t, Bindings Ex env (SplitRec t)) +splitRec rhs typ = case typ of + STNil -> (PNil, BTop) + STPair (a :: STy a) (b :: STy b) + | Refl <- lemAppendAssoc @(SplitRec b) @(SplitRec a) @env -> + let (p1, bs1) = splitRec (EFst ext rhs) a + (p2, bs2) = splitRec (ESnd ext (sinkWithBindings bs1 `weakenExpr` rhs)) b + in (PPair (PWeak (sinkWithBindings bs2) p1) p2, bconcat bs1 bs2) + STEither{} -> other + STLEither{} -> other + STMaybe{} -> other + STArr{} -> other + STScal{} -> other + STAccum{} -> other + where + other :: (Pointers (t : env) t, Bindings Ex env '[t]) + other = (Point typ IZ, BPush BTop (typ, rhs)) diff --git a/src/CHAD/AST/Types.hs b/src/CHAD/AST/Types.hs new file mode 100644 index 0000000..059077d --- /dev/null +++ b/src/CHAD/AST/Types.hs @@ -0,0 +1,215 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.Types where + +import Data.Int (Int32, Int64) +import Data.GADT.Compare +import Data.GADT.Show +import Data.Kind (Type) +import Data.Type.Equality + +import CHAD.Data + + +type data Ty + = TNil + | TPair Ty Ty + | TEither Ty Ty + | TLEither Ty Ty + | TMaybe Ty + | TArr Nat Ty -- ^ rank, element type + | TScal ScalTy + | TAccum Ty -- ^ contained type must be a monoid type + +type data ScalTy = TI32 | TI64 | TF32 | TF64 | TBool + +type STy :: Ty -> Type +data STy t where + STNil :: STy TNil + STPair :: STy a -> STy b -> STy (TPair a b) + STEither :: STy a -> STy b -> STy (TEither a b) + STLEither :: STy a -> STy b -> STy (TLEither a b) + STMaybe :: STy a -> STy (TMaybe a) + STArr :: SNat n -> STy t -> STy (TArr n t) + STScal :: SScalTy t -> STy (TScal t) + STAccum :: SMTy t -> STy (TAccum t) +deriving instance Show (STy t) + +instance GCompare STy where + gcompare = \cases + STNil STNil -> GEQ + STNil _ -> GLT ; _ STNil -> GGT + (STPair a b) (STPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STPair{} _ -> GLT ; _ STPair{} -> GGT + (STEither a b) (STEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STEither{} _ -> GLT ; _ STEither{} -> GGT + (STLEither a b) (STLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + STLEither{} _ -> GLT ; _ STLEither{} -> GGT + (STMaybe a) (STMaybe a') -> gorderingLift1 (gcompare a a') + STMaybe{} _ -> GLT ; _ STMaybe{} -> GGT + (STArr n t) (STArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + STArr{} _ -> GLT ; _ STArr{} -> GGT + (STScal t) (STScal t') -> gorderingLift1 (gcompare t t') + STScal{} _ -> GLT ; _ STScal{} -> GGT + (STAccum t) (STAccum t') -> gorderingLift1 (gcompare t t') + -- STAccum{} _ -> GLT ; _ STAccum{} -> GGT + +instance TestEquality STy where testEquality = geq +instance GEq STy where geq = defaultGeq +instance GShow STy where gshowsPrec = defaultGshowsPrec + +-- | Monoid types +type SMTy :: Ty -> Type +data SMTy t where + SMTNil :: SMTy TNil + SMTPair :: SMTy a -> SMTy b -> SMTy (TPair a b) + SMTLEither :: SMTy a -> SMTy b -> SMTy (TLEither a b) + SMTMaybe :: SMTy a -> SMTy (TMaybe a) + SMTArr :: SNat n -> SMTy t -> SMTy (TArr n t) + SMTScal :: ScalIsNumeric t ~ True => SScalTy t -> SMTy (TScal t) +deriving instance Show (SMTy t) + +instance GCompare SMTy where + gcompare = \cases + SMTNil SMTNil -> GEQ + SMTNil _ -> GLT ; _ SMTNil -> GGT + (SMTPair a b) (SMTPair a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTPair{} _ -> GLT ; _ SMTPair{} -> GGT + (SMTLEither a b) (SMTLEither a' b') -> gorderingLift2 (gcompare a a') (gcompare b b') + SMTLEither{} _ -> GLT ; _ SMTLEither{} -> GGT + (SMTMaybe a) (SMTMaybe a') -> gorderingLift1 (gcompare a a') + SMTMaybe{} _ -> GLT ; _ SMTMaybe{} -> GGT + (SMTArr n t) (SMTArr n' t') -> gorderingLift2 (gcompare n n') (gcompare t t') + SMTArr{} _ -> GLT ; _ SMTArr{} -> GGT + (SMTScal t) (SMTScal t') -> gorderingLift1 (gcompare t t') + -- SMTScal{} _ -> GLT ; _ SMTScal{} -> GGT + +instance TestEquality SMTy where testEquality = geq +instance GEq SMTy where geq = defaultGeq +instance GShow SMTy where gshowsPrec = defaultGshowsPrec + +fromSMTy :: SMTy t -> STy t +fromSMTy = \case + SMTNil -> STNil + SMTPair t1 t2 -> STPair (fromSMTy t1) (fromSMTy t2) + SMTLEither t1 t2 -> STLEither (fromSMTy t1) (fromSMTy t2) + SMTMaybe t -> STMaybe (fromSMTy t) + SMTArr n t -> STArr n (fromSMTy t) + SMTScal sty -> STScal sty + +data SScalTy t where + STI32 :: SScalTy TI32 + STI64 :: SScalTy TI64 + STF32 :: SScalTy TF32 + STF64 :: SScalTy TF64 + STBool :: SScalTy TBool +deriving instance Show (SScalTy t) + +instance GCompare SScalTy where + gcompare = \cases + STI32 STI32 -> GEQ + STI32 _ -> GLT ; _ STI32 -> GGT + STI64 STI64 -> GEQ + STI64 _ -> GLT ; _ STI64 -> GGT + STF32 STF32 -> GEQ + STF32 _ -> GLT ; _ STF32 -> GGT + STF64 STF64 -> GEQ + STF64 _ -> GLT ; _ STF64 -> GGT + STBool STBool -> GEQ + -- STBool _ -> GLT ; _ STBool -> GGT + +instance TestEquality SScalTy where testEquality = geq +instance GEq SScalTy where geq = defaultGeq +instance GShow SScalTy where gshowsPrec = defaultGshowsPrec + +scalRepIsShow :: SScalTy t -> Dict (Show (ScalRep t)) +scalRepIsShow STI32 = Dict +scalRepIsShow STI64 = Dict +scalRepIsShow STF32 = Dict +scalRepIsShow STF64 = Dict +scalRepIsShow STBool = Dict + +type TIx = TScal TI64 + +tIx :: STy TIx +tIx = STScal STI64 + +type family ScalRep t where + ScalRep TI32 = Int32 + ScalRep TI64 = Int64 + ScalRep TF32 = Float + ScalRep TF64 = Double + ScalRep TBool = Bool + +type family ScalIsNumeric t where + ScalIsNumeric TI32 = True + ScalIsNumeric TI64 = True + ScalIsNumeric TF32 = True + ScalIsNumeric TF64 = True + ScalIsNumeric TBool = False + +type family ScalIsFloating t where + ScalIsFloating TI32 = False + ScalIsFloating TI64 = False + ScalIsFloating TF32 = True + ScalIsFloating TF64 = True + ScalIsFloating TBool = False + +type family ScalIsIntegral t where + ScalIsIntegral TI32 = True + ScalIsIntegral TI64 = True + ScalIsIntegral TF32 = False + ScalIsIntegral TF64 = False + ScalIsIntegral TBool = False + +-- | Returns true for arrays /and/ accumulators. +typeHasArrays :: STy t' -> Bool +typeHasArrays STNil = False +typeHasArrays (STPair a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STLEither a b) = typeHasArrays a || typeHasArrays b +typeHasArrays (STMaybe t) = typeHasArrays t +typeHasArrays STArr{} = True +typeHasArrays STScal{} = False +typeHasArrays STAccum{} = True + +typeHasAccums :: STy t' -> Bool +typeHasAccums STNil = False +typeHasAccums (STPair a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STLEither a b) = typeHasAccums a || typeHasAccums b +typeHasAccums (STMaybe t) = typeHasAccums t +typeHasAccums STArr{} = False +typeHasAccums STScal{} = False +typeHasAccums STAccum{} = True + +type family Tup env where + Tup '[] = TNil + Tup (t : ts) = TPair (Tup ts) t + +mkTup :: f TNil -> (forall a b. f a -> f b -> f (TPair a b)) + -> SList f list -> f (Tup list) +mkTup nil _ SNil = nil +mkTup nil pair (e `SCons` es) = pair (mkTup nil pair es) e + +tTup :: SList STy env -> STy (Tup env) +tTup = mkTup STNil STPair + +unTup :: (forall a b. c (TPair a b) -> (c a, c b)) + -> SList f list -> c (Tup list) -> SList c list +unTup _ SNil _ = SNil +unTup unpack (_ `SCons` list) tup = + let (xs, x) = unpack tup + in x `SCons` unTup unpack list xs + +type family InvTup core env where + InvTup core '[] = core + InvTup core (t : ts) = InvTup (TPair core t) ts diff --git a/src/CHAD/AST/UnMonoid.hs b/src/CHAD/AST/UnMonoid.hs new file mode 100644 index 0000000..27c5f0a --- /dev/null +++ b/src/CHAD/AST/UnMonoid.hs @@ -0,0 +1,255 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.AST.UnMonoid (unMonoid, zero, plus, acPrjCompose) where + +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data + + +-- | Remove 'EZero', 'EDeepZero', 'EPlus' and 'EOneHot' from the program by +-- expanding them into their concrete implementations. Also ensure that +-- 'EAccum' has a dense sparsity. +unMonoid :: Ex env t -> Ex env t +unMonoid = \case + EZero _ t e -> zero t e + EDeepZero _ t e -> deepZero t e + EPlus _ t a b -> plus t (unMonoid a) (unMonoid b) + EOneHot _ t p a b -> onehot t p (unMonoid a) (unMonoid b) + + EVar _ t i -> EVar ext t i + ELet _ rhs body -> ELet ext (unMonoid rhs) (unMonoid body) + EPair _ a b -> EPair ext (unMonoid a) (unMonoid b) + EFst _ e -> EFst ext (unMonoid e) + ESnd _ e -> ESnd ext (unMonoid e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext t (unMonoid e) + EInr _ t e -> EInr ext t (unMonoid e) + ECase _ e a b -> ECase ext (unMonoid e) (unMonoid a) (unMonoid b) + ENothing _ t -> ENothing ext t + EJust _ e -> EJust ext (unMonoid e) + EMaybe _ a b e -> EMaybe ext (unMonoid a) (unMonoid b) (unMonoid e) + ELNil _ t1 t2 -> ELNil ext t1 t2 + ELInl _ t e -> ELInl ext t (unMonoid e) + ELInr _ t e -> ELInr ext t (unMonoid e) + ELCase _ e a b c -> ELCase ext (unMonoid e) (unMonoid a) (unMonoid b) (unMonoid c) + EConstArr _ n t x -> EConstArr ext n t x + EBuild _ n a b -> EBuild ext n (unMonoid a) (unMonoid b) + EMap _ a b -> EMap ext (unMonoid a) (unMonoid b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (unMonoid a) (unMonoid b) (unMonoid c) + ESum1Inner _ e -> ESum1Inner ext (unMonoid e) + EUnit _ e -> EUnit ext (unMonoid e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (unMonoid a) (unMonoid b) + EMaximum1Inner _ e -> EMaximum1Inner ext (unMonoid e) + EMinimum1Inner _ e -> EMinimum1Inner ext (unMonoid e) + EReshape _ n a b -> EReshape ext n (unMonoid a) (unMonoid b) + EZip _ a b -> EZip ext (unMonoid a) (unMonoid b) + EFold1InnerD1 _ cm a b c -> EFold1InnerD1 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EFold1InnerD2 _ cm a b c -> EFold1InnerD2 ext cm (unMonoid a) (unMonoid b) (unMonoid c) + EConst _ t x -> EConst ext t x + EIdx0 _ e -> EIdx0 ext (unMonoid e) + EIdx1 _ a b -> EIdx1 ext (unMonoid a) (unMonoid b) + EIdx _ a b -> EIdx ext (unMonoid a) (unMonoid b) + EShape _ e -> EShape ext (unMonoid e) + EOp _ op e -> EOp ext op (unMonoid e) + ECustom _ t1 t2 t3 a b c e1 e2 -> ECustom ext t1 t2 t3 (unMonoid a) (unMonoid b) (unMonoid c) (unMonoid e1) (unMonoid e2) + ERecompute _ e -> ERecompute ext (unMonoid e) + EWith _ t a b -> EWith ext t (unMonoid a) (unMonoid b) + EAccum _ t p eidx sp eval eacc -> + accumulateSparse (acPrjTy p t) sp eval $ \w prj2 idx2 val2 -> + acPrjCompose SAID p (weakenExpr w eidx) prj2 idx2 $ \prj' idx' -> + EAccum ext t prj' (unMonoid idx') (spDense (acPrjTy prj' t)) (unMonoid val2) (weakenExpr w (unMonoid eacc)) + EError _ t s -> EError ext t s + +zero :: SMTy t -> Ex env (ZeroInfo t) -> Ex env t +-- don't destroy the effects! +zero SMTNil e = ELet ext e $ ENil ext +zero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (zero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (zero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +zero (SMTLEither t1 t2) _ = ELNil ext (fromSMTy t1) (fromSMTy t2) +zero (SMTMaybe t) _ = ENothing ext (fromSMTy t) +zero (SMTArr _ t) e = emap (zero t (EVar ext (tZeroInfo t) IZ)) e +zero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +deepZero :: SMTy t -> Ex env (DeepZeroInfo t) -> Ex env t +deepZero SMTNil e = elet e $ ENil ext +deepZero (SMTPair t1 t2) e = + ELet ext e $ EPair ext (deepZero t1 (EFst ext (EVar ext (typeOf e) IZ))) + (deepZero t2 (ESnd ext (EVar ext (typeOf e) IZ))) +deepZero (SMTLEither t1 t2) e = + elcase e + (ELNil ext (fromSMTy t1) (fromSMTy t2)) + (ELInl ext (fromSMTy t2) (deepZero t1 (evar IZ))) + (ELInr ext (fromSMTy t1) (deepZero t2 (evar IZ))) +deepZero (SMTMaybe t) e = + emaybe e + (ENothing ext (fromSMTy t)) + (EJust ext (deepZero t (evar IZ))) +deepZero (SMTArr _ t) e = emap (deepZero t (evar IZ)) e +deepZero (SMTScal t) _ = case t of + STI32 -> EConst ext STI32 0 + STI64 -> EConst ext STI64 0 + STF32 -> EConst ext STF32 0.0 + STF64 -> EConst ext STF64 0.0 + +plus :: SMTy t -> Ex env t -> Ex env t -> Ex env t +-- don't destroy the effects! +plus SMTNil a b = ELet ext a $ ELet ext (weakenExpr WSink b) $ ENil ext +plus (SMTPair t1 t2) a b = + let t = STPair (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + EPair ext (plus t1 (EFst ext (EVar ext t (IS IZ))) + (EFst ext (EVar ext t IZ))) + (plus t2 (ESnd ext (EVar ext t (IS IZ))) + (ESnd ext (EVar ext t IZ))) +plus (SMTLEither t1 t2) a b = + let t = STLEither (fromSMTy t1) (fromSMTy t2) + in ELet ext a $ + ELet ext (weakenExpr WSink b) $ + ELCase ext (EVar ext t (IS IZ)) + (EVar ext t IZ) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (ELInl ext (fromSMTy t2) (plus t1 (EVar ext (fromSMTy t1) (IS IZ)) (EVar ext (fromSMTy t1) IZ))) + (EError ext t "plus l+r")) + (ELCase ext (EVar ext t (IS IZ)) + (EVar ext t (IS (IS IZ))) + (EError ext t "plus r+l") + (ELInr ext (fromSMTy t1) (plus t2 (EVar ext (fromSMTy t2) (IS IZ)) (EVar ext (fromSMTy t2) IZ)))) +plus (SMTMaybe t) a b = + ELet ext b $ + EMaybe ext + (EVar ext (STMaybe (fromSMTy t)) IZ) + (EJust ext + (EMaybe ext + (EVar ext (fromSMTy t) IZ) + (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + (EVar ext (STMaybe (fromSMTy t)) (IS IZ)))) + (weakenExpr WSink a) +plus (SMTArr _ t) a b = + ezipWith (plus t (EVar ext (fromSMTy t) (IS IZ)) (EVar ext (fromSMTy t) IZ)) + a b +plus (SMTScal t) a b = EOp ext (OAdd t) (EPair ext a b) + +onehot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env t +onehot typ topprj idx arg = case (typ, topprj) of + (_, SAPHere) -> + ELet ext arg $ + EVar ext (fromSMTy typ) IZ + + (SMTPair t1 t2, SAPFst prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t1 prj (EFst ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t1 in + EPair ext (EVar ext toh IZ) + (zero t2 (ESnd ext (EVar ext tidx (IS IZ)))) + + (SMTPair t1 t2, SAPSnd prj) -> + ELet ext idx $ + let tidx = typeOf idx in + ELet ext (onehot t2 prj (ESnd ext (EVar ext (typeOf idx) IZ)) (weakenExpr WSink arg)) $ + let toh = fromSMTy t2 in + EPair ext (zero t1 (EFst ext (EVar ext tidx (IS IZ)))) + (EVar ext toh IZ) + + (SMTLEither t1 t2, SAPLeft prj) -> + ELInl ext (fromSMTy t2) (onehot t1 prj idx arg) + (SMTLEither t1 t2, SAPRight prj) -> + ELInr ext (fromSMTy t1) (onehot t2 prj idx arg) + + (SMTMaybe t1, SAPJust prj) -> + EJust ext (onehot t1 prj idx arg) + + (SMTArr n t1, SAPArrIdx prj) -> + let tidx = tTup (sreplicate n tIx) + in ELet ext idx $ + EBuild ext n (EShape ext (ESnd ext (EFst ext (EVar ext (typeOf idx) IZ)))) $ + eif (eidxEq n (EVar ext tidx IZ) (EFst ext (EFst ext (EVar ext (typeOf idx) (IS IZ))))) + (onehot t1 prj (ESnd ext (EVar ext (typeOf idx) (IS IZ))) (weakenExpr (WSink .> WSink) arg)) + (ELet ext (EIdx ext (ESnd ext (EFst ext (EVar ext (typeOf idx) (IS IZ)))) (EVar ext tidx IZ)) $ + zero t1 (EVar ext (tZeroInfo t1) IZ)) + +accumulateSparse + :: SMTy t -> Sparse t t' -> Ex env t' + -> (forall p b env'. env :> env' -> SAcPrj p t b -> Ex env' (AcIdxD p t) -> Ex env' b -> Ex env' TNil) + -> Ex env TNil +accumulateSparse topty topsp arg accum = case (topty, topsp) of + (_, s) | Just Refl <- isDense topty s -> + accum WId SAPHere (ENil ext) arg + (SMTScal _, SpScal) -> + accum WId SAPHere (ENil ext) arg -- should be handled by isDense already, but meh + (_, SpSparse s) -> + emaybe arg + (ENil ext) + (accumulateSparse topty s (evar IZ) (\w -> accum (WPop w))) + (_, SpAbsent) -> + ENil ext + (SMTPair t1 t2, SpPair s1 s2) -> + eunPair arg $ \w1 e1 e2 -> + elet (accumulateSparse t1 s1 e1 (\w prj -> accum (w .> w1) (SAPFst prj))) $ + accumulateSparse t2 s2 (weakenExpr WSink e2) (\w prj -> accum (w .> WSink .> w1) (SAPSnd prj)) + (SMTLEither t1 t2, SpLEither s1 s2) -> + elcase arg + (ENil ext) + (accumulateSparse t1 s1 (evar IZ) (\w prj -> accum (WPop w) (SAPLeft prj))) + (accumulateSparse t2 s2 (evar IZ) (\w prj -> accum (WPop w) (SAPRight prj))) + (SMTMaybe t, SpMaybe s) -> + emaybe arg + (ENil ext) + (accumulateSparse t s (evar IZ) (\w prj -> accum (WPop w) (SAPJust prj))) + (SMTArr n t, SpArr s) -> + let tn = tTup (sreplicate n tIx) in + elet arg $ + elet (EBuild ext n (EShape ext (evar IZ)) $ + accumulateSparse t s + (EIdx ext (evar (IS IZ)) (EVar ext tn IZ)) + (\w prj idx val -> accum (WPop (WPop w)) (SAPArrIdx prj) (EPair ext (EVar ext tn (w @> IZ)) idx) val)) $ + ENil ext + +acPrjCompose + :: SAIDense dense + -> SAcPrj p1 a b -> Ex env (AcIdx dense p1 a) + -> SAcPrj p2 b c -> Ex env (AcIdx dense p2 b) + -> (forall p'. SAcPrj p' a c -> Ex env (AcIdx dense p' a) -> r) -> r +acPrjCompose _ SAPHere _ p2 idx2 k = k p2 idx2 +acPrjCompose SAID (SAPFst p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPFst p') idx' +acPrjCompose SAID (SAPSnd p1) idx1 p2 idx2 k = + acPrjCompose SAID p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPSnd p') idx' +acPrjCompose SAIS (SAPFst p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (efst (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPFst p') (elet idx1 $ EPair ext idx' (esnd (evar IZ))) +acPrjCompose SAIS (SAPSnd p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPSnd p') (elet idx1 $ EPair ext (efst (evar IZ)) idx') +acPrjCompose d (SAPLeft p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPLeft p') idx' +acPrjCompose d (SAPRight p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPRight p') idx' +acPrjCompose d (SAPJust p1) idx1 p2 idx2 k = + acPrjCompose d p1 idx1 p2 idx2 $ \p' idx' -> + k (SAPJust p') idx' +acPrjCompose SAID (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAID p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') +acPrjCompose SAIS (SAPArrIdx p1) idx1 p2 idx2 k + | Dict <- styKnown (typeOf idx1) = + acPrjCompose SAIS p1 (esnd (evar IZ)) p2 (weakenExpr WSink idx2) $ \p' idx' -> + k (SAPArrIdx p') (elet idx1 $ EPair ext (EFst ext (EVar ext (typeOf idx1) IZ)) idx') diff --git a/src/CHAD/AST/Weaken.hs b/src/CHAD/AST/Weaken.hs new file mode 100644 index 0000000..ac0d152 --- /dev/null +++ b/src/CHAD/AST/Weaken.hs @@ -0,0 +1,138 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} + +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} + +-- The reason why this is a separate module with "little" in it: +{-# LANGUAGE AllowAmbiguousTypes #-} + +module CHAD.AST.Weaken (module CHAD.AST.Weaken, Append) where + +import Data.Bifunctor (first) +import Data.Functor.Const +import Data.GADT.Compare +import Data.Kind (Type) + +import CHAD.Data +import CHAD.Lemmas + + +type Idx :: [k] -> k -> Type +data Idx env t where + IZ :: Idx (t : env) t + IS :: Idx env t -> Idx (a : env) t +deriving instance Show (Idx env t) + +instance GEq (Idx env) where + geq IZ IZ = Just Refl + geq (IS i) (IS j) | Just Refl <- geq i j = Just Refl + geq _ _ = Nothing + +splitIdx :: forall env2 env1 t f. SList f env1 -> Idx (Append env1 env2) t -> Either (Idx env1 t) (Idx env2 t) +splitIdx SNil i = Right i +splitIdx (SCons _ _) IZ = Left IZ +splitIdx (SCons _ l) (IS i) = first IS (splitIdx l i) + +slistIdx :: SList f list -> Idx list t -> f t +slistIdx (SCons x _) IZ = x +slistIdx (SCons _ list) (IS i) = slistIdx list i +slistIdx SNil i = case i of {} + +idx2int :: Idx env t -> Int +idx2int IZ = 0 +idx2int (IS n) = 1 + idx2int n + +data env :> env' where + WId :: env :> env + WSink :: forall t env. env :> (t : env) + WCopy :: forall t env env'. env :> env' -> (t : env) :> (t : env') + WPop :: (t : env) :> env' -> env :> env' + WThen :: env1 :> env2 -> env2 :> env3 -> env1 :> env3 + WClosed :: '[] :> env + WIdx :: Idx env t -> (t : env) :> env + WPick :: forall t pre env env'. SList (Const ()) pre -> env :> env' + -> Append pre (t : env) :> t : Append pre env' + WSwap :: forall env as bs. SList (Const ()) as -> SList (Const ()) bs + -> Append as (Append bs env) :> Append bs (Append as env) + WStack :: forall env1 env2 as bs. SList (Const ()) as -> SList (Const ()) bs + -> as :> bs -> env1 :> env2 + -> Append as env1 :> Append bs env2 +deriving instance Show (env :> env') +infix 4 :> + +infixr 2 @> +(@>) :: env :> env' -> Idx env t -> Idx env' t +WId @> i = i +WSink @> i = IS i +WCopy _ @> IZ = IZ +WCopy w @> IS i = IS (w @> i) +WPop w @> i = w @> IS i +WThen w1 w2 @> i = w2 @> w1 @> i +WClosed @> i = case i of {} +WIdx j @> IZ = j +WIdx _ @> IS i = i +WPick SNil w @> i = WCopy w @> i +WPick (_ `SCons` _) _ @> IZ = IS IZ +WPick @t (_ `SCons` pre) w @> IS i = WCopy WSink .> WPick @t pre w @> i +WSwap @env (as :: SList _ as) (bs :: SList _ bs) @> i = + case splitIdx @(Append bs env) as i of + Left i' -> indexSinks bs (indexRaiseAbove @env as i') + Right i' -> case splitIdx @env bs i' of + Left j -> indexRaiseAbove @(Append as env) bs j + Right j -> indexSinks bs (indexSinks as j) +WStack @env1 @env2 as bs wlo whi @> i = + case splitIdx @env1 as i of + Left i' -> indexRaiseAbove @env2 bs (wlo @> i') + Right i' -> indexSinks bs (whi @> i') + +indexSinks :: SList f as -> Idx bs t -> Idx (Append as bs) t +indexSinks SNil j = j +indexSinks (_ `SCons` bs') j = IS (indexSinks bs' j) + +indexRaiseAbove :: forall env as t f. SList f as -> Idx as t -> Idx (Append as env) t +indexRaiseAbove = flip go + where + go :: forall as'. Idx as' t -> SList f as' -> Idx (Append as' env) t + go IZ (_ `SCons` _) = IZ + go (IS i) (_ `SCons` as) = IS (go i as) + +infixr 3 .> +(.>) :: env2 :> env3 -> env1 :> env2 -> env1 :> env3 +(.>) = flip WThen + +class KnownListSpine list where knownListSpine :: SList (Const ()) list +instance KnownListSpine '[] where knownListSpine = SNil +instance KnownListSpine list => KnownListSpine (t : list) where knownListSpine = SCons (Const ()) knownListSpine + +wSinks' :: forall list env. KnownListSpine list => env :> Append list env +wSinks' = wSinks (knownListSpine :: SList (Const ()) list) + +wSinks :: forall env bs f. SList f bs -> env :> Append bs env +wSinks SNil = WId +wSinks (SCons _ spine) = WSink .> wSinks spine + +wSinksAnd :: forall env env' bs f. SList f bs -> env :> env' -> env :> Append bs env' +wSinksAnd SNil w = w +wSinksAnd (SCons _ spine) w = WSink .> wSinksAnd spine w + +wCopies :: SList f bs -> env1 :> env2 -> Append bs env1 :> Append bs env2 +wCopies bs w = + let bs' = slistMap (\_ -> Const ()) bs + in WStack bs' bs' WId w + +wRaiseAbove :: SList f env1 -> proxy env -> env1 :> Append env1 env +wRaiseAbove SNil _ = WClosed +wRaiseAbove (SCons _ s) env = WCopy (wRaiseAbove s env) + +wPops :: SList f bs -> Append bs env1 :> env2 -> env1 :> env2 +wPops SNil w = w +wPops (_ `SCons` bs) w = wPops bs (WPop w) diff --git a/src/CHAD/AST/Weaken/Auto.hs b/src/CHAD/AST/Weaken/Auto.hs new file mode 100644 index 0000000..14d8c59 --- /dev/null +++ b/src/CHAD/AST/Weaken/Auto.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +{-# LANGUAGE AllowAmbiguousTypes #-} + +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS_GHC -Wno-partial-type-signatures #-} +module CHAD.AST.Weaken.Auto ( + autoWeak, + (&.), auto, auto1, + Layout(..), +) where + +import Data.Functor.Const +import Data.Kind (Constraint) +import GHC.OverloadedLabels +import GHC.TypeLits +import Unsafe.Coerce (unsafeCoerce) + +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Lemmas + + +type family Lookup name list where + Lookup name ('(name, x) : _) = x + Lookup name (_ : list) = Lookup name list + Lookup name '[] = TypeError (Text "The name '" :<>: Text name :<>: Text "' does not appear in the list.") + + +-- | The @withPre@ type parameter indicates whether there can be 'LPreW' +-- occurrences within this layout. 'names' is the list of names that this +-- layout /produces/. That is: for LPreW, it contains the target name. The +-- 'names' list of a source layout must be a subset of the names list of the +-- target layout (which cannot contain LPreW); this is checked with SubLayout. +data Layout (withPre :: Bool) (segments :: [(Symbol, [t])]) (names :: [Symbol]) (env :: [t]) where + LSeg :: forall name segments withPre. SSymbol name -> Layout withPre segments '[name] (Lookup name segments) + -- | Pre-weaken with a weakening + LPreW :: forall name1 name2 segments. + SegmentName name1 -> SegmentName name2 + -> Lookup name1 segments :> Lookup name2 segments + -> Layout True segments '[name2] (Lookup name1 segments) + (:++:) :: Layout withPre segments names1 env1 -> Layout withPre segments names2 env2 -> Layout withPre segments (Append names1 names2) (Append env1 env2) +infixr :++: + +instance (KnownSymbol name, seg ~ Lookup name segments, names ~ '[name]) => IsLabel name (Layout withPre segments names seg) where + fromLabel = LSeg (symbolSing @name) + +newtype SegmentName name = SegmentName (SSymbol name) + deriving (Show) + +instance (KnownSymbol name, name ~ name') => IsLabel name (SegmentName name') where + fromLabel = SegmentName symbolSing + + +type family SubLayout names1 names2 where + SubLayout '[] _ = () :: Constraint + SubLayout (n : names1) names2 = SubLayout' n (Contains n names2) names1 names2 +type family SubLayout' n ok names1 names2 where + SubLayout' n False _ _ = TypeError (Text "The name '" :<>: Text n :<>: Text "' appears in the source layout but not in the target.") + SubLayout' _ True names1 names2 = SubLayout names1 names2 +type family Contains n names where + Contains _ '[] = False + Contains n (n : _) = True + Contains n (_ : names) = Contains n names + + +data SSegments (segments :: [(Symbol, [t])]) where + SSegNil :: SSegments '[] + SSegCons :: SSymbol name -> SList (Const ()) ts -> SSegments list -> SSegments ('(name, ts) : list) + +instance (KnownSymbol name, segs ~ '[ '(name, ts)]) => IsLabel name (SList f ts -> SSegments segs) where + fromLabel = \spine -> SSegCons symbolSing (slistMap (\_ -> Const ()) spine) SSegNil + +auto :: KnownListSpine list => SList (Const ()) list +auto = knownListSpine + +auto1 :: SList (Const ()) '[t] +auto1 = Const () `SCons` SNil + +infixr &. +(&.) :: SSegments '[segs1] -> SSegments segs2 -> SSegments (segs1 : segs2) +(&.) = ssegmentsAppend + where + ssegmentsAppend :: SSegments a -> SSegments b -> SSegments (Append a b) + ssegmentsAppend SSegNil l2 = l2 + ssegmentsAppend (SSegCons name list l1) l2 = SSegCons name list (ssegmentsAppend l1 l2) + + +-- | If the found segment is a TopSeg, returns Nothing. +segmentLookup :: forall segments name. SSegments segments -> SSymbol name -> SList (Const ()) (Lookup name segments) +segmentLookup = \segs name -> case go segs name of + Just ts -> ts + Nothing -> error $ "Segment not found: " ++ fromSSymbol name + where + go :: forall segs'. SSegments segs' -> SSymbol name -> Maybe (SList (Const ()) (Lookup name segs')) + go SSegNil _ = Nothing + go (SSegCons n@(SSymbol @n) (ts :: SList _ ts) (sseg :: SSegments rest)) name@SSymbol = + case sameSymbol n name of + Just Refl -> + case go sseg name of + Nothing -> Just ts + Just _ -> error $ "Duplicate segment with name " ++ fromSSymbol name + Nothing -> + case unsafeCoerce Refl :: (Lookup name ('(n, ts) : rest) :~: Lookup name rest) of + Refl -> go sseg name + +data LinLayout (withPre :: Bool) (segments :: [(Symbol, [t])]) (env :: [t]) where + LinEnd :: LinLayout withPre segments '[] + LinApp :: SSymbol name -> LinLayout withPre segments env + -> LinLayout withPre segments (Append (Lookup name segments) env) + LinAppPreW :: SSymbol name1 -> SSymbol name2 + -> Lookup name1 segments :> Lookup name2 segments + -> LinLayout True segments env + -> LinLayout True segments (Append (Lookup name1 segments) env) + +linLayoutAppend :: LinLayout withPre segments env1 -> LinLayout withPre segments env2 -> LinLayout withPre segments (Append env1 env2) +linLayoutAppend LinEnd lin = lin +linLayoutAppend (LinApp (name :: SSymbol name) (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) + | Refl <- lemAppendAssoc @(Lookup name segments) @env1' @env2 + = LinApp name (linLayoutAppend lin1 lin2) +linLayoutAppend (LinAppPreW (name1 :: SSymbol name1) name2 w (lin1 :: LinLayout _ segments env1')) (lin2 :: LinLayout _ _ env2) + | Refl <- lemAppendAssoc @(Lookup name1 segments) @env1' @env2 + = LinAppPreW name1 name2 w (linLayoutAppend lin1 lin2) + +lineariseLayout :: Layout withPre segments names env -> LinLayout withPre segments env +lineariseLayout (LSeg name :: Layout _ _ _ seg) + | Refl <- lemAppendNil @seg + = LinApp name LinEnd +lineariseLayout (ly1 :++: ly2) = lineariseLayout ly1 `linLayoutAppend` lineariseLayout ly2 +lineariseLayout (LPreW (SegmentName name1) (SegmentName name2) w :: Layout _ _ _ seg) + | Refl <- lemAppendNil @seg + = LinAppPreW name1 name2 w LinEnd + +preWeaken :: SSegments segments -> LinLayout True segments env + -> (forall env'. env :> env' -> LinLayout False segments env' -> r) -> r +preWeaken _ LinEnd k = k WId LinEnd +preWeaken segs (LinApp name lin) k = + preWeaken segs lin $ \w lin' -> + k (wCopies (segmentLookup segs name) w) (LinApp name lin') +preWeaken segs (LinAppPreW name1 name2 weak lin) k = + preWeaken segs lin $ \w lin' -> + k (WStack (segmentLookup segs name1) (segmentLookup segs name2) weak w) (LinApp name2 lin') + +pullDown :: SSegments segments -> SSymbol name -> LinLayout False segments env + -> r -- Name was not found in source + -> (forall env'. LinLayout False segments env' -> env :> Append (Lookup name segments) env' -> r) + -> r +pullDown segs name@SSymbol linlayout kNotFound k = + case linlayout of + LinEnd -> kNotFound + LinApp n'@SSymbol lin + | Just Refl <- sameSymbol name n' -> k lin WId + | otherwise -> + pullDown segs name lin kNotFound $ \(lin' :: LinLayout _ _ env') w -> + k (LinApp n' lin') (WSwap @env' (segmentLookup segs n') (segmentLookup segs name) + .> wCopies (segmentLookup segs n') w) + +sortLinLayouts :: SSegments segments + -> LinLayout False segments env1 -> LinLayout False segments env2 -> env1 :> env2 +sortLinLayouts _ LinEnd LinEnd = WId +sortLinLayouts segs lin1@(LinApp name1@SSymbol tail1) (LinApp name2@SSymbol tail2) + | Just Refl <- sameSymbol name1 name2 = wCopies (segmentLookup segs name1) (sortLinLayouts segs tail1 tail2) + | otherwise = + pullDown segs name2 lin1 + (wSinks (segmentLookup segs name2) .> sortLinLayouts segs lin1 tail2) + (\tail1' w -> + -- We've pulled down name2 in lin1 so that it's at the head; the + -- resulting modified tail is tail1'. Thus now we have (name2 : tail1') + -- vs (name2 : tail2). Thus we continue sorting tail1' vs tail2, and + -- wCopies the name2 on top of that. + wCopies (segmentLookup segs name2) (sortLinLayouts segs tail1' tail2) .> w) +sortLinLayouts _ LinEnd LinApp{} = WClosed +sortLinLayouts _ LinApp{} LinEnd = error "Segments in source that do not occur in target" + +autoWeak :: SubLayout names1 names2 + => SSegments segments -> Layout True segments names1 env1 -> Layout False segments names2 env2 -> env1 :> env2 +autoWeak segs ly1 ly2 = + preWeaken segs (lineariseLayout ly1) $ \wPreweak lin1 -> + sortLinLayouts segs lin1 (lineariseLayout ly2) .> wPreweak diff --git a/src/CHAD/Accum.hs b/src/CHAD/Accum.hs deleted file mode 100644 index a7bc53f..0000000 --- a/src/CHAD/Accum.hs +++ /dev/null @@ -1,72 +0,0 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE TypeOperators #-} --- | TODO this module is a grab-bag of random utility functions that are shared --- between CHAD and CHAD.Top. -module CHAD.Accum where - -import AST -import CHAD.Types -import Data -import AST.Env - - -d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) -d2zeroInfo STNil _ = ENil ext -d2zeroInfo (STPair a b) e = - eunPair e $ \_ e1 e2 -> - EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) -d2zeroInfo STEither{} _ = ENil ext -d2zeroInfo STLEither{} _ = ENil ext -d2zeroInfo STMaybe{} _ = ENil ext -d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e -d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext -d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" - -d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) -d2deepZeroInfo STNil _ = ENil ext -d2deepZeroInfo (STPair a b) e = - eunPair e $ \_ e1 e2 -> - EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) -d2deepZeroInfo (STEither a b) e = - ECase ext e - (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) - (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) -d2deepZeroInfo (STLEither a b) e = - elcase e - (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) - (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) - (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) -d2deepZeroInfo (STMaybe a) e = - emaybe e - (ENothing ext (tDeepZeroInfo (d2M a))) - (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) -d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e -d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext -d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" - --- The weakening is necessary because we need to initialise the created --- accumulators with zeros. Those zeros are deep and need full primals. This --- means, in the end, that primals corresponding to environment entries --- promoted to an accumulator with accumPromote in CHAD need to be stored for --- the dual. -makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) -makeAccumulators _ SNil e = e -makeAccumulators w (t `SCons` envpro) e = - makeAccumulators (WPop w) envpro $ - EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e - -uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) -uninvertTup SNil _ e = EPair ext e (ENil ext) -uninvertTup (t `SCons` list) tcore e = - ELet ext (uninvertTup list (STPair tcore t) e) $ - let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding - in EPair ext - (EFst ext (EFst ext (EVar ext recT IZ))) - (EPair ext - (ESnd ext (EVar ext recT IZ)) - (ESnd ext (EFst ext (EVar ext recT IZ)))) - -subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') -subenvD1E SETop = SETop -subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) -subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/Analysis/Identity.hs b/src/CHAD/Analysis/Identity.hs new file mode 100644 index 0000000..212cc7d --- /dev/null +++ b/src/CHAD/Analysis/Identity.hs @@ -0,0 +1,436 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +module CHAD.Analysis.Identity ( + identityAnalysis, + identityAnalysis', + ValId(..), + validSplitEither, +) where + +import Data.Foldable (toList) +import Data.List (intercalate) + +import CHAD.AST +import CHAD.AST.Pretty (PrettyX(..)) +import CHAD.Data +import CHAD.Drev.Types (d1, d2) +import CHAD.Util.IdGen + + +-- | Every array, scalar and accumulator has an ID. Trivial values such as +-- Nothing only have the knowledge that they are indeed Nothing. Compound +-- values know which values they consist of. +data ValId t where + VINil :: ValId TNil + VIPair :: ValId a -> ValId b -> ValId (TPair a b) + VIEither :: Either (ValId a) (ValId b) -> ValId (TEither a b) -- ^ known alternative + VIEither' :: ValId a -> ValId b -> ValId (TEither a b) -- ^ unknown alternative, but known values in each case + VILEither :: ValId (TMaybe (TEither a b)) -> ValId (TLEither a b) + VIMaybe :: Maybe (ValId a) -> ValId (TMaybe a) + VIMaybe' :: ValId a -> ValId (TMaybe a) -- ^ if it's Just, it contains this value + VIArr :: Int -> Vec n Int -> ValId (TArr n t) + VIScal :: Int -> ValId (TScal t) + VIAccum :: Int -> ValId (TAccum t) +deriving instance Show (ValId t) + +instance PrettyX ValId where + prettyX = \case + VINil -> "" + VIPair a b -> "(" ++ prettyX a ++ "," ++ prettyX b ++ ")" + VIEither (Left a) -> "(L" ++ prettyX a ++ ")" + VIEither (Right a) -> "(R" ++ prettyX a ++ ")" + VIEither' a b -> "(" ++ prettyX a ++ "|" ++ prettyX b ++ ")" + VIMaybe Nothing -> "N" + VIMaybe (Just a) -> 'J' : prettyX a + VIMaybe' a -> 'M' : prettyX a + VILEither (VIMaybe Nothing) -> "lN" + VILEither (VIMaybe (Just (VIEither (Left a)))) -> "(lL" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither (Right a)))) -> "(lR" ++ prettyX a ++ ")" + VILEither (VIMaybe (Just (VIEither' a b))) -> "(" ++ prettyX a ++ "⊕" ++ prettyX b ++ ")" + VILEither (VIMaybe' (VIEither (Left a))) -> "(mlL" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither (Right a))) -> "(mlR" ++ prettyX a ++ ")" + VILEither (VIMaybe' (VIEither' a b)) -> "(m(" ++ prettyX a ++ "⊕" ++ prettyX b ++ "))" + VIArr i is -> 'A' : show i ++ "[" ++ intercalate "," (map show (toList is)) ++ "]" + VIScal i -> show i + VIAccum i -> 'C' : show i + +validSplitEither :: ValId (TEither a b) -> (Maybe (ValId a), Maybe (ValId b)) +validSplitEither (VIEither (Left v)) = (Just v, Nothing) +validSplitEither (VIEither (Right v)) = (Nothing, Just v) +validSplitEither (VIEither' v1 v2) = (Just v1, Just v2) + +-- | Symbolic partial evaluation. +identityAnalysis :: SList STy env -> Expr x env t -> Expr ValId env t +identityAnalysis env term = runIdGen 0 $ do + env' <- slistMapA genIds env + snd <$> idana env' term + +identityAnalysis' :: SList ValId env -> Expr x env t -> Expr ValId env t +identityAnalysis' env term = snd (runIdGen 0 (idana env term)) + +idana :: SList ValId env -> Expr x env t -> IdGen (ValId t, Expr ValId env t) +idana env expr = case expr of + EVar _ t i -> do + let v = slistIdx env i + pure (v, EVar v t i) + + ELet _ e1 e2 -> do + (v1, e1') <- idana env e1 + (v2, e2') <- idana (v1 `SCons` env) e2 + pure (v2, ELet v2 e1' e2') + + EPair _ e1 e2 -> do + (v1, e1') <- idana env e1 + (v2, e2') <- idana env e2 + pure (VIPair v1 v2, EPair (VIPair v1 v2) e1' e2') + + EFst _ e -> do + (v, e') <- idana env e + let VIPair v1 _ = v + pure (v1, EFst v1 e') + + ESnd _ e -> do + (v, e') <- idana env e + let VIPair _ v2 = v + pure (v2, ESnd v2 e') + + ENil _ -> pure (VINil, ENil VINil) + + EInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VIEither (Left v1) + pure (v, EInl v t2 e1') + + EInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VIEither (Right v2) + pure (v, EInr v t1 e2') + + ECase _ e1 e2 e3 -> do + let STEither t1 t2 = typeOf e1 + (v1, e1') <- idana env e1 + case v1 of + VIEither (Left v1') -> do + (v2, e2') <- idana (v1' `SCons` env) e2 + scrap <- genIds t2 + (_, e3') <- idana (scrap `SCons` env) e3 + pure (v2, ECase v2 e1' e2' e3') + VIEither (Right v1') -> do + scrap <- genIds t1 + (_, e2') <- idana (scrap `SCons` env) e2 + (v3, e3') <- idana (v1' `SCons` env) e3 + pure (v3, ECase v3 e1' e2' e3') + VIEither' v1'l v1'r -> do + (v2, e2') <- idana (v1'l `SCons` env) e2 + (v3, e3') <- idana (v1'r `SCons` env) e3 + res <- unify v2 v3 + pure (res, ECase res e1' e2' e3') + + ENothing _ t -> pure (VIMaybe Nothing, ENothing (VIMaybe Nothing) t) + + EJust _ e1 -> do + (v1, e1') <- idana env e1 + let v = VIMaybe (Just v1) + pure (v, EJust v e1') + + EMaybe _ e1 e2 e3 -> do + let STMaybe t1 = typeOf e3 + (v3, e3') <- idana env e3 + case v3 of + VIMaybe Nothing -> do + (v1, e1') <- idana env e1 + scrap <- genIds t1 + (_, e2') <- idana (scrap `SCons` env) e2 + pure (v1, EMaybe v1 e1' e2' e3') + VIMaybe (Just v3j) -> do + (v2, e2') <- idana (v3j `SCons` env) e2 + (_, e1') <- idana env e1 + pure (v2, EMaybe v2 e1' e2' e3') + VIMaybe' v3' -> do + (v2, e2') <- idana (v3' `SCons` env) e2 + (v1, e1') <- idana env e1 + res <- unify v1 v2 + pure (res, EMaybe res e1' e2' e3') + + ELNil _ t1 t2 -> do + let v = VILEither (VIMaybe Nothing) + pure (v, ELNil v t1 t2) + + ELInl _ t2 e1 -> do + (v1, e1') <- idana env e1 + let v = VILEither (VIMaybe (Just (VIEither (Left v1)))) + pure (v, ELInl v t2 e1') + + ELInr _ t1 e2 -> do + (v2, e2') <- idana env e2 + let v = VILEither (VIMaybe (Just (VIEither (Right v2)))) + pure (v, ELInr v t1 e2') + + ELCase _ e1 e2 e3 e4 -> do + let STLEither t1 t2 = typeOf e1 + (v1L, e1') <- idana env e1 + let VILEither v1 = v1L + let go mv1'l mv1'r f = do + v1'l <- maybe (genIds t1) pure mv1'l + v1'r <- maybe (genIds t2) pure mv1'r + (v2, e2') <- idana env e2 + (v3, e3') <- idana (v1'l `SCons` env) e3 + (v4, e4') <- idana (v1'r `SCons` env) e4 + res <- f v2 v3 v4 + pure (res, ELCase res e1' e2' e3' e4') + case v1 of + VIMaybe Nothing -> go Nothing Nothing (\v2 _ _ -> pure v2) + VIMaybe (Just (VIEither (Left v1'))) -> go (Just v1') Nothing (\_ v3 _ -> pure v3) + VIMaybe (Just (VIEither (Right v1'))) -> go Nothing (Just v1') (\_ _ v4 -> pure v4) + VIMaybe (Just (VIEither' v1'l v1'r)) -> go (Just v1'l) (Just v1'r) (\_ v3 v4 -> unify v3 v4) + VIMaybe' (VIEither (Left v1')) -> go (Just v1') Nothing (\v2 v3 _ -> unify v2 v3) + VIMaybe' (VIEither (Right v1')) -> go Nothing (Just v1') (\v2 _ v4 -> unify v2 v4) + VIMaybe' (VIEither' v1'l v1'r) -> + go (Just v1'l) (Just v1'r) (\v2 v3 v4 -> unify v2 =<< unify v3 v4) + + EConstArr _ dim t arr -> do + x1 <- VIArr <$> genId <*> vecReplicateA dim genId + pure (x1, EConstArr x1 dim t arr) + + EBuild _ dim e1 e2 -> do + (shids, e1') <- idana env e1 + x1 <- genIds (tTup (sreplicate dim tIx)) + (_, e2') <- idana (x1 `SCons` env) e2 + res <- VIArr <$> genId <*> shidsToVec dim shids + pure (res, EBuild res dim e1' e2') + + EMap _ e1 e2 -> do + let STArr _ t = typeOf e2 + x1 <- genIds t + (_, e1') <- idana (x1 `SCons` env) e1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure sh + pure (res, EMap res e1' e2') + + EFold1Inner _ cm e1 e2 e3 -> do + let t1 = typeOf e1 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ (_ :< sh) = v3 + res <- VIArr <$> genId <*> pure sh + pure (res, EFold1Inner res cm e1' e2' e3') + + ESum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, ESum1Inner res e1') + + EUnit _ e1 -> do + (_, e1') <- idana env e1 + res <- VIArr <$> genId <*> pure VNil + pure (res, EUnit res e1') + + EReplicate1Inner _ e1 e2 -> do + (v1, e1') <- idana env e1 + let VIScal v1' = v1 + (v2, e2') <- idana env e2 + let VIArr _ sh = v2 + res <- VIArr <$> genId <*> pure (v1' :< sh) + pure (res, EReplicate1Inner res e1' e2') + + EMaximum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EMaximum1Inner res e1') + + EMinimum1Inner _ e1 -> do + (v1, e1') <- idana env e1 + let VIArr _ (_ :< sh) = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EMinimum1Inner res e1') + + EReshape _ dim e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> shidsToVec dim v1 + pure (res, EReshape res dim e1' e2') + + EZip _ e1 e2 -> do + (v1, e1') <- idana env e1 + (_, e2') <- idana env e2 + let VIArr _ sh = v1 + res <- VIArr <$> genId <*> pure sh + pure (res, EZip res e1' e2') + + EFold1InnerD1 _ cm e1 e2 e3 -> do + let t1 = typeOf e2 + x1 <- genIds (STPair t1 t1) + (_, e1') <- idana (x1 `SCons` env) e1 + (_, e2') <- idana env e2 + (v3, e3') <- idana env e3 + let VIArr _ sh'@(_ :< sh) = v3 + res <- VIPair <$> (VIArr <$> genId <*> pure sh) <*> (VIArr <$> genId <*> pure sh') + pure (res, EFold1InnerD1 res cm e1' e2' e3') + + EFold1InnerD2 _ cm ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + xf1 <- genIds t2 + xf2 <- genIds tB + (_, e1') <- idana (xf1 `SCons` xf2 `SCons` env) ef + (v2, e2') <- idana env ebog + (_, e3') <- idana env ed + let VIArr _ sh@(_ :< sh') = v2 + res <- VIPair <$> (VIArr <$> genId <*> pure sh') <*> (VIArr <$> genId <*> pure sh) + pure (res, EFold1InnerD2 res cm e1' e2' e3') + + EConst _ t val -> do + res <- VIScal <$> genId + pure (res, EConst res t val) + + EIdx0 _ e1 -> do + (_, e1') <- idana env e1 + res <- genIds (typeOf expr) + pure (res, EIdx0 res e1') + + EIdx1 _ e1 e2 -> do + (v1, e1') <- idana env e1 + let VIArr _ sh = v1 + (_, e2') <- idana env e2 + res <- VIArr <$> genId <*> pure (vecInit sh) + pure (res, EIdx1 res e1' e2') + + EIdx _ e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (typeOf expr) + pure (res, EIdx res e1' e2') + + EShape _ e1 -> do + let STArr dim _ = typeOf e1 + (v1, e1') <- idana env e1 + let VIArr _ sh = v1 + res = vecToShids dim sh + pure (res, EShape res e1') + + EOp _ (op :: SOp a t) e1 -> do + (_, e1') <- idana env e1 + res <- genIds (typeOf expr) + pure (res, EOp res op e1') + + ECustom _ t1 t2 t3 e1 e2 e3 e4 e5 -> do + let t4 = typeOf e1 + x1 <- genIds t2 + x2 <- genIds t1 + (_, e1') <- idana (x1 `SCons` x2 `SCons` SNil) e1 + x3 <- genIds (d1 t2) + x4 <- genIds (d1 t1) + (_, e2') <- idana (x3 `SCons` x4 `SCons` SNil) e2 + x5 <- genIds (d2 t4) + x6 <- genIds t3 + (_, e3') <- idana (x5 `SCons` x6 `SCons` SNil) e3 + (_, e4') <- idana env e4 + (_, e5') <- idana env e5 + res <- genIds t4 + pure (res, ECustom res t1 t2 t3 e1' e2' e3' e4' e5') + + ERecompute _ e -> do + (v, e') <- idana env e + pure (v, ERecompute v e') + + EWith _ t e1 e2 -> do + let t1 = typeOf e1 + (_, e1') <- idana env e1 + x1 <- VIAccum <$> genId + (v2, e2') <- idana (x1 `SCons` env) e2 + x2 <- genIds t1 + let res = VIPair v2 x2 + pure (res, EWith res t e1' e2') + + EAccum _ t prj e1 sp e2 e3 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + (_, e3') <- idana env e3 + pure (VINil, EAccum VINil t prj e1' sp e2' e3') + + EZero _ t e1 -> do + -- Approximate the result of EZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EZero res t e1') + + EDeepZero _ t e1 -> do + -- Approximate the result of EDeepZero to be independent from the zero info + -- expression; not quite true for shape variables + (_, e1') <- idana env e1 + res <- genIds (fromSMTy t) + pure (res, EDeepZero res t e1') + + EPlus _ t e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (fromSMTy t) + pure (res, EPlus res t e1' e2') + + EOneHot _ t i e1 e2 -> do + (_, e1') <- idana env e1 + (_, e2') <- idana env e2 + res <- genIds (fromSMTy t) + pure (res, EOneHot res t i e1' e2') + + EError _ t s -> do + res <- genIds t + pure (res, EError res t s) + +-- | This value might be either of the two arguments; we don't know which. +unify :: ValId t -> ValId t -> IdGen (ValId t) +unify VINil VINil = pure VINil +unify (VIPair a b) (VIPair c d) = VIPair <$> unify a c <*> unify b d +unify (VIEither (Left a)) (VIEither (Left b)) = VIEither . Left <$> unify a b +unify (VIEither (Right a)) (VIEither (Right b)) = VIEither . Right <$> unify a b +unify (VIEither (Left a)) (VIEither (Right b)) = pure $ VIEither' a b +unify (VIEither (Right a)) (VIEither (Left b)) = pure $ VIEither' b a +unify (VIEither (Left a)) (VIEither' b c) = VIEither' <$> unify a b <*> pure c +unify (VIEither (Right a)) (VIEither' b c) = VIEither' <$> pure b <*> unify a c +unify (VIEither' a b) (VIEither (Left c)) = VIEither' <$> unify a c <*> pure b +unify (VIEither' a b) (VIEither (Right c)) = VIEither' <$> pure a <*> unify b c +unify (VIEither' a b) (VIEither' c d) = VIEither' <$> unify a c <*> unify b d +unify (VIMaybe Nothing) (VIMaybe Nothing) = pure $ VIMaybe Nothing +unify (VIMaybe (Just a)) (VIMaybe (Just b)) = VIMaybe . Just <$> unify a b +unify (VIMaybe Nothing) (VIMaybe (Just a)) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe Nothing) (VIMaybe' a) = pure $ VIMaybe' a +unify (VIMaybe (Just a)) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe Nothing) = pure $ VIMaybe' a +unify (VIMaybe' a) (VIMaybe (Just b)) = VIMaybe' <$> unify a b +unify (VIMaybe' a) (VIMaybe' b) = VIMaybe' <$> unify a b +unify (VILEither a) (VILEither b) = VILEither <$> unify a b +unify (VIArr i is) (VIArr j js) = VIArr <$> unifyID i j <*> vecZipWithA unifyID is js +unify (VIScal i) (VIScal j) = VIScal <$> unifyID i j +unify (VIAccum i) (VIAccum j) = VIAccum <$> unifyID i j + +unifyID :: Int -> Int -> IdGen Int +unifyID i j | i == j = pure i + | otherwise = genId + +genIds :: STy t -> IdGen (ValId t) +genIds STNil = pure VINil +genIds (STPair a b) = VIPair <$> genIds a <*> genIds b +genIds (STEither a b) = VIEither' <$> genIds a <*> genIds b +genIds (STLEither a b) = VILEither . VIMaybe' <$> (VIEither' <$> genIds a <*> genIds b) +genIds (STMaybe t) = VIMaybe' <$> genIds t +genIds (STArr n _) = VIArr <$> genId <*> vecReplicateA n genId +genIds STScal{} = VIScal <$> genId +genIds STAccum{} = VIAccum <$> genId + +shidsToVec :: SNat n -> ValId (Tup (Replicate n TIx)) -> IdGen (Vec n Int) +shidsToVec SZ _ = pure VNil +shidsToVec (SS n) (VIPair is (VIScal i)) = (i :<) <$> shidsToVec n is + +vecToShids :: SNat n -> Vec n Int -> ValId (Tup (Replicate n TIx)) +vecToShids SZ VNil = VINil +vecToShids (SS n) (i :< is) = VIPair (vecToShids n is) (VIScal i) diff --git a/src/CHAD/Array.hs b/src/CHAD/Array.hs new file mode 100644 index 0000000..f80f961 --- /dev/null +++ b/src/CHAD/Array.hs @@ -0,0 +1,131 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} +module CHAD.Array where + +import Control.DeepSeq +import Control.Monad.Trans.State.Strict +import Data.Foldable (traverse_) +import Data.Vector (Vector) +import qualified Data.Vector as V +import GHC.Generics (Generic) + +import CHAD.Data + + +data Shape n where + ShNil :: Shape Z + ShCons :: Shape n -> Int -> Shape (S n) +deriving instance Show (Shape n) +deriving instance Eq (Shape n) + +instance NFData (Shape n) where + rnf ShNil = () + rnf (sh `ShCons` n) = rnf n `seq` rnf sh + +data Index n where + IxNil :: Index Z + IxCons :: Index n -> Int -> Index (S n) +deriving instance Show (Index n) +deriving instance Eq (Index n) + +instance NFData (Index n) where + rnf IxNil = () + rnf (sh `IxCons` n) = rnf n `seq` rnf sh + +shapeSize :: Shape n -> Int +shapeSize ShNil = 1 +shapeSize (ShCons sh n) = shapeSize sh * n + +shapeRank :: Shape n -> SNat n +shapeRank ShNil = SZ +shapeRank (sh `ShCons` _) = SS (shapeRank sh) + +fromLinearIndex :: Shape n -> Int -> Index n +fromLinearIndex ShNil 0 = IxNil +fromLinearIndex ShNil _ = error "Index out of range" +fromLinearIndex (sh `ShCons` n) i = + let (q, r) = i `quotRem` n + in fromLinearIndex sh q `IxCons` r + +toLinearIndex :: Shape n -> Index n -> Int +toLinearIndex ShNil IxNil = 0 +toLinearIndex (sh `ShCons` n) (idx `IxCons` i) = toLinearIndex sh idx * n + i + +emptyShape :: SNat n -> Shape n +emptyShape SZ = ShNil +emptyShape (SS m) = emptyShape m `ShCons` 0 + +enumShape :: Shape n -> [Index n] +enumShape sh = map (fromLinearIndex sh) [0 .. shapeSize sh - 1] + +shapeToList :: Shape n -> [Int] +shapeToList = go [] + where + go :: [Int] -> Shape n -> [Int] + go suff ShNil = suff + go suff (sh `ShCons` n) = go (n:suff) sh + + +-- | TODO: this Vector is a boxed vector, which is horrendously inefficient. +data Array (n :: Nat) t = Array (Shape n) (Vector t) + deriving (Show, Functor, Foldable, Traversable, Generic) +instance NFData t => NFData (Array n t) + +arrayShape :: Array n t -> Shape n +arrayShape (Array sh _) = sh + +arraySize :: Array n t -> Int +arraySize (Array sh _) = shapeSize sh + +emptyArray :: SNat n -> Array n t +emptyArray n = Array (emptyShape n) V.empty + +arrayFromList :: Shape n -> [t] -> Array n t +arrayFromList sh l = Array sh (V.fromListN (shapeSize sh) l) + +arrayToList :: Array n t -> [t] +arrayToList (Array _ v) = V.toList v + +arrayReshape :: Shape n -> Array m t -> Array n t +arrayReshape sh (Array sh' v) + | shapeSize sh == shapeSize sh' = Array sh v + | otherwise = error $ "arrayReshape: different shape size than original (" ++ show sh' ++ " -> " ++ show sh ++ ")" + +arrayUnit :: t -> Array Z t +arrayUnit x = Array ShNil (V.singleton x) + +arrayIndex :: Array n t -> Index n -> t +arrayIndex arr@(Array sh _) idx = arrayIndexLinear arr (toLinearIndex sh idx) + +arrayIndexLinear :: Array n t -> Int -> t +arrayIndexLinear (Array _ v) i = v V.! i + +arrayIndex1 :: Array (S n) t -> Int -> Array n t +arrayIndex1 (Array (sh `ShCons` _) v) i = let sz = shapeSize sh in Array sh (V.slice (sz * i) sz v) + +arrayGenerate :: Shape n -> (Index n -> t) -> Array n t +arrayGenerate sh f = arrayGenerateLin sh (f . fromLinearIndex sh) + +arrayGenerateLin :: Shape n -> (Int -> t) -> Array n t +arrayGenerateLin sh f = Array sh (V.generate (shapeSize sh) f) + +arrayGenerateM :: Monad m => Shape n -> (Index n -> m t) -> m (Array n t) +arrayGenerateM sh f = arrayGenerateLinM sh (f . fromLinearIndex sh) + +arrayGenerateLinM :: Monad m => Shape n -> (Int -> m t) -> m (Array n t) +arrayGenerateLinM sh f = Array sh <$> V.generateM (shapeSize sh) f + +arrayMap :: (a -> b) -> Array n a -> Array n b +arrayMap f arr = arrayGenerateLin (arrayShape arr) (f . arrayIndexLinear arr) + +arrayMapM :: Monad m => (a -> m b) -> Array n a -> m (Array n b) +arrayMapM f arr = arrayGenerateLinM (arrayShape arr) (f . arrayIndexLinear arr) + +-- | The Int is the linear index of the value. +traverseArray_ :: Monad m => (Int -> t -> m ()) -> Array n t -> m () +traverseArray_ f (Array _ v) = evalStateT (traverse_ (\x -> StateT (\i -> (,i+1) <$> f i x)) v) 0 diff --git a/src/CHAD/Compile.hs b/src/CHAD/Compile.hs new file mode 100644 index 0000000..5b71651 --- /dev/null +++ b/src/CHAD/Compile.hs @@ -0,0 +1,1796 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Compile (compile, compileStderr) where + +import Control.Applicative (empty) +import Control.Monad (forM_, when, replicateM) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.Maybe +import Control.Monad.Trans.State.Strict +import Control.Monad.Trans.Writer.CPS +import Data.Bifunctor (first) +import Data.Char (ord) +import Data.Foldable (toList) +import Data.Functor.Const +import qualified Data.Functor.Product as Product +import Data.Functor.Product (Product) +import Data.IORef +import Data.List (foldl1', intersperse, intercalate) +import qualified Data.Map.Strict as Map +import Data.Maybe (fromMaybe) +import qualified Data.Set as Set +import Data.Set (Set) +import Data.Some +import qualified Data.Vector as V +import Foreign +import GHC.Exts (int2Word#, addr2Int#) +import GHC.Num (integerFromWord#) +import GHC.Ptr (Ptr(..)) +import GHC.Stack (HasCallStack) +import Numeric (showHex) +import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType) +import System.IO.Unsafe (unsafePerformIO) + +import Prelude hiding ((^)) +import qualified Prelude + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty (ppSTy, ppExpr) +import CHAD.AST.Sparse.Types (isDense) +import CHAD.Compile.Exec +import CHAD.Data +import CHAD.Interpreter.Rep +import qualified CHAD.Util.IdGen as IdGen + + +-- In shape and index arrays, the innermost dimension is on the right (last index). + +-- TODO: test that I'm properly incrementing and decrementing refcounts in all required places + + +-- | Print the compiled AST +debugPrintAST :: Bool; debugPrintAST = toEnum 0 +-- | Print the generated C source +debugCSource :: Bool; debugCSource = toEnum 0 +-- | Print extra stuff about reference counts of arrays +debugRefc :: Bool; debugRefc = toEnum 0 +-- | Print some shape-related information +debugShapes :: Bool; debugShapes = toEnum 0 +-- | Print information on allocation +debugAllocs :: Bool; debugAllocs = toEnum 0 +-- | Emit extra C code that checks stuff +emitChecks :: Bool; emitChecks = toEnum 0 + +-- | Returns compiled function plus compilation output (warnings) +compile :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t), String) +compile = \env expr -> do + codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) + + let (source, offsets) = compileToString codeID env expr + when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" + when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" + (lib, compileOutput) <- buildKernel source "kernel" + + let result_type = typeOf expr + result_size = sizeofSTy result_type + + let function val = do + allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do + let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) + serialiseArguments args ptr $ do + callKernelFun lib ptr + ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) + when (ok /= 1) $ + ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) + deserialise result_type ptr (koResultOffset offsets) + return (function, compileOutput) + where + serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r + serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = + serialise t arg ptr off $ + serialiseArguments args ptr k + serialiseArguments _ _ k = k + +-- | 'compile', but writes any produced C compiler output to stderr. +compileStderr :: SList STy env -> Ex env t + -> IO (SList Value env -> IO (Rep t)) +compileStderr env expr = do + (fun, output) <- compile env expr + when (not (null output)) $ + hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" + return fun + + +data StructDecl = StructDecl + String -- ^ name + String -- ^ contents + String -- ^ comment + deriving (Show) + +data Stmt + = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side + | SVarDeclUninit String String -- ^ type, variable name (no initialiser) + | SAsg String CExpr -- ^ variable name, right-hand side + | SBlock (Bag Stmt) + | SIf CExpr (Bag Stmt) (Bag Stmt) + | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for ( = ; name < ; name++) {} + | SVerbatim String -- ^ no implicit ';', just printed as-is + deriving (Show) + +data CExpr + = CELit String -- ^ inserted as-is, assumed no parentheses needed + | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` + | CEProj CExpr String -- ^ field projection: expr.field + | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field + | CEAddrOf CExpr -- ^ &expr + | CEIndex CExpr CExpr -- ^ expr[expr] + | CECall String [CExpr] -- ^ function(arg1, ..., argn) + | CEBinop CExpr String CExpr -- ^ expr + expr + | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr + | CECast String CExpr -- ^ () + deriving (Show) + +printStructDecl :: StructDecl -> ShowS +printStructDecl (StructDecl name contents comment) = + showString "typedef struct { " . showString contents . showString " } " . showString name + . showString ";" . (if null comment then id else showString (" // " ++ comment)) + +printStmt :: Int -> Stmt -> ShowS +printStmt indent = \case + SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";" + SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") + SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" + SBlock stmts + | null stmts -> showString "{}" + | otherwise -> + showString "{" + . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] + . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") + SIf cond b1 b2 -> + showString "if (" . printCExpr 0 cond . showString ") " + . printStmt indent (SBlock b1) + . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) + SLoop typ name e1 e2 stmts -> + showString ("for (" ++ typ ++ " " ++ name ++ " = ") + . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2 + . showString ("; " ++ name ++ "++) ") + . printStmt indent (SBlock stmts) + SVerbatim s -> showString s + +-- d values: +-- * 0: top level +-- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) +-- * 2-...: various operators (see precTable) +-- * 80: address-of operator (&) +-- * 98: inside unknown operator +-- * 99: left of a field projection +-- Unlisted operators are conservatively written with full parentheses. +printCExpr :: Int -> CExpr -> ShowS +printCExpr d = \case + CELit s -> showString s + CEStruct name pairs -> + showParen (d >= 99) $ + showString ("(" ++ name ++ "){") + . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e + | (n, e) <- pairs]) + . showString "}" + CEProj e name -> printCExpr 99 e . showString ("." ++ name) + CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name) + CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e + CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" + CECall n es -> + showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" + CEBinop e1 n e2 -> + let mprec = Map.lookup n precTable + p = maybe (-1) fst mprec -- precedence of this operator + (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments + in showParen (d > p) $ + printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 + CEIf e1 e2 e3 -> + showParen (d > 0) $ + printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 + CECast typ e -> + showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e + where + precTable = Map.fromList + [("||", (2, (2, 2))) + ,("&&", (3, (3, 3))) + ,("==", (4, (5, 5))) + ,("!=", (4, (5, 5))) + ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop + ,(">", (5, (6, 6))) + ,("<=", (5, (6, 6))) + ,(">=", (5, (6, 6))) + ,("+", (6, (6, 7))) + ,("-", (6, (6, 7))) + ,("*", (7, (7, 8))) + ,("/", (7, (7, 8))) + ,("%", (7, (7, 8)))] + +repSTy :: STy t -> String +repSTy (STScal st) = case st of + STI32 -> "int32_t" + STI64 -> "int64_t" + STF32 -> "float" + STF64 -> "double" + STBool -> "uint8_t" +repSTy t = genStructName t + +genStructName, genArrBufStructName :: STy t -> String +(genStructName, genArrBufStructName) = + (\t -> "ty_" ++ gen t + ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension + t -> error $ "genArrBufStructName: not an array type: " ++ show t) + where + -- all tags start with a letter, so the array mangling is unambiguous. + gen :: STy t -> String + gen STNil = "n" + gen (STPair a b) = 'P' : gen a ++ gen b + gen (STEither a b) = 'E' : gen a ++ gen b + gen (STLEither a b) = 'L' : gen a ++ gen b + gen (STMaybe t) = 'M' : gen t + gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t + gen (STScal st) = case st of + STI32 -> "i" + STI64 -> "j" + STF32 -> "f" + STF64 -> "d" + STBool -> "b" + gen (STAccum t) = 'C' : gen (fromSMTy t) + +-- The subtrees contain structs used in the bodies of the structs in this node. +data StructTree = TreeNode [StructDecl] [StructTree] + deriving (Show) + +-- | This function generates the actual struct declarations for each of the +-- types in our language. It thus implicitly "documents" the layout of the +-- types in the C translation. +-- +-- For accumulation it is important that for struct representations of monoid +-- types, the all-zero-bytes value corresponds to the zero value of that type. +buildStructTree :: STy t -> StructTree +buildStructTree topty = case topty of + STNil -> + TreeNode [StructDecl name "" com] [] + STPair a b -> + TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] + [buildStructTree a, buildStructTree b] + STEither a b -> -- 0 -> l, 1 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r + TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] + [buildStructTree a, buildStructTree b] + STMaybe t -> -- 0 -> nothing, 1 -> just + TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] + [buildStructTree t] + STArr n t -> + -- The buffer is trailed by a VLA for the actual array data. + -- TODO: no buffer if n = 0 + TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" + ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] + [buildStructTree t] + STScal _ -> + TreeNode [] [] + STAccum t -> + TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" + ,StructDecl name (name ++ "_buf *buf;") com] + [buildStructTree (fromSMTy t)] + where + name = genStructName topty + com = ppSTy 0 topty + +-- State: already-generated (skippable) struct names +-- Writer: the structs in declaration order +genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () +genStructTreeW (TreeNode these deps) = do + seen <- lift get + case filter ((`Set.notMember` seen) . nameOf) these of + [] -> pure () + structs -> do + lift $ modify (Set.fromList (map nameOf structs) <>) + mapM_ genStructTreeW deps + tell (BList structs) + where + nameOf (StructDecl name _ _) = name + +genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] +genAllStructs tys = + let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys + in toList (evalState (execWriterT m) mempty) + +data CompState = CompState + { csStructs :: Set (Some STy) + , csTopLevelDecls :: Bag String + , csStmts :: Bag Stmt + , csNextId :: Int } + deriving (Show) + +newtype CompM a = CompM (State CompState a) + deriving newtype (Functor, Applicative, Monad) + +runCompM :: CompM a -> (a, CompState) +runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) + +class Monad m => MonadNameGen m where genId :: m Int +instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) +instance MonadNameGen IdGen.IdGen where genId = IdGen.genId +instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) + +genName' :: MonadNameGen m => String -> m String +genName' "" = genName +genName' prefix = (prefix ++) . show <$> genId + +genName :: MonadNameGen m => m String +genName = genName' "x" + +onlyIdGen :: IdGen.IdGen a -> CompM a +onlyIdGen m = CompM $ do + i1 <- gets csNextId + let (res, i2) = IdGen.runIdGen' i1 m + modify (\s -> s { csNextId = i2 }) + return res + +emit :: Stmt -> CompM () +emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } + +scope :: CompM a -> CompM (a, Bag Stmt) +scope m = do + stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) + res <- m + innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) + return (res, innerStmts) + +emitStruct :: STy t -> CompM String +emitStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty) + +-- | Also returns the name of the array buffer struct +emitArrStruct :: STy t -> CompM (String, String) +emitArrStruct ty = CompM $ do + modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } + return (genStructName ty, genArrBufStructName ty) + +emitTLD :: String -> CompM () +emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } + +nameEnv :: SList f env -> SList (Const String) env +nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) + +data KernelOffsets = KernelOffsets + { koArgOffsets :: [Int] -- ^ the function arguments + , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred + , koResultOffset :: Int -- ^ the function result + } + +compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) +compileToString codeID env expr = + let args = nameEnv env + (res, s) = runCompM (compile' args expr) + structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) + + (arg_pairs, arg_metrics) = + unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) + (slistZip env args)) + (arg_offsets, okres_offset) = computeStructOffsets arg_metrics + result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) + + offsets = KernelOffsets + { koArgOffsets = arg_offsets + , koOkResOffset = okres_offset + , koResultOffset = result_offset } + in (,offsets) . ($ "") $ compose + [showString "#include \n" + ,showString "#include \n" + ,showString "#include \n" + ,showString "#include \n" + ,showString "#include \n" + ,showString "#include \n" + ,showString "#include \n\n" + -- PRint-tag + ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n" + + ,compose [printStructDecl sd . showString "\n" | sd <- structs] + ,showString "\n" + + -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" + ,showString "static void* malloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = malloc(n);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void* calloc_instr_fun(size_t n, int line) {\n" + ,showString " void *ptr = calloc(n, 1);\n" + ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n" + else id + ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n" + else id + ,showString " return ptr;\n" + ,showString "}\n" + ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" + ,showString "static void free_instr(void *ptr) {\n" + ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n" + else id + ,showString " free(ptr);\n" + ,showString "}\n\n" + + ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] + + ,showString $ + "static bool typed_kernel(" ++ + repSTy (typeOf expr) ++ " *output" ++ + concatMap (", " ++) + (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ + ") {\n" + ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] + ,showString " *output = " . printCExpr 0 res . showString ";\n" + ,showString " return true;\n" + ,showString "}\n\n" + + ,showString "void kernel(void *data) {\n" + -- Some code here assumes that we're on a 64-bit system, so let's check that + ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n" + else id + ,showString $ " const bool success = typed_kernel(" ++ + "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ + concat (map (\((arg, typ), off) -> + ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" + ++ " /* " ++ arg ++ " */") + (zip arg_pairs arg_offsets)) ++ + "\n );\n" + ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" + ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" + else id + ,showString "}\n"] + +-- | Takes list of metrics (alignment, sizeof). +-- Returns (offsets, size of struct). +computeStructOffsets :: [(Int, Int)] -> ([Int], Int) +computeStructOffsets = go 0 0 + where + go off maxal [(al, sz)] = + ([off], align (max maxal al) (off + sz)) + go off maxal ((al, sz) : pairs@((al2,_):_)) = + first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs + go _ _ [] = ([], 0) + +-- | Assumes that this is called at the correct alignment. +serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r +serialise topty topval ptr off k = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case (topty, topval) of + (STNil, ()) -> k + (STPair a b, (x, y)) -> + serialise a x ptr off $ + serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k + (STEither a _, Left x) -> do + pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STEither _ b, Right y) -> do + pokeByteOff ptr off (1 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k + (STLEither _ _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STLEither a _, Just (Left x)) -> do + pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) + serialise a x ptr (off + alignmentSTy topty) k + (STLEither _ b, Just (Right y)) -> do + pokeByteOff ptr off (2 :: Word8) + serialise b y ptr (off + alignmentSTy topty) k + (STMaybe _, Nothing) -> do + pokeByteOff ptr off (0 :: Word8) + k + (STMaybe t, Just x) -> do + pokeByteOff ptr off (1 :: Word8) + serialise t x ptr (off + alignmentSTy t) k + (STArr n t, Array sh vec) -> do + let eltsz = sizeofSTy t + allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do + when debugRefc $ + hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr + pokeByteOff ptr off bufptr + pokeShape ptr (off + 8) n sh + + pokeByteOff @Word64 bufptr 0 (2 ^ 63) + + let loop i + | i == shapeSize sh = k + | otherwise = + serialise t (vec V.! i) bufptr (8 + i * eltsz) $ + loop (i+1) + loop 0 + (STScal sty, x) -> case sty of + STI32 -> pokeByteOff ptr off (x :: Int32) >> k + STI64 -> pokeByteOff ptr off (x :: Int64) >> k + STF32 -> pokeByteOff ptr off (x :: Float) >> k + STF64 -> pokeByteOff ptr off (x :: Double) >> k + STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k + (STAccum{}, _) -> error "Cannot serialise accumulators" + +-- | Assumes that this is called at the correct alignment. +deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) +deserialise topty ptr off = + -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls + case topty of + STNil -> return () + STPair a b -> do + x <- deserialise a ptr off + y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a)) + return (x, y) + STEither a b -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) + then Left <$> deserialise a ptr (off + alignmentSTy topty) + else Right <$> deserialise b ptr (off + alignmentSTy topty) + STLEither a b -> do + tag <- peekByteOff @Word8 ptr off + case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) + 0 -> return Nothing + 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) + 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) + _ -> error "Invalid tag value" + STMaybe t -> do + tag <- peekByteOff @Word8 ptr off + if tag == 0 + then return Nothing + else Just <$> deserialise t ptr (off + alignmentSTy t) + STArr n t -> do + bufptr <- peekByteOff @(Ptr ()) ptr off + sh <- peekShape ptr (off + 8) n + refc <- peekByteOff @Word64 bufptr 0 + when debugRefc $ + hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc + let eltsz = sizeofSTy t + arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) + when (refc < 2 ^ 62) $ free bufptr + return arr + STScal sty -> case sty of + STI32 -> peekByteOff @Int32 ptr off + STI64 -> peekByteOff @Int64 ptr off + STF32 -> peekByteOff @Float ptr off + STF64 -> peekByteOff @Double ptr off + STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off + STAccum{} -> error "Cannot serialise accumulators" + +align :: Int -> Int -> Int +align a off = (off + a - 1) `div` a * a + +alignmentSTy :: STy t -> Int +alignmentSTy = fst . metricsSTy + +sizeofSTy :: STy t -> Int +sizeofSTy = snd . metricsSTy + +-- | Returns (alignment, sizeof) +metricsSTy :: STy t -> (Int, Int) +metricsSTy STNil = (1, 0) +metricsSTy (STPair a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, align (max a1 a2) (s1 + s2)) +metricsSTy (STEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STLEither a b) = + let (a1, s1) = metricsSTy a + (a2, s2) = metricsSTy b + in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned +metricsSTy (STMaybe t) = + let (a, s) = metricsSTy t + in (a, a + s) -- the union after the tag byte is aligned +metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) +metricsSTy (STScal sty) = case sty of + STI32 -> (4, 4) + STI64 -> (8, 8) + STF32 -> (4, 4) + STF64 -> (8, 8) + STBool -> (1, 1) -- compiled to uint8_t +metricsSTy (STAccum t) = metricsSTy (fromSMTy t) + +pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () +pokeShape ptr off = go . fromSNat + where + go :: Int -> Shape n -> IO () + go rank = \case + ShNil -> return () + sh `ShCons` n -> do + pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64) + go (rank - 1) sh + +peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) +peekShape ptr off = \case + SZ -> return ShNil + SS n -> ShCons <$> peekShape ptr off n + <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) + +compile' :: SList (Const String) env -> Ex env t -> CompM CExpr +compile' env = \case + EVar _ t i -> do + let Const var = slistIdx env i + incrementVarAlways "var" Increment t var + return $ CELit var + + ELet _ rhs body -> do + var <- compileAssign "" env rhs + rete <- compile' (Const var `SCons` env) body + incrementVarAlways "let" Decrement (typeOf rhs) var + return rete + + EPair _ a b -> do + name <- emitStruct (STPair (typeOf a) (typeOf b)) + e1 <- compile' env a + e2 <- compile' env b + return $ CEStruct name [("a", e1), ("b", e2)] + + EFst _ e -> do + let STPair _ t2 = typeOf e + e' <- compile' env e + case incrementVar "fst" Decrement t2 of + Nothing -> return $ CEProj e' "a" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".b") + return $ CEProj (CELit var) "a" + + ESnd _ e -> do + let STPair t1 _ = typeOf e + e' <- compile' env e + case incrementVar "snd" Decrement t1 of + Nothing -> return $ CEProj e' "b" + Just f -> do var <- genName + emit $ SVarDecl True (repSTy (typeOf e)) var e' + f (var ++ ".a") + return $ CEProj (CELit var) "b" + + ENil _ -> do + name <- emitStruct STNil + return $ CEStruct name [] + + EInl _ t e -> do + name <- emitStruct (STEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "0"), ("l", e1)] + + EInr _ t e -> do + name <- emitStruct (STEither t (typeOf e)) + e2 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("r", e2)] + + ECase _ (EOp _ OIf e) a b -> do + e1 <- compile' env e + (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you + (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SIf e1 + (stmts2 <> pure (SAsg retvar e2)) + (stmts3 <> pure (SAsg retvar e3)) + return (CELit retvar) + + ECase _ e a b -> do + let STEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + -- I know those are not variable names, but it's fine, probably + (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b + ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 + <> stmtsRel1 + <> pure (SAsg retvar e2)) + (stmts3 + <> stmtsRel2 + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + ENothing _ t -> do + name <- emitStruct (STMaybe t) + return $ CEStruct name [("tag", CELit "0")] + + EJust _ e -> do + name <- emitStruct (STMaybe (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("j", e1)] + + EMaybe _ a b e -> do + let STMaybe t = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b + ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 + <> pure (SAsg retvar e2)) + (stmts3 + <> stmtsRel + <> pure (SAsg retvar e3)))) + return (CELit retvar) + + ELNil _ t1 t2 -> do + name <- emitStruct (STLEither t1 t2) + return $ CEStruct name [("tag", CELit "0")] + + ELInl _ t e -> do + name <- emitStruct (STLEither (typeOf e) t) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "1"), ("l", e1)] + + ELInr _ t e -> do + name <- emitStruct (STLEither t (typeOf e)) + e1 <- compile' env e + return $ CEStruct name [("tag", CELit "2"), ("r", e1)] + + ELCase _ e a b c -> do + let STLEither t1 t2 = typeOf e + e1 <- compile' env e + var <- genName + (e2, stmts2) <- scope $ compile' env a + (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b + (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c + ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") + ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") + retvar <- genName + emit $ SVarDeclUninit (repSTy (typeOf a)) retvar + emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) + <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) + (stmts2 <> pure (SAsg retvar e2)) + (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) + (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) + (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) + return (CELit retvar) + + EConstArr _ n t (Array sh vec) -> do + (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) + tldname <- genName' "carraybuf" + -- Give it a refcount of _half_ the size_t max, so that it can be + -- incremented and decremented at will and will "never" reach anything + -- where something happens + emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ + "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ + ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" + return (CEStruct strname + [("buf", CEAddrOf (CELit tldname)) + ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) + + EBuild _ n esh efun -> do + shname <- compileAssign "sh" env esh + + arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) + + idxargname <- genName' "ix" + (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun + + linivar <- genName' "li" + ivars <- replicateM (fromSNat n) (genName' "i") + emit $ SBlock $ + pure (SVarDecl False "size_t" linivar (CELit "0")) + <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") + (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) + | (ivar, dimidx) <- zip ivars [0::Int ..]] + (pure (SVarDecl True (repSTy (typeOf esh)) idxargname + (shapeTupFromLitVars n ivars)) + <> funstmts + <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval)) + + return (CELit arrname) + + -- TODO: actually generate decent code here + EMap _ e1 e2 -> do + let STArr n _ = typeOf e2 + compile' env $ + elet e2 $ + EBuild ext n (EShape ext (evar IZ)) $ + elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e1 + + EFold1Inner _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + + -- let vecwid = case commut of Commut -> 8 :: Int + -- Noncommut -> 1 + + x0name <- compileAssign "foldx0" env ex0 + arrname <- compileAssign "foldarr" env earr + + zeroRefcountCheck (typeOf earr) "fold1i" arrname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) + + resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + -- kvar <- if vecwid > 1 then genName' "k" else return "" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + + let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ + ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SAsg accvar funres)) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldx0" Decrement t x0name + incrementVarAlways "foldarr" Decrement (typeOf earr) arrname + + return (CELit resname) + + ESum1Inner _ e -> do + let STArr (SS n) t = typeOf e + argname <- compileAssign "sumarg" env e + + zeroRefcountCheck (typeOf e) "sum1i" argname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, like EFold1Inner. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + let vecwid = 8 :: Int + ivar <- genName' "i" + jvar <- genName' "j" + kvar <- genName' "k" + accvar <- genName' "tot" + let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 0 and (+) in C + [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" + ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ + pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" + ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ + pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] + + incrementVarAlways "sum" Decrement (typeOf e) argname + + return (CELit resname) + + EUnit _ e -> do + e' <- compile' env e + let typ = STArr SZ (typeOf e) + strname <- emitStruct typ + name <- genName + emit $ SVarDecl True strname name (CEStruct strname + [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])]) + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + emit $ SAsg (name ++ ".buf->xs[0]") e' + return (CELit name) + + EReplicate1Inner _ elen earg -> do + let STArr n t = typeOf earg + lenname <- compileAssign "replen" env elen + argname <- compileAssign "reparg" env earg + + zeroRefcountCheck (typeOf earg) "replicate1i" argname + + shszname <- genName' "shsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray "repl1i" Malloc "rep" (SS n) t + (Just (CEBinop (CELit shszname) "*" (CELit lenname))) + (compileArrShapeComponents n argname ++ [CELit lenname]) + + ivar <- genName' "i" + jvar <- genName' "j" + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]") + (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]")) + + incrementVarAlways "repl1i" Decrement (typeOf earg) argname + + return (CELit resname) + + EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e + + EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e + + EReshape _ dim esh earg -> do + let STArr origDim eltty = typeOf earg + strname <- emitStruct (STArr dim eltty) + + shname <- compileAssign "reshsh" env esh + arrname <- compileAssign "resharg" env earg + + when emitChecks $ do + emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ + printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ + printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") + mempty + + return (CEStruct strname + [("buf", CEProj (CELit arrname) "buf") + ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) + + -- TODO: actually generate decent code here + EZip _ e1 e2 -> do + let STArr n _ = typeOf e1 + compile' env $ + elet e1 $ + elet (weakenExpr WSink e2) $ + EBuild ext n (EShape ext (evar (IS IZ))) $ + EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) + (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) + + EFold1InnerD1 _ commut efun ex0 earr -> do + let STArr (SS n) t = typeOf earr + STPair _ bty = typeOf efun + + x0name <- compileAssign "foldd1x0" env ex0 + arrname <- compileAssign "foldd1arr" env earr + + zeroRefcountCheck (typeOf earr) "fold1iD1" arrname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) + storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) + + ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "tot" + pairvar <- genName' "pair" -- function input + (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar + arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" + funresvar <- genName' "res" + ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit + + pairstrname <- emitStruct (STPair t t) + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t) accvar (CELit x0name)) + <> x0incrStmts -- we're copying x0 here + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the array element + -- and the accumulator. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the array element. + arreltIncrStmts + <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd1x0" Decrement t x0name + incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname + + strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) + return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) + + EFold1InnerD2 _ commut efun estores ectg -> do + let STArr n t2 = typeOf ectg + STArr _ bty = typeOf estores + + storesname <- compileAssign "foldd2stores" env estores + ctgname <- compileAssign "foldd2ctg" env ectg + + zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + shsz1name <- genName' "shszN" + emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape + shsz2name <- genName' "shszSN" + emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) + + x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) + outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) + + ivar <- genName' "i" + jvar <- genName' "j" + + accvar <- genName' "acc" + let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar + storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" + ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" + (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun + funresvar <- genName' "res" + ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit + ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ + pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) + <> ctgeltIncrStmts + -- we need to loop in reverse here, but we let jvar run in the + -- forward direction so that we can use SLoop. Note jvar is + -- reversed in eltidx above + <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ + -- The combination function will consume the accumulator + -- and the stores element. The accumulator is replaced by + -- what comes out of the function anyway, so that's + -- fine, but we do need to increment the stores element. + storeseltIncrStmts + <> funStmts + <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) + <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) + <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) + <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) + + incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname + incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname + + strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) + return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) + + EConst _ t x -> return $ CELit $ compileScal True t x + + EIdx0 _ e -> do + let STArr _ t = typeOf e + arrname <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "idx0" arrname + name <- genName + emit $ SVarDecl True (repSTy t) name + (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) + incrementVarAlways "idx0" Decrement (STArr SZ t) arrname + return (CELit name) + + -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) + + EIdx _ earr eidx -> do + let STArr n t = typeOf earr + arrname <- compileAssign "ixarr" env earr + zeroRefcountCheck (typeOf earr) "idx" arrname + idxname <- if fromSNat n > 0 -- prevent an unused-varable warning + then compileAssign "ixix" env eidx + else return "" -- won't be used in this case + + when emitChecks $ + forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> + emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" + (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ + arrname ++ ".buf); return false;") + mempty + + resname <- genName' "ixres" + emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) + incrementVarAlways "idxelt" Increment t resname + incrementVarAlways "idx" Decrement (STArr n t) arrname + return (CELit resname) + + EShape _ e -> do + let STArr n _ = typeOf e + t = tTup (sreplicate n tIx) + _ <- emitStruct t + name <- compileAssign "" env e + zeroRefcountCheck (typeOf e) "shape" name + resname <- genName + emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) + incrementVarAlways "shape" Decrement (typeOf e) name + return (CELit resname) + + EOp _ op (EPair _ e1 e2) -> do + e1' <- compile' env e1 + e2' <- compile' env e2 + compileOpPair op e1' e2' + + EOp _ op e -> do + e' <- compile' env e + compileOpGeneral op e' + + ECustom _ _ _ _ earg _ _ e1 e2 -> do + name1 <- compileAssign "" env e1 + name2 <- compileAssign "" env e2 + case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of + (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg + (mfun1, mfun2) -> do + name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg + maybe (return ()) ($ name1) mfun1 + maybe (return ()) ($ name2) mfun2 + return (CELit name) + + ERecompute _ e -> compile' env e + + EWith _ t e1 e2 -> do + actyname <- emitStruct (STAccum t) + name1 <- compileAssign "" env e1 + + zeroRefcountCheck (typeOf e1) "with" name1 + + emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" + mcopy <- copyForWriting t name1 + accname <- genName' "accum" + emit $ SVarDecl False actyname accname + (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) + emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) + emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." + + e2' <- compile' (Const accname `SCons` env) e2 + + resname <- genName' "acret" + emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) + emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" + + rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) + return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] + + EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do + let -- Add a value (s) into an existing accumulation value (d). If a sparse + -- component of d is encountered, s is copied there. + add :: SMTy a -> String -> String -> CompM () + add SMTNil _ _ = return () + add (SMTPair t1 t2) d s = do + add t1 (d++".a") (s++".a") + add t2 (d++".b") (s++".b") + add (SMTLEither t1 t2) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s + ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") + ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + ((if emitChecks + then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) + "&&" + (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ + "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ + "return false;") + mempty) + else mempty) + -- note: s may have tag 0 + <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) + stmts1 + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) + stmts2 mempty)))) + add (SMTMaybe t1) d s = do + ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s + ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") + emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) + (pure (SAsg d (CELit s)) + <> srcIncrStmts) + (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) + add (SMTArr n t1) d s = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ [0 .. fromSNat n - 1] $ \j -> do + emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) + "!=" + (CELit (d ++ ".sh[" ++ show j ++ "]"))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ + "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ + d ++ ".buf" ++ + concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + ", " ++ s ++ ".buf" ++ + concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + "); " ++ + "return false;") + mempty + + shsizename <- genName' "acshsz" + emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) + ivar <- genName' "i" + ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) + stmts1 + add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" + + let -- | Dereference an accumulation value and add a given value to that + -- position. Sparse components encountered along the way are + -- initialised before proceeding downwards. + -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) + accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () + accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend + + accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend + accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend + + accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef ta prj' (v++".l") i addend + accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tb prj' (v++".r") i addend + + accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do + when emitChecks $ do + emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ + "return false;") + mempty + accumRef tj prj' (v++".j") i addend + + accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do + when emitChecks $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + forM_ (zip [0::Int ..] + (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do + let a .||. b = CEBinop a "||" b + emit $ SIf (CEBinop ixcomp "<" (CELit "0") + .||. + CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) + (pure $ SVerbatim $ + "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ + "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ + v ++ ".buf" ++ + concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ + concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ + "); " ++ + "return false;") + mempty + + accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend + + nameidx <- compileAssign "acidx" env eidx + nameval <- compileAssign "acval" env eval + nameacc <- compileAssign "acac" env eacc + + emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" + accumRef t prj (nameacc++".buf->ac") nameidx nameval + emit $ SVerbatim $ "// compile EAccum end" + + incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval + + return $ CEStruct (repSTy STNil) [] + + EAccum{} -> + error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" + + EError _ t s -> do + let padleft len c s' = replicate (len - length s) c ++ s' + escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] + | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") + | otherwise -> [c] + emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" + case t of + STScal _ -> return (CELit "0") + _ -> do + name <- emitStruct t + return $ CEStruct name [] + + EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" + + EIdx1{} -> error "Compile: not implemented: EIdx1" + +compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String +compileAssign prefix env e = do + e' <- compile' env e + case e' of + CELit name -> return name + _ -> do + name <- genName' prefix + emit $ SVarDecl True (repSTy (typeOf e)) name e' + return name + +data Increment = Increment | Decrement + deriving (Show) + +-- | Increment reference counts in the components of the given variable. +incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ()) +incrementVar marker inc ty = + let tree = makeArrayTree ty + in case tree of ATNoop -> Nothing + _ -> Just $ \var -> incrementVar' marker inc var tree + +incrementVarAlways :: String -> Increment -> STy a -> String -> CompM () +incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty) + +data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array) + | ATNoop -- ^ don't do anything here + | ATProj String ArrayTree -- ^ descend one field deeper + | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second + | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 + | ATBoth ArrayTree ArrayTree -- ^ do both these paths + +smartATProj :: String -> ArrayTree -> ArrayTree +smartATProj _ ATNoop = ATNoop +smartATProj field t = ATProj field t + +smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree +smartATCondTag ATNoop ATNoop = ATNoop +smartATCondTag t t' = ATCondTag t t' + +smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree +smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop +smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 + +smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree +smartATBoth ATNoop t = t +smartATBoth t ATNoop = t +smartATBoth t t' = ATBoth t t' + +makeArrayTree :: STy a -> ArrayTree +makeArrayTree STNil = ATNoop +makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) + (smartATProj "b" (makeArrayTree b)) +makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop + (smartATProj "l" (makeArrayTree a)) + (smartATProj "r" (makeArrayTree b)) +makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) +makeArrayTree (STArr n t) = ATArray (Some n) (Some t) +makeArrayTree (STScal _) = ATNoop +makeArrayTree (STAccum _) = ATNoop + +incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () +incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = + case inc of + Increment -> do + emit $ SVerbatim (path ++ ".buf->refc++;") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" + Decrement -> do + case incrementVar (marker++".elt") Decrement eltty of + Nothing -> + if debugRefc + then do + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" + else do + emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);" + Just f -> do + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" + shszvar <- genName' "frshsz" + ivar <- genName' "i" + ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") + emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0")) + (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path) + ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $ + eltDecrStmts + ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"]) + mempty +incrementVar' _ _ _ ATNoop = pure () +incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t +incrementVar' marker inc path (ATCondTag t1 t2) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 +incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do + ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 + ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 + ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 + emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) + stmts2 + (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) + stmts3 + stmts1)) +incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 + +toLinearIdx :: SNat n -> String -> String -> CExpr +toLinearIdx SZ _ _ = CELit "0" +toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") +toLinearIdx (SS n) arrvar idxvar = + CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) + "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) + "+" (CELit (idxvar ++ ".b")) + +-- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr +-- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] +-- fromLinearIdx (SS n) arrvar idxvar = do +-- name <- genName +-- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) +-- _ + +data AllocMethod = Malloc | Calloc + deriving (Show) + +-- | The shape must have the outer dimension at the head (and the inner dimension on the right). +allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String +allocArray marker method nameBase rank eltty mshsz shape = do + when (length shape /= fromSNat rank) $ + error "allocArray: shape does not match rank" + let arrty = STArr rank eltty + strname <- emitStruct arrty + arrname <- genName' nameBase + shsz <- case mshsz of + Just e -> return e + Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) + let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) + "+" + (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) + emit $ SVarDecl True strname arrname $ CEStruct strname + [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] + Calloc -> CECall "calloc_instr" [nbytesExpr]) + ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] + emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") + when debugRefc $ + emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" + return arrname + +compileShapeQuery :: SNat n -> String -> CExpr +compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] +compileShapeQuery (SS n) var = + CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) + [("a", compileShapeQuery n var) + ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeSize :: SNat n -> String -> CExpr +compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) + +-- | Takes a variable name for the array, not the buffer. +compileArrShapeComponents :: SNat n -> String -> [CExpr] +compileArrShapeComponents n var = + [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] + +indexTupleComponents :: SNat n -> String -> [CExpr] +indexTupleComponents = \n var -> map CELit (toList (go n var)) + where + go :: SNat n -> String -> Bag String + go SZ _ = mempty + go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b") + +-- | Takes variable names with the innermost dimension on the right. +shapeTupFromLitVars :: SNat n -> [String] -> CExpr +shapeTupFromLitVars = \n -> go n . reverse + where + -- takes variables with the innermost dimension at the _head_ + go :: SNat n -> [String] -> CExpr + go SZ [] = CEStruct (repSTy STNil) [] + go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] + go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" + +prodExpr :: [CExpr] -> CExpr +prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") + +compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr +compileOpGeneral op e1 = do + let unary cop = return @CompM $ CECall cop [e1] + let binary cop = do + name <- genName + emit $ SVarDecl True (repSTy (opt1 op)) name e1 + return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + ONeg _ -> unary "-" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + ONot -> unary "!" + OAnd -> binary "&&" + OOr -> binary "||" + OIf -> do + name <- emitStruct (STEither STNil STNil) + _ <- emitStruct STNil + return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) + (CEStruct name [("tag", CELit "1")]) + ORound64 -> unary "(int64_t)round" -- ew + OToFl64 -> unary "(double)" + ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 + OExp STF32 -> unary "expf" + OExp STF64 -> unary "exp" + OLog STF32 -> unary "logf" + OLog STF64 -> unary "log" + OIDiv _ -> binary "/" + OMod _ -> binary "%" + +compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr +compileOpPair op e1 e2 = do + let binary cop = return @CompM $ CEBinop e1 cop e2 + case op of + OAdd _ -> binary "+" + OMul _ -> binary "*" + OLt _ -> binary "<" + OLe _ -> binary "<=" + OEq _ -> binary "==" + OAnd -> binary "&&" + OOr -> binary "||" + OIDiv _ -> binary "/" + OMod _ -> binary "%" + _ -> error "compileOpPair: got unary operator" + +-- | Bool: whether to ensure that the literal itself already has the appropriate type +compileScal :: Bool -> SScalTy t -> ScalRep t -> String +compileScal pedantic typ x = case typ of + STI32 -> (if pedantic then "(int32_t)" else "") ++ show x + STI64 -> (if pedantic then "(int64_t)" else "") ++ show x + STF32 -> show x ++ "f" + STF64 -> show x + STBool -> if x then "1" else "0" + +compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr +compileExtremum nameBase opName operator env e = do + let STArr (SS n) t = typeOf e + argname <- compileAssign (nameBase ++ "arg") env e + + zeroRefcountCheck (typeOf e) opName argname + + shszname <- genName' "shsz" + -- This n is one less than the shape of the thing we're querying, which is + -- unexpected. But it's exactly what we want, so we do it anyway. + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) + + resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) + + lenname <- genName' "n" + emit $ SVarDecl True (repSTy tIx) lenname + (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) + + emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" + + ivar <- genName' "i" + jvar <- genName' "j" + xvar <- genName + redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList + -- we have ScalIsNumeric, so it has 1 and (<) etc. in C + [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]")) + ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList + [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) + ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) + ] + ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)] + + incrementVarAlways nameBase Decrement (typeOf e) argname + + return (CELit resname) + +-- | If this returns Nothing, there was nothing to copy because making a simple +-- value copy in C already makes it suitable to write to. +copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) +copyForWriting topty var = case topty of + SMTNil -> return Nothing + + SMTPair a b -> do + e1 <- copyForWriting a (var ++ ".a") + e2 <- copyForWriting b (var ++ ".b") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> return $ Just $ CEStruct toptyname + [("a", fromMaybe (CELit (var++".a")) e1) + ,("b", fromMaybe (CELit (var++".b")) e2)] + + SMTLEither a b -> do + (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") + (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") + case (e1, e2) of + (Nothing, Nothing) -> return Nothing + _ -> do + name <- genName + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (stmts1 + <> pure (SAsg name (CEStruct toptyname + [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) + (stmts2 + <> pure (SAsg name (CEStruct toptyname + [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) + return (Just (CELit name)) + + SMTMaybe t -> do + (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") + case e1 of + Nothing -> return Nothing + Just e1' -> do + name <- genName + emit $ SVarDeclUninit toptyname name + emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) + (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) + (stmts1 + <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) + return (Just (CELit name)) + + -- If there are no nested arrays, we know that a refcount of 1 means that the + -- whole thing is owned. Nested arrays have their own refcount, so with + -- nesting we'd have to check the refcounts of all the nested arrays _too_; + -- let's not do that. Furthermore, no sub-arrays means that the whole thing + -- is flat, and we can just memcpy if necessary. + SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do + name <- genName + shszname <- genName' "shsz" + emit $ SVarDeclUninit toptyname name + + when debugShapes $ do + let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" + emit $ SVerbatim $ + "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ + concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ + ");" + + emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) + (pure (SAsg name (CELit var))) + (let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + in BList + [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" + ,SAsg (name ++ ".buf->refc") (CELit "1") + ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ + printCExpr 0 databytes ");"]) + return (Just (CELit name)) + + SMTArr n t -> do + shszname <- genName' "shsz" + emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) + + let shbytes = fromSNat n * 8 + databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) + totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes + + name <- genName + emit $ SVarDecl False toptyname name + (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) + emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" + emit $ SAsg (name ++ ".buf->refc") (CELit "1") + + -- put the arrays in variables to cut short the not-quite-var chain + dstvar <- genName' "cpydst" + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) + srcvar <- genName' "cpysrc" + emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) + + ivar <- genName' "i" + + (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") + let cpye' = case cpye of + Just e -> e + Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" + + emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ + cpystmts + <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') + + return (Just (CELit name)) + + SMTScal _ -> return Nothing + + where + toptyname = repSTy (fromSMTy topty) + +zeroRefcountCheck :: STy t -> String -> String -> CompM () +zeroRefcountCheck toptyp opname topvar = + when emitChecks $ do + mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) + case mstmts of + Nothing -> return () + Just stmts -> forM_ stmts emit + where + -- | If this returns 'Nothing', no statements need to be generated for this type. + go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) + go STNil _ = empty + go (STPair a b) path = do + (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) + return (s1 <> s2) + go (STEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 + go (STLEither a b) path = do + (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) + return $ pure $ + SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) + s1 + (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) + s2 + mempty)) + go (STMaybe a) path = do + ss <- go a (path++".j") + return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty + go (STArr n a) path = do + ivar <- genName' "i" + ss <- go a (path++".buf->xs["++ivar++"]") + shszname <- genName' "shsz" + let s1 = SVerbatim $ + "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ + "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++ + "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" + let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) + let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss + return (BList [s1, s2, s3]) + go STScal{} _ = empty + go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" + + combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) + combine (MaybeT a) (MaybeT b) = MaybeT $ do + x <- a + y <- b + return $ case (x, y) of + (Nothing, Nothing) -> Nothing + (Just x', Nothing) -> Just (x', mempty) + (Nothing, Just y') -> Just (mempty, y') + (Just x', Just y') -> Just (x', y') + +{-# NOINLINE uniqueIdGenRef #-} +uniqueIdGenRef :: IORef Int +uniqueIdGenRef = unsafePerformIO $ newIORef 1 + +compose :: Foldable t => t (a -> a) -> a -> a +compose = foldr (.) id + +showPtr :: Ptr a -> String +showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" + +-- | Type-restricted. +(^) :: Num a => a -> Int -> a +(^) = (Prelude.^) + +foldl0' :: (a -> a -> a) -> a -> [a] -> a +foldl0' _ x [] = x +foldl0' f _ l = foldl1' f l diff --git a/src/CHAD/Compile/Exec.hs b/src/CHAD/Compile/Exec.hs new file mode 100644 index 0000000..5b4afc8 --- /dev/null +++ b/src/CHAD/Compile/Exec.hs @@ -0,0 +1,99 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TupleSections #-} +module CHAD.Compile.Exec ( + KernelLib, + buildKernel, + callKernelFun, + + -- * misc + lineNumbers, +) where + +import Control.Monad (when) +import Data.IORef +import Foreign (Ptr) +import Foreign.Ptr (FunPtr) +import System.Directory (removeDirectoryRecursive) +import System.Environment (lookupEnv) +import System.Exit (ExitCode(..)) +import System.IO (hPutStrLn, stderr) +import System.IO.Error (mkIOError, userErrorType) +import System.IO.Unsafe (unsafePerformIO) +import System.Posix.DynamicLinker +import System.Posix.Temp (mkdtemp) +import System.Process (readProcessWithExitCode) + + +debug :: Bool +debug = False + +-- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) +data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) + +buildKernel :: String -> String -> IO (KernelLib, String) +buildKernel csource funname = do + template <- (++ "/tmp.chad.") <$> getTempDir + path <- mkdtemp template + + let outso = path ++ "/out.so" + let args = ["-O3", "-march=native" + ,"-shared", "-fPIC" + ,"-std=c99", "-x", "c" + ,"-o", outso, "-" + ,"-Wall", "-Wextra" + ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" + ,"-Wno-unused-parameter", "-Wno-unused-function" + ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives + ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain + (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource + + -- Print the source before the GCC output. + case ec of + ExitSuccess -> return () + ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" + + case ec of + ExitSuccess -> return () + ExitFailure{} -> do + removeDirectoryRecursive path + ioError (mkIOError userErrorType "chad kernel compilation failed" Nothing Nothing) + + numLoaded <- atomicModifyIORef' numLoadedCounter (\n -> (n+1, n+1)) + when debug $ hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)" + dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL] + + removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now + + ref <- newIORef =<< dlsym dl funname + _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) + when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" + dlclose dl) + return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) + +foreign import ccall "dynamic" + wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () + +-- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive +{-# NOINLINE callKernelFun #-} +callKernelFun :: KernelLib -> Ptr () -> IO () +callKernelFun (KernelLib ref) arg = do + ptr <- readIORef ref + wrapKernelFun ptr arg + +getTempDir :: IO FilePath +getTempDir = + lookupEnv "TMPDIR" >>= \case + Just s | not (null s) -> return s + _ -> return "/tmp" + +{-# NOINLINE numLoadedCounter #-} +numLoadedCounter :: IORef Int +numLoadedCounter = unsafePerformIO $ newIORef 0 + +lineNumbers :: String -> String +lineNumbers str = + let lns = lines str + numlines = length lns + width = length (show numlines) + pad s = replicate (width - length s) ' ' ++ s + in unlines (zipWith (\i ln -> pad (show i) ++ " | " ++ ln) [1::Int ..] lns) diff --git a/src/CHAD/Data.hs b/src/CHAD/Data.hs new file mode 100644 index 0000000..8c7605c --- /dev/null +++ b/src/CHAD/Data.hs @@ -0,0 +1,192 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveTraversable #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Data (module CHAD.Data, (:~:)(Refl), If) where + +import Data.Functor.Product +import Data.GADT.Compare +import Data.GADT.Show +import Data.Some +import Data.Type.Bool (If) +import Data.Type.Equality +import Unsafe.Coerce (unsafeCoerce) + +import CHAD.Lemmas (Append) + + +data Dict c where + Dict :: c => Dict c + + +data SList f l where + SNil :: SList f '[] + SCons :: f a -> SList f l -> SList f (a : l) +deriving instance (forall a. Show (f a)) => Show (SList f l) +infixr `SCons` + +slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list +slistMap _ SNil = SNil +slistMap f (SCons x list) = SCons (f x) (slistMap f list) + +slistMapA :: Applicative m => (forall t. f t -> m (g t)) -> SList f list -> m (SList g list) +slistMapA _ SNil = pure SNil +slistMapA f (SCons x list) = SCons <$> f x <*> slistMapA f list + +slistZip :: SList f list -> SList g list -> SList (Product f g) list +slistZip SNil SNil = SNil +slistZip (x `SCons` l1) (y `SCons` l2) = Pair x y `SCons` slistZip l1 l2 + +unSList :: (forall t. f t -> a) -> SList f list -> [a] +unSList _ SNil = [] +unSList f (x `SCons` l) = f x : unSList f l + +showSList :: (forall t. Int -> f t -> String) -> SList f list -> String +showSList _ SNil = "SNil" +showSList f (x `SCons` l) = f 11 x ++ " `SCons` " ++ showSList f l + +sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) +sappend SNil l = l +sappend (SCons x xs) l = SCons x (sappend xs l) + +type family Replicate n x where + Replicate Z x = '[] + Replicate (S n) x = x : Replicate n x + +sreplicate :: SNat n -> f t -> SList f (Replicate n t) +sreplicate SZ _ = SNil +sreplicate (SS n) x = x `SCons` sreplicate n x + +data Nat = Z | S Nat + deriving (Show, Eq, Ord) + +type N0 = Z +type N1 = S N0 +type N2 = S N1 +type N3 = S N2 + +data SNat n where + SZ :: SNat Z + SS :: SNat n -> SNat (S n) +deriving instance Show (SNat n) + +instance GCompare SNat where + gcompare SZ SZ = GEQ + gcompare SZ _ = GLT + gcompare _ SZ = GGT + gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n') + +instance TestEquality SNat where testEquality = geq +instance GEq SNat where geq = defaultGeq +instance GShow SNat where gshowsPrec = defaultGshowsPrec + +fromSNat :: SNat n -> Int +fromSNat SZ = 0 +fromSNat (SS n) = succ (fromSNat n) + +unSNat :: SNat n -> Nat +unSNat SZ = Z +unSNat (SS n) = S (unSNat n) + +reSNat :: Nat -> Some SNat +reSNat Z = Some SZ +reSNat (S n) | Some n' <- reSNat n = Some (SS n') + +class KnownNat n where knownNat :: SNat n +instance KnownNat Z where knownNat = SZ +instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat + +snatKnown :: SNat n -> Dict (KnownNat n) +snatKnown SZ = Dict +snatKnown (SS n) | Dict <- snatKnown n = Dict + +type family n + m where + Z + m = m + S n + m = S (n + m) + +type family n - m where + n - Z = n + S n - S m = n - m + +snatAdd :: SNat n -> SNat m -> SNat (n + m) +snatAdd SZ m = m +snatAdd (SS n) m = SS (snatAdd n m) + +lemPlusSuccRight :: n + S m :~: S (n + m) +lemPlusSuccRight = unsafeCoerceRefl + +lemPlusZero :: n + Z :~: n +lemPlusZero = unsafeCoerceRefl + +data Vec n t where + VNil :: Vec Z t + (:<) :: t -> Vec n t -> Vec (S n) t +deriving instance Show t => Show (Vec n t) +deriving instance Eq t => Eq (Vec n t) +deriving instance Functor (Vec n) +deriving instance Foldable (Vec n) +deriving instance Traversable (Vec n) + +vecLength :: Vec n t -> SNat n +vecLength VNil = SZ +vecLength (_ :< v) = SS (vecLength v) + +vecGenerate :: SNat n -> (forall i. SNat i -> t) -> Vec n t +vecGenerate = \n f -> go n f SZ + where + go :: SNat n -> (forall i. SNat i -> t) -> SNat i' -> Vec n t + go SZ _ _ = VNil + go (SS n) f i = f i :< go n f (SS i) + +vecReplicateA :: Applicative f => SNat n -> f a -> f (Vec n a) +vecReplicateA SZ _ = pure VNil +vecReplicateA (SS n) gen = (:<) <$> gen <*> vecReplicateA n gen + +vecZipWithA :: Applicative f => (a -> b -> f c) -> Vec n a -> Vec n b -> f (Vec n c) +vecZipWithA _ VNil VNil = pure VNil +vecZipWithA f (x :< xs) (y :< ys) = (:<) <$> f x y <*> vecZipWithA f xs ys + +vecInit :: Vec (S n) a -> Vec n a +vecInit (_ :< VNil) = VNil +vecInit (x :< xs@(_ :< _)) = x :< vecInit xs + +unsafeCoerceRefl :: a :~: b +unsafeCoerceRefl = unsafeCoerce Refl + +gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a') +gorderingLift1 GLT = GLT +gorderingLift1 GGT = GGT +gorderingLift1 GEQ = GEQ + +gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b') +gorderingLift2 GLT _ = GLT +gorderingLift2 GGT _ = GGT +gorderingLift2 GEQ GLT = GLT +gorderingLift2 GEQ GGT = GGT +gorderingLift2 GEQ GEQ = GEQ + +data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t] + deriving (Show, Functor, Foldable, Traversable) + +-- | This instance is mostly there just for 'pure' +instance Applicative Bag where + pure = BOne + BNone <*> _ = BNone + BOne f <*> b = f <$> b + BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b) + BMany bs <*> b = BMany (map (<*> b) bs) + BList bs <*> b = BMany (map (<$> b) bs) + +instance Semigroup (Bag t) where (<>) = BTwo +instance Monoid (Bag t) where mempty = BNone + +data SBool b where + SF :: SBool False + ST :: SBool True +deriving instance Show (SBool b) diff --git a/src/CHAD/Data/VarMap.hs b/src/CHAD/Data/VarMap.hs new file mode 100644 index 0000000..6e16b82 --- /dev/null +++ b/src/CHAD/Data/VarMap.hs @@ -0,0 +1,119 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RoleAnnotations #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Data.VarMap ( + VarMap, + empty, + insert, + delete, + TypedIdx(..), + lookup, + disjointUnion, + sink1, + unsink1, + subMap, + superMap, +) where + +import Prelude hiding (lookup) + +import qualified Data.Map.Strict as Map +import Data.Map.Strict (Map) +import Data.Maybe (mapMaybe) +import Data.Some +import qualified Data.Vector.Storable as VS +import Unsafe.Coerce + +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken + + +type role VarMap _ nominal -- ensure that 'env' is not phantom +data VarMap k (env :: [Ty]) = + VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index + Int -- ^ Time since last cleanup + (Map k (Some STy, Int)) +deriving instance Show k => Show (VarMap k env) + +empty :: VarMap k env +empty = VarMap 0 0 Map.empty + +insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env +insert k ty idx (VarMap off interval mp) = + maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp) + +delete :: Ord k => k -> VarMap k env -> VarMap k env +delete k (VarMap off interval mp) = + maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp) + +data TypedIdx env t = TypedIdx (STy t) (Idx env t) + deriving (Show) + +lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env)) +lookup k (VarMap off _ mp) = do + (Some ty, i) <- Map.lookup k mp + idx <- unsafeInt2idx (i + off) + return (Some (TypedIdx ty idx)) + +disjointUnion :: Ord k => VarMap k env -> VarMap k env -> VarMap k env +disjointUnion (VarMap off1 cl1 m1) (VarMap off2 cl2 m2) | off1 == off2 = + VarMap off1 (min cl1 cl2) (Map.unionWith (error "VarMap.disjointUnion: overlapping keys") m1 m2) +disjointUnion vm1 vm2 = disjointUnion (cleanup vm1) (cleanup vm2) + +sink1 :: VarMap k env -> VarMap k (t : env) +sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp + +unsink1 :: VarMap k (t : env) -> VarMap k env +unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp + +subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' +subMap subenv = + let bools = let loop :: Subenv env env' -> [Bool] + loop SETop = [] + loop (SEYesR sub) = True : loop sub + loop (SENo sub) = False : loop sub + in VS.fromList $ loop subenv + newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools + modify off (k, (ty, i)) + | i + off < 0 = Nothing + | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map" + | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off))) + | otherwise = Nothing + in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) + +superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env +superMap subenv = + let loop :: Subenv env env' -> Int -> [Int] + loop SETop _ = [] + loop (SEYesR sub) i = i : loop sub (i+1) + loop (SENo sub) i = loop sub (i+1) + + newIndices = VS.fromList $ loop subenv 0 + modify off (k, (ty, i)) + | i + off < 0 = Nothing + | i + off >= VS.length newIndices = error "VarMap.superMap: found negative indices in map" + | otherwise = let j = newIndices VS.! (i + off) + in if j == -1 then Nothing else Just (k, (ty, j)) + + in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) + +maybeCleanup :: VarMap k env -> VarMap k env +maybeCleanup vm@(VarMap _ interval mp) + | let sz = Map.size mp + , sz > 0, 2 * interval >= 3 * sz + = cleanup vm +maybeCleanup vm = vm + +cleanup :: VarMap k env -> VarMap k env +cleanup (VarMap off _ mp) = VarMap 0 0 (Map.mapMaybe (\(t, i) -> if i + off >= 0 then Just (t, i + off) else Nothing) mp) + +unsafeInt2idx :: Int -> Maybe (Idx env t) +unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n) + where + go :: Int -> Idx env t + go 0 = unsafeCoerce IZ + go n = unsafeCoerce (IS (go (n-1))) diff --git a/src/CHAD/Drev.hs b/src/CHAD/Drev.hs new file mode 100644 index 0000000..595d3c7 --- /dev/null +++ b/src/CHAD/Drev.hs @@ -0,0 +1,1583 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE ImpredicativeTypes #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeData #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.Drev ( + drev, + freezeRet, + CHADConfig(..), + defaultConfig, + Storage(..), + Descr(..), + Select, +) where + +import Data.Functor.Const +import Data.Some +import Data.Type.Equality (type (==), testEquality) + +import CHAD.Analysis.Identity (ValId(..), validSplitEither) +import CHAD.AST +import CHAD.AST.Bindings +import CHAD.AST.Count +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Data.VarMap (VarMap) +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types +import CHAD.Lemmas + + +------------------------------ TAPES AND BINDINGS ------------------------------ + +type family Tape binds where + Tape '[] = TNil + Tape (t : ts) = TPair t (Tape ts) + +tapeTy :: SList STy binds -> STy (Tape binds) +tapeTy SNil = STNil +tapeTy (SCons t ts) = STPair t (tapeTy ts) + +bindingsCollectTape :: SList STy binds -> Subenv binds tapebinds + -> binds :> env2 -> Ex env2 (Tape tapebinds) +bindingsCollectTape SNil SETop _ = ENil ext +bindingsCollectTape (t `SCons` binds) (SEYesR sub) w = + EPair ext (EVar ext t (w @> IZ)) + (bindingsCollectTape binds sub (w .> WSink)) +bindingsCollectTape (_ `SCons` binds) (SENo sub) w = + bindingsCollectTape binds sub (w .> WSink) + +-- bindingsCollectTape' :: forall f env binds tapebinds env2. Bindings f env binds -> Subenv binds tapebinds +-- -> Append binds env :> env2 -> Ex env2 (Tape tapebinds) +-- bindingsCollectTape' binds sub w +-- | Refl <- lemAppendNil @binds +-- = bindingsCollectTape (bindingsBinds binds) sub (w .> wCopies @_ @_ @'[] (bindingsBinds binds) (WClosed @env)) + +-- In order from large to small: i.e. in reverse order from what we want, +-- because in a Bindings, the head of the list is the bottom-most entry. +type family TapeUnfoldings binds where + TapeUnfoldings '[] = '[] + TapeUnfoldings (t : ts) = Tape ts : TapeUnfoldings ts + +type family Reverse l where + Reverse '[] = '[] + Reverse (t : ts) = Append (Reverse ts) '[t] + +-- An expression that is always 'snd' +data UnfExpr env t where + UnfExSnd :: STy s -> STy t -> UnfExpr (TPair s t : env) t + +fromUnfExpr :: UnfExpr env t -> Ex env t +fromUnfExpr (UnfExSnd s t) = ESnd ext (EVar ext (STPair s t) IZ) + +-- - A bunch of 'snd' expressions taking us from knowing that there's a +-- 'Tape ts' in the environment (for simplicity assume it's at IZ, we'll fix +-- this in reconstructBindings), to having 'Reverse (TapeUnfoldings ts)' in +-- the environment. +-- - In the extended environment, another bunch of let bindings (these are +-- 'fst' expressions, but no need to know that statically) that project the +-- fsts out of what we introduced above, one for each type in 'ts'. +data Reconstructor env ts = + Reconstructor + (Bindings UnfExpr (Tape ts : env) (Reverse (TapeUnfoldings ts))) + (Bindings Ex (Append (Reverse (TapeUnfoldings ts)) (Tape ts : env)) ts) + +ssnoc :: SList f ts -> f t -> SList f (Append ts '[t]) +ssnoc SNil a = SCons a SNil +ssnoc (SCons t ts) a = SCons t (ssnoc ts a) + +sreverse :: SList f ts -> SList f (Reverse ts) +sreverse SNil = SNil +sreverse (SCons t ts) = ssnoc (sreverse ts) t + +stapeUnfoldings :: SList STy ts -> SList STy (TapeUnfoldings ts) +stapeUnfoldings SNil = SNil +stapeUnfoldings (SCons _ ts) = SCons (tapeTy ts) (stapeUnfoldings ts) + +-- Puts a 'snd' at the top of an unfolder stack and grows the context variable by one. +shiftUnfolder + :: STy t + -> SList STy ts + -> Bindings UnfExpr (Tape ts : env) list + -> Bindings UnfExpr (Tape (t : ts) : env) (Append list '[Tape ts]) +shiftUnfolder newTy ts BTop = BPush BTop (tapeTy ts, UnfExSnd newTy (tapeTy ts)) +shiftUnfolder newTy ts (BPush b (t, UnfExSnd itemTy _)) = + -- Recurse on 'b', and retype the 'snd'. We need to unfold 'b' once in order + -- to expand an 'Append' in the types so that things simplify just enough. + -- We have an equality 'Append binds x1 ~ a : x2', where 'binds' is the list + -- of bindings produced by 'b'. We want to conclude from this that + -- 'binds ~ a : x3' for some 'x3', but GHC will only do that once we know + -- that 'binds ~ y : ys' so that the 'Append' can expand one step, after + -- which 'y ~ a' as desired. The 'case' unfolds 'b' one step. + BPush (shiftUnfolder newTy ts b) (t, case b of BTop -> UnfExSnd itemTy t + BPush{} -> UnfExSnd itemTy t) + +growRecon :: forall env t ts. STy t -> SList STy ts -> Reconstructor env ts -> Reconstructor env (t : ts) +growRecon t ts (Reconstructor unfbs bs) + | Refl <- lemAppendNil @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) + , Refl <- lemAppendAssoc @ts @(Append (Reverse (TapeUnfoldings ts)) '[Tape ts]) @(Tape (t : ts) : env) + , Refl <- lemAppendAssoc @(Reverse (TapeUnfoldings ts)) @'[Tape ts] @env + = Reconstructor + (shiftUnfolder t ts unfbs) + -- Add a 'fst' at the bottom of the builder stack. + -- First we have to weaken most of 'bs' to skip one more binding in the + -- unfolder stack above it. + (BPush (fst (weakenBindingsE + (wCopies (sappend (sreverse (stapeUnfoldings ts)) (SCons (tapeTy ts) SNil)) + (WSink :: env :> (Tape (t : ts) : env))) bs)) + (t + ,EFst ext $ EVar ext (tapeTy (SCons t ts)) $ + wSinks @(Tape (t : ts) : env) + (sappend ts + (sappend (sappend (sreverse (stapeUnfoldings ts)) + (SCons (tapeTy ts) SNil)) + SNil)) + @> IZ)) + +buildReconstructor :: SList STy ts -> Reconstructor env ts +buildReconstructor SNil = Reconstructor BTop BTop +buildReconstructor (SCons t ts) = growRecon t ts (buildReconstructor ts) + +-- STRATEGY FOR reconstructBindings +-- +-- binds = [] +-- e : () +-- +-- binds = [c] +-- e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst e : c +-- +-- binds = [b, c] +-- e : (b, (c, ())) +-- x1 = snd e : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- +-- binds = [a, b, c] +-- e : (a, (b, (c, ()))) +-- x2 = snd e : (b, (c, ())) +-- x1 = snd x2 : (c, ()) +-- x0 = snd x1 : () +-- y1 = fst x1 : c +-- y2 = fst x2 : b +-- y3 = fst x3 : a + +-- Given that in 'env' we can find a 'Tape binds', i.e. a tuple containing all +-- the things in the list 'binds', we want to create a let stack that extracts +-- all values from that tuple and in effect "restores" the environment +-- described by 'binds'. The idea is that elsewhere, we took a slice of the +-- environment and saved it all in a tuple to be restored later. We +-- incidentally also add a bunch of additional bindings, namely 'Reverse +-- (TapeUnfoldings binds)', so the calling code just has to skip those in +-- whatever it wants to do. +reconstructBindings :: SList STy binds + -> (forall env. Idx env (Tape binds) -> Bindings Ex env (Append binds (Reverse (TapeUnfoldings binds))) + ,SList STy (Reverse (TapeUnfoldings binds))) +reconstructBindings binds = + (\tape -> let Reconstructor unf build = buildReconstructor binds + in fst $ weakenBindingsE (WIdx tape) + (bconcat (mapBindings fromUnfExpr unf) build) + ,sreverse (stapeUnfoldings binds)) + + +---------------------------------- DERIVATIVES --------------------------------- + +d1op :: SOp a t -> Ex env (D1 a) -> Ex env (D1 t) +d1op (OAdd t) e = EOp ext (OAdd t) e +d1op (OMul t) e = EOp ext (OMul t) e +d1op (ONeg t) e = EOp ext (ONeg t) e +d1op (OLt t) e = EOp ext (OLt t) e +d1op (OLe t) e = EOp ext (OLe t) e +d1op (OEq t) e = EOp ext (OEq t) e +d1op ONot e = EOp ext ONot e +d1op OAnd e = EOp ext OAnd e +d1op OOr e = EOp ext OOr e +d1op OIf e = EOp ext OIf e +d1op ORound64 e = EOp ext ORound64 e +d1op OToFl64 e = EOp ext OToFl64 e +d1op (ORecip t) e = EOp ext (ORecip t) e +d1op (OExp t) e = EOp ext (OExp t) e +d1op (OLog t) e = EOp ext (OLog t) e +d1op (OIDiv t) e = EOp ext (OIDiv t) e +d1op (OMod t) e = EOp ext (OMod t) e + +-- | Both primal and dual must be duplicable expressions +data D2Op a t = Linear (forall env. Ex env (D2 t) -> Ex env (D2 a)) + | Nonlinear (forall env. Ex env (D1 a) -> Ex env (D2 t) -> Ex env (D2 a)) + +d2op :: SOp a t -> D2Op a t +d2op op = case op of + OAdd t -> d2opBinArrangeInt t $ Linear $ \d -> EPair ext d d + OMul t -> d2opBinArrangeInt t $ Nonlinear $ \e d -> + EPair ext (EOp ext (OMul t) (EPair ext (ESnd ext e) d)) + (EOp ext (OMul t) (EPair ext (EFst ext e) d)) + ONeg t -> d2opUnArrangeInt t $ Linear $ \d -> EOp ext (ONeg t) d + OLt t -> Linear $ \_ -> pairZero t + OLe t -> Linear $ \_ -> pairZero t + OEq t -> Linear $ \_ -> pairZero t + ONot -> Linear $ \_ -> ENil ext + OAnd -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OOr -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OIf -> Linear $ \_ -> ENil ext + ORound64 -> Linear $ \_ -> EZero ext (SMTScal STF64) (ENil ext) + OToFl64 -> Linear $ \_ -> ENil ext + ORecip t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ONeg t) (EOp ext (ORecip t) (EOp ext (OMul t) (EPair ext e e)))) d) + OExp t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (OExp t) e) d) + OLog t -> floatingD2 t $ Nonlinear $ \e d -> EOp ext (OMul t) (EPair ext (EOp ext (ORecip t) e) d) + OIDiv t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + OMod t -> integralD2 t $ Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + where + pairZero :: SScalTy a -> Ex env (D2 (TPair (TScal a) (TScal a))) + pairZero t = ziNil t $ EPair ext (EZero ext (d2M (STScal t)) (ENil ext)) + (EZero ext (d2M (STScal t)) (ENil ext)) + where + ziNil :: SScalTy a -> (ZeroInfo (D2s a) ~ TNil => r) -> r + ziNil STI32 k = k + ziNil STI64 k = k + ziNil STF32 k = k + ziNil STF64 k = k + ziNil STBool k = k + + d2opUnArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TScal a) t) + -> D2Op (TScal a) t + d2opUnArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> ENil ext + STI64 -> Linear $ \_ -> ENil ext + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> ENil ext + + d2opBinArrangeInt :: SScalTy a + -> (D2s a ~ TScal a => D2Op (TPair (TScal a) (TScal a)) t) + -> D2Op (TPair (TScal a) (TScal a)) t + d2opBinArrangeInt ty float = case ty of + STI32 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STI64 -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + STF32 -> float + STF64 -> float + STBool -> Linear $ \_ -> EPair ext (ENil ext) (ENil ext) + + floatingD2 :: ScalIsFloating a ~ True + => SScalTy a -> ((D2s a ~ TScal a, ScalIsNumeric a ~ True) => r) -> r + floatingD2 STF32 k = k + floatingD2 STF64 k = k + + integralD2 :: ScalIsIntegral a ~ True + => SScalTy a -> ((D2s a ~ TNil, ScalIsNumeric a ~ True) => r) -> r + integralD2 STI32 k = k + integralD2 STI64 k = k + +desD1E :: Descr env sto -> SList STy (D1E env) +desD1E = d1e . descrList + +-- d1W :: env :> env' -> D1E env :> D1E env' +-- d1W WId = WId +-- d1W WSink = WSink +-- d1W (WCopy w) = WCopy (d1W w) +-- d1W (WPop w) = WPop (d1W w) +-- d1W (WThen u w) = WThen (d1W u) (d1W w) + +conv1Idx :: Idx env t -> Idx (D1E env) (D1 t) +conv1Idx IZ = IZ +conv1Idx (IS i) = IS (conv1Idx i) + +data Idx2 env sto t + = Idx2Ac (Idx (D2AcE (Select env sto "accum")) (TAccum (D2 t))) + | Idx2Me (Idx (D2E (Select env sto "merge")) (D2 t)) + | Idx2Di (Idx (Select env sto "discr") t) + +conv2Idx :: Descr env sto -> Idx env t -> Idx2 env sto t +conv2Idx (DPush _ (_, _, SAccum)) IZ = Idx2Ac IZ +conv2Idx (DPush _ (_, _, SMerge)) IZ = Idx2Me IZ +conv2Idx (DPush _ (_, _, SDiscr)) IZ = Idx2Di IZ +conv2Idx (DPush des (_, _, SAccum)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac (IS j) + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SMerge)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me (IS j) + Idx2Di j -> Idx2Di j +conv2Idx (DPush des (_, _, SDiscr)) (IS i) = + case conv2Idx des i of Idx2Ac j -> Idx2Ac j + Idx2Me j -> Idx2Me j + Idx2Di j -> Idx2Di (IS j) +conv2Idx DTop i = case i of {} + +opt2UnSparse :: SOp a b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) +opt2UnSparse = go . opt2 + where + go :: STy b -> Sparse (D2 b) b' -> Ex env b' -> Ex env (D2 b) + go (STScal STI32) SpAbsent = \_ -> ENil ext + go (STScal STI64) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpAbsent = \_ -> EZero ext (SMTScal STF32) (ENil ext) + go (STScal STF64) SpAbsent = \_ -> EZero ext (SMTScal STF64) (ENil ext) + go (STScal STBool) SpAbsent = \_ -> ENil ext + go (STScal STF32) SpScal = id + go (STScal STF64) SpScal = id + go STNil _ = \_ -> ENil ext + go (STPair t1 t2) (SpPair s1 s2) = \e -> eunPair e $ \_ e1 e2 -> EPair ext (go t1 s1 e1) (go t2 s2 e2) + go t _ = error $ "Primitive operations that return " ++ show t ++ " are scary" + + +----------------------------------- SPARSITY ----------------------------------- + +expandSparse :: STy a -> Sparse (D2 a) b -> Ex env (D1 a) -> Ex env b -> Ex env (D2 a) +expandSparse t sp _ e | Just Refl <- isDense (d2M t) sp = e +expandSparse t (SpSparse sp) epr e = + EMaybe ext + (EZero ext (d2M t) (d2zeroInfo t epr)) + (expandSparse t sp (weakenExpr WSink epr) (EVar ext (applySparse sp (d2 t)) IZ)) + e +expandSparse t SpAbsent epr _ = EZero ext (d2M t) (d2zeroInfo t epr) +expandSparse (STPair t1 t2) (SpPair s1 s2) epr e = + eunPair epr $ \w1 epr1 epr2 -> + eunPair (weakenExpr w1 e) $ \w2 e1 e2 -> + EPair ext (expandSparse t1 s1 (weakenExpr w2 epr1) e1) + (expandSparse t2 s2 (weakenExpr w2 epr2) e2) +expandSparse (STEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ECase ext (weakenExpr WSink epr) + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa r<-dl")) + (ECase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa l<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STLEither t1 t2) (SpLEither s1 s2) epr e = + ELCase ext e + (EZero ext (d2M (STEither t1 t2)) (ENil ext)) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dl") + (ELInl ext (d2 t2) (expandSparse t1 s1 (EVar ext (d1 t1) IZ) (EVar ext (applySparse s1 (d2 t1)) (IS IZ)))) + (EError ext (d2 (STEither t1 t2)) "expspa lr<-dl")) + (ELCase ext (weakenExpr WSink epr) + (EError ext (d2 (STEither t1 t2)) "expspa ln<-dr") + (EError ext (d2 (STEither t1 t2)) "expspa ll<-dr") + (ELInr ext (d2 t1) (expandSparse t2 s2 (EVar ext (d1 t2) IZ) (EVar ext (applySparse s2 (d2 t2)) (IS IZ))))) +expandSparse (STMaybe t) (SpMaybe s) epr e = + EMaybe ext + (ENothing ext (d2 t)) + (let epr' = EMaybe ext (EError ext (d1 t) "expspa n<-dj") (EVar ext (d1 t) IZ) epr + in EJust ext (expandSparse t s (weakenExpr WSink epr') (EVar ext (applySparse s (d2 t)) IZ))) + e +expandSparse (STArr _ t) (SpArr s) epr e = + ezipWith (expandSparse t s (EVar ext (d1 t) (IS IZ)) (EVar ext (applySparse s (d2 t)) IZ)) epr e +expandSparse (STScal STF32) SpScal _ e = e +expandSparse (STScal STF64) SpScal _ e = e +expandSparse (STAccum{}) _ _ _ = error "accumulators not allowed in source program" + +subenvPlus :: SBool req1 -> SBool req2 + -> SList SMTy env + -> SubenvS env env1 -> SubenvS env env2 + -> (forall env3. SubenvS env env3 + -> Injection req1 (Tup env1) (Tup env3) + -> Injection req2 (Tup env2) (Tup env3) + -> (forall e. Ex e (Tup env1) -> Ex e (Tup env2) -> Ex e (Tup env3)) + -> r) + -> r +-- don't destroy effects! +subenvPlus _ _ SNil SETop SETop k = k SETop (Inj id) (Inj id) (\a b -> use a $ use b $ ENil ext) + +subenvPlus req1 req2 (SCons _ env) (SENo sub1) (SENo sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 s31 s32 pl -> + k (SENo sub3) s31 s32 pl + +subenvPlus req1 SF (SCons _ env) (SEYes sp1 sub1) (SENo sub2) k = + subenvPlus req1 SF env sub1 sub2 $ \sub3 minj13 _ pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + Noinj + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) +subenvPlus req1 ST (SCons t env) (SEYes sp1 sub1) (SENo sub2) k + | Just zero1 <- cheapZero (applySparse sp1 t) = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes sp1 sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) e1b) + (Inj $ \e2 -> EPair ext (inj23 e2) zero1) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (ESnd ext (EVar ext (typeOf e1) IZ))) + | otherwise = + subenvPlus req1 ST env sub1 sub2 $ \sub3 minj13 (Inj inj23) pl -> + k (SEYes (SpSparse sp1) sub3) + (withInj minj13 $ \inj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (EJust ext e1b)) + (Inj $ \e2 -> EPair ext (inj23 e2) (ENothing ext (applySparse sp1 (fromSMTy t)))) + (\e1 e2 -> + ELet ext e1 $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) IZ)) + (weakenExpr WSink e2)) + (EJust ext (ESnd ext (EVar ext (typeOf e1) IZ)))) + +subenvPlus req1 req2 (SCons t env) sub1@SENo{} sub2@SEYes{} k = + subenvPlus req2 req1 (SCons t env) sub2 sub1 $ \sub3 minj23 minj13 pl -> + k sub3 minj13 minj23 (flip pl) + +subenvPlus req1 req2 (SCons t env) (SEYes sp1 sub1) (SEYes sp2 sub2) k = + subenvPlus req1 req2 env sub1 sub2 $ \sub3 minj13 minj23 pl -> + sparsePlusS req1 req2 t sp1 sp2 $ \sp3 mTinj13 mTinj23 plus -> + k (SEYes sp3 sub3) + (withInj2 minj13 mTinj13 $ \inj13 tinj13 -> + \e1 -> eunPair e1 $ \_ e1a e1b -> + EPair ext (inj13 e1a) (tinj13 e1b)) + (withInj2 minj23 mTinj23 $ \inj23 tinj23 -> + \e2 -> eunPair e2 $ \_ e2a e2b -> + EPair ext (inj23 e2a) (tinj23 e2b)) + (\e1 e2 -> + ELet ext e1 $ + ELet ext (weakenExpr WSink e2) $ + EPair ext (pl (EFst ext (EVar ext (typeOf e1) (IS IZ))) + (EFst ext (EVar ext (typeOf e2) IZ))) + (plus + (ESnd ext (EVar ext (typeOf e1) (IS IZ))) + (ESnd ext (EVar ext (typeOf e2) IZ)))) + +expandSubenvZeros :: D1E env0 :> env -> SList STy env0 -> SubenvS (D2E env0) contribs + -> Ex env (Tup contribs) -> Ex env (Tup (D2E env0)) +expandSubenvZeros _ SNil SETop _ = ENil ext +expandSubenvZeros w (SCons t ts) (SEYes sp sub) e = + eunPair e $ \w1 e1 e2 -> + EPair ext + (expandSubenvZeros (w1 .> WPop w) ts sub e1) + (expandSparse t sp (EVar ext (d1 t) (w1 .> w @> IZ)) e2) +expandSubenvZeros w (SCons t ts) (SENo sub) e = + EPair ext + (expandSubenvZeros (WPop w) ts sub e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + + +--------------------------------- ACCUMULATORS --------------------------------- + +fromArrayValId :: Maybe (ValId t) -> Maybe Int +fromArrayValId (Just (VIArr i _)) = Just i +fromArrayValId _ = Nothing + +accumPromote :: forall dt env sto proxy r. + proxy dt + -> Descr env sto + -> (forall stoRepl envPro. + (Select env stoRepl "merge" ~ '[]) + => Descr env stoRepl + -- ^ A revised environment description that switches + -- arrays (used in the OccEnv) that are currently on + -- "merge" storage, to "accum" storage. + -> SList STy envPro + -- ^ New entries on top of the original dual environment, + -- that house the accumulators for the promoted arrays in + -- the original environment. + -> Subenv (Select env sto "merge") envPro + -- ^ The promoted entries were merge entries in the + -- original environment. + -> Subenv (D2AcE (Select env stoRepl "accum")) (D2AcE (Select env sto "accum")) + -- ^ All entries that were accumulators are still + -- accumulators. + -> VarMap Int (D2AcE (Select env stoRepl "accum")) + -- ^ Accumulator map for _only_ the the newly allocated + -- accumulators. + -> (forall shbinds. + SList STy shbinds + -> (dt : Append shbinds (D2AcE (Select env stoRepl "accum"))) + :> Append (D2AcE envPro) (dt : Append shbinds (D2AcE (Select env sto "accum")))) + -- ^ A weakening that converts a computation in the + -- revised environment to one in the original environment + -- extended with some accumulators. + -> r) + -> r +accumPromote _ DTop k = k DTop SNil SETop SETop VarMap.empty (\_ -> WId) +accumPromote pdty (descr `DPush` (t :: STy t, vid, sto)) k = case sto of + -- Accumulators are left as-is + SAccum -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + envpro + prosub + (SEYesR accrevsub) + (VarMap.sink1 accumMap) + (\shbinds -> + autoWeak (#pro (d2ace envpro) &. #d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum descr))) + (#acc :++: (#pro :++: #d :++: #shb :++: #tl)) + (#pro :++: #d :++: #shb :++: #acc :++: #tl) + .> WCopy (wf shbinds) + .> autoWeak (#d (auto1 @dt) &. #shb shbinds &. #acc (auto1 @(TAccum (D2 t))) &. #tl (d2ace (select SAccum storepl))) + (#d :++: #shb :++: #acc :++: #tl) + (#acc :++: (#d :++: #shb :++: #tl))) + + SMerge -> case t of + -- Discrete values are left as-is + _ | isDiscrete t -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap' wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + (SENo prosub) + accrevsub + accumMap' + wf + + -- Values with "merge" storage are promoted to an accumulator in envPro + _ -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SAccum)) + (t `SCons` envpro) + (SEYesR prosub) + (SENo accrevsub) + (let accumMap' = VarMap.sink1 accumMap + in case fromArrayValId vid of + Just i -> VarMap.insert i (STAccum (d2M t)) IZ accumMap' + Nothing -> accumMap') + (\(shbinds :: SList _ shbinds) -> + let shbindsC = slistMap (\_ -> Const ()) shbinds + in + -- wf: + -- D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WCopy wf: + -- TAccum n t3 : D2 t : Append shbinds (D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + -- WPICK: ^ THESE TWO || + -- goal: | ARE EQUAL || + -- D2 t : Append shbinds (TAccum n t3 : D2AcE (Select envPro stoRepl "accum")) :> TAccum n t3 : Append envPro (D2 t : Append shbinds (D2AcE (Select envPro sto1 "accum"))) + WCopy (wf shbinds) + .> WPick @(TAccum (D2 t)) @(dt : shbinds) (Const () `SCons` shbindsC) + (WId @(D2AcE (Select env1 stoRepl "accum")))) + + -- Discrete values are left as-is, nothing to do + SDiscr -> + accumPromote pdty descr $ \(storepl :: Descr env1 stoRepl) (envpro :: SList _ envPro) prosub accrevsub accumMap wf -> + k (storepl `DPush` (t, vid, SDiscr)) + envpro + prosub + accrevsub + accumMap + wf + where + isDiscrete :: STy t' -> Bool + isDiscrete = \case + STNil -> True + STPair a b -> isDiscrete a && isDiscrete b + STEither a b -> isDiscrete a && isDiscrete b + STLEither a b -> isDiscrete a && isDiscrete b + STMaybe a -> isDiscrete a + STArr _ a -> isDiscrete a + STScal st -> case st of + STI32 -> True + STI64 -> True + STF32 -> False + STF64 -> False + STBool -> True + STAccum{} -> False + + +---------------------------- RETURN TRIPLE FROM CHAD --------------------------- + +data Ret env0 sto sd t = + forall shbinds tapebinds contribs. + Ret (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (Ex (Append shbinds (D1E env0)) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs)) +deriving instance Show (Ret env0 sto sd t) + +type data TyTyPair = MkTyTyPair Ty Ty + +data SingleRet env0 sto (pair :: TyTyPair) = + forall shbinds tapebinds. + SingleRet + (Bindings Ex (D1E env0) shbinds) -- shared binds + (Subenv shbinds tapebinds) + (RetPair env0 sto (D1E env0) shbinds tapebinds pair) + +-- pattern Ret1 :: forall env0 sto Bindings Ex (D1E env0) shbinds +-- -> Subenv shbinds tapebinds +-- -> Ex (Append shbinds (D1E env0)) (D1 t) +-- -> SubenvS (D2E (Select env0 sto "merge")) contribs +-- -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) +-- -> SingleRet env0 sto (MkTyTyPair sd t) +-- pattern Ret1 e0 subtape e1 sub e2 = SingleRet e0 subtape (RetPair e1 sub e2) +-- {-# COMPLETE Ret1 #-} + +data RetPair env0 sto env shbinds tapebinds (pair :: TyTyPair) where + RetPair :: forall sd t contribs -- existentials + env0 sto env shbinds tapebinds. -- universals + Ex (Append shbinds env) (D1 t) + -> SubenvS (D2E (Select env0 sto "merge")) contribs + -> Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) (Tup contribs) + -> RetPair env0 sto env shbinds tapebinds (MkTyTyPair sd t) +deriving instance Show (RetPair env0 sto env shbinds tapebinds pair) + +data Rets env0 sto env list = + forall shbinds tapebinds. + Rets (Bindings Ex env shbinds) + (Subenv shbinds tapebinds) + (SList (RetPair env0 sto env shbinds tapebinds) list) +deriving instance Show (Rets env0 sto env list) + +toSingleRet :: Ret env0 sto sd t -> SingleRet env0 sto (MkTyTyPair sd t) +toSingleRet (Ret e0 subtape e1 sub e2) = SingleRet e0 subtape (RetPair e1 sub e2) + +weakenRetPair :: SList STy shbinds -> env :> env' + -> RetPair env0 sto env shbinds tapebinds pair -> RetPair env0 sto env' shbinds tapebinds pair +weakenRetPair bindslist w (RetPair e1 sub e2) = RetPair (weakenExpr (weakenOver bindslist w) e1) sub e2 + +weakenRets :: env :> env' -> Rets env0 sto env list -> Rets env0 sto env' list +weakenRets w (Rets binds tapesub list) = + let (binds', _) = weakenBindingsE w binds + in Rets binds' tapesub (slistMap (weakenRetPair (bindingsBinds binds) w) list) + +rebaseRetPair :: forall env b1 b2 tapebinds1 tapebinds2 env0 sto pair f. + Descr env0 sto + -> SList f b1 -> SList f b2 + -> Subenv b1 tapebinds1 -> Subenv b2 tapebinds2 + -> RetPair env0 sto (Append b1 env) b2 tapebinds2 pair + -> RetPair env0 sto env (Append b2 b1) (Append tapebinds2 tapebinds1) pair +rebaseRetPair descr b1 b2 subtape1 subtape2 (RetPair @sd e1 sub e2) + | Refl <- lemAppendAssoc @b2 @b1 @env = + RetPair e1 sub + (weakenExpr (autoWeak + (#d (auto1 @sd) + &. #t2 (subList b2 subtape2) + &. #t1 (subList b1 subtape1) + &. #tl (d2ace (select SAccum descr))) + (#d :++: (#t2 :++: #tl)) + (#d :++: ((#t2 :++: #t1) :++: #tl))) + e2) + +retConcat :: forall env0 sto list. Descr env0 sto -> SList (SingleRet env0 sto) list -> Rets env0 sto (D1E env0) list +retConcat _ SNil = Rets BTop SETop SNil +retConcat descr (SCons (SingleRet (e0 :: Bindings _ _ shbinds1) (subtape :: Subenv _ tapebinds1) (RetPair e1 sub e2)) list) + | Rets (binds :: Bindings _ _ shbinds2) (subtape2 :: Subenv _ tapebinds2) pairs + <- weakenRets (sinkWithBindings e0) (retConcat descr list) + , Refl <- lemAppendAssoc @shbinds2 @shbinds1 @(D1E env0) + , Refl <- lemAppendAssoc @tapebinds2 @tapebinds1 @(D2AcE (Select env0 sto "accum")) + = Rets (bconcat e0 binds) + (subenvConcat subtape subtape2) + (SCons (RetPair (weakenExpr (sinkWithBindings binds) e1) + sub + (weakenExpr (WCopy (sinkWithSubenv subtape2)) e2)) + (slistMap (rebaseRetPair descr (bindingsBinds e0) (bindingsBinds binds) + subtape subtape2) + pairs)) + +freezeRet :: Descr env sto + -> Ret env sto (D2 t) t + -> Ex (D2 t : Append (D2AcE (Select env sto "accum")) (D1E env)) (TPair (D1 t) (Tup (D2E (Select env sto "merge")))) +freezeRet descr (Ret e0 subtape e1 sub e2 :: Ret _ _ _ t) = + let (e0', wInsertD2Ac) = weakenBindingsE (WSink .> wSinks (d2ace (select SAccum descr))) e0 + e2' = weakenExpr (WCopy (wCopies (subList (bindingsBinds e0) subtape) (wRaiseAbove (d2ace (select SAccum descr)) (desD1E descr)))) e2 + tContribs = tTup (slistMap fromSMTy (subList (d2eM (select SMerge descr)) sub)) + library = #d (auto1 @(D2 t)) + &. #tape (subList (bindingsBinds e0) subtape) + &. #shbinds (bindingsBinds e0) + &. #d2ace (d2ace (select SAccum descr)) + &. #tl (desD1E descr) + &. #contribs (SCons tContribs SNil) + in letBinds e0' $ + EPair ext + (weakenExpr wInsertD2Ac e1) + (ELet ext (weakenExpr (autoWeak library + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: #d2ace :++: #tl) + (#shbinds :++: #d :++: #d2ace :++: #tl)) + e2') $ + expandSubenvZeros + (autoWeak library #tl (#contribs :++: #shbinds :++: #d :++: #d2ace :++: #tl) + .> wUndoSubenv (subenvD1E (selectSub SMerge descr))) + (select SMerge descr) sub (EVar ext tContribs IZ)) + + +---------------------------- THE CHAD TRANSFORMATION --------------------------- + +drev :: forall env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2 t) sd + -> Expr ValId env t -> Ret env sto sd t +drev des _ sd | isAbsent sd = + \e -> + Ret BTop + SETop + (drevPrimal des e) + (subenvNone (d2e (select SMerge des))) + (ENil ext) +drev _ _ SpAbsent = error "Absent should be isAbsent" + +drev des accumMap (SpSparse sd) = + \e -> + case drev des accumMap sd e of { Ret e0 subtape e1 sub e2 -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + e1 + sub' + (emaybe (EVar ext (STMaybe (applySparse sd (d2 (typeOf e)))) IZ) + (inj2 (ENil ext)) + (inj1 (weakenExpr (WCopy WSink) e2))) + } + +drev des accumMap sd = \case + EVar _ t i -> + case conv2Idx des i of + Idx2Ac accI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (let ty = applySparse sd (d2M t) + in EAccum ext (d2M t) SAPHere (ENil ext) sd (EVar ext (fromSMTy ty) IZ) (EVar ext (STAccum (d2M t)) (IS accI))) + + Idx2Me tupI -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvOnehot (d2e (select SMerge des)) tupI sd) + (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t)) IZ)) + + Idx2Di _ -> + Ret BTop + SETop + (EVar ext (d1 t) (conv1Idx i)) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + ELet _ (rhs :: Expr _ _ a) body + | ChosenStorage (storage :: Storage s) <- if chcLetArrayAccum ?config && typeHasArrays (typeOf rhs) then ChosenStorage SAccum else ChosenStorage SMerge + , RetScoped (body0 :: Bindings _ _ body_shbinds) (subtapeBody :: Subenv _ body_tapebinds) body1 subBody sdBody body2 <- drevScoped des accumMap (typeOf rhs) storage (Just (extOf rhs)) sd body + , Ret (rhs0 :: Bindings _ _ rhs_shbinds) (subtapeRHS :: Subenv _ rhs_tapebinds) rhs1 subRHS rhs2 <- drev des accumMap sdBody rhs + , let (body0', wbody0') = weakenBindingsE (WCopy (sinkWithBindings rhs0)) body0 + , Refl <- lemAppendAssoc @body_shbinds @'[D1 a] @rhs_shbinds + , Refl <- lemAppendAssoc @body_shbinds @(D1 a : rhs_shbinds) @(D1E env) + , Refl <- lemAppendAssoc @body_tapebinds @rhs_tapebinds @(D2AcE (Select env sto "accum")) + -> + subenvPlus SF SF (d2eM (select SMerge des)) subRHS subBody $ \subBoth _ _ plus_RHS_Body -> + let bodyResType = STPair (contribTupTy des subBody) (applySparse sdBody (d2 (typeOf rhs))) in + Ret (bconcat (rhs0 `bpush` rhs1) body0') + (subenvConcat subtapeRHS subtapeBody) + (weakenExpr wbody0' body1) + subBoth + (ELet ext (weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds body0 `sappend` SCons (d1 (typeOf rhs)) SNil) subtapeBody) + &. #rhs (subList (bindingsBinds rhs0) subtapeRHS) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #tl) + (#d :++: (#body :++: #rhs) :++: #tl)) + body2) $ + ELet ext + (ELet ext (ESnd ext (EVar ext bodyResType IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_] .> sinkWithSubenv subtapeBody)) rhs2) $ + plus_RHS_Body + (EVar ext (contribTupTy des subRHS) IZ) + (EFst ext (EVar ext bodyResType (IS IZ)))) + + EPair _ a b + | SpPair sd1 sd2 <- sd + , Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap sd1 a) `SCons` toSingleRet (drev des accumMap sd2 b) `SCons` SNil + , let dt = STPair (applySparse sd1 (d2 (typeOf a))) (applySparse sd2 (d2 (typeOf b))) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EPair ext a1 b1) + subBoth + (ELet ext (ELet ext (EFst ext (EVar ext dt IZ)) + (weakenExpr (WCopy WSink) a2)) $ + ELet ext (ELet ext (ESnd ext (EVar ext dt (IS IZ))) + (weakenExpr (WCopy (WSink .> WSink)) b2)) $ + plus_A_B + (EVar ext (contribTupTy des subA) (IS IZ)) + (EVar ext (contribTupTy des subB) IZ)) + + EFst _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair sd SpAbsent) e + , STPair t1 _ <- typeOf e -> + Ret e0 + subtape + (EFst ext e1) + sub + (ELet ext (EPair ext (EVar ext (applySparse sd (d2 t1)) IZ) (ENil ext)) $ + weakenExpr (WCopy WSink) e2) + + ESnd _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpPair SpAbsent sd) e + , STPair _ t2 <- typeOf e -> + Ret e0 + subtape + (ESnd ext e1) + sub + (ELet ext (EPair ext (ENil ext) (EVar ext (applySparse sd (d2 t2)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + -- Don't need to handle ENil, because its cotangent is always absent! + -- ENil _ -> Ret BTop SETop (ENil ext) (subenvNone (d2e (select SMerge des))) (ENil ext) + + EInl _ t2 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd1 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInl ext (d1 t2) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 (typeOf e))) (applySparse sd2 (d2 t2))) IZ) + (inj2 $ ENil ext) + (inj1 $ weakenExpr (WCopy WSink) e2) + (EError ext (contribTupTy des sub') "inl<-dinr")) + + EInr _ t1 e + | SpLEither sd1 sd2 <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sd2 e -> + subenvPlus ST ST (d2eM (select SMerge des)) sub (subenvNone (d2e (select SMerge des))) $ \sub' (Inj inj1) (Inj inj2) _ -> + Ret e0 + subtape + (EInr ext (d1 t1) e1) + sub' + (ELCase ext + (EVar ext (STLEither (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 (typeOf e)))) IZ) + (inj2 $ ENil ext) + (EError ext (contribTupTy des sub') "inr<-dinl") + (inj1 $ weakenExpr (WCopy WSink) e2)) + + ECase _ e (a :: Expr _ _ t) b + | STEither (t1 :: STy a) (t2 :: STy b) <- typeOf e + , ChosenStorage storage1 <- if chcCaseArrayAccum ?config && typeHasArrays t1 then ChosenStorage SAccum else ChosenStorage SMerge + , ChosenStorage storage2 <- if chcCaseArrayAccum ?config && typeHasArrays t2 then ChosenStorage SAccum else ChosenStorage SMerge + , let (bindids1, bindids2) = validSplitEither (extOf e) + , RetScoped (a0 :: Bindings _ _ rhs_a_binds) (subtapeA :: Subenv _ rhs_a_tape) a1 subA sd1 a2 + <- drevScoped des accumMap t1 storage1 bindids1 sd a + , RetScoped (b0 :: Bindings _ _ rhs_b_binds) (subtapeB :: Subenv _ rhs_b_tape) b1 subB sd2 b2 + <- drevScoped des accumMap t2 storage2 bindids2 sd b + , Ret (e0 :: Bindings _ _ e_binds) (subtapeE :: Subenv _ e_tape) e1 subE e2 <- drev des accumMap (SpLEither sd1 sd2) e + , Refl <- lemAppendAssoc @(Append rhs_a_binds (Reverse (TapeUnfoldings rhs_a_binds))) @(Tape rhs_a_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , Refl <- lemAppendAssoc @(Append rhs_b_binds (Reverse (TapeUnfoldings rhs_b_binds))) @(Tape rhs_b_binds : D2 t : TPair (D1 t) (TEither (Tape rhs_a_binds) (Tape rhs_b_binds)) : e_binds) @(D2AcE (Select env sto "accum")) + , let subtapeListA = subList (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let subtapeListB = subList (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , let tapeA = tapeTy subtapeListA + , let tapeB = tapeTy subtapeListB + , let collectA = bindingsCollectTape @_ @_ @(Append rhs_a_binds (D1 a : Append e_binds (D1E env))) + (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) subtapeA + , let collectB = bindingsCollectTape @_ @_ @(Append rhs_b_binds (D1 b : Append e_binds (D1E env))) + (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) subtapeB + , (tPrimal :: STy t_primal_ty) <- STPair (d1 (typeOf a)) (STEither tapeA tapeB) + , let (a0', wa0') = weakenBindingsE (WCopy (sinkWithBindings e0)) a0 + , let (b0', wb0') = weakenBindingsE (WCopy (sinkWithBindings e0)) b0 + , Refl <- lemAppendNil @(Append rhs_a_binds '[D1 a]) + , Refl <- lemAppendNil @(Append rhs_b_binds '[D1 b]) + , Refl <- lemAppendAssoc @rhs_a_binds @'[D1 a] @(D1E env) + , Refl <- lemAppendAssoc @rhs_b_binds @'[D1 b] @(D1E env) + , let wa0'' = wa0' .> wCopies (sappend (bindingsBinds a0) (d1 t1 `SCons` SNil)) (WClosed @(D1E env)) + , let wb0'' = wb0' .> wCopies (sappend (bindingsBinds b0) (d1 t2 `SCons` SNil)) (WClosed @(D1E env)) + -> + subenvPlus ST ST (d2eM (select SMerge des)) subA subB $ \subAB (Inj sAB_A) (Inj sAB_B) _ -> + subenvPlus SF SF (d2eM (select SMerge des)) subAB subE $ \subOut _ _ plus_AB_E -> + Ret (e0 `bpush` ECase ext e1 + (letBinds a0' (EPair ext (weakenExpr wa0' a1) (EInl ext tapeB (collectA wa0'')))) + (letBinds b0' (EPair ext (weakenExpr wb0' b1) (EInr ext tapeA (collectB wb0''))))) + (SEYesR subtapeE) + (EFst ext (EVar ext tPrimal IZ)) + subOut + (elet + (ECase ext (ESnd ext (EVar ext tPrimal (IS IZ))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListA + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_a_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListA prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #ta0 subtapeListA + &. #prea0 prerebinds + &. #recon (tapeA `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #ta0 :++: #tl) + (#d :++: (#ta0 :++: #prea0) :++: #recon :++: #binds :++: #tl)) + a2) $ + EPair ext (sAB_A $ EFst ext (evar IZ)) + (ELInl ext (applySparse sd2 (d2 t2)) (ESnd ext (evar IZ)))) + (let (rebinds, prerebinds) = reconstructBindings subtapeListB + in letBinds (rebinds IZ) $ + ELet ext + (EVar ext (applySparse sd (d2 (typeOf a))) (wSinks @(Tape rhs_b_tape : sd : t_primal_ty : Append e_tape (D2AcE (Select env sto "accum"))) (sappend subtapeListB prerebinds) @> IS IZ)) $ + elet + (weakenExpr (autoWeak (#d (auto1 @sd) + &. #tb0 subtapeListB + &. #preb0 prerebinds + &. #recon (tapeB `SCons` applySparse sd (d2 (typeOf a)) `SCons` SNil) + &. #binds (tPrimal `SCons` subList (bindingsBinds e0) subtapeE) + &. #tl (d2ace (select SAccum des))) + (#d :++: #tb0 :++: #tl) + (#d :++: (#tb0 :++: #preb0) :++: #recon :++: #binds :++: #tl)) + b2) $ + EPair ext (sAB_B $ EFst ext (evar IZ)) + (ELInr ext (applySparse sd1 (d2 t1)) (ESnd ext (evar IZ))))) $ + plus_AB_E + (EFst ext (evar IZ)) + (ELet ext (ESnd ext (evar IZ)) $ + weakenExpr (WCopy (wSinks' @[_,_,_])) e2)) + + EConst _ t val -> + Ret BTop + SETop + (EConst ext t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EOp _ op e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (spDense (d2M (opt1 op))) e -> + case d2op op of + Linear d2opfun -> + Ret e0 + subtape + (d1op op e1) + sub + (ELet ext (d2opfun (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy WSink) e2)) + Nonlinear d2opfun -> + Ret (e0 `bpush` e1) + (SEYesR subtape) + (d1op op $ EVar ext (d1 (typeOf e)) IZ) + sub + (ELet ext (d2opfun (EVar ext (d1 (typeOf e)) (IS IZ)) + (opt2UnSparse op sd (EVar ext (applySparse sd (d2 (opt2 op))) IZ))) + (weakenExpr (WCopy (wSinks' @[_,_])) e2)) + + ECustom _ _ tb _ srce pr du a b + -- allowed to ignore a2 because 'a' is the part of the input that is inactive + | Ret b0 bsubtape b1 bsub b2 <- drev des accumMap (spDense (d2M tb)) b -> + case isDense (d2M (typeOf srce)) sd of + Just Refl -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr) + `bpush` ESnd ext (EVar ext (typeOf pr) IZ)) + (SEYesR (SENo (SENo (SENo bsubtape)))) + (EFst ext (EVar ext (typeOf pr) (IS IZ))) + bsub + (ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink)) b2) + + Nothing -> + Ret (b0 `bpush` weakenExpr (sinkWithBindings b0) (drevPrimal des a) + `bpush` weakenExpr WSink b1 + `bpush` weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) pr)) + (SEYesR (SENo (SENo bsubtape))) + (EFst ext (EVar ext (typeOf pr) IZ)) + bsub + (ELet ext (ESnd ext (EVar ext (typeOf pr) (IS IZ))) $ -- tape + ELet ext (expandSparse (typeOf srce) sd -- expanded incoming cotangent + (EFst ext (EVar ext (typeOf pr) (IS (IS IZ)))) + (EVar ext (applySparse sd (d2 (typeOf srce))) (IS IZ))) $ + ELet ext (weakenExpr (WCopy (WCopy WClosed)) (mapExt (const ext) du)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink .> WSink)) b2) + + ERecompute _ e -> + deleteUnused (descrList des) (occCountAll e) $ \usedSub -> + let smallE = unsafeWeakenWithSubenv usedSub e in + subDescr des usedSub $ \usedDes subMergeUsed subAccumUsed subD1eUsed -> + case drev usedDes (VarMap.subMap subAccumUsed accumMap) sd smallE of { Ret e0 subtape _ sub e2 -> + let subMergeUsed' = subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) (subenvD2E subMergeUsed) in + Ret (collectBindings (desD1E des) subD1eUsed) + (subenvAll (desD1E usedDes)) + (weakenExpr (wSinks (desD1E usedDes)) $ drevPrimal des e) + (subenvCompose subMergeUsed' sub) + (letBinds (fst (weakenBindingsE (WSink .> wRaiseAbove (desD1E usedDes) (d2ace (select SAccum des))) e0)) $ + weakenExpr + (autoWeak (#d (auto1 @sd) + &. #shbinds (bindingsBinds e0) + &. #tape (subList (bindingsBinds e0) subtape) + &. #d1env (desD1E usedDes) + &. #tl' (d2ace (select SAccum usedDes)) + &. #tl (d2ace (select SAccum des))) + (#d :++: LPreW #tape #shbinds (wUndoSubenv subtape) :++: LPreW #tl' #tl (wUndoSubenv subAccumUsed)) + (#shbinds :++: #d :++: #d1env :++: #tl)) + e2) + } + + EError _ t s -> + Ret BTop + SETop + (EError ext (d1 t) s) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EConstArr _ n t val -> + Ret BTop + SETop + (EConstArr ext n t val) + (subenvNone (d2e (select SMerge des))) + (ENil ext) + + EBuild _ (ndim :: SNat ndim) she (ef :: Expr _ _ eltty) + | SpArr @_ @sdElt sdElt <- sd + , let eltty = typeOf ef + , shty :: STy shty <- tTup (sreplicate ndim tIx) + , Refl <- indexTupD1Id ndim -> + drevLambda des accumMap (shty, SDiscr) sdElt ef $ \(provars :: SList _ envPro) esub proPrimalBinds e0 e1 (e1tape :: Ex _ e_tape) _ wrapAccum e2 -> + let library = #ix (shty `SCons` SNil) + &. #e0 (bindingsBinds e0) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d (auto1 @sdElt) + &. #tape (auto1 @e_tape) + &. #pro (d2ace provars) + &. #d2acEnv (d2ace (select SAccum des)) + &. #darr (auto1 @(TArr ndim sdElt)) + &. #tapearr (auto1 @(TArr ndim e_tape)) in + Ret (proPrimalBinds + `bpush` weakenExpr (wSinks (d1e provars)) + (EBuild ext ndim + (drevPrimal des she) + (letBinds e0 $ + EPair ext e1 e1tape)) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) IZ)) + (SEYesR (SENo (subenvAll (d1e provars)))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 eltty) (typeOf e1tape))) (IS IZ))) + (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) esub) + (let sinkOverEnvPro = wSinks @(sd : TArr ndim e_tape : Append (D1E envPro) (D2AcE (Select env sto "accum"))) (d2ace provars) in + ESnd ext $ + wrapAccum (WSink .> WSink .> wRaiseAbove (d1e provars) (d2ace (select SAccum des))) $ + EBuild ext ndim (EShape ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (sinkOverEnvPro @> IZ))) $ + -- the cotangent for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (WSink .> sinkOverEnvPro @> IZ)) + (EVar ext shty IZ)) $ + -- the tape for this element + ELet ext (EIdx ext (EVar ext (STArr ndim (typeOf e1tape)) (WSink .> WSink .> sinkOverEnvPro @> IS IZ)) + (EVar ext shty (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #d :++: #pro :++: #d2acEnv) + (#tape :++: #d :++: #ix :++: #pro :++: #darr :++: #tapearr :++: #propr :++: #d2acEnv)) + e2) + + EMap _ ef (earr :: Expr _ _ (TArr n a)) + | SpArr sdElt <- sd + , let STArr ndim t1 = typeOf earr + t2 = typeOf ef -> + drevLambda des accumMap (t1, SMerge) sdElt ef $ \provars efsub proPrimalBinds ef0 ef1 ef1tape spEf wrapAccum ef2 -> + case drev des accumMap (SpArr spEf) earr of { Ret ea0 easubtape ea1 easub ea2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings ea0) proPrimalBinds + ttape = typeOf ef1tape + library = #d1env (desD1E des) + &. #a0 (bindingsBinds ea0) + &. #atapebinds (subList (bindingsBinds ea0) easubtape) + &. #propr (d1e provars) + &. #x (d1 t1 `SCons` SNil) + &. #parr (STArr ndim (d1 t1) `SCons` SNil) + &. #tapearr (STArr ndim ttape `SCons` SNil) + &. #darr (STArr ndim (applySparse sdElt (d2 t2)) `SCons` SNil) + &. #dy (applySparse sdElt (d2 t2) `SCons` SNil) + &. #tape (ttape `SCons` SNil) + &. #dytape (STPair (applySparse sdElt (d2 t2)) ttape `SCons` SNil) + &. #d2acEnv (d2ace (select SAccum des)) + &. #pro (d2ace provars) + in + subenvPlus SF SF (d2eM (select SMerge des)) (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) easub $ \subfa _ _ plus_f_a -> + Ret (bconcat ea0 proPrimalBinds' + `bpush` weakenExpr (autoWeak library (#a0 :++: #d1env) ((#propr :++: #a0) :++: #d1env)) ea1 + `bpush` emap (weakenExpr (autoWeak library (#x :++: #d1env) (#x :++: #parr :++: (#propr :++: #a0) :++: #d1env)) + (letBinds ef0 $ + EPair ext ef1 ef1tape)) + (EVar ext (STArr ndim (d1 t1)) IZ) + `bpush` emap (ESnd ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) IZ)) + (SEYesR (SENo (SENo (subenvConcat easubtape (subenvAll (d1e provars)))))) + (emap (EFst ext (evar IZ)) (EVar ext (STArr ndim (STPair (d1 t2) ttape)) (IS IZ))) + subfa + (let layout = #darr :++: #tapearr :++: (#propr :++: #atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout) $ + emap (elet (EFst ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) IZ)) $ + elet (ESnd ext (EVar ext (STPair (applySparse sdElt (d2 t2)) ttape) (IS IZ))) $ + weakenExpr (autoWeak library (#tape :++: #dy :++: #pro :++: #d2acEnv) + (#tape :++: #dy :++: #dytape :++: #pro :++: layout)) + ef2) + (ezip (EVar ext (STArr ndim (applySparse sdElt (d2 t2))) (autoWeak library #darr (#pro :++: layout) @> IZ)) + (EVar ext (STArr ndim ttape) (autoWeak library #tapearr (#pro :++: layout) @> IZ)))) $ + plus_f_a + (ESnd ext (evar IZ)) + (weakenExpr (WCopy (autoWeak library (#atapebinds :++: #d2acEnv) layout)) + (subst0 (EFst ext (EVar ext (STPair (STArr ndim (typeOf ef2)) (tTup (d2e provars))) IZ)) + ea2))) + } + + EFold1Inner _ commut origef ex₀ earr + | SpArr @_ @sdElt sdElt <- sd + , STArr (SS ndim) eltty :: STy (TArr (S n) elt) <- typeOf earr + , Rets bindsx₀a subtapex₀a (RetPair ex₀1 subx₀ ex₀2 `SCons` RetPair ea1 suba ea2 `SCons` SNil) + <- retConcat des $ toSingleRet (drev des accumMap (spDense (d2M eltty)) ex₀) `SCons` toSingleRet (drev des accumMap (spDense (SMTArr (SS ndim) (d2M eltty))) earr) `SCons` SNil -> + drevLambda des accumMap (STPair eltty eltty, SMerge) (spDense (d2M eltty)) origef $ \(provars :: SList _ envPro) efsub proPrimalBinds ef0 ef1 (ef1tape :: Ex _ ef_tape) spEf wrapAccum ef2 -> + let (proPrimalBinds', _) = weakenBindingsE (sinkWithBindings bindsx₀a) proPrimalBinds in + let bogEltTy = STPair (STPair (d1 eltty) (d1 eltty)) (typeOf ef1tape) + bogTy = STArr (SS ndim) bogEltTy + primalTy = STPair (STArr ndim (d1 eltty)) bogTy + library = #xy (STPair (d1 eltty) (d1 eltty) `SCons` SNil) + &. #parr (auto1 @(TArr (S n) (D1 elt))) + &. #px₀ (auto1 @(D1 elt)) + &. #px (auto1 @(D1 elt)) + &. #pzi (auto1 @(ZeroInfo (D2 elt))) + &. #primal (primalTy `SCons` SNil) + &. #darr (auto1 @(TArr n sdElt)) + &. #d (auto1 @(D2 elt)) + &. #x₀abinds (bindingsBinds bindsx₀a) + &. #fbinds (bindingsBinds ef0) + &. #x₀atapebinds (subList (bindingsBinds bindsx₀a) subtapex₀a) + &. #ftape (auto1 @ef_tape) + &. #bogelt (bogEltTy `SCons` SNil) + &. #propr (d1e provars) + &. #d1env (desD1E des) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace provars) + &. #foldd2res (auto1 @(TPair (TPair (D2 elt) (TArr (S n) (D2 elt))) (Tup (D2E envPro)))) + wOverPrimalBindings = autoWeak library (#x₀abinds :++: #d1env) ((#propr :++: #x₀abinds) :++: #d1env) in + subenvPlus SF SF (d2eM (select SMerge des)) subx₀ suba $ \subx₀a _ _ plus_x₀_a -> + subenvPlus SF SF (d2eM (select SMerge des)) subx₀a (subenvMap (\t Refl -> spDense t) (d2eM (select SMerge des)) efsub) $ \subx₀af _ _ plus_x₀a_f -> + Ret (bconcat bindsx₀a proPrimalBinds' + `bpush` weakenExpr wOverPrimalBindings ex₀1 + `bpush` d2zeroInfo eltty (EVar ext (d1 eltty) IZ) + `bpush` weakenExpr (WSink .> WSink .> wOverPrimalBindings) ea1 + `bpush` EFold1InnerD1 ext commut + (let layout = #xy :++: #parr :++: #pzi :++: #px₀ :++: (#propr :++: #x₀abinds) :++: #d1env in + weakenExpr (autoWeak library (#xy :++: #d1env) layout) + (letBinds ef0 $ + EPair ext -- (out, ((in1, in2), tape)); the "additional stores" are ((in1, in2), tape) + ef1 + (EPair ext + (EVar ext (STPair (d1 eltty) (d1 eltty)) (autoWeak library #xy (#fbinds :++: #xy :++: #d1env) @> IZ)) + ef1tape))) + (EVar ext (d1 eltty) (IS (IS IZ))) + (EVar ext (STArr (SS ndim) (d1 eltty)) IZ)) + (SEYesR (SEYesR (SEYesR (SENo (subenvConcat subtapex₀a (subenvAll (d1e provars))))))) + (EFst ext (EVar ext primalTy IZ)) + subx₀af + (let layout1 = #darr :++: #primal :++: #parr :++: #pzi :++: (#propr :++: #x₀atapebinds) :++: #d2acEnv in + elet + (wrapAccum (autoWeak library #propr layout1) $ + let layout2 = #d2acPro :++: layout1 in + EFold1InnerD2 ext commut + (elet (ESnd ext (EVar ext bogEltTy (IS IZ))) $ + let layout3 = #ftape :++: #d :++: #bogelt :++: layout2 in + expandSparse (STPair eltty eltty) spEf (EFst ext (EVar ext bogEltTy (IS (IS IZ)))) $ + weakenExpr (autoWeak library (#ftape :++: #d :++: #d2acPro :++: #d2acEnv) layout3) ef2) + (ESnd ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))) + (ezipWith (expandSparse eltty sdElt (evar IZ) (evar (IS IZ))) + (EVar ext (STArr ndim (applySparse sdElt (d2 eltty))) (autoWeak library #darr layout2 @> IZ)) + (EFst ext (EVar ext primalTy (autoWeak library #primal layout2 @> IZ))))) $ + plus_x₀a_f + (plus_x₀_a + (elet (EIdx0 ext + (EFold1Inner ext Commut + (let t = STPair (d2 eltty) (d2 eltty) + in EPlus ext (d2M eltty) (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (EZero ext (d2M eltty) (EVar ext (tZeroInfo (d2M eltty)) (WSink .> autoWeak library #pzi layout1 @> IZ))) + (eflatten (EFst ext (EFst ext (evar IZ)))))) $ + weakenExpr (WCopy (WSink .> autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) + ex₀2) + (weakenExpr (WCopy (autoWeak library (#x₀atapebinds :++: #d2acEnv) layout1)) $ + subst0 (ESnd ext (EFst ext (evar IZ))) ea2)) + (ESnd ext (evar IZ))) + + EUnit _ e + | SpArr sdElt <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap sdElt e -> + Ret e0 + subtape + (EUnit ext e1) + sub + (ELet ext (EIdx0 ext (EVar ext (STArr SZ (applySparse sdElt (d2 (typeOf e)))) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EReplicate1Inner _ en e + -- We're allowed to differentiate 'en' as primal-only here because its output is discrete. + | SpArr sdElt <- sd + , let STArr ndim eltty = typeOf e -> + -- This pessimistic sparsity union is because the array might have been empty, in which case we need to generate a zero. + sparsePlusS ST ST (d2M eltty) sdElt SpAbsent $ \sdElt' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sdElt') e of { Ret binds subtape e1 sub e2 -> + Ret binds + subtape + (EReplicate1Inner ext (weakenExpr (wSinks (bindingsBinds binds)) (drevPrimal des en)) e1) + sub + (ELet ext (EFold1Inner ext Commut + (let t = STPair (applySparse sdElt' (d2 eltty)) (applySparse sdElt' (d2 eltty)) + in sparsePlus (d2M eltty) sdElt' (EFst ext (EVar ext t IZ)) (ESnd ext (EVar ext t IZ))) + (inj2 (ENil ext)) + (emap (inj1 (evar IZ)) $ EVar ext (STArr (SS ndim) (applySparse sdElt (d2 eltty))) IZ)) $ + weakenExpr (WCopy WSink) e2) + } + + EIdx0 _ e + | Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd) e + , STArr _ t <- typeOf e -> + Ret e0 + subtape + (EIdx0 ext e1) + sub + (ELet ext (EUnit ext (EVar ext (applySparse sd (d2 t)) IZ)) $ + weakenExpr (WCopy WSink) e2) + + EIdx1{} -> error "CHAD of EIdx1: Please use EIdx instead" + {- + EIdx1 _ e ei + -- We're allowed to ignore ei2 here because the output of 'ei' is discrete. + | Rets binds subtape (RetPair e1 sub e2 `SCons` RetPair ei1 _ _ `SCons` SNil) + <- retConcat des $ drev des accumMap e `SCons` drev des accumMap ei `SCons` SNil + , STArr (SS n) eltty <- typeOf e -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) (d1 eltty)) IZ)) + (SEYesR (SENo subtape)) + (EIdx1 ext (EVar ext (STArr (SS n) (d1 eltty)) (IS IZ)) + (weakenExpr (WSink .> WSink) ei1)) + sub + (ELet ext (ebuildUp1 n (EFst ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (d2 eltty)) (IS IZ))) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + -} + + EIdx _ e ei + -- We're allowed to differentiate ei as primal because its output is discrete. + | STArr n eltty <- typeOf e + , Refl <- indexTupD1Id n + , let tIxN = tTup (sreplicate n tIx) -> + sparsePlusS ST ST (d2M eltty) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret binds subtape e1 sub e2 -> + Ret (binds `bpush` e1 + `bpush` EShape ext (EVar ext (typeOf e1) IZ) + `bpush` weakenExpr (WSink .> WSink .> wSinks (bindingsBinds binds)) (drevPrimal des ei)) + (SEYesR (SEYesR (SENo subtape))) + (EIdx ext (EVar ext (STArr n (d1 eltty)) (IS (IS IZ))) + (EVar ext (tTup (sreplicate n tIx)) IZ)) + sub + (ELet ext + (EOneHot ext (SMTArr n (applySparse sd' (d2M eltty))) + (SAPArrIdx SAPHere) + (EPair ext + (EPair ext (EVar ext tIxN (IS IZ)) + (EBuild ext n (EVar ext tIxN (IS (IS IZ))) $ + makeZeroInfo (applySparse sd' (d2M eltty)) (inj2 (ENil ext)))) + (ENil ext)) + (inj1 $ EVar ext (applySparse sd (d2 eltty)) IZ)) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + + EShape _ e + -- Allowed to differentiate e as primal because the output of EShape is + -- discrete, hence we'd be passing a zero cotangent to e anyway. + | STArr n _ <- typeOf e + , Refl <- indexTupD1Id n -> + Ret BTop + SETop + (EShape ext (drevPrimal des e)) + (subenvNone (d2eM (select SMerge des))) + (ENil ext) + + ESum1Inner _ e + | SpArr sd' <- sd + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , STArr (SS n) t <- typeOf e -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr (SS n) t) IZ)) + (SEYesR (SENo subtape)) + (ESum1Inner ext (EVar ext (STArr (SS n) t) (IS IZ))) + sub + (ELet ext (EReplicate1Inner ext + (ESnd ext (EVar ext (tTup (sreplicate (SS n) tIx)) (IS IZ))) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EMaximum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMaximum1Inner ext) des accumMap sd' e + EMinimum1Inner _ e | SpArr sd' <- sd -> deriv_extremum (EMinimum1Inner ext) des accumMap sd' e + + EReshape _ n esh e + | SpArr sd' <- sd + , STArr orign t <- typeOf e + , Ret e0 subtape e1 sub e2 <- drev des accumMap (SpArr sd') e + , Refl <- indexTupD1Id n -> + Ret (e0 `bpush` e1 + `bpush` EShape ext (EVar ext (STArr orign (d1 t)) IZ)) + (SEYesR (SENo subtape)) + (EReshape ext n (weakenExpr (WSink .> WSink .> wSinks (bindingsBinds e0)) (drevPrimal des esh)) + (EVar ext (STArr orign (d1 t)) (IS IZ))) + sub + (elet (EReshape ext orign (EVar ext (tTup (sreplicate orign tIx)) (IS IZ)) + (EVar ext (STArr n (applySparse sd' (d2 t))) IZ)) $ + weakenExpr (WCopy (WSink .> WSink)) e2) + + EZip _ a b + | SpArr sd' <- sd + , STArr n t1 <- typeOf a + , STArr _ t2 <- typeOf b -> + splitSparsePair (STPair (d2 t1) (d2 t2)) sd' $ \sd1 sd2 pairSplitE -> + case retConcat des (toSingleRet (drev des accumMap (SpArr sd1) a) `SCons` + toSingleRet (drev des accumMap (SpArr sd2) b) `SCons` SNil) of + { Rets binds subtape (RetPair a1 subA a2 `SCons` RetPair b1 subB b2 `SCons` SNil) -> + subenvPlus SF SF (d2eM (select SMerge des)) subA subB $ \subBoth _ _ plus_A_B -> + Ret binds + subtape + (EZip ext a1 b1) + subBoth + (case pairSplitE of + Left Refl -> + let t' = STArr n (STPair (applySparse sd1 (d2 t1)) (applySparse sd2 (d2 t2))) in + plus_A_B + (elet (emap (EFst ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) a2) + (elet (emap (ESnd ext (evar IZ)) (EVar ext t' IZ)) $ weakenExpr (WCopy WSink) b2) + Right f -> f IZ $ \wrapPair pick1 pick2 -> + elet (emap (wrapPair (EPair ext pick1 pick2)) + (EVar ext (applySparse (SpArr sd') (STArr n (STPair (d2 t1) (d2 t2)))) IZ)) $ + plus_A_B + (elet (emap (EFst ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) a2) + (elet (emap (ESnd ext (evar IZ)) (evar IZ)) $ weakenExpr (WCopy (WSink .> WSink)) b2)) + } + + ENothing{} -> err_unsupported "ENothing" + EJust{} -> err_unsupported "EJust" + EMaybe{} -> err_unsupported "EMaybe" + ELNil{} -> err_unsupported "ELNil" + ELInl{} -> err_unsupported "ELInl" + ELInr{} -> err_unsupported "ELInr" + ELCase{} -> err_unsupported "ELCase" + + EWith{} -> err_accum + EZero{} -> err_monoid + EDeepZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_unsupported s = error $ "CHAD: unsupported " ++ s + err_targetlang s = error $ "CHAD: Target language operation " ++ s ++ " not supported in source program" + + contribTupTy :: Descr env sto -> SubenvS (D2E (Select env sto "merge")) contribs -> STy (Tup contribs) + contribTupTy des' sub = tTup (slistMap fromSMTy (subList (d2eM (select SMerge des')) sub)) + +deriv_extremum :: (?config :: CHADConfig, ScalIsNumeric t ~ True) + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> Sparse (D2s t) sd + -> Expr ValId env (TArr (S n) (TScal t)) -> Ret env sto (TArr n sd) (TArr n (TScal t)) +deriv_extremum extremum des accumMap sd e + | at@(STArr (SS n) t@(STScal st)) <- typeOf e + , let at' = STArr n t + , let tIxN = tTup (sreplicate (SS n) tIx) = + sparsePlusS ST ST (d2M t) sd SpAbsent $ \sd' (Inj inj1) (Inj inj2) _ -> + case drev des accumMap (SpArr sd') e of { Ret e0 subtape e1 sub e2 -> + Ret (e0 `bpush` e1 + `bpush` extremum (EVar ext at IZ)) + (SEYesR (SEYesR subtape)) + (EVar ext at' IZ) + sub + (ELet ext + (EBuild ext (SS n) (EShape ext (EVar ext at (IS (IS IZ)))) $ + eif (EOp ext (OEq st) (EPair ext + (EIdx ext (EVar ext at (IS (IS (IS IZ)))) (EVar ext tIxN IZ)) + (EIdx ext (EVar ext at' (IS (IS IZ))) (EFst ext (EVar ext tIxN IZ))))) + (inj1 $ EIdx ext (EVar ext (STArr n (applySparse sd (d2 t))) (IS IZ)) (EFst ext (EVar ext tIxN IZ))) + (inj2 (ENil ext))) $ + weakenExpr (WCopy (WSink .> WSink .> WSink)) e2) + } + +data ChosenStorage = forall s. ((s == "discr") ~ False) => ChosenStorage (Storage s) + +data RetScoped env0 sto a s sd t = + forall shbinds tapebinds contribs sa. + RetScoped + (Bindings Ex (D1E (a : env0)) shbinds) -- shared binds + (Subenv (Append shbinds '[D1 a]) tapebinds) + (Ex (Append shbinds (D1E (a : env0))) (D1 t)) + (SubenvS (D2E (Select env0 sto "merge")) contribs) + -- ^ merge contributions to the _enclosing_ merge environment + (Sparse (D2 a) sa) + -- ^ contribution to the argument + (Ex (sd : Append tapebinds (D2AcE (Select env0 sto "accum"))) + (If (s == "discr") (Tup contribs) + (TPair (Tup contribs) sa))) + -- ^ the merge contributions, plus the cotangent to the argument + -- (if there is any) +deriving instance Show (RetScoped env0 sto a s sd t) + +drevScoped :: forall a s env sto sd t. + (?config :: CHADConfig) + => Descr env sto -> VarMap Int (D2AcE (Select env sto "accum")) + -> STy a -> Storage s -> Maybe (ValId a) + -> Sparse (D2 t) sd + -> Expr ValId (a : env) t + -> RetScoped env sto a s sd t +drevScoped des accumMap argty argsto argids sd expr = case argsto of + SMerge + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + case sub of + SEYes sp sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' sp e2 + SENo sub' -> RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub' SpAbsent (EPair ext e2 (ENil ext)) + + SAccum + | chcSmartWith ?config + , Just (VIArr i _) <- argids + , Just (Some (VarMap.TypedIdx foundTy idx)) <- VarMap.lookup i accumMap + , Just Refl <- testEquality foundTy (STAccum (d2M argty)) + , Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) (VarMap.sink1 accumMap) sd expr + , Refl <- lemAppendNil @tapebinds -> + -- Our contribution to the binding's cotangent _here_ is zero (absent), + -- because we're contributing to an earlier binding of the same value + -- instead. + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent $ + let wtapebinds = wSinks (subList (bindingsBinds e0) subtape) in + ELet ext (EVar ext (STAccum (d2M argty)) (WSink .> wtapebinds @> idx)) $ + weakenExpr (autoWeak (#d (auto1 @sd) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des))) + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: #body :++: #tl)) + (EPair ext e2 (ENil ext)) + + | let accumMap' = case argids of + Just (VIArr i _) -> VarMap.insert i (STAccum (d2M argty)) IZ (VarMap.sink1 accumMap) + _ -> VarMap.sink1 accumMap + , Ret e0 subtape e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap' sd expr -> + let library = #d (auto1 @sd) + &. #p (auto1 @(D1 a)) + &. #body (subList (bindingsBinds e0) subtape) + &. #ac (auto1 @(TAccum (D2 a))) + &. #tl (d2ace (select SAccum des)) + in + RetScoped e0 (subenvConcat (SEYesR @_ @_ @(D1 a) SETop) subtape) e1 sub (spDense (d2M argty)) $ + let primalIdx = autoWeak library #p (#d :++: (#body :++: #p) :++: #tl) @> IZ in + EWith ext (d2M argty) (EDeepZero ext (d2M argty) (d2deepZeroInfo argty (EVar ext (d1 argty) primalIdx))) $ + weakenExpr (autoWeak library + (#d :++: #body :++: #ac :++: #tl) + (#ac :++: #d :++: (#body :++: #p) :++: #tl)) + e2 + + SDiscr + | Ret e0 (subtape :: Subenv _ tapebinds) e1 sub e2 <- drev (des `DPush` (argty, argids, argsto)) accumMap sd expr + , Refl <- lemAppendNil @tapebinds -> + RetScoped e0 (subenvConcat (SENo @(D1 a) SETop) subtape) e1 sub SpAbsent e2 + +drevLambda :: (?config :: CHADConfig, (s == "accum") ~ False) + => Descr env sto + -> VarMap Int (D2AcE (Select env sto "accum")) + -> (STy a, Storage s) + -> Sparse (D2 t) dt + -> Expr ValId (a : env) t + -> (forall provars shbinds tape d2a'. + SList STy provars + -> Subenv (D2E (Select env sto "merge")) (D2E provars) + -> Bindings Ex (D1E env) (D1E provars) -- accum-promoted free variables of which we need a primal in the reverse pass (to initialise the accumulator) + -> Bindings Ex (D1 a : D1E env) shbinds + -> Ex (Append shbinds (D1 a : D1E env)) (D1 t) + -> Ex (Append shbinds (D1 a : D1E env)) tape + -> Sparse (D2 a) d2a' + -> (forall env' b. + D1E provars :> env' + -> Ex (Append (D2AcE provars) env') b + -> Ex ( env') (TPair b (Tup (D2E provars)))) + -> Ex (tape : dt : Append (D2AcE provars) (D2AcE (Select env sto "accum"))) d2a' + -> r) + -> r +drevLambda des accumMap (argty, argsto) sd origef k = + let t = typeOf origef in + deleteUnused (descrList des) (occEnvPopSome (occCountAll origef)) $ \(usedSub :: Subenv env env') -> + let ef = unsafeWeakenWithSubenv (SEYesR usedSub) origef in + subDescr des usedSub $ \(usedDes :: Descr env' _) subMergeUsed subAccumUsed subD1eUsed -> + accumPromote (applySparse sd (d2 t)) usedDes $ \prodes (envPro :: SList _ envPro) proSub proAccRevSub accumMapProPart wPro -> + let accumMapPro = VarMap.disjointUnion (VarMap.superMap proAccRevSub (VarMap.subMap subAccumUsed accumMap)) accumMapProPart in + let mergePrimalSub = subenvD1E (selectSub SMerge des `subenvCompose` subMergeUsed `subenvCompose` proSub) in + let mergePrimalBindings = collectBindings (d1e (descrList des)) mergePrimalSub in + case prf1 prodes argty argsto of { Refl -> + case drev (prodes `DPush` (argty, Nothing, argsto)) accumMapPro sd ef of { Ret (ef0 :: Bindings _ _ e_binds) (subtapeEf :: Subenv _ e_tape) ef1 subEf ef2 -> + let (efRebinds, efPrerebinds) = reconstructBindings (subList (bindingsBinds ef0) subtapeEf) in + extractContrib prodes argty argsto subEf $ \argSp getSparseArg -> + let library = #fbinds (bindingsBinds ef0) + &. #ftapebinds (subList (bindingsBinds ef0) subtapeEf) + &. #ftape (auto1 @(Tape e_tape)) + &. #arg (d1 argty `SCons` SNil) + &. #d (applySparse sd (d2 t) `SCons` SNil) + &. #d1env (desD1E des) + &. #d1env' (desD1E usedDes) + &. #propr (d1e envPro) + &. #d2acUsed (d2ace (select SAccum usedDes)) + &. #d2acEnv (d2ace (select SAccum des)) + &. #d2acPro (d2ace envPro) + &. #efPrerebinds efPrerebinds in + k envPro + (subenvD2E (subenvCompose subMergeUsed proSub)) + mergePrimalBindings + (fst (weakenBindingsE (WCopy (wUndoSubenv subD1eUsed)) ef0)) + (weakenExpr (autoWeak library (#fbinds :++: #arg :++: LPreW #d1env' #d1env (wUndoSubenv subD1eUsed)) + (#fbinds :++: #arg :++: #d1env)) + ef1) + (bindingsCollectTape (bindingsBinds ef0) subtapeEf (autoWeak library #fbinds (#fbinds :++: #arg :++: #d1env))) + argSp + (\wpro1 body -> + uninvertTup (d2e envPro) (typeOf body) $ + makeAccumulators wpro1 envPro $ + body) + (letBinds (efRebinds IZ) $ + weakenExpr + (autoWeak library (#d2acPro :++: #d :++: #ftapebinds :++: LPreW #d2acUsed #d2acEnv (wUndoSubenv subAccumUsed)) + ((#ftapebinds :++: #efPrerebinds) :++: #ftape :++: #d :++: #d2acPro :++: #d2acEnv) + .> wPro (subList (bindingsBinds ef0) subtapeEf)) + (getSparseArg ef2)) + }} + where + extractContrib :: (Select env sto "merge" ~ '[], (s == "accum") ~ False) + => proxy env sto -> proxy2 a -> Storage s + -- if s == "merge", this simplifies to SubenvS '[D2 a] t' + -- if s == "discr", this simplifies to SubenvS '[] t' + -> SubenvS (D2E (Select (a : env) (s : sto) "merge")) t' + -> (forall d'. Sparse (D2 a) d' -> (forall env'. Ex env' (Tup t') -> Ex env' d') -> r) -> r + extractContrib _ _ SMerge (SENo SETop) k' = k' SpAbsent id + extractContrib _ _ SMerge (SEYes s SETop) k' = k' s (ESnd ext) + extractContrib _ _ SDiscr SETop k' = k' SpAbsent id + + prf1 :: (s == "accum") ~ False => proxy env sto -> proxy2 a -> Storage s + -> Select (a : env) (s : sto) "accum" :~: Select env sto "accum" + prf1 _ _ SMerge = Refl + prf1 _ _ SDiscr = Refl + +-- TODO: proper primal-only transform that doesn't depend on D1 = Id +drevPrimal :: Descr env sto -> Expr x env t -> Ex (D1E env) (D1 t) +drevPrimal des e + | Refl <- d1Identity (typeOf e) + , Refl <- d1eIdentity (descrList des) + = mapExt (const ext) e diff --git a/src/CHAD/Drev/Accum.hs b/src/CHAD/Drev/Accum.hs new file mode 100644 index 0000000..6f25f11 --- /dev/null +++ b/src/CHAD/Drev/Accum.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE GADTs #-} +{-# LANGUAGE TypeOperators #-} +-- | TODO this module is a grab-bag of random utility functions that are shared +-- between CHAD.Drev and CHAD.Drev.Top. +module CHAD.Drev.Accum where + +import CHAD.AST +import CHAD.Data +import CHAD.Drev.Types +import CHAD.AST.Env + + +d2zeroInfo :: STy t -> Ex env (D1 t) -> Ex env (ZeroInfo (D2 t)) +d2zeroInfo STNil _ = ENil ext +d2zeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2zeroInfo a e1) (d2zeroInfo b e2) +d2zeroInfo STEither{} _ = ENil ext +d2zeroInfo STLEither{} _ = ENil ext +d2zeroInfo STMaybe{} _ = ENil ext +d2zeroInfo (STArr _ t) e = emap (d2zeroInfo t (EVar ext (d1 t) IZ)) e +d2zeroInfo (STScal t) _ | Refl <- lemZeroInfoScal t = ENil ext +d2zeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +d2deepZeroInfo :: STy t -> Ex env (D1 t) -> Ex env (DeepZeroInfo (D2 t)) +d2deepZeroInfo STNil _ = ENil ext +d2deepZeroInfo (STPair a b) e = + eunPair e $ \_ e1 e2 -> + EPair ext (d2deepZeroInfo a e1) (d2deepZeroInfo b e2) +d2deepZeroInfo (STEither a b) e = + ECase ext e + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STLEither a b) e = + elcase e + (ELNil ext (tDeepZeroInfo (d2M a)) (tDeepZeroInfo (d2M b))) + (ELInl ext (tDeepZeroInfo (d2M b)) (d2deepZeroInfo a (EVar ext (d1 a) IZ))) + (ELInr ext (tDeepZeroInfo (d2M a)) (d2deepZeroInfo b (EVar ext (d1 b) IZ))) +d2deepZeroInfo (STMaybe a) e = + emaybe e + (ENothing ext (tDeepZeroInfo (d2M a))) + (EJust ext (d2deepZeroInfo a (EVar ext (d1 a) IZ))) +d2deepZeroInfo (STArr _ t) e = emap (d2deepZeroInfo t (EVar ext (d1 t) IZ)) e +d2deepZeroInfo (STScal t) _ | Refl <- lemDeepZeroInfoScal t = ENil ext +d2deepZeroInfo STAccum{} _ = error "accumulators not allowed in source program" + +-- The weakening is necessary because we need to initialise the created +-- accumulators with zeros. Those zeros are deep and need full primals. This +-- means, in the end, that primals corresponding to environment entries +-- promoted to an accumulator with accumPromote in CHAD need to be stored for +-- the dual. +makeAccumulators :: D1E envPro :> env -> SList STy envPro -> Ex (Append (D2AcE envPro) env) t -> Ex env (InvTup t (D2E envPro)) +makeAccumulators _ SNil e = e +makeAccumulators w (t `SCons` envpro) e = + makeAccumulators (WPop w) envpro $ + EWith ext (d2M t) (EDeepZero ext (d2M t) (d2deepZeroInfo t (EVar ext (d1 t) (wSinks (d2ace envpro) .> w @> IZ)))) e + +uninvertTup :: SList STy list -> STy core -> Ex env (InvTup core list) -> Ex env (TPair core (Tup list)) +uninvertTup SNil _ e = EPair ext e (ENil ext) +uninvertTup (t `SCons` list) tcore e = + ELet ext (uninvertTup list (STPair tcore t) e) $ + let recT = STPair (STPair tcore t) (tTup list) -- type of the RHS of that let binding + in EPair ext + (EFst ext (EFst ext (EVar ext recT IZ))) + (EPair ext + (ESnd ext (EVar ext recT IZ)) + (ESnd ext (EFst ext (EVar ext recT IZ)))) + +subenvD1E :: Subenv env env' -> Subenv (D1E env) (D1E env') +subenvD1E SETop = SETop +subenvD1E (SEYesR sub) = SEYesR (subenvD1E sub) +subenvD1E (SENo sub) = SENo (subenvD1E sub) diff --git a/src/CHAD/Drev/EnvDescr.hs b/src/CHAD/Drev/EnvDescr.hs new file mode 100644 index 0000000..5a90303 --- /dev/null +++ b/src/CHAD/Drev/EnvDescr.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.EnvDescr where + +import Data.Kind (Type) +import Data.Some +import GHC.TypeLits (Symbol) + +import CHAD.Analysis.Identity (ValId(..)) +import CHAD.AST.Env +import CHAD.AST.Types +import CHAD.AST.Weaken +import CHAD.Data +import CHAD.Drev.Types + + +type Storage :: Symbol -> Type +data Storage s where + SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator + SMerge :: Storage "merge" -- ^ just return and merge + SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions +deriving instance Show (Storage s) + +-- | Environment description +data Descr env sto where + DTop :: Descr '[] '[] + DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto) +deriving instance Show (Descr env sto) + +descrList :: Descr env sto -> SList STy env +descrList DTop = SNil +descrList (des `DPush` (t, _, _)) = t `SCons` descrList des + +descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage) +descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto) +descrPrj (des `DPush` _) (IS i) = descrPrj des i +descrPrj DTop i = case i of {} + +-- | This could have more precise typing on the output storage. +subDescr :: Descr env sto -> Subenv env env' + -> (forall sto'. Descr env' sto' + -> Subenv (Select env sto "merge") (Select env' sto' "merge") + -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) + -> Subenv (D1E env) (D1E env') + -> r) + -> r +subDescr DTop SETop k = k DTop SETop SETop SETop +subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) + SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) + SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) +subDescr (des `DPush` (_, _, sto)) (SENo sub) k = + subDescr des sub $ \des' submerge subaccum subd1e -> + case sto of + SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) + SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) + SDiscr -> k des' submerge subaccum (SENo subd1e) + +-- | Select only the types from the environment that have the specified storage +type family Select env sto s where + Select '[] '[] _ = '[] + Select (t : ts) (s : sto) s = t : Select ts sto s + Select (_ : ts) (_ : sto) s = Select ts sto s + +select :: Storage s -> Descr env sto -> SList STy (Select env sto s) +select _ DTop = SNil +select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des) +select s@SMerge (DPush des (_, _, SAccum)) = select s des +select s@SDiscr (DPush des (_, _, SAccum)) = select s des +select s@SAccum (DPush des (_, _, SMerge)) = select s des +select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des) +select s@SDiscr (DPush des (_, _, SMerge)) = select s des +select s@SAccum (DPush des (_, _, SDiscr)) = select s des +select s@SMerge (DPush des (_, _, SDiscr)) = select s des +select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) + +selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) +selectSub _ DTop = SETop +selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) +selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) +selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Drev/Top.hs b/src/CHAD/Drev/Top.hs new file mode 100644 index 0000000..510e73e --- /dev/null +++ b/src/CHAD/Drev/Top.hs @@ -0,0 +1,96 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Top where + +import CHAD.Analysis.Identity +import CHAD.AST +import CHAD.AST.Env +import CHAD.AST.Sparse +import CHAD.AST.SplitLets +import CHAD.AST.Weaken.Auto +import CHAD.Data +import qualified CHAD.Data.VarMap as VarMap +import CHAD.Drev +import CHAD.Drev.Accum +import CHAD.Drev.EnvDescr +import CHAD.Drev.Types + + +type family MergeEnv env where + MergeEnv '[] = '[] + MergeEnv (t : ts) = "merge" : MergeEnv ts + +mergeDescr :: SList STy env -> Descr env (MergeEnv env) +mergeDescr SNil = DTop +mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) + +mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] +mergeEnvNoAccum SNil = Refl +mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl + +mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env +mergeEnvOnlyMerge SNil = Refl +mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl + +accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r +accumDescr SNil k = k DTop +accumDescr (t `SCons` env) k = accumDescr env $ \des -> + if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) + else k (des `DPush` (t, Nothing, SMerge)) + +reassembleD2E :: Descr env sto + -> D1E env :> env' + -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) + -> Ex env' (Tup (D2E env)) +reassembleD2E DTop _ _ = ENil ext +reassembleD2E (des `DPush` (_, _, SAccum)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e1 $ \w2 e11 e12 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 +reassembleD2E (des `DPush` (_, _, SMerge)) w e = + eunPair e $ \w1 e1 e2 -> + eunPair e2 $ \w2 e21 e22 -> + EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 +reassembleD2E (des `DPush` (t, _, SDiscr)) w e = + EPair ext (reassembleD2E des (WPop w) e) + (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) + +chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) +chad config env (term :: Ex env t) + | True <- chcArgArrayAccum config + = let ?config = config + in accumDescr env $ \descr -> + let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) + tvar = STPair t1 (tTup (d2e (select SAccum descr))) + in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ + makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ + weakenExpr (autoWeak (#d (auto1 @(D2 t)) + &. #acenv (d2ace (select SAccum descr)) + &. #tl (d1e env)) + (#d :++: #acenv :++: #tl) + (#acenv :++: #d :++: #tl)) $ + freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ + EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) + (reassembleD2E descr (WSink .> WSink) + (EPair ext (ESnd ext (EVar ext tvar IZ)) + (ESnd ext (EFst ext (EVar ext tvar IZ))))) + + | False <- chcArgArrayAccum config + , Refl <- mergeEnvNoAccum env + , Refl <- mergeEnvOnlyMerge env + = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') + where + term' = identityAnalysis env (splitLets term) + +chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +chad' config env term + | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) + = chad config env term diff --git a/src/CHAD/Drev/Types.hs b/src/CHAD/Drev/Types.hs new file mode 100644 index 0000000..367a974 --- /dev/null +++ b/src/CHAD/Drev/Types.hs @@ -0,0 +1,153 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Drev.Types where + +import CHAD.AST.Accum +import CHAD.AST.Types +import CHAD.Data + + +type family D1 t where + D1 TNil = TNil + D1 (TPair a b) = TPair (D1 a) (D1 b) + D1 (TEither a b) = TEither (D1 a) (D1 b) + D1 (TLEither a b) = TLEither (D1 a) (D1 b) + D1 (TMaybe a) = TMaybe (D1 a) + D1 (TArr n t) = TArr n (D1 t) + D1 (TScal t) = TScal t + +type family D2 t where + D2 TNil = TNil + D2 (TPair a b) = TPair (D2 a) (D2 b) + D2 (TEither a b) = TLEither (D2 a) (D2 b) + D2 (TLEither a b) = TLEither (D2 a) (D2 b) + D2 (TMaybe t) = TMaybe (D2 t) + D2 (TArr n t) = TArr n (D2 t) + D2 (TScal t) = D2s t + +type family D2s t where + D2s TI32 = TNil + D2s TI64 = TNil + D2s TF32 = TScal TF32 + D2s TF64 = TScal TF64 + D2s TBool = TNil + +type family D1E env where + D1E '[] = '[] + D1E (t : env) = D1 t : D1E env + +type family D2E env where + D2E '[] = '[] + D2E (t : env) = D2 t : D2E env + +type family D2AcE env where + D2AcE '[] = '[] + D2AcE (t : env) = TAccum (D2 t) : D2AcE env + +d1 :: STy t -> STy (D1 t) +d1 STNil = STNil +d1 (STPair a b) = STPair (d1 a) (d1 b) +d1 (STEither a b) = STEither (d1 a) (d1 b) +d1 (STLEither a b) = STLEither (d1 a) (d1 b) +d1 (STMaybe t) = STMaybe (d1 t) +d1 (STArr n t) = STArr n (d1 t) +d1 (STScal t) = STScal t +d1 STAccum{} = error "Accumulators not allowed in input program" + +d1e :: SList STy env -> SList STy (D1E env) +d1e SNil = SNil +d1e (t `SCons` env) = d1 t `SCons` d1e env + +d2M :: STy t -> SMTy (D2 t) +d2M STNil = SMTNil +d2M (STPair a b) = SMTPair (d2M a) (d2M b) +d2M (STEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) +d2M (STMaybe t) = SMTMaybe (d2M t) +d2M (STArr n t) = SMTArr n (d2M t) +d2M (STScal t) = case t of + STI32 -> SMTNil + STI64 -> SMTNil + STF32 -> SMTScal STF32 + STF64 -> SMTScal STF64 + STBool -> SMTNil +d2M STAccum{} = error "Accumulators not allowed in input program" + +d2 :: STy t -> STy (D2 t) +d2 = fromSMTy . d2M + +d2eM :: SList STy env -> SList SMTy (D2E env) +d2eM SNil = SNil +d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts + +d2e :: SList STy env -> SList STy (D2E env) +d2e = slistMap fromSMTy . d2eM + +d2ace :: SList STy env -> SList STy (D2AcE env) +d2ace SNil = SNil +d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts + + +data CHADConfig = CHADConfig + { -- | D[let] will bind variables containing arrays in accumulator mode. + chcLetArrayAccum :: Bool + , -- | D[case] will bind variables containing arrays in accumulator mode. + chcCaseArrayAccum :: Bool + , -- | Introduce top-level arguments containing arrays in accumulator mode. + chcArgArrayAccum :: Bool + , -- | Place with-blocks around array variable scopes, and redirect accumulations there. + chcSmartWith :: Bool + } + deriving (Show) + +defaultConfig :: CHADConfig +defaultConfig = CHADConfig + { chcLetArrayAccum = False + , chcCaseArrayAccum = False + , chcArgArrayAccum = False + , chcSmartWith = False + } + +chcSetAccum :: CHADConfig -> CHADConfig +chcSetAccum c = c { chcLetArrayAccum = True + , chcCaseArrayAccum = True + , chcArgArrayAccum = True + , chcSmartWith = True } + + +------------------------------------ LEMMAS ------------------------------------ + +indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) +indexTupD1Id SZ = Refl +indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl + +lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil +lemZeroInfoScal STI32 = Refl +lemZeroInfoScal STI64 = Refl +lemZeroInfoScal STF32 = Refl +lemZeroInfoScal STF64 = Refl +lemZeroInfoScal STBool = Refl + +lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil +lemDeepZeroInfoScal STI32 = Refl +lemDeepZeroInfoScal STI64 = Refl +lemDeepZeroInfoScal STF32 = Refl +lemDeepZeroInfoScal STF64 = Refl +lemDeepZeroInfoScal STBool = Refl + +d1Identity :: STy t -> D1 t :~: t +d1Identity = \case + STNil -> Refl + STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl + STMaybe t | Refl <- d1Identity t -> Refl + STArr _ t | Refl <- d1Identity t -> Refl + STScal _ -> Refl + STAccum{} -> error "Accumulators not allowed in input program" + +d1eIdentity :: SList STy env -> D1E env :~: env +d1eIdentity SNil = Refl +d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Drev/Types/ToTan.hs b/src/CHAD/Drev/Types/ToTan.hs new file mode 100644 index 0000000..019119c --- /dev/null +++ b/src/CHAD/Drev/Types/ToTan.hs @@ -0,0 +1,43 @@ +{-# LANGUAGE GADTs #-} +module CHAD.Drev.Types.ToTan where + +import Data.Bifunctor (bimap) + +import CHAD.Array +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.ForwardAD +import CHAD.Interpreter.Rep + + +toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) +toTanE SNil SNil SNil = SNil +toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = + Value (toTan t p x) `SCons` toTanE env primal inp + +toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) +toTan typ primal der = case typ of + STNil -> der + STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal + STEither t1 t2 -> case der of + Nothing -> bimap (zeroTan t1) (zeroTan t2) primal + Just d -> case (primal, d) of + (Left p, Left d') -> Left (toTan t1 p d') + (Right p, Right d') -> Right (toTan t2 p d') + _ -> error "Primal and cotangent disagree on Either alternative" + STLEither t1 t2 -> case (primal, der) of + (_, Nothing) -> Nothing + (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) + (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) + _ -> error "Primal and cotangent disagree on LEither alternative" + STMaybe t -> liftA2 (toTan t) primal der + STArr _ t + | arrayShape primal == arrayShape der -> + arrayGenerateLin (arrayShape primal) $ \i -> + toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) + | otherwise -> + error "Primal and cotangent disagree on array shape" + STScal sty -> case sty of + STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der + STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/CHAD/EnvDescr.hs b/src/CHAD/EnvDescr.hs deleted file mode 100644 index 49ae0e6..0000000 --- a/src/CHAD/EnvDescr.hs +++ /dev/null @@ -1,96 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module CHAD.EnvDescr where - -import Data.Kind (Type) -import Data.Some -import GHC.TypeLits (Symbol) - -import Analysis.Identity (ValId(..)) -import AST.Env -import AST.Types -import AST.Weaken -import CHAD.Types -import Data - - -type Storage :: Symbol -> Type -data Storage s where - SAccum :: Storage "accum" -- ^ in the monad state as a mutable accumulator - SMerge :: Storage "merge" -- ^ just return and merge - SDiscr :: Storage "discr" -- ^ we happen to know this is a discrete type and won't need any contributions -deriving instance Show (Storage s) - --- | Environment description -data Descr env sto where - DTop :: Descr '[] '[] - DPush :: Descr env sto -> (STy t, Maybe (ValId t), Storage s) -> Descr (t : env) (s : sto) -deriving instance Show (Descr env sto) - -descrList :: Descr env sto -> SList STy env -descrList DTop = SNil -descrList (des `DPush` (t, _, _)) = t `SCons` descrList des - -descrPrj :: Descr env sto -> Idx env t -> (STy t, Maybe (ValId t), Some Storage) -descrPrj (_ `DPush` (ty, vid, sto)) IZ = (ty, vid, Some sto) -descrPrj (des `DPush` _) (IS i) = descrPrj des i -descrPrj DTop i = case i of {} - --- | This could have more precise typing on the output storage. -subDescr :: Descr env sto -> Subenv env env' - -> (forall sto'. Descr env' sto' - -> Subenv (Select env sto "merge") (Select env' sto' "merge") - -> Subenv (D2AcE (Select env sto "accum")) (D2AcE (Select env' sto' "accum")) - -> Subenv (D1E env) (D1E env') - -> r) - -> r -subDescr DTop SETop k = k DTop SETop SETop SETop -subDescr (des `DPush` (t, vid, sto)) (SEYesR sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k (des' `DPush` (t, vid, sto)) (SEYesR submerge) subaccum (SEYesR subd1e) - SAccum -> k (des' `DPush` (t, vid, sto)) submerge (SEYesR subaccum) (SEYesR subd1e) - SDiscr -> k (des' `DPush` (t, vid, sto)) submerge subaccum (SEYesR subd1e) -subDescr (des `DPush` (_, _, sto)) (SENo sub) k = - subDescr des sub $ \des' submerge subaccum subd1e -> - case sto of - SMerge -> k des' (SENo submerge) subaccum (SENo subd1e) - SAccum -> k des' submerge (SENo subaccum) (SENo subd1e) - SDiscr -> k des' submerge subaccum (SENo subd1e) - --- | Select only the types from the environment that have the specified storage -type family Select env sto s where - Select '[] '[] _ = '[] - Select (t : ts) (s : sto) s = t : Select ts sto s - Select (_ : ts) (_ : sto) s = Select ts sto s - -select :: Storage s -> Descr env sto -> SList STy (Select env sto s) -select _ DTop = SNil -select s@SAccum (DPush des (t, _, SAccum)) = SCons t (select s des) -select s@SMerge (DPush des (_, _, SAccum)) = select s des -select s@SDiscr (DPush des (_, _, SAccum)) = select s des -select s@SAccum (DPush des (_, _, SMerge)) = select s des -select s@SMerge (DPush des (t, _, SMerge)) = SCons t (select s des) -select s@SDiscr (DPush des (_, _, SMerge)) = select s des -select s@SAccum (DPush des (_, _, SDiscr)) = select s des -select s@SMerge (DPush des (_, _, SDiscr)) = select s des -select s@SDiscr (DPush des (t, _, SDiscr)) = SCons t (select s des) - -selectSub :: Storage s -> Descr env sto -> Subenv env (Select env sto s) -selectSub _ DTop = SETop -selectSub s@SAccum (DPush des (_, _, SAccum)) = SEYesR (selectSub s des) -selectSub s@SMerge (DPush des (_, _, SAccum)) = SENo (selectSub s des) -selectSub s@SDiscr (DPush des (_, _, SAccum)) = SENo (selectSub s des) -selectSub s@SAccum (DPush des (_, _, SMerge)) = SENo (selectSub s des) -selectSub s@SMerge (DPush des (_, _, SMerge)) = SEYesR (selectSub s des) -selectSub s@SDiscr (DPush des (_, _, SMerge)) = SENo (selectSub s des) -selectSub s@SAccum (DPush des (_, _, SDiscr)) = SENo (selectSub s des) -selectSub s@SMerge (DPush des (_, _, SDiscr)) = SENo (selectSub s des) -selectSub s@SDiscr (DPush des (_, _, SDiscr)) = SEYesR (selectSub s des) diff --git a/src/CHAD/Example.hs b/src/CHAD/Example.hs new file mode 100644 index 0000000..884f99a --- /dev/null +++ b/src/CHAD/Example.hs @@ -0,0 +1,197 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} + +{-# OPTIONS -Wno-unused-imports #-} +module CHAD.Example where + +import Debug.Trace + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.UnMonoid +import CHAD.Data +import CHAD.Drev +import CHAD.Drev.Top +import CHAD.Drev.Types +import CHAD.Example.Types +import CHAD.ForwardAD +import CHAD.Interpreter +import CHAD.Language +import CHAD.Simplify + + +-- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) + + +pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) +pipeline config term + | Dict <- styKnown (d2 (typeOf term)) = + simplifyFix $ pruneExpr knownEnv $ + simplifyFix $ unMonoid $ + simplifyFix $ chad' config knownEnv $ + simplifyFix $ term + +-- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures +pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () +pipeline' config term + | Dict <- styKnown (d2 (typeOf term)) = + pprintExpr (pipeline config term) + + +bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c +bin op a b = EOp ext op (EPair ext a b) + +senv1 :: SList STy [TScal TF32, TScal TF32] +senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil + +-- x y |- x * y + x +-- +-- let x3 = (x1, x2) +-- x4 = ((*) x3, x1) +-- in ( (+) x4 +-- , let x5 = 1.0 +-- x6 = Inr (x5, x5) +-- in case x6 of +-- Inl x7 -> return () +-- Inr x8 -> +-- let x9 = fst x8 +-- x10 = Inr (snd x3 * x9, fst x3 * x9) +-- in case x10 of +-- Inl x11 -> return () +-- Inr x12 -> +-- let x13 = fst x12 +-- in one "v1" x13 >>= \x14 -> +-- let x15 = snd x12 +-- in one "v2" x15 >>= \x16 -> +-- let x17 = snd x8 +-- in one "v1" x17) +-- +-- ( (x1 * x2) + x1 +-- , let x5 = 1.0 +-- in do one "v1" (x2 * x5) +-- one "v2" (x1 * x5) +-- one "v1" x5) +ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex1 = fromNamed $ lambda #x $ lambda #y $ body $ + #x * #y + #x + +-- x y |- let z = x + y in z * (z + x) +ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex2 = fromNamed $ lambda #x $ lambda #y $ body $ + let_ #z (#x + #y) $ + #z * (#z + #x) + +-- x y |- if x < y then 2 * x else 3 + x +ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex3 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x) (3 * #x) + +-- x y |- if x < y then 2 * x + y * y else 3 + x +ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) +ex4 = fromNamed $ lambda #x $ lambda #y $ body $ + if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) + +-- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} +ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) +ex5 = fromNamed $ lambda #x $ lambda #y $ body $ + case_ #x (#a :-> #a * #y) + (#b :-> #b * (#y + 1)) + +-- x:R n:I |- let a = unit x +-- b = build1 n (\i. let c = idx0 a in c * c) +-- in idx0 (b ! 3) +ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) +ex6 = fromNamed $ lambda #x $ lambda #n $ body $ + let_ #a (unit #x) $ + let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ + #b ! pair nil 3 + +-- A "neural network" except it's just scalars, not matrices. +-- ps:((((), (R,R)), (R,R)), (R,R)) x:R +-- |- let p1 = snd ps +-- p1' = fst ps +-- x1 = fst p1 * x + snd p1 +-- p2 = snd p1' +-- p2' = fst p1' +-- x2 = fst p2 * x + snd p2 +-- p3 = snd p2' +-- p3' = fst p2' +-- x3 = fst p3 * x + snd p3 +-- in x3 +ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R +ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ + let tR = STScal STF64 + tpair = STPair tR tR + + layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) + => STy p -> NExpr env R + layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = + let_ #par (snd_ #parstup) $ + let_ #restpars (fst_ #parstup) $ + let_ #inp (fst_ #par * #inp + snd_ #par) $ + let_ #parstup #restpars $ + layer t + layer STNil = #inp + layer _ = error "Invalid layer inputs" + + in let_ #parstup #pars123 $ + let_ #inp #input $ + layer (STPair (STPair (STPair STNil tpair) tpair) tpair) + +neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R +neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ + let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ + -- prod = wei `matmul` x + let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> + #wei ! #idx * #x ! pair nil (snd_ #idx)) $ + -- relu (prod + bias) + build (SS SZ) (shape #prod) $ #idx :-> + let_ #out (#prod ! #idx + #bias ! #idx) $ + if_ (#out .<= const_ 0) (const_ 0) #out + + in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ + let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ + let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ + #x3 ! nil + +type NeuralGrad = ((Array N2 Double, Array N1 Double) + ,(Array N2 Double, Array N1 Double) + ,Array N1 Double + ,Array N1 Double) + +neuralGo :: (Double -- primal + ,NeuralGrad -- gradient using CHAD + ,NeuralGrad) -- gradient using dual-numbers forward AD +neuralGo = + let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) + lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] + input = arrayFromList (ShNil `ShCons` 2) [1,1] + argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) + revderiv = + simplifyN 20 $ + ELet ext (EConst ext STF64 1.0) $ + chad defaultConfig knownEnv neural + (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of + (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') + (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 + in trace (ppExpr knownEnv revderiv) $ + (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) + +-- The build body uses free variables in a non-linear way, so their primal +-- values are required in the dual of the build. Thus, compositionally, they +-- are stored in the tape from each individual lambda invocation. This results +-- in n copies of y and z, where only one copy would have sufficed. +exUniformFree :: Ex '[R, I64] R +exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ + let_ #y (#x * 2) $ + let_ #z (#x * 3) $ + idx0 $ sum1i $ + build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/CHAD/Example/GMM.hs b/src/CHAD/Example/GMM.hs new file mode 100644 index 0000000..8f834e0 --- /dev/null +++ b/src/CHAD/Example/GMM.hs @@ -0,0 +1,124 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Example.GMM where + +import CHAD.Data (SList(..), SNat(..)) +import CHAD.Example.Types +import CHAD.Language + + + +-- N, D, K: integers > 0 +-- alpha, M, Q, L: the active parameters +-- X: inactive data +-- m: integer +-- k1: 1/2 N D log(2 pi) +-- k2: 1/2 gamma^2 +-- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) +-- where n' = D + m + 1 +-- +-- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. +-- +-- See: +-- - "A benchmark of selected algorithmic differentiation tools on some problems +-- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): +-- 889-906 (2018). +-- +-- +-- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. +-- Master thesis at Utrecht University. (Appendix B.1) +-- +-- +-- +-- The 'wrong' argument, when set to True, changes the objective function to +-- one with a bug that makes a certain `build` result unused. This +-- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's +-- dense, even though it may be a zero (i.e. empty). The "unused" test in +-- test/Main.hs tries to isolate this case, but the wrong version of +-- gmmObjective is here to check (after that bug is fixed) whether it really +-- fixes the original bug. +gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R +gmmObjective wrong = fromNamed $ + lambda #N $ lambda #D $ lambda #K $ + lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ + lambda #X $ lambda #m $ + lambda #k1 $ lambda #k2 $ lambda #k3 $ + body $ + let -- We have: + -- sum (exp (x - max(x))) + -- = sum (exp x / exp (max(x))) + -- = sum (exp x) / exp (max(x)) + -- Hence: + -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) + -- + -- So: + -- d/dxi log (sum (exp x)) + -- = 1/(sum (exp x)) * d/dxi sum (exp x) + -- = 1/(sum (exp x)) * sum (d/dxi exp x) + -- = 1/(sum (exp x)) * exp xi + -- = exp xi / sum (exp x) + -- (by (*)) + -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) + -- = exp (xi - max(x)) / sum (exp (x - max(x))) + logsumexp' = lambda @(TVec R) #vec $ body $ + let_ #m (maximum1i #vec) $ + log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m + -- custom (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) + -- (#_ :-> #v :-> + -- let_ #m (idx0 (maximum1i #v)) $ + -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ + -- let_ #s (idx0 (sum1i #ex)) $ + -- pair (log #s + #m) + -- (pair #ex #s)) + -- (#tape :-> #d :-> + -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) + -- nil #vec + logsumexp v = inline logsumexp' (SNil .$ v) + + mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ + let_ #hei (snd_ (fst_ (shape #mat))) $ + let_ #wid (snd_ (shape #mat)) $ + build1 #hei $ #i :-> + idx0 (sum1i (build1 #wid $ #j :-> + #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) + m *@ v = inline mulmatvec (SNil .$ m .$ v) + + subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ + build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i + a .- b = inline subvec (SNil .$ a .$ b) + + matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ + build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) + m .! i = inline matrow (SNil .$ m .$ i) + + normsq' = lambda @(TVec R) #vec $ body $ + idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) + normsq v = inline normsq' (SNil .$ v) + + qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ + let_ #n (snd_ (shape #q)) $ + build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> + let_ #i (snd_ (fst_ #idx)) $ + let_ #j (snd_ #idx) $ + if_ (#i .== #j) + (exp (#q ! pair nil #i)) + (if_ (#i .> #j) + (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) + else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) + 0.0) + qmat q l = inline qmat' (SNil .$ q .$ l) + in let_ #k2arr (unit #k2) $ + - #k1 + + idx0 (sum1i (build1 #N $ #i :-> + logsumexp (build1 #K $ #k :-> + #alpha ! pair nil #k + + idx0 (sum1i (#Q .! #k)) + - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) + - toFloat_ #N * logsumexp #alpha + + idx0 (sum1i (build1 #K $ #k :-> + idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) + - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) + - #k3 diff --git a/src/CHAD/Example/Types.hs b/src/CHAD/Example/Types.hs new file mode 100644 index 0000000..1e2f72d --- /dev/null +++ b/src/CHAD/Example/Types.hs @@ -0,0 +1,11 @@ +{-# LANGUAGE DataKinds #-} +module CHAD.Example.Types where + +import CHAD.AST +import CHAD.Data + + +type R = TScal TF64 +type I64 = TScal TI64 +type TVec = TArr (S Z) +type TMat = TArr (S (S Z)) diff --git a/src/CHAD/ForwardAD.hs b/src/CHAD/ForwardAD.hs new file mode 100644 index 0000000..7126e10 --- /dev/null +++ b/src/CHAD/ForwardAD.hs @@ -0,0 +1,270 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.ForwardAD where + +import Data.Bifunctor (bimap) +import System.IO.Unsafe + +-- import Debug.Trace +-- import CHAD.AST.Pretty + +import CHAD.Array +import CHAD.AST +import CHAD.Compile +import CHAD.Data +import CHAD.ForwardAD.DualNumbers +import CHAD.Interpreter +import CHAD.Interpreter.Rep + + +-- | Tangent along a type (coincides with cotangent for these types) +type family Tan t where + Tan TNil = TNil + Tan (TPair a b) = TPair (Tan a) (Tan b) + Tan (TEither a b) = TEither (Tan a) (Tan b) + Tan (TLEither a b) = TLEither (Tan a) (Tan b) + Tan (TMaybe t) = TMaybe (Tan t) + Tan (TArr n t) = TArr n (Tan t) + Tan (TScal t) = TanS t + +type family TanS t where + TanS TI32 = TNil + TanS TI64 = TNil + TanS TF32 = TScal TF32 + TanS TF64 = TScal TF64 + TanS TBool = TNil + +type family TanE env where + TanE '[] = '[] + TanE (t : env) = Tan t : TanE env + +tanty :: STy t -> STy (Tan t) +tanty STNil = STNil +tanty (STPair a b) = STPair (tanty a) (tanty b) +tanty (STEither a b) = STEither (tanty a) (tanty b) +tanty (STLEither a b) = STLEither (tanty a) (tanty b) +tanty (STMaybe t) = STMaybe (tanty t) +tanty (STArr n t) = STArr n (tanty t) +tanty (STScal t) = case t of + STI32 -> STNil + STI64 -> STNil + STF32 -> STScal STF32 + STF64 -> STScal STF64 + STBool -> STNil +tanty STAccum{} = error "Accumulators not allowed in input program" + +tanenv :: SList STy env -> SList STy (TanE env) +tanenv SNil = SNil +tanenv (t `SCons` env) = tanty t `SCons` tanenv env + +zeroTan :: STy t -> Rep t -> Rep (Tan t) +zeroTan STNil () = () +zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) +zeroTan (STEither a _) (Left x) = Left (zeroTan a x) +zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) +zeroTan (STLEither _ _) Nothing = Nothing +zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) +zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) +zeroTan (STMaybe _) Nothing = Nothing +zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) +zeroTan (STArr _ t) x = fmap (zeroTan t) x +zeroTan (STScal STI32) _ = () +zeroTan (STScal STI64) _ = () +zeroTan (STScal STF32) _ = 0.0 +zeroTan (STScal STF64) _ = 0.0 +zeroTan (STScal STBool) _ = () +zeroTan STAccum{} _ = error "Accumulators not allowed in input program" + +tanScalars :: STy t -> Rep (Tan t) -> [Double] +tanScalars STNil () = [] +tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y +tanScalars (STEither a _) (Left x) = tanScalars a x +tanScalars (STEither _ b) (Right y) = tanScalars b y +tanScalars (STLEither _ _) Nothing = [] +tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x +tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y +tanScalars (STMaybe _) Nothing = [] +tanScalars (STMaybe t) (Just x) = tanScalars t x +tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x +tanScalars (STScal STI32) _ = [] +tanScalars (STScal STI64) _ = [] +tanScalars (STScal STF32) x = [realToFrac x] +tanScalars (STScal STF64) x = [x] +tanScalars (STScal STBool) _ = [] +tanScalars STAccum{} _ = error "Accumulators not allowed in input program" + +tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] +tanEScalars SNil SNil = [] +tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs + +unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) +unzipDN STNil _ = ((), ()) +unzipDN (STPair a b) (d1, d2) = + let (x, dx) = unzipDN a d1 + (y, dy) = unzipDN b d2 + in ((x, y), (dx, dy)) +unzipDN (STEither a b) d = case d of + Left d1 -> bimap Left Left (unzipDN a d1) + Right d2 -> bimap Right Right (unzipDN b d2) +unzipDN (STLEither a b) d = case d of + Nothing -> (Nothing, Nothing) + Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) + Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) +unzipDN (STMaybe t) d = case d of + Nothing -> (Nothing, Nothing) + Just d' -> bimap Just Just (unzipDN t d') +unzipDN (STArr _ t) d = + let pairs = arrayMap (unzipDN t) d + in (arrayMap fst pairs, arrayMap snd pairs) +unzipDN (STScal ty) d = case ty of + STI32 -> (d, ()) + STI64 -> (d, ()) + STF32 -> d + STF64 -> d + STBool -> (d, ()) +unzipDN STAccum{} _ = error "Accumulators not allowed in input program" + +dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double +dotprodTan STNil _ _ = 0.0 +dotprodTan (STPair a b) (x, y) (x', y') = + dotprodTan a x x' + dotprodTan b y y' +dotprodTan (STEither a b) x y = case (x, y) of + (Left x', Left y') -> dotprodTan a x' y' + (Right x', Right y') -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible Either alternatives" +dotprodTan (STLEither a b) x y = case (x, y) of + (Nothing, _) -> 0.0 -- 0 * y = 0 + (_, Nothing) -> 0.0 -- x * 0 = 0 + (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' + (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' + _ -> error "dotprodTan: incompatible LEither alternatives" +dotprodTan (STMaybe t) x y = case (x, y) of + (Nothing, Nothing) -> 0.0 + (Just x', Just y') -> dotprodTan t x' y' + _ -> error "dotprodTan: incompatible Maybe alternatives" +dotprodTan (STArr _ t) x y = + let sh1 = arrayShape x + sh2 = arrayShape y + in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 + | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] + | otherwise -> error "dotprodTan: incompatible array shapes" +dotprodTan (STScal ty) x y = case ty of + STI32 -> 0.0 + STI64 -> 0.0 + STF32 -> realToFrac @Float @Double (x * y) + STF64 -> x * y + STBool -> 0.0 +dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" + +-- -- Primal expression must be duplicable +-- dnConstE :: STy t -> Ex env t -> Ex env (DN t) +-- dnConstE STNil _ = ENil ext +-- dnConstE (STPair t1 t2) e = +-- -- This creates fst/snd stacks of unbounded size, but let's not care here +-- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) +-- dnConstE (STEither t1 t2) e = +-- ECase ext e +-- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) +-- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) +-- dnConstE (STMaybe t) e = +-- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e +-- dnConstE (STArr n t) e = +-- EBuild ext n (EShape ext e) +-- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) +-- dnConstE (STScal t) e = case t of +-- STI32 -> e +-- STI64 -> e +-- STF32 -> EPair ext e (EConst ext STF32 0.0) +-- STF64 -> EPair ext e (EConst ext STF64 0.0) +-- STBool -> e +-- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" + +dnConst :: STy t -> Rep t -> Rep (DN t) +dnConst STNil = const () +dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) +dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) +dnConst (STMaybe t) = fmap (dnConst t) +dnConst (STArr _ t) = arrayMap (dnConst t) +dnConst (STScal t) = case t of + STI32 -> id + STI64 -> id + STF32 -> (,0.0) + STF64 -> (,0.0) + STBool -> id +dnConst STAccum{} = error "Accumulators not allowed in input program" + +-- | Given a function that computes the forward derivative for a particular +-- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this +-- @t@ input. +type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) + +dnOnehots :: STy t -> Rep t -> RevByFwd t +dnOnehots STNil _ = \_ -> () +dnOnehots (STPair t1 t2) (x, y) = + \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) +dnOnehots (STEither t1 t2) e = + case e of + Left x -> \f -> Left (dnOnehots t1 x (f . Left)) + Right y -> \f -> Right (dnOnehots t2 y (f . Right)) +dnOnehots (STLEither t1 t2) e = + case e of + Nothing -> \_ -> Nothing + Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) + Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) +dnOnehots (STMaybe t) m = + case m of + Nothing -> \_ -> Nothing + Just x -> \f -> Just (dnOnehots t x (f . Just)) +dnOnehots (STArr _ t) a = + \f -> + arrayGenerate (arrayShape a) $ \idx -> + dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> + if i == idx then oh else dnConst t (arrayIndex a i))) +dnOnehots (STScal t) x = case t of + STI32 -> \_ -> () + STI64 -> \_ -> () + STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) + STF64 -> \f -> f (x, 1.0) + STBool -> \_ -> () +dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" + +dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) +dnConstEnv SNil SNil = SNil +dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val + +type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) + +dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env +dnOnehotEnvs SNil SNil = \_ -> SNil +dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = + \f -> + Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) + `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) + +data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) + +makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t +makeFwdADArtifactInterp env expr = + let dexpr = dfwdDN expr + in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) + +{-# NOINLINE makeFwdADArtifactCompile #-} +makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) +makeFwdADArtifactCompile env expr = do + (fun, output) <- compile (dne env) (dfwdDN expr) + return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) + +drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) + +drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) +drevByFwd (FwdADArtifact env outty fun) input dres = + dnOnehotEnvs env input $ \dnInput -> + -- trace (showEnv (dne env) dnInput) $ + let (_, outtan) = unzipDN outty (fun dnInput) + in dotprodTan outty outtan dres diff --git a/src/CHAD/ForwardAD/DualNumbers.hs b/src/CHAD/ForwardAD/DualNumbers.hs new file mode 100644 index 0000000..a71efc8 --- /dev/null +++ b/src/CHAD/ForwardAD/DualNumbers.hs @@ -0,0 +1,231 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE EmptyCase #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +-- I want to bring various type variables in scope using type annotations in +-- patterns, but I don't want to have to mention all the other type parameters +-- of the types in question as well then. Partial type signatures (with '_') are +-- useful here. +{-# LANGUAGE PartialTypeSignatures #-} +{-# OPTIONS -Wno-partial-type-signatures #-} +module CHAD.ForwardAD.DualNumbers ( + dfwdDN, + DN, DNS, DNE, dn, dne, +) where + +import CHAD.AST +import CHAD.Data +import CHAD.ForwardAD.DualNumbers.Types + + +dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) +dnPreservesTupIx SZ = Refl +dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl + +convIdx :: Idx env t -> Idx (DNE env) (DN t) +convIdx IZ = IZ +convIdx (IS i) = IS (convIdx i) + +scalTyCase :: SScalTy t + -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) + -> (DN (TScal t) ~ TScal t => r) + -> r +scalTyCase STF32 k1 _ = k1 +scalTyCase STF64 k1 _ = k1 +scalTyCase STI32 _ k2 = k2 +scalTyCase STI64 _ k2 = k2 +scalTyCase STBool _ k2 = k2 + +floatingDual :: ScalIsFloating t ~ True + => SScalTy t + -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r +floatingDual STF32 k = k +floatingDual STF64 k = k + +-- | Argument does not need to be duplicable. +dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) +dop = \case + OAdd t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) + (EOp ext (OAdd t)) + OMul t -> scalTyCase t + (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) + (EOp ext (OMul t)) + ONeg t -> scalTyCase t + (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) + (EOp ext (ONeg t)) + OLt t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) + (EOp ext (OLt t)) + OLe t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) + (EOp ext (OLe t)) + OEq t -> scalTyCase t + (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) + (EOp ext (OEq t)) + ONot -> EOp ext ONot + OAnd -> EOp ext OAnd + OOr -> EOp ext OOr + OIf -> EOp ext OIf + ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) + OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) + ORecip t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (recip' t x) + (mul t (neg t (recip' t (mul t x x))) dx)) + OExp t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx)) + OLog t -> floatingDual t $ unFloat (\(x, dx) -> + EPair ext (EOp ext (OLog t) x) + (mul t (recip' t x) dx)) + OIDiv t -> scalTyCase t + (case t of {}) + (EOp ext (OIDiv t)) + OMod t -> scalTyCase t + (case t of {}) + (EOp ext (OMod t)) + where + add :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + add t a b = EOp ext (OAdd t) (EPair ext a b) + + mul :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) + mul t a b = EOp ext (OMul t) (EPair ext a b) + + neg :: ScalIsNumeric t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + neg t = EOp ext (ONeg t) + + recip' :: ScalIsFloating t ~ True + => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) + recip' t = EOp ext (ORecip t) + + unFloat :: DN a ~ TPair a a + => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + unFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext var, ESnd ext var) + + binFloat :: (a ~ TPair s s, DN s ~ TPair s s) + => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) + -> Ex env (DN a) -> Ex env (DN b) + binFloat f e = + ELet ext e $ + let var = EVar ext (typeOf e) IZ + in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) + (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) + +zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) +zeroScalarConst STI32 = EConst ext STI32 0 +zeroScalarConst STI64 = EConst ext STI64 0 +zeroScalarConst STF32 = EConst ext STF32 0.0 +zeroScalarConst STF64 = EConst ext STF64 0.0 + +dfwdDN :: Ex env t -> Ex (DNE env) (DN t) +dfwdDN = \case + EVar _ t i -> EVar ext (dn t) (convIdx i) + ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) + EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) + EFst _ e -> EFst ext (dfwdDN e) + ESnd _ e -> ESnd ext (dfwdDN e) + ENil _ -> ENil ext + EInl _ t e -> EInl ext (dn t) (dfwdDN e) + EInr _ t e -> EInr ext (dn t) (dfwdDN e) + ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ENothing _ t -> ENothing ext (dn t) + EJust _ e -> EJust ext (dfwdDN e) + EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) + ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) + ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) + ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) + ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) + EConstArr _ n t x -> scalTyCase t + (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) + (EConstArr ext n t x)) + (EConstArr ext n t x) + EBuild _ n a b + | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) + EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) + EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) + ESum1Inner _ e -> + let STArr n (STScal t) = typeOf e + pairty = (STPair (STScal t) (STScal t)) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ))) + (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) + (EVar ext (STArr n pairty) IZ)))) + (ESum1Inner ext (dfwdDN e)) + EUnit _ e -> EUnit ext (dfwdDN e) + EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) + EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e + EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e + EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) + EReshape _ n esh e + | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) + EConst _ t x -> scalTyCase t + (EPair ext (EConst ext t x) (EConst ext t 0.0)) + (EConst ext t x) + EIdx0 _ e -> EIdx0 ext (dfwdDN e) + EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) + EIdx _ a b + | STArr n _ <- typeOf a + , Refl <- dnPreservesTupIx n + -> EIdx ext (dfwdDN a) (dfwdDN b) + EShape _ e + | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) + -> EShape ext (dfwdDN e) + EOp _ op e -> dop op (dfwdDN e) + ECustom _ _ _ _ pr _ _ e1 e2 -> + ELet ext (dfwdDN e1) $ + ELet ext (weakenExpr WSink (dfwdDN e2)) $ + weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) + ERecompute _ e -> dfwdDN e + EError _ t s -> EError ext (dn t) s + + EWith{} -> err_accum + EAccum{} -> err_accum + EDeepZero{} -> err_monoid + EZero{} -> err_monoid + EPlus{} -> err_monoid + EOneHot{} -> err_monoid + + EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" + EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" + where + err_accum = error "Accumulator operations unsupported in the source program" + err_monoid = error "Monoid operations unsupported in the source program" + err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" + + deriv_extremum :: ScalIsNumeric t ~ True + => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) + -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) + deriv_extremum extremum e = + let STArr (SS n) (STScal t) = typeOf e + t2 = STPair (STScal t) (STScal t) + ta2 = STArr (SS n) t2 + tIxN = tTup (sreplicate (SS n) tIx) + in scalTyCase t + (ELet ext (dfwdDN e) $ + ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ + ezip (EVar ext (STArr n (STScal t)) IZ) + (ESum1Inner ext + {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} + (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ + ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ + ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext + (EFst ext (EVar ext t2 IZ)) + (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) + (EFst ext (EVar ext tIxN (IS IZ))))))) + (ESnd ext (EVar ext t2 (IS IZ))) + (zeroScalarConst t)))) + (extremum (dfwdDN e)) diff --git a/src/CHAD/ForwardAD/DualNumbers/Types.hs b/src/CHAD/ForwardAD/DualNumbers/Types.hs new file mode 100644 index 0000000..5d5dd9e --- /dev/null +++ b/src/CHAD/ForwardAD/DualNumbers/Types.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.ForwardAD.DualNumbers.Types where + +import CHAD.AST.Types +import CHAD.Data + + +-- | Dual-numbers transformation +type family DN t where + DN TNil = TNil + DN (TPair a b) = TPair (DN a) (DN b) + DN (TEither a b) = TEither (DN a) (DN b) + DN (TLEither a b) = TLEither (DN a) (DN b) + DN (TMaybe t) = TMaybe (DN t) + DN (TArr n t) = TArr n (DN t) + DN (TScal t) = DNS t + +type family DNS t where + DNS TF32 = TPair (TScal TF32) (TScal TF32) + DNS TF64 = TPair (TScal TF64) (TScal TF64) + DNS TI32 = TScal TI32 + DNS TI64 = TScal TI64 + DNS TBool = TScal TBool + +type family DNE env where + DNE '[] = '[] + DNE (t : ts) = DN t : DNE ts + +dn :: STy t -> STy (DN t) +dn STNil = STNil +dn (STPair a b) = STPair (dn a) (dn b) +dn (STEither a b) = STEither (dn a) (dn b) +dn (STLEither a b) = STLEither (dn a) (dn b) +dn (STMaybe t) = STMaybe (dn t) +dn (STArr n t) = STArr n (dn t) +dn (STScal t) = case t of + STF32 -> STPair (STScal STF32) (STScal STF32) + STF64 -> STPair (STScal STF64) (STScal STF64) + STI32 -> STScal STI32 + STI64 -> STScal STI64 + STBool -> STScal STBool +dn STAccum{} = error "Accum in source program" + +dne :: SList STy env -> SList STy (DNE env) +dne SNil = SNil +dne (t `SCons` env) = dn t `SCons` dne env diff --git a/src/CHAD/Interpreter.hs b/src/CHAD/Interpreter.hs new file mode 100644 index 0000000..a9421e6 --- /dev/null +++ b/src/CHAD/Interpreter.hs @@ -0,0 +1,471 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Interpreter ( + interpret, + interpretOpen, + Value(..), +) where + +import Control.Monad (foldM, join, when, forM_) +import Control.Monad.Trans.Class (lift) +import Control.Monad.Trans.State.Strict (runStateT, get, put) +import Data.Bifunctor (bimap) +import Data.Bitraversable (bitraverse) +import Data.Char (isSpace) +import Data.Functor.Identity +import qualified Data.Functor.Product as Product +import Data.Int (Int64) +import Data.IORef +import Data.Tuple (swap) +import System.IO (hPutStrLn, stderr) +import System.IO.Unsafe (unsafePerformIO) + +import Debug.Trace + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Interpreter.Rep + + +newtype AcM s a = AcM { unAcM :: IO a } + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +acmDebugLog :: String -> AcM s () +acmDebugLog s = AcM (hPutStrLn stderr s) + +data V t = V (STy t) (Rep t) + +interpret :: Ex '[] t -> Rep t +interpret = interpretOpen False SNil SNil + +-- | Bool: whether to trace execution with debug prints (very verbose) +interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t +interpretOpen prints env venv e = + runAcM $ + let ?depth = 0 + ?prints = prints + in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e + +interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) + => SList V env -> Ex env t -> AcM s (Rep t) +interpret' env e = do + let tenv = slistMap (\(V t _) -> t) env + let dep = ?depth + let lenlimit = max 20 (100 - dep) + let replace a b = map (\c -> if c == a then b else c) + let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." + | otherwise = replace '\n' ' ' s + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) + res <- let ?depth = dep + 1 in interpret'Rec env e + when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" + return res + +interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) +interpret'Rec env = \case + EVar _ _ i -> case slistIdx env i of V _ x -> return x + ELet _ a b -> do + x <- interpret' env a + let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b + expr | False && trace (" " ++ takeWhile (not . isSpace) (show expr)) False -> undefined + EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b + EFst _ e -> fst <$> interpret' env e + ESnd _ e -> snd <$> interpret' env e + ENil _ -> return () + EInl _ _ e -> Left <$> interpret' env e + EInr _ _ e -> Right <$> interpret' env e + ECase _ e a b -> + let STEither t1 t2 = typeOf e + in interpret' env e >>= \case + Left x -> interpret' (V t1 x `SCons` env) a + Right y -> interpret' (V t2 y `SCons` env) b + ENothing _ _ -> return Nothing + EJust _ e -> Just <$> interpret' env e + EMaybe _ a b e -> + let STMaybe t1 = typeOf e + in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e + ELNil _ _ _ -> return Nothing + ELInl _ _ e -> Just . Left <$> interpret' env e + ELInr _ _ e -> Just . Right <$> interpret' env e + ELCase _ e a b c -> + let STLEither t1 t2 = typeOf e + in interpret' env e >>= \case + Nothing -> interpret' env a + Just (Left x) -> interpret' (V t1 x `SCons` env) b + Just (Right y) -> interpret' (V t2 y `SCons` env) c + EConstArr _ _ _ v -> return v + EBuild _ dim a b -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a + arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) + EMap _ a b -> do + let STArr _ t = typeOf b + arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b + EFold1Inner _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + ESum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) + EReplicate1Inner _ a b -> do + n <- fromIntegral @Int64 @Int <$> interpret' env a + arr <- interpret' env b + let sh = arrayShape arr + return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) + EMaximum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EMinimum1Inner _ e -> do + arr <- interpret' env e + let STArr _ (STScal t) = typeOf e + sh `ShCons` n = arrayShape arr + numericIsNum t $ return $ + arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) + EReshape _ dim esh e -> do + sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh + arr <- interpret' env e + return $ arrayReshape sh arr + EZip _ a b -> do + arr1 <- interpret' env a + arr2 <- interpret' env b + let sh = arrayShape arr1 + when (sh /= arrayShape arr2) $ + error "Interpreter: mismatched shapes in EZip" + return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) + EFold1InnerD1 _ _ a b c -> do + let t = typeOf b + let f = \x -> interpret' (V (STPair t t) x `SCons` env) a + x0 <- interpret' env b + arr <- interpret' env c + let sh `ShCons` n = arrayShape arr + -- TODO: this is very inefficient, even for an interpreter; with mutable + -- arrays this can be a lot better with no lists + res <- arrayGenerateM sh $ \idx -> do + (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] + return (y, arrayFromList (ShNil `ShCons` n) stores) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EFold1InnerD2 _ _ ef ebog ed -> do + let STArr _ tB = typeOf ebog + STArr _ t2 = typeOf ed + let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef + bog <- interpret' env ebog + arrctg <- interpret' env ed + let sh `ShCons` n = arrayShape bog + when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" + res <- arrayGenerateM sh $ \idx -> do + let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) + loop i !ctg !inpctgs = do + let b = arrayIndex bog (idx `IxCons` i) + (ctg1, ctg2) <- f b ctg + loop (i - 1) ctg1 (ctg2 : inpctgs) + (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] + return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) + return (arrayMap fst res + ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> + arrayIndexLinear (snd (arrayIndex res idx)) i) + EConst _ _ v -> return v + EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e + EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) + EIdx _ a b -> + let STArr n _ = typeOf a + in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) + EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e + EOp _ op e -> interpretOp op <$> interpret' env e + ECustom _ t1 t2 _ pr _ _ e1 e2 -> do + e1' <- interpret' env e1 + e2' <- interpret' env e2 + interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr + ERecompute _ e -> interpret' env e + EWith _ t e1 e2 -> do + initval <- interpret' env e1 + withAccum t (typeOf e2) initval $ \accum -> + interpret' (V (STAccum t) accum `SCons` env) e2 + EAccum _ t p e1 sp e2 e3 -> do + idx <- interpret' env e1 + val <- interpret' env e2 + accum <- interpret' env e3 + accumAddSparseD t p accum idx sp val + EZero _ t ezi -> do + zi <- interpret' env ezi + return $ zeroM t zi + EDeepZero _ t ezi -> do + zi <- interpret' env ezi + return $ deepZeroM t zi + EPlus _ t a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ addM t a' b' + EOneHot _ t p a b -> do + a' <- interpret' env a + b' <- interpret' env b + return $ onehotM p t a' b' + EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s + +interpretOp :: SOp a t -> Rep a -> Rep t +interpretOp op arg = case op of + OAdd st -> numericIsNum st $ uncurry (+) arg + OMul st -> numericIsNum st $ uncurry (*) arg + ONeg st -> numericIsNum st $ negate arg + OLt st -> numericIsNum st $ uncurry (<) arg + OLe st -> numericIsNum st $ uncurry (<=) arg + OEq st -> styIsEq st $ uncurry (==) arg + ONot -> not arg + OAnd -> uncurry (&&) arg + OOr -> uncurry (||) arg + OIf -> if arg then Left () else Right () + ORound64 -> round arg + OToFl64 -> fromIntegral arg + ORecip st -> floatingIsFractional st $ recip arg + OExp st -> floatingIsFractional st $ exp arg + OLog st -> floatingIsFractional st $ log arg + OIDiv st -> integralIsIntegral st $ uncurry quot arg + OMod st -> integralIsIntegral st $ uncurry rem arg + where + styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r + styIsEq STI32 = id + styIsEq STI64 = id + styIsEq STF32 = id + styIsEq STF64 = id + styIsEq STBool = id + +zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t +zeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) + SMTLEither _ _ -> Nothing + SMTMaybe _ -> Nothing + SMTArr _ t -> arrayMap (zeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t +deepZeroM typ zi = case typ of + SMTNil -> () + SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) + SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi + SMTMaybe t -> fmap (deepZeroM t) zi + SMTArr _ t -> arrayMap (deepZeroM t) zi + SMTScal sty -> case sty of + STI32 -> 0 + STI64 -> 0 + STF32 -> 0.0 + STF64 -> 0.0 + +addM :: SMTy t -> Rep t -> Rep t -> Rep t +addM typ a b = case typ of + SMTNil -> () + SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) + SMTLEither t1 t2 -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) + (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) + _ -> error "Plus of inconsistent LEithers" + SMTMaybe t -> case (a, b) of + (Nothing, _) -> b + (_, Nothing) -> a + (Just x, Just y) -> Just (addM t x y) + SMTArr _ t -> + let sh1 = arrayShape a + sh2 = arrayShape b + in if | shapeSize sh1 == 0 -> b + | shapeSize sh2 == 0 -> a + | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) + | otherwise -> error "Plus of inconsistently shaped arrays" + SMTScal sty -> numericIsNum sty $ a + b + +onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a +onehotM SAPHere _ _ val = val +onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) +onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) +onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) +onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) +onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) +onehotM (SAPArrIdx prj) (SMTArr n a) idx val = + runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx + +withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) +withAccum t _ initval f = AcM $ do + accum <- newAcDense t initval + out <- unAcM $ f accum + val <- readAc t accum + return (out, val) + +newAcDense :: SMTy a -> Rep a -> IO (RepAc a) +newAcDense typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val + SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val + SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val + SMTArr _ t1 -> arrayMapM (newAcDense t1) val + SMTScal _ -> newIORef val + +onehotArray :: Monad m + => (Rep (AcIdxS p a) -> m v) -- ^ the "one" + -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" + -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) +onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = + let arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ziarr + !linindex = toLinearIndex arrsh arrindex + in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) + +readAc :: SMTy t -> RepAc t -> IO (Rep t) +readAc typ val = case typ of + SMTNil -> return () + SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val + SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val + SMTMaybe t -> traverse (readAc t) =<< readIORef val + SMTArr _ t -> traverse (readAc t) val + SMTScal _ -> readIORef val + +accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () +accumAddSparseD typ prj ref idx sp val = case (typ, prj) of + (_, SAPHere) -> accumAddDense typ ref sp val + + (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val + (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val + + (SMTLEither t1 _, SAPLeft prj') -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val + Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") + (SMTLEither _ t2, SAPRight prj') -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val + Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") + + (SMTMaybe t1, SAPJust prj') -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") + (\ac -> accumAddSparseD t1 prj' ac idx sp val) + + (SMTArr n t1, SAPArrIdx prj') -> + let (arrindex', idx') = idx + arrindex = unTupRepIdx IxNil IxCons n arrindex' + arrsh = arrayShape ref + linindex = toLinearIndex arrsh arrindex + in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val + +accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () +accumAddDense typ ref sp val = case (typ, sp) of + (_, _) | isAbsent sp -> return () + (_, SpAbsent) -> return () + (_, SpSparse s) -> + case val of + Nothing -> return () + Just val' -> accumAddDense typ ref s val' + (SMTPair t1 t2, SpPair s1 s2) -> do + accumAddDense t1 (fst ref) s1 (fst val) + accumAddDense t2 (snd ref) s2 (snd val) + (SMTLEither t1 t2, SpLEither s1 s2) -> + case val of + Nothing -> return () + Just (Left val1) -> + realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") + (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 + Right{} -> error "Mismatched Either in accumAddSparse (r +l)") + Just (Right val2) -> + realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") + (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 + Left{} -> error "Mismatched Either in accumAddSparse (l +r)") + (SMTMaybe t, SpMaybe s) -> + case val of + Nothing -> return () + Just val' -> + realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") + (\ac -> accumAddDense t ac s val') + (SMTArr _ t1, SpArr s) -> + forM_ [0 .. arraySize ref - 1] $ \i -> + accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) + (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) + +-- TODO: makeval is always 'error' now. Simplify? +realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () +realiseMaybeSparse ref makeval modifyval = + -- Try modifying what's already in ref. The 'join' makes the snd + -- of the function's return value a _continuation_ that is run after + -- the critical section ends. + AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of + -- Oops, ref's contents was still sparse. Have to initialise + -- it first, then try again. + Nothing -> (ac, do val <- makeval + join $ atomicModifyIORef' ref $ \ac' -> case ac' of + Nothing -> (Just val, return ()) + Just val' -> (ac', unAcM $ modifyval val')) + -- Yep, ref already had a value in there, so we can just add + -- val' to it recursively. + Just val -> (ac, unAcM $ modifyval val) + + +numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r +numericIsNum STI32 = id +numericIsNum STI64 = id +numericIsNum STF32 = id +numericIsNum STF64 = id + +floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r +floatingIsFractional STF32 = id +floatingIsFractional STF64 = id + +integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r +integralIsIntegral STI32 = id +integralIsIntegral STI64 = id + +unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) + -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n +unTupRepIdx nil _ SZ _ = nil +unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i + +tupRepIdx :: (forall m. f (S m) -> (f m, Int)) + -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) +tupRepIdx _ SZ _ = () +tupRepIdx uncons (SS n) tup = + let (tup', i) = uncons tup + in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i + +ixUncons :: Index (S n) -> (Index n, Int) +ixUncons (IxCons idx i) = (idx, i) + +shUncons :: Shape (S n) -> (Shape n, Int) +shUncons (ShCons idx i) = (idx, i) + +mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) +mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' + where f' x = do + s <- get + (s', y) <- lift $ f s x + put s' + return y diff --git a/src/CHAD/Interpreter/Accum.hs b/src/CHAD/Interpreter/Accum.hs new file mode 100644 index 0000000..8e5c040 --- /dev/null +++ b/src/CHAD/Interpreter/Accum.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module CHAD.Interpreter.Accum ( + AcM, + runAcM, + Rep', + Accum, + withAccum, + accumAdd, + inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import CHAD.Array +import CHAD.AST +import CHAD.Data + + +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TMaybe t) = Maybe (Rep' s t) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of + STNil -> 0 + STPair a b -> tSize' p a + tSize' p b + STEither a b -> 1 + max (tSize' p a) (tSize' p b) + -- Representation of Maybe t is the same as Either () t; the add operation is different, however. + STMaybe t -> tSize' p (STEither STNil t) + STArr ndim t -> + case val of + Nothing -> error "Nested arrays not supported in this implementation" + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing + STScal sty -> goScal sty + STAccum{} -> error "Nested accumulators unsupported" + where + goScal :: SScalTy t -> Int + goScal STI32 = 4 + goScal STI64 = 8 + goScal STF32 = 4 + goScal STF64 = 8 + goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + go inarr ty val off = case ty of + STNil -> return off + STPair a b -> do + off1 <- go inarr a (fst val) off + go inarr b (snd val) off1 + STEither a b -> do + let !(I# off#) = off + off1 <- case val of + Left x -> do + let !(I8# tag#) = 0 + writeInt8# addr# off# tag# + go inarr a x (off + 1) + Right y -> do + let !(I8# tag#) = 1 + writeInt8# addr# off# tag# + go inarr b y (off + 1) + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) + else return off1 + -- Representation is the same, but add operation is different + STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off + STArr _ t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + off1 <- goShape (arrayShape val) off + let eltsize = tSize' (Proxy @s) t Nothing + n = arraySize val + traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val + return (off1 + eltsize * n) + STScal sty -> goScal sty val off + STAccum{} -> error "Nested accumulators unsupported" + + goShape :: Shape n -> Int -> IO Int + goShape ShNil off = return off + goShape (ShCons sh n) off = do + off1@(I# off1#) <- goShape sh off + let !(I64# n'#) = fromIntegral n + writeInt64# addr# off1# n'# + return (off1 + 8) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x + goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x + goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x + goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x + goScal STBool b off@(I# off#) = do + let !(I8# i) = fromIntegral (fromEnum b) + off + 1 <$ writeInt8# addr# off# i + + in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') + go inarr ty off = case ty of + STNil -> return (off, ()) + STPair a b -> do + (off1, x) <- go inarr a off + (off2, y) <- go inarr b off1 + return (off1 + off2, (x, y)) + STEither a b -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + (off1, val) <- case tag of + 0 -> fmap Left <$> go inarr a (off + 1) + 1 -> fmap Right <$> go inarr b (off + 1) + _ -> error "Invalid tag in accum memory" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) + else return (off1, val) + -- Representation is the same, but add operation is different + STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off + STArr ndim t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + (off1, sh) <- readShape addr# ndim off + let eltsize = tSize' (Proxy @s) t Nothing + n = shapeSize sh + arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) + return (off1 + eltsize * n, arr) + STScal sty -> goScal sty off + STAccum{} -> error "Nested accumulators unsupported" + + goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') + goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# + goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# + goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# + goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# + goScal STBool off@(I# off#) = do + i8 <- readInt8 addr# off# + return (off + 1, toEnum (fromIntegral i8)) + + in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do + (off1@(I# off1#), sh) <- readShape mbarr ndim off + n' <- readInt64 mbarr off1# + return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where + IShNil :: InvShape Z + IShCons :: Int -- ^ How many subarrays are there? + -> Int -- ^ What is the size of all subarrays together? + -> InvShape n -- ^ Sub array inverted shape + -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil + where + go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) + go ShNil ish = ish + go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () + go inarr ty SZ () val off = () <$ performAdd inarr ty val off + go inarr ty (SS dep) idx val off = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" + (STMaybe t, _, _) -> _ idx val + (STArr rank eltty, _, _) + | inarr -> error "Nested arrays" + | otherwise -> do + (off1, ish) <- second invertShape <$> readShape addr# rank off + goArr (SS dep) ish eltty idx val off1 + (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" + (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" + (STAccum{}, _, _) -> error "Nested accumulators unsupported" + + goArr :: SNat i' -> InvShape n -> STy t' + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () + goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off + goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off + goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do + let i' = fromIntegral @(Rep' s TIx) @Int i + when (i' < 0 || i' >= n) $ + error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" + goArr depm1 ish t1 idx val (off + i' * ishSize ish) + + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int + performAddArr arraySz eltty val off = do + let eltsize = tSize' (Proxy @s) eltty Nothing + forM_ [0 .. arraySz - 1] $ \lini -> + performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) + return (off + arraySz * eltsize) + + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + performAdd inarr ty val off = case ty of + STNil -> return off + STPair t1 t2 -> do + off1 <- performAdd inarr t1 (fst val) off + performAdd inarr t2 (snd val) off1 + STEither t1 t2 -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + off1 <- case (val, tag) of + (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) + (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) + _ -> error "accumAdd: Tag mismatch for Either" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) + else return off1 + STArr n ty' + | inarr -> error "Nested array" + | otherwise -> do + (off1, sh) <- readShape addr# n off + performAddArr (shapeSize sh) ty' val off1 + STScal ty' -> performAddScal ty' val off + STAccum{} -> error "Nested accumulators unsupported" + + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + performAddScal STI32 (I32# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 4 + = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) + | otherwise + = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) + performAddScal STI64 (I64# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 8 + = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) + | otherwise + = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) + performAddScal STF32 x off@(I# off#) = + off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) + performAddScal STF64 x off@(I# off#) = + off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) + performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans + + casLoop :: Eq w + => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) + -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) + -> Addr# -- ^ Address to attempt to modify + -> (w -> w) -- ^ Operation to apply to the value + -> IO () + casLoop readOp casOp addr modify = readOp addr 0# >>= loop + where + loop value = do + value' <- casOp addr value (modify value) + if value == value' + then return () + else loop value' + + in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do + -- The initial write must happen before any of the adds or reads, so it makes + -- sense to put it in IO together with the allocation, instead of in AcM. + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) + ptr <- newForeignPtr finalizerFree buffer + let accum = Accum ty ptr + accumWrite accum start + return accum + b <- fun accum + out <- accumRead accum + return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do + mvars <- mapM (\_ -> newEmptyMVar) actions + forM_ (zip actions mvars) $ \(AcM action, var) -> + forkIO $ action >>= putMVar var + mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8 :: Addr# -> Int# -> IO Int8 +readInt32 :: Addr# -> Int# -> IO Int32 +readInt64 :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) +readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) + +writeInt8# :: Addr# -> Int# -> Int8# -> IO () +writeInt32# :: Addr# -> Int# -> Int32# -> IO () +writeInt64# :: Addr# -> Int# -> Int64# -> IO () +writeFloat# :: Addr# -> Int# -> Float# -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = + IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = + IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/CHAD/Interpreter/AccumOld.hs b/src/CHAD/Interpreter/AccumOld.hs new file mode 100644 index 0000000..8e5c040 --- /dev/null +++ b/src/CHAD/Interpreter/AccumOld.hs @@ -0,0 +1,366 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UnboxedTuples #-} +module CHAD.Interpreter.Accum ( + AcM, + runAcM, + Rep', + Accum, + withAccum, + accumAdd, + inParallel, +) where + +import Control.Concurrent +import Control.Monad (when, forM_) +import Data.Bifunctor (second) +import Data.Proxy +import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) +import Foreign.Storable (sizeOf) +import GHC.Exts +import GHC.Float +import GHC.Int +import GHC.IO (IO(..)) +import GHC.Word +import System.IO.Unsafe (unsafePerformIO) + +import CHAD.Array +import CHAD.AST +import CHAD.Data + + +newtype AcM s a = AcM (IO a) + deriving newtype (Functor, Applicative, Monad) + +runAcM :: (forall s. AcM s a) -> a +runAcM (AcM m) = unsafePerformIO m + +-- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. +type family Rep' s t where + Rep' s TNil = () + Rep' s (TPair a b) = (Rep' s a, Rep' s b) + Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) + Rep' s (TMaybe t) = Maybe (Rep' s t) + Rep' s (TArr n t) = Array n (Rep' s t) + Rep' s (TScal sty) = ScalRep sty + Rep' s (TAccum t) = Accum s t + +-- | Floats and integers are accumulated; booleans are left as-is. +data Accum s t = Accum (STy t) (ForeignPtr ()) + +tSize :: Proxy s -> STy t -> Rep' s t -> Int +tSize p ty x = tSize' p ty (Just x) + +tSize' :: Proxy s -> STy t -> Int +tSize' p typ = case typ of + STNil -> 0 + STPair a b -> tSize' p a + tSize' p b + STEither a b -> 1 + max (tSize' p a) (tSize' p b) + -- Representation of Maybe t is the same as Either () t; the add operation is different, however. + STMaybe t -> tSize' p (STEither STNil t) + STArr ndim t -> + case val of + Nothing -> error "Nested arrays not supported in this implementation" + Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing + STScal sty -> goScal sty + STAccum{} -> error "Nested accumulators unsupported" + where + goScal :: SScalTy t -> Int + goScal STI32 = 4 + goScal STI64 = 8 + goScal STF32 = 4 + goScal STF64 = 8 + goScal STBool = 1 + +-- | This operation does not commute with 'accumAdd', so it must be used with +-- care. Furthermore it must be used on exactly the same value as tSize was +-- called on. Hence it lives in IO, not in AcM. +accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () +accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + go inarr ty val off = case ty of + STNil -> return off + STPair a b -> do + off1 <- go inarr a (fst val) off + go inarr b (snd val) off1 + STEither a b -> do + let !(I# off#) = off + off1 <- case val of + Left x -> do + let !(I8# tag#) = 0 + writeInt8# addr# off# tag# + go inarr a x (off + 1) + Right y -> do + let !(I8# tag#) = 1 + writeInt8# addr# off# tag# + go inarr b y (off + 1) + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) + else return off1 + -- Representation is the same, but add operation is different + STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off + STArr _ t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + off1 <- goShape (arrayShape val) off + let eltsize = tSize' (Proxy @s) t Nothing + n = arraySize val + traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val + return (off1 + eltsize * n) + STScal sty -> goScal sty val off + STAccum{} -> error "Nested accumulators unsupported" + + goShape :: Shape n -> Int -> IO Int + goShape ShNil off = return off + goShape (ShCons sh n) off = do + off1@(I# off1#) <- goShape sh off + let !(I64# n'#) = fromIntegral n + writeInt64# addr# off1# n'# + return (off1 + 8) + + goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x + goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x + goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x + goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x + goScal STBool b off@(I# off#) = do + let !(I8# i) = fromIntegral (fromEnum b) + off + 1 <$ writeInt8# addr# off# i + + in () <$ go False topty top_value 0 + +accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) +accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') + go inarr ty off = case ty of + STNil -> return (off, ()) + STPair a b -> do + (off1, x) <- go inarr a off + (off2, y) <- go inarr b off1 + return (off1 + off2, (x, y)) + STEither a b -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + (off1, val) <- case tag of + 0 -> fmap Left <$> go inarr a (off + 1) + 1 -> fmap Right <$> go inarr b (off + 1) + _ -> error "Invalid tag in accum memory" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) + else return (off1, val) + -- Representation is the same, but add operation is different + STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off + STArr ndim t + | inarr -> error "Nested arrays not supported in this implementation" + | otherwise -> do + (off1, sh) <- readShape addr# ndim off + let eltsize = tSize' (Proxy @s) t Nothing + n = shapeSize sh + arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) + return (off1 + eltsize * n, arr) + STScal sty -> goScal sty off + STAccum{} -> error "Nested accumulators unsupported" + + goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') + goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# + goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# + goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# + goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# + goScal STBool off@(I# off#) = do + i8 <- readInt8 addr# off# + return (off + 1, toEnum (fromIntegral i8)) + + in snd <$> go False topty 0 + +readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) +readShape _ SZ off = return (off, ShNil) +readShape mbarr (SS ndim) off = do + (off1@(I# off1#), sh) <- readShape mbarr ndim off + n' <- readInt64 mbarr off1# + return (off1 + 8, ShCons sh (fromIntegral n')) + +-- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of +-- the list. +data InvShape n where + IShNil :: InvShape Z + IShCons :: Int -- ^ How many subarrays are there? + -> Int -- ^ What is the size of all subarrays together? + -> InvShape n -- ^ Sub array inverted shape + -> InvShape (S n) + +ishSize :: InvShape n -> Int +ishSize IShNil = 1 +ishSize (IShCons _ sz _) = sz + +invertShape :: forall n. Shape n -> InvShape n +invertShape | Refl <- lemPlusZero @n = flip go IShNil + where + go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) + go ShNil ish = ish + go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) + +accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () +accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> + let + go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () + go inarr ty SZ () val off = () <$ performAdd inarr ty val off + go inarr ty (SS dep) idx val off = case (ty, idx, val) of + (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" + (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off + (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off + (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" + (STMaybe t, _, _) -> _ idx val + (STArr rank eltty, _, _) + | inarr -> error "Nested arrays" + | otherwise -> do + (off1, ish) <- second invertShape <$> readShape addr# rank off + goArr (SS dep) ish eltty idx val off1 + (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" + (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" + (STAccum{}, _, _) -> error "Nested accumulators unsupported" + + goArr :: SNat i' -> InvShape n -> STy t' + -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () + goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off + goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off + goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do + let i' = fromIntegral @(Rep' s TIx) @Int i + when (i' < 0 || i' >= n) $ + error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" + goArr depm1 ish t1 idx val (off + i' * ishSize ish) + + performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int + performAddArr arraySz eltty val off = do + let eltsize = tSize' (Proxy @s) eltty Nothing + forM_ [0 .. arraySz - 1] $ \lini -> + performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) + return (off + arraySz * eltsize) + + performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int + performAdd inarr ty val off = case ty of + STNil -> return off + STPair t1 t2 -> do + off1 <- performAdd inarr t1 (fst val) off + performAdd inarr t2 (snd val) off1 + STEither t1 t2 -> do + let !(I# off#) = off + tag <- readInt8 addr# off# + off1 <- case (val, tag) of + (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) + (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) + _ -> error "accumAdd: Tag mismatch for Either" + if inarr + then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) + else return off1 + STArr n ty' + | inarr -> error "Nested array" + | otherwise -> do + (off1, sh) <- readShape addr# n off + performAddArr (shapeSize sh) ty' val off1 + STScal ty' -> performAddScal ty' val off + STAccum{} -> error "Nested accumulators unsupported" + + performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int + performAddScal STI32 (I32# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 4 + = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) + | otherwise + = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) + performAddScal STI64 (I64# x#) off@(I# off#) + | sizeOf (undefined :: Int) == 8 + = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) + | otherwise + = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) + performAddScal STF32 x off@(I# off#) = + off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) + performAddScal STF64 x off@(I# off#) = + off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) + performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans + + casLoop :: Eq w + => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) + -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) + -> Addr# -- ^ Address to attempt to modify + -> (w -> w) -- ^ Operation to apply to the value + -> IO () + casLoop readOp casOp addr modify = readOp addr 0# >>= loop + where + loop value = do + value' <- casOp addr value (modify value) + if value == value' + then return () + else loop value' + + in () <$ go False topty top_depth top_index top_value 0 + +withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) +withAccum ty start fun = do + -- The initial write must happen before any of the adds or reads, so it makes + -- sense to put it in IO together with the allocation, instead of in AcM. + accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) + ptr <- newForeignPtr finalizerFree buffer + let accum = Accum ty ptr + accumWrite accum start + return accum + b <- fun accum + out <- accumRead accum + return (b, out) + +inParallel :: [AcM s t] -> AcM s [t] +inParallel actions = AcM $ do + mvars <- mapM (\_ -> newEmptyMVar) actions + forM_ (zip actions mvars) $ \(AcM action, var) -> + forkIO $ action >>= putMVar var + mapM takeMVar mvars + +-- | Offset is in bytes. +readInt8 :: Addr# -> Int# -> IO Int8 +readInt32 :: Addr# -> Int# -> IO Int32 +readInt64 :: Addr# -> Int# -> IO Int64 +readWord32 :: Addr# -> Int# -> IO Word32 +readWord64 :: Addr# -> Int# -> IO Word64 +readFloat :: Addr# -> Int# -> IO Float +readDouble :: Addr# -> Int# -> IO Double +readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) +readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) +readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) +readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) +readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) +readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) +readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) + +writeInt8# :: Addr# -> Int# -> Int8# -> IO () +writeInt32# :: Addr# -> Int# -> Int32# -> IO () +writeInt64# :: Addr# -> Int# -> Int64# -> IO () +writeFloat# :: Addr# -> Int# -> Float# -> IO () +writeDouble# :: Addr# -> Int# -> Double# -> IO () +writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) +writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) + +fetchAddWord# :: Addr# -> Int# -> Word# -> IO () +fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) + +atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 +atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 +atomicCasWord32Addr addr (W32# expected) (W32# desired) = + IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) +atomicCasWord64Addr addr (W64# expected) (W64# desired) = + IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/CHAD/Interpreter/Rep.hs b/src/CHAD/Interpreter/Rep.hs new file mode 100644 index 0000000..fadc6be --- /dev/null +++ b/src/CHAD/Interpreter/Rep.hs @@ -0,0 +1,105 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.Interpreter.Rep where + +import Control.DeepSeq +import Data.Coerce (coerce) +import Data.List (intersperse, intercalate) +import Data.Foldable (toList) +import Data.IORef +import GHC.Exts (withDict) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Pretty +import CHAD.Data + + +type family Rep t where + Rep TNil = () + Rep (TPair a b) = (Rep a, Rep b) + Rep (TEither a b) = Either (Rep a) (Rep b) + Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) + Rep (TMaybe t) = Maybe (Rep t) + Rep (TArr n t) = Array n (Rep t) + Rep (TScal sty) = ScalRep sty + Rep (TAccum t) = RepAc t + +-- Mutable, represents monoid types t. +type family RepAc t where + RepAc TNil = () + RepAc (TPair a b) = (RepAc a, RepAc b) + RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) + RepAc (TMaybe t) = IORef (Maybe (RepAc t)) + RepAc (TArr n t) = Array n (RepAc t) + RepAc (TScal sty) = IORef (ScalRep sty) + +newtype Value t = Value { unValue :: Rep t } + +liftV :: (Rep a -> Rep b) -> Value a -> Value b +liftV f (Value x) = Value (f x) + +liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c +liftV2 f (Value x) (Value y) = Value (f x y) + +vPair :: Value a -> Value b -> Value (TPair a b) +vPair = liftV2 (,) + +vUnpair :: Value (TPair a b) -> (Value a, Value b) +vUnpair (Value (x, y)) = (Value x, Value y) + +showValue :: Int -> STy t -> Rep t -> ShowS +showValue _ STNil () = showString "()" +showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" +showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x +showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y +showValue _ (STLEither _ _) Nothing = showString "LNil" +showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x +showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y +showValue _ (STMaybe _) Nothing = showString "Nothing" +showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x +showValue d (STArr _ t) arr = showParen (d > 10) $ + showString "arrayFromList " . showsPrec 11 (arrayShape arr) + . showString " [" + . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) + . showString "]" +showValue d (STScal sty) x = case sty of + STF32 -> showsPrec d x + STF64 -> showsPrec d x + STI32 -> showsPrec d x + STI64 -> showsPrec d x + STBool -> showsPrec d x +showValue _ (STAccum t) _ = showString $ "" + +showEnv :: SList STy env -> SList Value env -> String +showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" + where + showEntries :: SList STy env -> SList Value env -> [String] + showEntries SNil SNil = [] + showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs + +rnfRep :: STy t -> Rep t -> () +rnfRep STNil () = () +rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y +rnfRep (STEither a _) (Left x) = rnfRep a x +rnfRep (STEither _ b) (Right y) = rnfRep b y +rnfRep (STLEither _ _) Nothing = () +rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x +rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y +rnfRep (STMaybe _) Nothing = () +rnfRep (STMaybe t) (Just x) = rnfRep t x +rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = + withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) +rnfRep (STScal t) x = case t of + STI32 -> rnf x + STI64 -> rnf x + STF32 -> rnf x + STF64 -> rnf x + STBool -> rnf x +rnfRep STAccum{} _ = error "Cannot rnf accumulators" + +instance KnownTy t => NFData (Value t) where + rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/CHAD/Language.hs b/src/CHAD/Language.hs new file mode 100644 index 0000000..6dc91a5 --- /dev/null +++ b/src/CHAD/Language.hs @@ -0,0 +1,266 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitForAll #-} +{-# LANGUAGE OverloadedLabels #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE TypeApplications #-} +module CHAD.Language ( + fromNamed, + NExpr, + Ex, + module CHAD.Language, + module CHAD.AST.Types, + Lookup, +) where + +import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.AST.Types +import CHAD.Data +import CHAD.Drev.Types +import CHAD.Language.AST + + +data a :-> b = a :-> b + deriving (Show) +infixr 0 :-> + + +body :: NExpr env t -> NFun env env t +body = NBody + +lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t +lambda = NLam + +inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t +inline = inlineNFun + +-- To be used to construct the argument list for 'inline'. +-- +-- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y +-- > in inline fun (SNil .$ 16 .$ 26) +(.$) :: SList f list -> f a -> SList f (a : list) +(.$) = flip SCons + + +let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t +let_ = NELet + +pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) +pair = NEPair + +fst_ :: NExpr env (TPair a b) -> NExpr env a +fst_ = NEFst + +snd_ :: NExpr env (TPair a b) -> NExpr env b +snd_ = NESnd + +nil :: NExpr env TNil +nil = NENil + +inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) +inl = NEInl knownTy + +inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) +inr = NEInr knownTy + +case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c +case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 + +nothing :: KnownTy a => NExpr env (TMaybe a) +nothing = NENothing knownTy + +just :: NExpr env a -> NExpr env (TMaybe a) +just = NEJust + +maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b +maybe_ a (v :-> b) c = NEMaybe a v b c + +constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) +constArr_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConstArr knownNat ty x + +build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) +build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) + +build2 :: NExpr env TIx -> NExpr env TIx + -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) + -> NExpr env (TArr (S (S Z)) t) +build2 a1 a2 (v1 :-> v2 :-> b) = + NEBuild (SS (SS SZ)) + (pair (pair nil a1) a2) + #idx + (let_ v1 (snd_ (fst_ #idx)) $ + let_ v2 (NEDrop SZ (snd_ #idx)) $ + NEDrop (SS (SS SZ)) b) + +build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) +build n a (v :-> b) = NEBuild n a v b + +map_ :: forall n a b env name. (KnownNat n, KnownTy a) + => (Var name a :-> NExpr ('(name, a) : env) b) + -> NExpr env (TArr n a) -> NExpr env (TArr n b) +map_ (v :-> a) b = NEMap v a b + +fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t t) + in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) +fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 + +sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +sum1i e = NESum1Inner e + +unit :: NExpr env t -> NExpr env (TArr Z t) +unit = NEUnit + +replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) +replicate1i n a = NEReplicate1Inner n a + +maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +maximum1i e = NEMaximum1Inner e + +minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) +minimum1i e = NEMinimum1Inner e + +reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) +reshape = NEReshape + +fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = + withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> + assertSymbolNotUnderscore s3 $ + equalityReflexive s3 $ + assertSymbolDistinct s3 s1 $ + let v3 = Var s3 (STPair t1 t1) + in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ + let_ v2 (snd_ (NEVar v3)) $ + NEDrop (SS (SS SZ)) e1) + e2 e3 + +fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) + -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) +fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 + +fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) + -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) +fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 + +const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) +const_ x = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty x + +idx0 :: NExpr env (TArr Z t) -> NExpr env t +idx0 = NEIdx0 + +-- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) +-- (.!) = NEIdx1 +-- infixl 9 .! + +(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t +(!) = NEIdx +infixl 9 ! + +shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) +shape = NEShape + +length_ :: NExpr env (TArr N1 t) -> NExpr env TIx +length_ e = snd_ (shape e) + +oper :: SOp a t -> NExpr env a -> NExpr env t +oper = NEOp + +oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t +oper2 op a b = NEOp op (pair a b) + +error_ :: KnownTy t => String -> NExpr env t +error_ s = NEError knownTy s + +custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) + -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) + -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) + -> NExpr env a -> NExpr env b + -> NExpr env t +custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = + NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 + +recompute :: NExpr env a -> NExpr env a +recompute = NERecompute + +with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) +with a (n :-> b) = NEWith (knownMTy @t) a n b + +accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil +accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c + +accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil +accumS p a sp b c = NEAccum knownMTy p a sp b c + + +(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .== b = oper (OEq knownScalTy) (pair a b) +infix 4 .== + +(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .< b = oper (OLt knownScalTy) (pair a b) +infix 4 .< + +(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>) = flip (.<) +infix 4 .> + +(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +a .<= b = oper (OLe knownScalTy) (pair a b) +infix 4 .<= + +(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) +(.>=) = flip (.<=) +infix 4 .>= + +not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) +not_ = oper ONot + +and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +and_ = oper2 OAnd +infixr 3 `and_` + +or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) +or_ = oper2 OOr +infixr 2 `or_` + +mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) +mod_ = oper2 (OMod knownScalTy) +infixl 7 `mod_` + +-- | The first alternative is the True case; the second is the False case. +if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t +if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) + +round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) +round_ = oper ORound64 + +toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) +toFloat_ = oper OToFl64 + +idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) +idiv = oper2 (OIDiv knownScalTy) +infixl 7 `idiv` diff --git a/src/CHAD/Language/AST.hs b/src/CHAD/Language/AST.hs new file mode 100644 index 0000000..b270844 --- /dev/null +++ b/src/CHAD/Language/AST.hs @@ -0,0 +1,300 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} +module CHAD.Language.AST where + +import Data.Kind (Type) +import Data.Type.Equality +import GHC.OverloadedLabels +import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) + +import CHAD.Array +import CHAD.AST +import CHAD.AST.Sparse.Types +import CHAD.Data +import CHAD.Drev.Types + + +type NExpr :: [(Symbol, Ty)] -> Ty -> Type +data NExpr env t where + -- lambda calculus + NEVar :: Lookup name env ~ t => Var name t -> NExpr env t + NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t + + -- environment management + NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t + + -- base types + NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) + NEFst :: NExpr env (TPair a b) -> NExpr env a + NESnd :: NExpr env (TPair a b) -> NExpr env b + NENil :: NExpr env TNil + NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) + NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) + NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c + NENothing :: STy t -> NExpr env (TMaybe t) + NEJust :: NExpr env t -> NExpr env (TMaybe t) + NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b + + -- array operations + NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) + NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) + NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) + NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) + NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEUnit :: NExpr env t -> NExpr env (TArr Z t) + NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) + NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) + NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) + + NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) + -> NExpr env t1 + -> NExpr env (TArr (S n) t1) + -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) + NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) + -> NExpr env (TArr (S n) b) + -> NExpr env (TArr n t2) + -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) + + -- expression operations + NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) + NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t + NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) + NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t + NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) + NEOp :: SOp a t -> NExpr env a -> NExpr env t + + -- custom derivatives + NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation + -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass + -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative + -> NExpr env a -> NExpr env b + -> NExpr env t + + -- fake halfway checkpointing + NERecompute :: NExpr env t -> NExpr env t + + -- accumulation effect on monoids + NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) + NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil + + -- partiality + NEError :: STy a -> String -> NExpr env a + + -- embedded unnamed expressions + NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t +deriving instance Show (NExpr env t) + +type Lookup name env = Lookup1 (name == "_") name env +type family Lookup1 eqblank name env where + Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") + Lookup1 False name env = Lookup2 name env +type family Lookup2 name env where + Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") + Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env +type family Lookup3 eq t name env where + Lookup3 True t _ _ = t + Lookup3 False _ name env = Lookup2 name env + +type family DropNth i env where + DropNth Z (_ : env) = env + DropNth (S i) (p : env) = p : DropNth i env + +data Var name t = Var (SSymbol name) (STy t) + deriving (Show) + +instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where + a + b = NEOp (OAdd knownScalTy) (NEPair a b) + a * b = NEOp (OMul knownScalTy) (NEPair a b) + negate e = NEOp (ONeg knownScalTy) e + abs = error "abs undefined for NExpr" + signum = error "signum undefined for NExpr" + fromInteger = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty . fromInteger + +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) + => Fractional (NExpr env t) where + recip e = NEOp (ORecip knownScalTy) e + fromRational = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty . fromRational + +instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) + => Floating (NExpr env t) where + pi = + let ty = knownScalTy + in case scalRepIsShow ty of + Dict -> NEConst ty pi + exp = NEOp (OExp knownScalTy) + log = NEOp (OExp knownScalTy) + sin = undefined ; cos = undefined ; tan = undefined + asin = undefined ; acos = undefined ; atan = undefined + sinh = undefined ; cosh = undefined + asinh = undefined ; acosh = undefined ; atanh = undefined + +instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where + fromLabel = Var symbolSing knownTy + +instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where + fromLabel = NEVar (fromLabel @name) + +-- | Innermost variable variable on the outside, on the right. +data NEnv env where + NTop :: NEnv '[] + NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) + +-- | First (outermost) parameter on the outside, on the left. +-- * env: environment of this function (grows as you go deeper inside lambdas) +-- * env': environment of the body of the function +-- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list +data NFun env env' t where + NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t + NBody :: NExpr env' t -> NFun env' env' t + +type family UnName env where + UnName '[] = '[] + UnName ('(name, t) : env) = t : UnName env + +envFromNEnv :: NEnv env -> SList STy (UnName env) +envFromNEnv NTop = SNil +envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env + +inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t +inlineNFun fun args = NEUnnamed (fromNamed fun) args + +fromNamed :: NFun '[] env t -> Ex (UnName env) t +fromNamed = fromNamedFun NTop + +-- | Some of the parameters have already been put in the environment; some +-- haven't. Transfer all parameters to the left into the environment. +-- +-- [] `fromNamedFun` λx y z. E +-- = []:x `fromNamedFun` λy z. E +-- = []:x:y `fromNamedFun` λz. E +-- = []:x:y:z `fromNamedFun` λ. E +-- = []:x:y:z `fromNamedExpr` E +fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t +fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun +fromNamedFun env (NBody e) = fromNamedExpr env e + +fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t +fromNamedExpr val = \case + NEVar var@(Var _ ty) + | Just idx <- find var val -> EVar ext ty idx + | otherwise -> error "Variable out of scope in conversion from surface \ + \expression to De Bruijn expression" + NELet n a b -> ELet ext (go a) (lambda val n b) + + NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) + + NEPair a b -> EPair ext (go a) (go b) + NEFst e -> EFst ext (go e) + NESnd e -> ESnd ext (go e) + NENil -> ENil ext + NEInl t e -> EInl ext t (go e) + NEInr t e -> EInr ext t (go e) + NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) + NENothing t -> ENothing ext t + NEJust e -> EJust ext (go e) + NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) + + NEConstArr n t x -> EConstArr ext n t x + NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) + NEMap n a b -> EMap ext (lambda val n a) (go b) + NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) + NESum1Inner e -> ESum1Inner ext (go e) + NEUnit e -> EUnit ext (go e) + NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) + NEMaximum1Inner e -> EMaximum1Inner ext (go e) + NEMinimum1Inner e -> EMinimum1Inner ext (go e) + NEReshape n a b -> EReshape ext n (go a) (go b) + + NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) + NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) + + NEConst t x -> EConst ext t x + NEIdx0 e -> EIdx0 ext (go e) + NEIdx1 a b -> EIdx1 ext (go a) (go b) + NEIdx a b -> EIdx ext (go a) (go b) + NEShape e -> EShape ext (go e) + NEOp op e -> EOp ext op (go e) + + NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> + ECustom ext ta tb ttape + (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) + (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) + (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) + (go e1) (go e2) + NERecompute e -> ERecompute ext (go e) + + NEWith t a n b -> EWith ext t (go a) (lambda val n b) + NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) + + NEError t s -> EError ext t s + + NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args + where + go :: NExpr env t' -> Ex (UnName env) t' + go = fromNamedExpr val + + find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') + find _ NTop = Nothing + find var@(Var s ty) (val' `NPush` Var s' ty') + | Just Refl <- testEquality s s' + , Just Refl <- testEquality ty ty' + = Just IZ + | otherwise + = IS <$> find var val' + + lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b + lambda val' var e = fromNamedExpr (val' `NPush` var) e + + lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c + lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e + + injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t + injectWrapLet e SNil = e + injectWrapLet e (arg `SCons` args) = + injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) + args + +dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) +dropNth SZ (val `NPush` _) = val +dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p +dropNth _ NTop = error "DropNth: index out of range" + +dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env +dropNthW SZ (_ `NPush` _) = WSink +dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) +dropNthW _ NTop = error "DropNth: index out of range" + +assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r +assertSymbolNotUnderscore s@SSymbol k = + case symbolVal s of + "_" -> error "assertSymbolNotUnderscore: was underscore" + _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k + +assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r +assertSymbolDistinct s1@SSymbol s2@SSymbol k + | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" + | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k + +equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r +equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/CHAD/Lemmas.hs b/src/CHAD/Lemmas.hs new file mode 100644 index 0000000..55ef042 --- /dev/null +++ b/src/CHAD/Lemmas.hs @@ -0,0 +1,21 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +{-# LANGUAGE AllowAmbiguousTypes #-} +module CHAD.Lemmas (module CHAD.Lemmas, (:~:)(Refl)) where + +import Data.Type.Equality +import Unsafe.Coerce (unsafeCoerce) + + +type family Append a b where + Append '[] l = l + Append (x : xs) l = x : Append xs l + +lemAppendNil :: Append a '[] :~: a +lemAppendNil = unsafeCoerce Refl + +lemAppendAssoc :: Append a (Append b c) :~: Append (Append a b) c +lemAppendAssoc = unsafeCoerce Refl diff --git a/src/CHAD/Simplify.hs b/src/CHAD/Simplify.hs new file mode 100644 index 0000000..2510cc5 --- /dev/null +++ b/src/CHAD/Simplify.hs @@ -0,0 +1,619 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeOperators #-} +module CHAD.Simplify ( + simplifyN, simplifyFix, + SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, +) where + +import Control.Monad (ap) +import Data.Bifunctor (first) +import Data.Function (fix) +import Data.Monoid (Any(..)) + +import Debug.Trace + +import CHAD.AST +import CHAD.AST.Count +import CHAD.AST.Pretty +import CHAD.AST.Sparse.Types +import CHAD.AST.UnMonoid (acPrjCompose) +import CHAD.Data +import CHAD.Simplify.TH + + +data SimplifyConfig = SimplifyConfig + { scLogging :: Bool + } + +defaultSimplifyConfig :: SimplifyConfig +defaultSimplifyConfig = SimplifyConfig False + +simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t +simplifyN 0 = id +simplifyN n = simplifyN (n - 1) . simplify + +simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplify = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = defaultSimplifyConfig + in snd . runSM . simplify' + +simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyWith config = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config + in snd . runSM . simplify' + +simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t +simplifyFix = simplifyFixWith defaultSimplifyConfig + +simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t +simplifyFixWith config = + let ?accumInScope = checkAccumInScope @env knownEnv + ?config = config + in fix $ \loop e -> + let (act, e') = runSM (simplify' e) + in if act then loop e' else e' + +-- | simplify monad +newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) + deriving (Functor) + +instance Applicative (SM tenv tt env t) where + pure x = SM (\_ -> (Any False, x)) + (<*>) = ap + +instance Monad (SM tenv tt env t) where + SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx + +runSM :: SM env t env t a -> (Bool, a) +runSM (SM f) = first getAny (f id) + +smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) +smReconstruct core = SM (\ctx -> (Any False, ctx core)) + +class Monad m => ActedMonad m where + tellActed :: m () + hideActed :: m a -> m a + liftActed :: (Any, a) -> m a + +instance ActedMonad ((,) Any) where + tellActed = (Any True, ()) + hideActed (_, x) = (Any False, x) + liftActed = id + +instance ActedMonad (SM tenv tt env t) where + tellActed = SM (\_ -> tellActed) + hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) + liftActed pair = SM (\_ -> pair) + +-- more convenient in practice +acted :: ActedMonad m => m a -> m a +acted m = tellActed >> m + +within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a +within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) + +simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify' expr + | scLogging ?config = do + res <- simplify'Rec expr + full <- smReconstruct res + let printed = ppExpr knownEnv full + replace a bs = concatMap (\x -> if x == a then bs else [x]) + str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed + | otherwise = "--- simplify step: " ++ printed + traceM str + return res + | otherwise = simplify'Rec expr + +simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) +simplify'Rec = \case + -- inlining + ELet _ rhs body + | cheapExpr rhs + -> acted $ simplify' (substInline rhs body) + + | Occ lexOcc runOcc <- occCount IZ body + , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply + || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not + -> acted $ simplify' (substInline rhs body) + + -- let splitting / let peeling + ELet _ (EPair _ a b) body -> + acted $ simplify' $ + ELet ext a $ + ELet ext (weakenExpr WSink b) $ + subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) + IS i -> EVar ext t (IS (IS i))) + body + ELet _ (EJust _ a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body + ELet _ (EInl _ t2 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body + ELet _ (EInr _ t1 a) body -> + acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body + + -- let rotation + ELet _ (ELet _ rhs a) b -> do + b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b + acted $ simplify' $ + ELet ext rhs $ + ELet ext a $ + weakenExpr (WCopy WSink) b' + + -- beta rules for products + EFst _ (EPair _ e e') + | not (hasAdds e') -> acted $ simplify' e + | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) + ESnd _ (EPair _ e' e) + | not (hasAdds e') -> acted $ simplify' e + | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) + + -- beta rules for coproducts + ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) + ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) + + -- beta rules for maybe + EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 + EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 + + -- let floating + EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) + ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) + ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) + EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) + EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) + EAccum _ t p e1 sp (ELet _ rhs body) acc -> + acted $ simplify' $ + ELet ext rhs $ + EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) + + -- let () = e in () ~> e + ELet _ e1 (ENil _) | STNil <- typeOf e1 -> + acted $ simplify' e1 + + -- map (\_ -> x) e ~> build (shape e) (\_ -> x) + EMap _ e1 e2 + | Occ Zero Zero <- occCount IZ e1 + , STArr n _ <- typeOf e2 -> + acted $ simplify' $ + EBuild ext n (EShape ext e2) $ + subst (\_ t' -> \case IZ -> error "Unused variable was used" + IS i -> EVar ext t' (IS i)) + e1 + + -- vertical fusion + EMap _ e1 (EMap _ e2 e3) -> + acted $ simplify' $ + EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 + + -- projection down-commuting + EFst _ (ECase _ e1 e2 e3) -> + acted $ simplify' $ + ECase ext e1 (EFst ext e2) (EFst ext e3) + ESnd _ (ECase _ e1 e2 e3) -> + acted $ simplify' $ + ECase ext e1 (ESnd ext e2) (ESnd ext e3) + EFst _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (EFst ext e1) (EFst ext e2) e3 + ESnd _ (EMaybe _ e1 e2 e3) -> + acted $ simplify' $ + EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 + + -- TODO: more array indexing + EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 + EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 + EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) + EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 + + -- TODO: more array shape + EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 + EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) + + -- TODO: more constant folding + EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) + EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) + + -- inline cheap array constructors + ELet _ (EReplicate1Inner _ e1 e2) e3 -> + acted $ simplify' $ + ELet ext (EPair ext e1 e2) $ + let v = EVar ext (STPair tIx (typeOf e2)) IZ + in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 + -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is + -- -- cheap, which it can't be because (!) is not cheap if you do AD after. + -- -- Should do proper SoA representation. + -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> + -- acted $ simplify' $ + -- ELet ext e1 $ + -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 + + -- eta rule for unit + e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> + case e of + ENil _ -> return e + _ -> acted $ return (ENil ext) + + EBuild _ SZ _ e -> + acted $ simplify' $ EUnit ext (substInline (ENil ext) e) + + -- monoid rules + EAccum _ t p e1 sp e2 acc -> do + e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 + e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 + acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc + simplifyOHT (OneHotTerm SAID t p e1' sp e2') + (acted $ return (ENil ext)) + (\sp' (InContext w wrap e) -> do + e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e + return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) + (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do + -- The acted management here is a hideous mess. + e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' + return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) + EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e + EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e + EOneHot _ t p e1 e2 -> do + e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 + e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 + simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') + (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) + (\sp' (InContext _ wrap e) -> + case isDense t sp' of + Just Refl -> do + e' <- hideActed $ within wrap $ simplify' e + return (wrap e') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> + case isDense (acPrjTy p' t') sp' of + Just Refl -> do + e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' + e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' + return (wrap $ EOneHot ext t' p' e1''' e2''') + Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") + + -- type-specific equations for plus + EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> + acted $ return (ENil ext) + + EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> + acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) + + EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> + acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) + EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> + acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) + EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e + EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e + + EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> + acted $ simplify' $ EJust ext (EPlus ext t e1 e2) + EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e + EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e + + -- fallback recursion + EVar _ t i -> pure $ EVar ext t i + ELet _ a b -> [simprec| ELet ext *a *b |] + EPair _ a b -> [simprec| EPair ext *a *b |] + EFst _ e -> [simprec| EFst ext *e |] + ESnd _ e -> [simprec| ESnd ext *e |] + ENil _ -> pure $ ENil ext + EInl _ t e -> [simprec| EInl ext t *e |] + EInr _ t e -> [simprec| EInr ext t *e |] + ECase _ e a b -> [simprec| ECase ext *e *a *b |] + ENothing _ t -> pure $ ENothing ext t + EJust _ e -> [simprec| EJust ext *e |] + EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] + ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 + ELInl _ t e -> [simprec| ELInl ext t *e |] + ELInr _ t e -> [simprec| ELInr ext t *e |] + ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] + EConstArr _ n t v -> pure $ EConstArr ext n t v + EBuild _ n a b -> [simprec| EBuild ext n *a *b |] + EMap _ a b -> [simprec| EMap ext *a *b |] + EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] + ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] + EUnit _ e -> [simprec| EUnit ext *e |] + EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] + EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] + EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] + EReshape _ n a b -> [simprec| EReshape ext n *a *b |] + EZip _ a b -> [simprec| EZip ext *a *b |] + EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] + EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] + EConst _ t v -> pure $ EConst ext t v + EIdx0 _ e -> [simprec| EIdx0 ext *e |] + EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] + EIdx _ a b -> [simprec| EIdx ext *a *b |] + EShape _ e -> [simprec| EShape ext *e |] + EOp _ op e -> [simprec| EOp ext op *e |] + ECustom _ s t p a b c e1 e2 -> do + a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) + b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) + c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) + e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) + e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) + pure (ECustom ext s t p a' b' c' e1' e2') + ERecompute _ e -> [simprec| ERecompute ext *e |] + EWith _ t e1 e2 -> do + e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) + e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) + pure (EWith ext t e1' e2') + -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] + -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] + EZero _ t e -> [simprec| EZero ext t *e |] + EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] + EPlus _ t a b -> [simprec| EPlus ext t *a *b |] + EError _ t s -> pure $ EError ext t s + +-- | This can be made more precise by tracking (and not counting) adds on +-- locally eliminated accumulators. +hasAdds :: Expr x env t -> Bool +hasAdds = \case + EVar _ _ _ -> False + ELet _ rhs body -> hasAdds rhs || hasAdds body + EPair _ a b -> hasAdds a || hasAdds b + EFst _ e -> hasAdds e + ESnd _ e -> hasAdds e + ENil _ -> False + EInl _ _ e -> hasAdds e + EInr _ _ e -> hasAdds e + ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b + ENothing _ _ -> False + EJust _ e -> hasAdds e + EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e + ELNil _ _ _ -> False + ELInl _ _ e -> hasAdds e + ELInr _ _ e -> hasAdds e + ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c + EConstArr _ _ _ _ -> False + EBuild _ _ a b -> hasAdds a || hasAdds b + EMap _ a b -> hasAdds a || hasAdds b + EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + ESum1Inner _ e -> hasAdds e + EUnit _ e -> hasAdds e + EReplicate1Inner _ a b -> hasAdds a || hasAdds b + EMaximum1Inner _ e -> hasAdds e + EMinimum1Inner _ e -> hasAdds e + EReshape _ _ a b -> hasAdds a || hasAdds b + EZip _ a b -> hasAdds a || hasAdds b + EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c + ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e + EConst _ _ _ -> False + EIdx0 _ e -> hasAdds e + EIdx1 _ a b -> hasAdds a || hasAdds b + EIdx _ a b -> hasAdds a || hasAdds b + EShape _ e -> hasAdds e + EOp _ _ e -> hasAdds e + EWith _ _ a b -> hasAdds a || hasAdds b + ERecompute _ e -> hasAdds e + EAccum _ _ _ _ _ _ _ -> True + EZero _ _ e -> hasAdds e + EDeepZero _ _ e -> hasAdds e + EPlus _ _ a b -> hasAdds a || hasAdds b + EOneHot _ _ _ a b -> hasAdds a || hasAdds b + EError _ _ _ -> False + +checkAccumInScope :: SList STy env -> Bool +checkAccumInScope = \case SNil -> False + SCons t env -> check t || checkAccumInScope env + where + check :: STy t -> Bool + check STNil = False + check (STPair s t) = check s || check t + check (STEither s t) = check s || check t + check (STLEither s t) = check s || check t + check (STMaybe t) = check t + check (STArr _ t) = check t + check (STScal _) = False + check STAccum{} = True + +data OneHotTerm dense env a where + OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a +deriving instance Show (OneHotTerm dense env a) + +data InContext f env (a :: Ty) where + InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a + +simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do + val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val + return $ OneHotTerm dense t prj idx sp val' + +simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) +simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = + unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> + acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> + return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) +simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht + +simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) +simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) + | Just Refl <- isDense (acPrjTy prj1 t1) sp = + let idx2' :: Ex env (AcIdx dense p2 c) + idx2' = case dense of + SAID -> reduceAcIdx t2 prj2 idx2 + SAIS -> idx2 + in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> + acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val +simplifyOHT_concat oht = return oht + +-- -- Property not expressed in types: if the Sparse in the input OneHotTerm is +-- -- dense, then the Sparse in the output will also be dense. This property is +-- -- used when simplifying EOneHot, which cannot represent sparsity. +simplifyOHT :: ActedMonad m => OneHotTerm dense env a + -> m r -- ^ Zero case (onehot is actually zero) + -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) + -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified + -> m r +simplifyOHT oht kzero ktriv k = do + -- traceM $ "sOHT: input " ++ show oht + oht1 <- simplifyOHT_recogniseMonoid oht + -- traceM $ "sOHT: recog " ++ show oht1 + InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 + -- traceM $ "sOHT: unspa " ++ show oht2 + oht3 <- simplifyOHT_concat oht2 + -- traceM $ "sOHT: conca " ++ show oht3 + -- traceM "" + case oht3 of + OneHotTerm _ _ _ _ _ EZero{} -> kzero + OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) + _ -> k (InContext w1 wrap1 oht3) + +-- Sets the acted flag whenever a non-trivial projection is returned or the +-- output Sparse is different from the input Sparse. +unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' + -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) + -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r +unsparseOneHotD topsp topval k = case (topsp, topval) of + -- eliminate always-Just sparse onehot + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k + + -- expand the top levels of a onehot for a sparse type into a onehot for the + -- corresponding non-sparse type + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPFst spprj) idx' s1' e' + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPSnd spprj) idx' s1' e' + (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPLeft spprj) idx' s1' e' + (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPRight spprj) idx' s1' e' + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> + acted $ k w wrap (SAPJust spprj) idx' s1' e' + (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) + | Dict <- styKnown (typeOf idx) -> + unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> + acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' + + -- anything else we don't know how to improve + _ -> k WId id SAPHere (ENil ext) topsp topval + +{- +unsparseOneHotS :: ActedMonad m + => Sparse a a' -> Ex env a' + -> (forall b. Sparse a b -> Ex env b -> m r) -> m r +unsparseOneHotS topsp topval k = case (topsp, topval) of + -- order is relevant to make sure we set the acted flag correctly + (SpAbsent, v@ENil{}) -> k SpAbsent v + (SpAbsent, v@EZero{}) -> k SpAbsent v + (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) + + -- the unsparsifying + (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> + acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k + + -- recursion + -- TODO: coproducts could safely become projections as they do not need + -- zeroinfo. But that would only work if the coproduct is at the top, because + -- as soon as we hit a product, we need zeroinfo to make it a projection and + -- we don't have that. + (SpSparse s, e) -> k (SpSparse s) e + (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> + acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) + (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> + acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') + (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do + case s2 of SpAbsent -> pure () ; _ -> tellActed + k (SpLEither s1' SpAbsent) (ELInl ext STNil e') + (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> + unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do + case s1 of SpAbsent -> pure () ; _ -> tellActed + acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') + (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> + k (SpMaybe s1') (EJust ext e') + (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> + unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> + k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') + _ -> _ +-} + +-- | Recognises 'EZero' and 'EOneHot'. +recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) +recogniseMonoid _ e@EOneHot{} = return e +recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) +recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = + ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case + (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) + (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' + (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' + (a', b') -> return $ EPair ext a' b' +recogniseMonoid typ@(SMTLEither t1 t2) expr = + case expr of + ELNil{} -> acted $ return $ EZero ext typ (ENil ext) + ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e + ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e + _ -> return expr +recogniseMonoid typ@(SMTMaybe t1) expr = + case expr of + ENothing{} -> acted $ return $ EZero ext typ (ENil ext) + EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e + _ -> return expr +recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = + acted $ do + e' <- recogniseMonoid t e + return $ + ELet ext e' $ + EOneHot ext typ (SAPArrIdx SAPHere) + (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) + (ENil ext)) + (EVar ext (fromSMTy t) IZ) +recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of + (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) + (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) + _ -> return e +recogniseMonoid _ e = return e + +reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) +reduceAcIdx topty topprj e = case (topty, topprj) of + (_, SAPHere) -> ENil ext + (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) + (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) + (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e + (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e + (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e + (SMTArr _ t, SAPArrIdx p) -> + eunPair e $ \_ e1 e2 -> + EPair ext (efst e1) (reduceAcIdx t p e2) + +zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) +zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) + where + -- invariant: AcIdx expression is duplicable + go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) + go t SAPHere _ e = makeZeroInfo t e + go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) + go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) + go SMTLEither{} _ _ _ = ENil ext + go SMTMaybe{} _ _ _ = ENil ext + go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) diff --git a/src/CHAD/Simplify/TH.hs b/src/CHAD/Simplify/TH.hs new file mode 100644 index 0000000..4af5394 --- /dev/null +++ b/src/CHAD/Simplify/TH.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE TemplateHaskellQuotes #-} +module CHAD.Simplify.TH (simprec) where + +import Data.Bifunctor (first) +import Data.Char +import Data.List (foldl', foldl1') +import Language.Haskell.TH +import Language.Haskell.TH.Quote +import Text.ParserCombinators.ReadP + + +-- [simprec| EPair ext *a *b |] +-- ~> +-- do a' <- within (\a' -> EPair ext a' b) (simplify' a) +-- b' <- within (\b' -> EPair ext a' b') (simplify' b) +-- pure (EPair ext a' b') + +simprec :: QuasiQuoter +simprec = QuasiQuoter + { quoteDec = \_ -> fail "simprec used outside of expression context" + , quoteType = \_ -> fail "simprec used outside of expression context" + , quoteExp = handler + , quotePat = \_ -> fail "simprec used outside of expression context" + } + +handler :: String -> Q Exp +handler str = + case readP_to_S pTemplate str of + [(template, "")] -> generate template + _:_:_ -> fail "simprec: template grammar ambiguous" + _ -> fail "simprec: could not parse template" + +generate :: Template -> Q Exp +generate (Template topitems) = + let takePrefix (Plain x : xs) = first (x:) (takePrefix xs) + takePrefix xs = ([], xs) + + itemVar "" = error "simprec: empty item name?" + itemVar name@(c:_) | isLower c = VarE (mkName name) + | isUpper c = ConE (mkName name) + | otherwise = error "simprec: non-letter item name?" + + loop :: Exp -> [Item] -> Q [Stmt] + loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)] + loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs + loop yet (Recurse x : xs) = do + primeName <- newName (x ++ "'") + let appPrePrime e (Plain y) = e `AppE` itemVar y + appPrePrime e (Recurse y) = e `AppE` itemVar y + let stmt = BindS (VarP primeName) $ + VarE (mkName "within") + `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs) + `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x)) + stmts <- loop (yet `AppE` VarE primeName) xs + return (stmt : stmts) + + (prefix, items') = takePrefix topitems + in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items' + +data Template = Template [Item] + deriving (Show) + +data Item = Plain String | Recurse String + deriving (Show) + +pTemplate :: ReadP Template +pTemplate = do + items <- many (skipSpaces >> pItem) + skipSpaces + eof + return (Template items) + +pItem :: ReadP Item +pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName) + +pName :: ReadP String +pName = do + c1 <- satisfy (\c -> isAlpha c || c == '_') + cs <- munch (\c -> isAlphaNum c || c `elem` "_'") + return (c1:cs) diff --git a/src/CHAD/Top.hs b/src/CHAD/Top.hs deleted file mode 100644 index 4814bdf..0000000 --- a/src/CHAD/Top.hs +++ /dev/null @@ -1,96 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module CHAD.Top where - -import Analysis.Identity -import AST -import AST.Env -import AST.Sparse -import AST.SplitLets -import AST.Weaken.Auto -import CHAD -import CHAD.Accum -import CHAD.EnvDescr -import CHAD.Types -import Data -import qualified Data.VarMap as VarMap - - -type family MergeEnv env where - MergeEnv '[] = '[] - MergeEnv (t : ts) = "merge" : MergeEnv ts - -mergeDescr :: SList STy env -> Descr env (MergeEnv env) -mergeDescr SNil = DTop -mergeDescr (t `SCons` env) = mergeDescr env `DPush` (t, Nothing, SMerge) - -mergeEnvNoAccum :: SList f env -> Select env (MergeEnv env) "accum" :~: '[] -mergeEnvNoAccum SNil = Refl -mergeEnvNoAccum (_ `SCons` env) | Refl <- mergeEnvNoAccum env = Refl - -mergeEnvOnlyMerge :: SList f env -> Select env (MergeEnv env) "merge" :~: env -mergeEnvOnlyMerge SNil = Refl -mergeEnvOnlyMerge (_ `SCons` env) | Refl <- mergeEnvOnlyMerge env = Refl - -accumDescr :: SList STy env -> (forall sto. Descr env sto -> r) -> r -accumDescr SNil k = k DTop -accumDescr (t `SCons` env) k = accumDescr env $ \des -> - if typeHasArrays t then k (des `DPush` (t, Nothing, SAccum)) - else k (des `DPush` (t, Nothing, SMerge)) - -reassembleD2E :: Descr env sto - -> D1E env :> env' - -> Ex env' (TPair (Tup (D2E (Select env sto "accum"))) (Tup (D2E (Select env sto "merge")))) - -> Ex env' (Tup (D2E env)) -reassembleD2E DTop _ _ = ENil ext -reassembleD2E (des `DPush` (_, _, SAccum)) w e = - eunPair e $ \w1 e1 e2 -> - eunPair e1 $ \w2 e11 e12 -> - EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext e11 (weakenExpr w2 e2))) e12 -reassembleD2E (des `DPush` (_, _, SMerge)) w e = - eunPair e $ \w1 e1 e2 -> - eunPair e2 $ \w2 e21 e22 -> - EPair ext (reassembleD2E des (w2 .> w1 .> WPop w) (EPair ext (weakenExpr w2 e1) e21)) e22 -reassembleD2E (des `DPush` (t, _, SDiscr)) w e = - EPair ext (reassembleD2E des (WPop w) e) - (EZero ext (d2M t) (d2zeroInfo t (EVar ext (d1 t) (w @> IZ)))) - -chad :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : D1E env) (TPair (D1 t) (Tup (D2E env))) -chad config env (term :: Ex env t) - | True <- chcArgArrayAccum config - = let ?config = config - in accumDescr env $ \descr -> - let t1 = STPair (d1 (typeOf term)) (tTup (d2e (select SMerge descr))) - tvar = STPair t1 (tTup (d2e (select SAccum descr))) - in ELet ext (uninvertTup (d2e (select SAccum descr)) t1 $ - makeAccumulators (WSink .> wUndoSubenv (subenvD1E (selectSub SAccum descr))) (select SAccum descr) $ - weakenExpr (autoWeak (#d (auto1 @(D2 t)) - &. #acenv (d2ace (select SAccum descr)) - &. #tl (d1e env)) - (#d :++: #acenv :++: #tl) - (#acenv :++: #d :++: #tl)) $ - freezeRet descr (drev descr VarMap.empty (spDense (d2M (typeOf term))) term')) $ - EPair ext (EFst ext (EFst ext (EVar ext tvar IZ))) - (reassembleD2E descr (WSink .> WSink) - (EPair ext (ESnd ext (EVar ext tvar IZ)) - (ESnd ext (EFst ext (EVar ext tvar IZ))))) - - | False <- chcArgArrayAccum config - , Refl <- mergeEnvNoAccum env - , Refl <- mergeEnvOnlyMerge env - = let ?config = config in freezeRet (mergeDescr env) (drev (mergeDescr env) VarMap.empty (spDense (d2M (typeOf term))) term') - where - term' = identityAnalysis env (splitLets term) - -chad' :: CHADConfig -> SList STy env -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) -chad' config env term - | Refl <- d1eIdentity env, Refl <- d1Identity (typeOf term) - = chad config env term diff --git a/src/CHAD/Types.hs b/src/CHAD/Types.hs deleted file mode 100644 index 44ac20e..0000000 --- a/src/CHAD/Types.hs +++ /dev/null @@ -1,153 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module CHAD.Types where - -import AST.Accum -import AST.Types -import Data - - -type family D1 t where - D1 TNil = TNil - D1 (TPair a b) = TPair (D1 a) (D1 b) - D1 (TEither a b) = TEither (D1 a) (D1 b) - D1 (TLEither a b) = TLEither (D1 a) (D1 b) - D1 (TMaybe a) = TMaybe (D1 a) - D1 (TArr n t) = TArr n (D1 t) - D1 (TScal t) = TScal t - -type family D2 t where - D2 TNil = TNil - D2 (TPair a b) = TPair (D2 a) (D2 b) - D2 (TEither a b) = TLEither (D2 a) (D2 b) - D2 (TLEither a b) = TLEither (D2 a) (D2 b) - D2 (TMaybe t) = TMaybe (D2 t) - D2 (TArr n t) = TArr n (D2 t) - D2 (TScal t) = D2s t - -type family D2s t where - D2s TI32 = TNil - D2s TI64 = TNil - D2s TF32 = TScal TF32 - D2s TF64 = TScal TF64 - D2s TBool = TNil - -type family D1E env where - D1E '[] = '[] - D1E (t : env) = D1 t : D1E env - -type family D2E env where - D2E '[] = '[] - D2E (t : env) = D2 t : D2E env - -type family D2AcE env where - D2AcE '[] = '[] - D2AcE (t : env) = TAccum (D2 t) : D2AcE env - -d1 :: STy t -> STy (D1 t) -d1 STNil = STNil -d1 (STPair a b) = STPair (d1 a) (d1 b) -d1 (STEither a b) = STEither (d1 a) (d1 b) -d1 (STLEither a b) = STLEither (d1 a) (d1 b) -d1 (STMaybe t) = STMaybe (d1 t) -d1 (STArr n t) = STArr n (d1 t) -d1 (STScal t) = STScal t -d1 STAccum{} = error "Accumulators not allowed in input program" - -d1e :: SList STy env -> SList STy (D1E env) -d1e SNil = SNil -d1e (t `SCons` env) = d1 t `SCons` d1e env - -d2M :: STy t -> SMTy (D2 t) -d2M STNil = SMTNil -d2M (STPair a b) = SMTPair (d2M a) (d2M b) -d2M (STEither a b) = SMTLEither (d2M a) (d2M b) -d2M (STLEither a b) = SMTLEither (d2M a) (d2M b) -d2M (STMaybe t) = SMTMaybe (d2M t) -d2M (STArr n t) = SMTArr n (d2M t) -d2M (STScal t) = case t of - STI32 -> SMTNil - STI64 -> SMTNil - STF32 -> SMTScal STF32 - STF64 -> SMTScal STF64 - STBool -> SMTNil -d2M STAccum{} = error "Accumulators not allowed in input program" - -d2 :: STy t -> STy (D2 t) -d2 = fromSMTy . d2M - -d2eM :: SList STy env -> SList SMTy (D2E env) -d2eM SNil = SNil -d2eM (t `SCons` ts) = d2M t `SCons` d2eM ts - -d2e :: SList STy env -> SList STy (D2E env) -d2e = slistMap fromSMTy . d2eM - -d2ace :: SList STy env -> SList STy (D2AcE env) -d2ace SNil = SNil -d2ace (t `SCons` ts) = STAccum (d2M t) `SCons` d2ace ts - - -data CHADConfig = CHADConfig - { -- | D[let] will bind variables containing arrays in accumulator mode. - chcLetArrayAccum :: Bool - , -- | D[case] will bind variables containing arrays in accumulator mode. - chcCaseArrayAccum :: Bool - , -- | Introduce top-level arguments containing arrays in accumulator mode. - chcArgArrayAccum :: Bool - , -- | Place with-blocks around array variable scopes, and redirect accumulations there. - chcSmartWith :: Bool - } - deriving (Show) - -defaultConfig :: CHADConfig -defaultConfig = CHADConfig - { chcLetArrayAccum = False - , chcCaseArrayAccum = False - , chcArgArrayAccum = False - , chcSmartWith = False - } - -chcSetAccum :: CHADConfig -> CHADConfig -chcSetAccum c = c { chcLetArrayAccum = True - , chcCaseArrayAccum = True - , chcArgArrayAccum = True - , chcSmartWith = True } - - ------------------------------------- LEMMAS ------------------------------------ - -indexTupD1Id :: SNat n -> Tup (Replicate n TIx) :~: D1 (Tup (Replicate n TIx)) -indexTupD1Id SZ = Refl -indexTupD1Id (SS n) | Refl <- indexTupD1Id n = Refl - -lemZeroInfoScal :: SScalTy t -> ZeroInfo (D2s t) :~: TNil -lemZeroInfoScal STI32 = Refl -lemZeroInfoScal STI64 = Refl -lemZeroInfoScal STF32 = Refl -lemZeroInfoScal STF64 = Refl -lemZeroInfoScal STBool = Refl - -lemDeepZeroInfoScal :: SScalTy t -> DeepZeroInfo (D2s t) :~: TNil -lemDeepZeroInfoScal STI32 = Refl -lemDeepZeroInfoScal STI64 = Refl -lemDeepZeroInfoScal STF32 = Refl -lemDeepZeroInfoScal STF64 = Refl -lemDeepZeroInfoScal STBool = Refl - -d1Identity :: STy t -> D1 t :~: t -d1Identity = \case - STNil -> Refl - STPair a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STLEither a b | Refl <- d1Identity a, Refl <- d1Identity b -> Refl - STMaybe t | Refl <- d1Identity t -> Refl - STArr _ t | Refl <- d1Identity t -> Refl - STScal _ -> Refl - STAccum{} -> error "Accumulators not allowed in input program" - -d1eIdentity :: SList STy env -> D1E env :~: env -d1eIdentity SNil = Refl -d1eIdentity (t `SCons` env) | Refl <- d1Identity t, Refl <- d1eIdentity env = Refl diff --git a/src/CHAD/Types/ToTan.hs b/src/CHAD/Types/ToTan.hs deleted file mode 100644 index 888fed4..0000000 --- a/src/CHAD/Types/ToTan.hs +++ /dev/null @@ -1,43 +0,0 @@ -{-# LANGUAGE GADTs #-} -module CHAD.Types.ToTan where - -import Data.Bifunctor (bimap) - -import Array -import AST.Types -import CHAD.Types -import Data -import ForwardAD -import Interpreter.Rep - - -toTanE :: SList STy env -> SList Value env -> SList Value (D2E env) -> SList Value (TanE env) -toTanE SNil SNil SNil = SNil -toTanE (t `SCons` env) (Value p `SCons` primal) (Value x `SCons` inp) = - Value (toTan t p x) `SCons` toTanE env primal inp - -toTan :: STy t -> Rep t -> Rep (D2 t) -> Rep (Tan t) -toTan typ primal der = case typ of - STNil -> der - STPair t1 t2 -> bimap (\p1 -> toTan t1 p1 (fst der)) (\p2 -> toTan t2 p2 (snd der)) primal - STEither t1 t2 -> case der of - Nothing -> bimap (zeroTan t1) (zeroTan t2) primal - Just d -> case (primal, d) of - (Left p, Left d') -> Left (toTan t1 p d') - (Right p, Right d') -> Right (toTan t2 p d') - _ -> error "Primal and cotangent disagree on Either alternative" - STLEither t1 t2 -> case (primal, der) of - (_, Nothing) -> Nothing - (Just (Left p), Just (Left d)) -> Just (Left (toTan t1 p d)) - (Just (Right p), Just (Right d)) -> Just (Right (toTan t2 p d)) - _ -> error "Primal and cotangent disagree on LEither alternative" - STMaybe t -> liftA2 (toTan t) primal der - STArr _ t - | arrayShape primal == arrayShape der -> - arrayGenerateLin (arrayShape primal) $ \i -> - toTan t (arrayIndexLinear primal i) (arrayIndexLinear der i) - | otherwise -> - error "Primal and cotangent disagree on array shape" - STScal sty -> case sty of - STI32 -> der ; STI64 -> der ; STF32 -> der ; STF64 -> der ; STBool -> der - STAccum{} -> error "Accumulators not allowed in input program" diff --git a/src/CHAD/Util/IdGen.hs b/src/CHAD/Util/IdGen.hs new file mode 100644 index 0000000..d4fd945 --- /dev/null +++ b/src/CHAD/Util/IdGen.hs @@ -0,0 +1,19 @@ +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +module CHAD.Util.IdGen where + +import Control.Monad.Fix +import Control.Monad.Trans.State.Strict + + +newtype IdGen a = IdGen (State Int a) + deriving newtype (Functor, Applicative, Monad, MonadFix) + +genId :: IdGen Int +genId = IdGen (state (\i -> (i, i + 1))) + +runIdGen :: Int -> IdGen a -> a +runIdGen start (IdGen m) = evalState m start + +runIdGen' :: Int -> IdGen a -> (a, Int) +runIdGen' start (IdGen m) = runState m start diff --git a/src/Compile.hs b/src/Compile.hs deleted file mode 100644 index 8627905..0000000 --- a/src/Compile.hs +++ /dev/null @@ -1,1796 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -module Compile (compile, compileStderr) where - -import Control.Applicative (empty) -import Control.Monad (forM_, when, replicateM) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Maybe -import Control.Monad.Trans.State.Strict -import Control.Monad.Trans.Writer.CPS -import Data.Bifunctor (first) -import Data.Char (ord) -import Data.Foldable (toList) -import Data.Functor.Const -import qualified Data.Functor.Product as Product -import Data.Functor.Product (Product) -import Data.IORef -import Data.List (foldl1', intersperse, intercalate) -import qualified Data.Map.Strict as Map -import Data.Maybe (fromMaybe) -import qualified Data.Set as Set -import Data.Set (Set) -import Data.Some -import qualified Data.Vector as V -import Foreign -import GHC.Exts (int2Word#, addr2Int#) -import GHC.Num (integerFromWord#) -import GHC.Ptr (Ptr(..)) -import GHC.Stack (HasCallStack) -import Numeric (showHex) -import System.IO (hPutStrLn, stderr) -import System.IO.Error (mkIOError, userErrorType) -import System.IO.Unsafe (unsafePerformIO) - -import Prelude hiding ((^)) -import qualified Prelude - -import Array -import AST -import AST.Pretty (ppSTy, ppExpr) -import AST.Sparse.Types (isDense) -import Compile.Exec -import Data -import Interpreter.Rep -import qualified Util.IdGen as IdGen - - --- In shape and index arrays, the innermost dimension is on the right (last index). - --- TODO: test that I'm properly incrementing and decrementing refcounts in all required places - - --- | Print the compiled AST -debugPrintAST :: Bool; debugPrintAST = toEnum 0 --- | Print the generated C source -debugCSource :: Bool; debugCSource = toEnum 0 --- | Print extra stuff about reference counts of arrays -debugRefc :: Bool; debugRefc = toEnum 0 --- | Print some shape-related information -debugShapes :: Bool; debugShapes = toEnum 0 --- | Print information on allocation -debugAllocs :: Bool; debugAllocs = toEnum 0 --- | Emit extra C code that checks stuff -emitChecks :: Bool; emitChecks = toEnum 0 - --- | Returns compiled function plus compilation output (warnings) -compile :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t), String) -compile = \env expr -> do - codeID <- atomicModifyIORef' uniqueIdGenRef (\i -> (i + 1, i)) - - let (source, offsets) = compileToString codeID env expr - when debugPrintAST $ hPutStrLn stderr $ "Compiled AST: <<<\n" ++ ppExpr env expr ++ "\n>>>" - when debugCSource $ hPutStrLn stderr $ "Generated C source: <<<\n\x1B[2m" ++ lineNumbers source ++ "\x1B[0m>>>" - (lib, compileOutput) <- buildKernel source "kernel" - - let result_type = typeOf expr - result_size = sizeofSTy result_type - - let function val = do - allocaBytes (koResultOffset offsets + result_size) $ \ptr -> do - let args = zip (reverse (unSList Some (slistZip env val))) (koArgOffsets offsets) - serialiseArguments args ptr $ do - callKernelFun lib ptr - ok <- peekByteOff @Word8 ptr (koOkResOffset offsets) - when (ok /= 1) $ - ioError (mkIOError userErrorType "fatal error detected during chad kernel execution (memory has been leaked)" Nothing Nothing) - deserialise result_type ptr (koResultOffset offsets) - return (function, compileOutput) - where - serialiseArguments :: [(Some (Product STy Value), Int)] -> Ptr () -> IO r -> IO r - serialiseArguments ((Some (Product.Pair t (Value arg)), off) : args) ptr k = - serialise t arg ptr off $ - serialiseArguments args ptr k - serialiseArguments _ _ k = k - --- | 'compile', but writes any produced C compiler output to stderr. -compileStderr :: SList STy env -> Ex env t - -> IO (SList Value env -> IO (Rep t)) -compileStderr env expr = do - (fun, output) <- compile env expr - when (not (null output)) $ - hPutStrLn stderr $ "[chad] Kernel compilation GCC output: <<<\n" ++ output ++ ">>>" - return fun - - -data StructDecl = StructDecl - String -- ^ name - String -- ^ contents - String -- ^ comment - deriving (Show) - -data Stmt - = SVarDecl Bool String String CExpr -- ^ const, type, variable name, right-hand side - | SVarDeclUninit String String -- ^ type, variable name (no initialiser) - | SAsg String CExpr -- ^ variable name, right-hand side - | SBlock (Bag Stmt) - | SIf CExpr (Bag Stmt) (Bag Stmt) - | SLoop String String CExpr CExpr (Bag Stmt) -- ^ for ( = ; name < ; name++) {} - | SVerbatim String -- ^ no implicit ';', just printed as-is - deriving (Show) - -data CExpr - = CELit String -- ^ inserted as-is, assumed no parentheses needed - | CEStruct String [(String, CExpr)] -- ^ struct construction literal: `(name){.field=expr}` - | CEProj CExpr String -- ^ field projection: expr.field - | CEPtrProj CExpr String -- ^ field projection through pointer: expr->field - | CEAddrOf CExpr -- ^ &expr - | CEIndex CExpr CExpr -- ^ expr[expr] - | CECall String [CExpr] -- ^ function(arg1, ..., argn) - | CEBinop CExpr String CExpr -- ^ expr + expr - | CEIf CExpr CExpr CExpr -- ^ expr ? expr : expr - | CECast String CExpr -- ^ () - deriving (Show) - -printStructDecl :: StructDecl -> ShowS -printStructDecl (StructDecl name contents comment) = - showString "typedef struct { " . showString contents . showString " } " . showString name - . showString ";" . (if null comment then id else showString (" // " ++ comment)) - -printStmt :: Int -> Stmt -> ShowS -printStmt indent = \case - SVarDecl cnst typ name rhs -> showString (typ ++ " " ++ (if cnst then "const " else "") ++ name ++ " = ") . printCExpr 0 rhs . showString ";" - SVarDeclUninit typ name -> showString (typ ++ " " ++ name ++ ";") - SAsg name rhs -> showString (name ++ " = ") . printCExpr 0 rhs . showString ";" - SBlock stmts - | null stmts -> showString "{}" - | otherwise -> - showString "{" - . compose [showString ("\n" ++ replicate (2*indent+2) ' ') . printStmt (indent+1) stmt | stmt <- toList stmts] - . showString ("\n" ++ replicate (2*indent) ' ' ++ "}") - SIf cond b1 b2 -> - showString "if (" . printCExpr 0 cond . showString ") " - . printStmt indent (SBlock b1) - . (if null b2 then id else showString " else " . printStmt indent (SBlock b2)) - SLoop typ name e1 e2 stmts -> - showString ("for (" ++ typ ++ " " ++ name ++ " = ") - . printCExpr 0 e1 . showString ("; " ++ name ++ " < ") . printCExpr 6 e2 - . showString ("; " ++ name ++ "++) ") - . printStmt indent (SBlock stmts) - SVerbatim s -> showString s - --- d values: --- * 0: top level --- * 1: in 1st or 2nd component of a ternary operator (technically same as top level, but readability) --- * 2-...: various operators (see precTable) --- * 80: address-of operator (&) --- * 98: inside unknown operator --- * 99: left of a field projection --- Unlisted operators are conservatively written with full parentheses. -printCExpr :: Int -> CExpr -> ShowS -printCExpr d = \case - CELit s -> showString s - CEStruct name pairs -> - showParen (d >= 99) $ - showString ("(" ++ name ++ "){") - . compose (intersperse (showString ", ") [showString ("." ++ n ++ " = ") . printCExpr 0 e - | (n, e) <- pairs]) - . showString "}" - CEProj e name -> printCExpr 99 e . showString ("." ++ name) - CEPtrProj e name -> printCExpr 99 e . showString ("->" ++ name) - CEAddrOf e -> showParen (d > 80) $ showString "&" . printCExpr 80 e - CEIndex e1 e2 -> printCExpr 99 e1 . showString "[" . printCExpr 0 e2 . showString "]" - CECall n es -> - showString (n ++ "(") . compose (intersperse (showString ", ") (map (printCExpr 0) es)) . showString ")" - CEBinop e1 n e2 -> - let mprec = Map.lookup n precTable - p = maybe (-1) fst mprec -- precedence of this operator - (d1, d2) = maybe (98, 98) snd mprec -- precedences for the arguments - in showParen (d > p) $ - printCExpr d1 e1 . showString (" " ++ n ++ " ") . printCExpr d2 e2 - CEIf e1 e2 e3 -> - showParen (d > 0) $ - printCExpr 1 e1 . showString " ? " . printCExpr 1 e2 . showString " : " . printCExpr 0 e3 - CECast typ e -> - showParen (d > 98) $ showString ("(" ++ typ ++ ")") . printCExpr 98 e - where - precTable = Map.fromList - [("||", (2, (2, 2))) - ,("&&", (3, (3, 3))) - ,("==", (4, (5, 5))) - ,("!=", (4, (5, 5))) - ,("<", (5, (6, 6))) -- Note: this precedence is used in the printing of SLoop - ,(">", (5, (6, 6))) - ,("<=", (5, (6, 6))) - ,(">=", (5, (6, 6))) - ,("+", (6, (6, 7))) - ,("-", (6, (6, 7))) - ,("*", (7, (7, 8))) - ,("/", (7, (7, 8))) - ,("%", (7, (7, 8)))] - -repSTy :: STy t -> String -repSTy (STScal st) = case st of - STI32 -> "int32_t" - STI64 -> "int64_t" - STF32 -> "float" - STF64 -> "double" - STBool -> "uint8_t" -repSTy t = genStructName t - -genStructName, genArrBufStructName :: STy t -> String -(genStructName, genArrBufStructName) = - (\t -> "ty_" ++ gen t - ,\case STArr _ t -> "ty_A_" ++ gen t ++ "_buf" -- just like the normal type, but with _ for the dimension - t -> error $ "genArrBufStructName: not an array type: " ++ show t) - where - -- all tags start with a letter, so the array mangling is unambiguous. - gen :: STy t -> String - gen STNil = "n" - gen (STPair a b) = 'P' : gen a ++ gen b - gen (STEither a b) = 'E' : gen a ++ gen b - gen (STLEither a b) = 'L' : gen a ++ gen b - gen (STMaybe t) = 'M' : gen t - gen (STArr n t) = "A" ++ show (fromSNat n) ++ gen t - gen (STScal st) = case st of - STI32 -> "i" - STI64 -> "j" - STF32 -> "f" - STF64 -> "d" - STBool -> "b" - gen (STAccum t) = 'C' : gen (fromSMTy t) - --- The subtrees contain structs used in the bodies of the structs in this node. -data StructTree = TreeNode [StructDecl] [StructTree] - deriving (Show) - --- | This function generates the actual struct declarations for each of the --- types in our language. It thus implicitly "documents" the layout of the --- types in the C translation. --- --- For accumulation it is important that for struct representations of monoid --- types, the all-zero-bytes value corresponds to the zero value of that type. -buildStructTree :: STy t -> StructTree -buildStructTree topty = case topty of - STNil -> - TreeNode [StructDecl name "" com] [] - STPair a b -> - TreeNode [StructDecl name (repSTy a ++ " a; " ++ repSTy b ++ " b;") com] - [buildStructTree a, buildStructTree b] - STEither a b -> -- 0 -> l, 1 -> r - TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] - [buildStructTree a, buildStructTree b] - STLEither a b -> -- 0 -> nil, 1 -> l, 2 -> r - TreeNode [StructDecl name ("uint8_t tag; union { " ++ repSTy a ++ " l; " ++ repSTy b ++ " r; };") com] - [buildStructTree a, buildStructTree b] - STMaybe t -> -- 0 -> nothing, 1 -> just - TreeNode [StructDecl name ("uint8_t tag; " ++ repSTy t ++ " j;") com] - [buildStructTree t] - STArr n t -> - -- The buffer is trailed by a VLA for the actual array data. - -- TODO: no buffer if n = 0 - TreeNode [StructDecl (genArrBufStructName topty) ("size_t refc; " ++ repSTy t ++ " xs[];") "" - ,StructDecl name (genArrBufStructName topty ++ " *buf; size_t sh[" ++ show (fromSNat n) ++ "];") com] - [buildStructTree t] - STScal _ -> - TreeNode [] [] - STAccum t -> - TreeNode [StructDecl (name ++ "_buf") (repSTy (fromSMTy t) ++ " ac;") "" - ,StructDecl name (name ++ "_buf *buf;") com] - [buildStructTree (fromSMTy t)] - where - name = genStructName topty - com = ppSTy 0 topty - --- State: already-generated (skippable) struct names --- Writer: the structs in declaration order -genStructTreeW :: StructTree -> WriterT (Bag StructDecl) (State (Set String)) () -genStructTreeW (TreeNode these deps) = do - seen <- lift get - case filter ((`Set.notMember` seen) . nameOf) these of - [] -> pure () - structs -> do - lift $ modify (Set.fromList (map nameOf structs) <>) - mapM_ genStructTreeW deps - tell (BList structs) - where - nameOf (StructDecl name _ _) = name - -genAllStructs :: Foldable t => t (Some STy) -> [StructDecl] -genAllStructs tys = - let m = mapM_ (\(Some t) -> genStructTreeW (buildStructTree t)) tys - in toList (evalState (execWriterT m) mempty) - -data CompState = CompState - { csStructs :: Set (Some STy) - , csTopLevelDecls :: Bag String - , csStmts :: Bag Stmt - , csNextId :: Int } - deriving (Show) - -newtype CompM a = CompM (State CompState a) - deriving newtype (Functor, Applicative, Monad) - -runCompM :: CompM a -> (a, CompState) -runCompM (CompM m) = runState m (CompState mempty mempty mempty 1) - -class Monad m => MonadNameGen m where genId :: m Int -instance MonadNameGen CompM where genId = CompM $ state $ \s -> (csNextId s, s { csNextId = csNextId s + 1 }) -instance MonadNameGen IdGen.IdGen where genId = IdGen.genId -instance MonadNameGen m => MonadNameGen (MaybeT m) where genId = MaybeT (Just <$> genId) - -genName' :: MonadNameGen m => String -> m String -genName' "" = genName -genName' prefix = (prefix ++) . show <$> genId - -genName :: MonadNameGen m => m String -genName = genName' "x" - -onlyIdGen :: IdGen.IdGen a -> CompM a -onlyIdGen m = CompM $ do - i1 <- gets csNextId - let (res, i2) = IdGen.runIdGen' i1 m - modify (\s -> s { csNextId = i2 }) - return res - -emit :: Stmt -> CompM () -emit stmt = CompM $ modify $ \s -> s { csStmts = csStmts s <> pure stmt } - -scope :: CompM a -> CompM (a, Bag Stmt) -scope m = do - stmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = mempty }) - res <- m - innerStmts <- CompM $ state $ \s -> (csStmts s, s { csStmts = stmts }) - return (res, innerStmts) - -emitStruct :: STy t -> CompM String -emitStruct ty = CompM $ do - modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } - return (genStructName ty) - --- | Also returns the name of the array buffer struct -emitArrStruct :: STy t -> CompM (String, String) -emitArrStruct ty = CompM $ do - modify $ \s -> s { csStructs = Set.insert (Some ty) (csStructs s) } - return (genStructName ty, genArrBufStructName ty) - -emitTLD :: String -> CompM () -emitTLD decl = CompM $ modify $ \s -> s { csTopLevelDecls = csTopLevelDecls s <> pure decl } - -nameEnv :: SList f env -> SList (Const String) env -nameEnv = flip evalState (0::Int) . slistMapA (\_ -> state $ \i -> (Const ("arg" ++ show i), i + 1)) - -data KernelOffsets = KernelOffsets - { koArgOffsets :: [Int] -- ^ the function arguments - , koOkResOffset :: Int -- ^ a byte: 1 if successful execution, 0 if (fatal) error occurred - , koResultOffset :: Int -- ^ the function result - } - -compileToString :: Int -> SList STy env -> Ex env t -> (String, KernelOffsets) -compileToString codeID env expr = - let args = nameEnv env - (res, s) = runCompM (compile' args expr) - structs = genAllStructs (csStructs s <> Set.fromList (unSList Some env)) - - (arg_pairs, arg_metrics) = - unzip $ reverse (unSList (\(Product.Pair t (Const n)) -> ((n, repSTy t), metricsSTy t)) - (slistZip env args)) - (arg_offsets, okres_offset) = computeStructOffsets arg_metrics - result_offset = align (alignmentSTy (typeOf expr)) (okres_offset + 1) - - offsets = KernelOffsets - { koArgOffsets = arg_offsets - , koOkResOffset = okres_offset - , koResultOffset = result_offset } - in (,offsets) . ($ "") $ compose - [showString "#include \n" - ,showString "#include \n" - ,showString "#include \n" - ,showString "#include \n" - ,showString "#include \n" - ,showString "#include \n" - ,showString "#include \n\n" - -- PRint-tag - ,showString $ "#define PRTAG \"[chad-kernel" ++ show codeID ++ "] \"\n\n" - - ,compose [printStructDecl sd . showString "\n" | sd <- structs] - ,showString "\n" - - -- Using %zd and not %zu here because values > SIZET_MAX/2 should be recognisable as "negative" - ,showString "static void* malloc_instr_fun(size_t n, int line) {\n" - ,showString " void *ptr = malloc(n);\n" - ,if debugAllocs then showString " printf(PRTAG \":%d malloc(%zd) -> %p\\n\", line, n, ptr);\n" - else id - ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"malloc(%zd) returned NULL on line %d\\n\", n, line); return false; }\n" - else id - ,showString " return ptr;\n" - ,showString "}\n" - ,showString "#define malloc_instr(n) ({void *ptr_ = malloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" - ,showString "static void* calloc_instr_fun(size_t n, int line) {\n" - ,showString " void *ptr = calloc(n, 1);\n" - ,if debugAllocs then showString " printf(PRTAG \":%d calloc(%zd) -> %p\\n\", line, n, ptr);\n" - else id - ,if emitChecks then showString " if (ptr == NULL) { printf(PRTAG \"calloc(%zd, 1) returned NULL on line %d\\n\", n, line); return false; }\n" - else id - ,showString " return ptr;\n" - ,showString "}\n" - ,showString "#define calloc_instr(n) ({void *ptr_ = calloc_instr_fun(n, __LINE__); if (ptr_ == NULL) return false; ptr_;})\n" - ,showString "static void free_instr(void *ptr) {\n" - ,if debugAllocs then showString "printf(PRTAG \"free(%p)\\n\", ptr);\n" - else id - ,showString " free(ptr);\n" - ,showString "}\n\n" - - ,compose [showString str . showString "\n\n" | str <- toList (csTopLevelDecls s)] - - ,showString $ - "static bool typed_kernel(" ++ - repSTy (typeOf expr) ++ " *output" ++ - concatMap (", " ++) - (reverse (unSList (\(Product.Pair t (Const n)) -> repSTy t ++ " " ++ n) (slistZip env args))) ++ - ") {\n" - ,compose [showString " " . printStmt 1 st . showString "\n" | st <- toList (csStmts s)] - ,showString " *output = " . printCExpr 0 res . showString ";\n" - ,showString " return true;\n" - ,showString "}\n\n" - - ,showString "void kernel(void *data) {\n" - -- Some code here assumes that we're on a 64-bit system, so let's check that - ,showString $ " if (sizeof(void*) != 8 || sizeof(size_t) != 8) { fprintf(stderr, \"Only 64-bit systems supported\\n\"); *(uint8_t*)(data + " ++ show okres_offset ++ ") = 0; return; }\n" - ,if debugRefc then showString " fprintf(stderr, PRTAG \"Start\\n\");\n" - else id - ,showString $ " const bool success = typed_kernel(" ++ - "\n (" ++ repSTy (typeOf expr) ++ "*)(data + " ++ show result_offset ++ ")" ++ - concat (map (\((arg, typ), off) -> - ",\n *(" ++ typ ++ "*)(data + " ++ show off ++ ")" - ++ " /* " ++ arg ++ " */") - (zip arg_pairs arg_offsets)) ++ - "\n );\n" - ,showString $ " *(uint8_t*)(data + " ++ show okres_offset ++ ") = success;\n" - ,if debugRefc then showString " fprintf(stderr, PRTAG \"Return\\n\");\n" - else id - ,showString "}\n"] - --- | Takes list of metrics (alignment, sizeof). --- Returns (offsets, size of struct). -computeStructOffsets :: [(Int, Int)] -> ([Int], Int) -computeStructOffsets = go 0 0 - where - go off maxal [(al, sz)] = - ([off], align (max maxal al) (off + sz)) - go off maxal ((al, sz) : pairs@((al2,_):_)) = - first (off :) $ go (align al2 (off + sz)) (max maxal al) pairs - go _ _ [] = ([], 0) - --- | Assumes that this is called at the correct alignment. -serialise :: STy t -> Rep t -> Ptr () -> Int -> IO r -> IO r -serialise topty topval ptr off k = - -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls - case (topty, topval) of - (STNil, ()) -> k - (STPair a b, (x, y)) -> - serialise a x ptr off $ - serialise b y ptr (align (alignmentSTy b) (off + sizeofSTy a)) k - (STEither a _, Left x) -> do - pokeByteOff ptr off (0 :: Word8) -- alignment of (union {a b}) is the same as alignment of (a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STEither _ b, Right y) -> do - pokeByteOff ptr off (1 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k - (STLEither _ _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STLEither a _, Just (Left x)) -> do - pokeByteOff ptr off (1 :: Word8) -- alignment of (union {a b}) is the same as alignment of (1 + a + b) - serialise a x ptr (off + alignmentSTy topty) k - (STLEither _ b, Just (Right y)) -> do - pokeByteOff ptr off (2 :: Word8) - serialise b y ptr (off + alignmentSTy topty) k - (STMaybe _, Nothing) -> do - pokeByteOff ptr off (0 :: Word8) - k - (STMaybe t, Just x) -> do - pokeByteOff ptr off (1 :: Word8) - serialise t x ptr (off + alignmentSTy t) k - (STArr n t, Array sh vec) -> do - let eltsz = sizeofSTy t - allocaBytes (8 + shapeSize sh * eltsz) $ \bufptr -> do - when debugRefc $ - hPutStrLn stderr $ "[chad-serialise] Allocating input buffer " ++ showPtr bufptr - pokeByteOff ptr off bufptr - pokeShape ptr (off + 8) n sh - - pokeByteOff @Word64 bufptr 0 (2 ^ 63) - - let loop i - | i == shapeSize sh = k - | otherwise = - serialise t (vec V.! i) bufptr (8 + i * eltsz) $ - loop (i+1) - loop 0 - (STScal sty, x) -> case sty of - STI32 -> pokeByteOff ptr off (x :: Int32) >> k - STI64 -> pokeByteOff ptr off (x :: Int64) >> k - STF32 -> pokeByteOff ptr off (x :: Float) >> k - STF64 -> pokeByteOff ptr off (x :: Double) >> k - STBool -> pokeByteOff ptr off (fromIntegral (fromEnum x) :: Word8) >> k - (STAccum{}, _) -> error "Cannot serialise accumulators" - --- | Assumes that this is called at the correct alignment. -deserialise :: STy t -> Ptr () -> Int -> IO (Rep t) -deserialise topty ptr off = - -- TODO: this code is quadratic in the depth of the type because of the alignment/sizeOf calls - case topty of - STNil -> return () - STPair a b -> do - x <- deserialise a ptr off - y <- deserialise b ptr (align (alignmentSTy b) (off + sizeofSTy a)) - return (x, y) - STEither a b -> do - tag <- peekByteOff @Word8 ptr off - if tag == 0 -- alignment of (union {a b}) is the same as alignment of (a + b) - then Left <$> deserialise a ptr (off + alignmentSTy topty) - else Right <$> deserialise b ptr (off + alignmentSTy topty) - STLEither a b -> do - tag <- peekByteOff @Word8 ptr off - case tag of -- alignment of (union {a b}) is the same as alignment of (a + b) - 0 -> return Nothing - 1 -> Just . Left <$> deserialise a ptr (off + alignmentSTy topty) - 2 -> Just . Right <$> deserialise b ptr (off + alignmentSTy topty) - _ -> error "Invalid tag value" - STMaybe t -> do - tag <- peekByteOff @Word8 ptr off - if tag == 0 - then return Nothing - else Just <$> deserialise t ptr (off + alignmentSTy t) - STArr n t -> do - bufptr <- peekByteOff @(Ptr ()) ptr off - sh <- peekShape ptr (off + 8) n - refc <- peekByteOff @Word64 bufptr 0 - when debugRefc $ - hPutStrLn stderr $ "[chad-deserialise] Got buffer " ++ showPtr bufptr ++ " at refc=" ++ show refc - let eltsz = sizeofSTy t - arr <- Array sh <$> V.generateM (shapeSize sh) (\i -> deserialise t bufptr (8 + i * eltsz)) - when (refc < 2 ^ 62) $ free bufptr - return arr - STScal sty -> case sty of - STI32 -> peekByteOff @Int32 ptr off - STI64 -> peekByteOff @Int64 ptr off - STF32 -> peekByteOff @Float ptr off - STF64 -> peekByteOff @Double ptr off - STBool -> toEnum . fromIntegral <$> peekByteOff @Word8 ptr off - STAccum{} -> error "Cannot serialise accumulators" - -align :: Int -> Int -> Int -align a off = (off + a - 1) `div` a * a - -alignmentSTy :: STy t -> Int -alignmentSTy = fst . metricsSTy - -sizeofSTy :: STy t -> Int -sizeofSTy = snd . metricsSTy - --- | Returns (alignment, sizeof) -metricsSTy :: STy t -> (Int, Int) -metricsSTy STNil = (1, 0) -metricsSTy (STPair a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, align (max a1 a2) (s1 + s2)) -metricsSTy (STEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned -metricsSTy (STLEither a b) = - let (a1, s1) = metricsSTy a - (a2, s2) = metricsSTy b - in (max a1 a2, max a1 a2 + max s1 s2) -- the union after the tag byte is aligned -metricsSTy (STMaybe t) = - let (a, s) = metricsSTy t - in (a, a + s) -- the union after the tag byte is aligned -metricsSTy (STArr n _) = (8, 8 + 8 * fromSNat n) -metricsSTy (STScal sty) = case sty of - STI32 -> (4, 4) - STI64 -> (8, 8) - STF32 -> (4, 4) - STF64 -> (8, 8) - STBool -> (1, 1) -- compiled to uint8_t -metricsSTy (STAccum t) = metricsSTy (fromSMTy t) - -pokeShape :: Ptr () -> Int -> SNat n -> Shape n -> IO () -pokeShape ptr off = go . fromSNat - where - go :: Int -> Shape n -> IO () - go rank = \case - ShNil -> return () - sh `ShCons` n -> do - pokeByteOff ptr (off + (rank - 1) * 8) (fromIntegral n :: Int64) - go (rank - 1) sh - -peekShape :: Ptr () -> Int -> SNat n -> IO (Shape n) -peekShape ptr off = \case - SZ -> return ShNil - SS n -> ShCons <$> peekShape ptr off n - <*> (fromIntegral <$> peekByteOff @Int64 ptr (off + (fromSNat n) * 8)) - -compile' :: SList (Const String) env -> Ex env t -> CompM CExpr -compile' env = \case - EVar _ t i -> do - let Const var = slistIdx env i - incrementVarAlways "var" Increment t var - return $ CELit var - - ELet _ rhs body -> do - var <- compileAssign "" env rhs - rete <- compile' (Const var `SCons` env) body - incrementVarAlways "let" Decrement (typeOf rhs) var - return rete - - EPair _ a b -> do - name <- emitStruct (STPair (typeOf a) (typeOf b)) - e1 <- compile' env a - e2 <- compile' env b - return $ CEStruct name [("a", e1), ("b", e2)] - - EFst _ e -> do - let STPair _ t2 = typeOf e - e' <- compile' env e - case incrementVar "fst" Decrement t2 of - Nothing -> return $ CEProj e' "a" - Just f -> do var <- genName - emit $ SVarDecl True (repSTy (typeOf e)) var e' - f (var ++ ".b") - return $ CEProj (CELit var) "a" - - ESnd _ e -> do - let STPair t1 _ = typeOf e - e' <- compile' env e - case incrementVar "snd" Decrement t1 of - Nothing -> return $ CEProj e' "b" - Just f -> do var <- genName - emit $ SVarDecl True (repSTy (typeOf e)) var e' - f (var ++ ".a") - return $ CEProj (CELit var) "b" - - ENil _ -> do - name <- emitStruct STNil - return $ CEStruct name [] - - EInl _ t e -> do - name <- emitStruct (STEither (typeOf e) t) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "0"), ("l", e1)] - - EInr _ t e -> do - name <- emitStruct (STEither t (typeOf e)) - e2 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("r", e2)] - - ECase _ (EOp _ OIf e) a b -> do - e1 <- compile' env e - (e2, stmts2) <- scope $ compile' (Const undefined `SCons` env) a -- don't access that nil, stupid you - (e3, stmts3) <- scope $ compile' (Const undefined `SCons` env) b - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SIf e1 - (stmts2 <> pure (SAsg retvar e2)) - (stmts3 <> pure (SAsg retvar e3)) - return (CELit retvar) - - ECase _ e a b -> do - let STEither t1 t2 = typeOf e - e1 <- compile' env e - var <- genName - -- I know those are not variable names, but it's fine, probably - (e2, stmts2) <- scope $ compile' (Const (var ++ ".l") `SCons` env) a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".r") `SCons` env) b - ((), stmtsRel1) <- scope $ incrementVarAlways "case1" Decrement t1 (var ++ ".l") - ((), stmtsRel2) <- scope $ incrementVarAlways "case2" Decrement t2 (var ++ ".r") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 - <> stmtsRel1 - <> pure (SAsg retvar e2)) - (stmts3 - <> stmtsRel2 - <> pure (SAsg retvar e3)))) - return (CELit retvar) - - ENothing _ t -> do - name <- emitStruct (STMaybe t) - return $ CEStruct name [("tag", CELit "0")] - - EJust _ e -> do - name <- emitStruct (STMaybe (typeOf e)) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("j", e1)] - - EMaybe _ a b e -> do - let STMaybe t = typeOf e - e1 <- compile' env e - var <- genName - (e2, stmts2) <- scope $ compile' env a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".j") `SCons` env) b - ((), stmtsRel) <- scope $ incrementVarAlways "maybe" Decrement t (var ++ ".j") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 - <> pure (SAsg retvar e2)) - (stmts3 - <> stmtsRel - <> pure (SAsg retvar e3)))) - return (CELit retvar) - - ELNil _ t1 t2 -> do - name <- emitStruct (STLEither t1 t2) - return $ CEStruct name [("tag", CELit "0")] - - ELInl _ t e -> do - name <- emitStruct (STLEither (typeOf e) t) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "1"), ("l", e1)] - - ELInr _ t e -> do - name <- emitStruct (STLEither t (typeOf e)) - e1 <- compile' env e - return $ CEStruct name [("tag", CELit "2"), ("r", e1)] - - ELCase _ e a b c -> do - let STLEither t1 t2 = typeOf e - e1 <- compile' env e - var <- genName - (e2, stmts2) <- scope $ compile' env a - (e3, stmts3) <- scope $ compile' (Const (var ++ ".l") `SCons` env) b - (e4, stmts4) <- scope $ compile' (Const (var ++ ".r") `SCons` env) c - ((), stmtsRel1) <- scope $ incrementVarAlways "lcase1" Decrement t1 (var ++ ".l") - ((), stmtsRel2) <- scope $ incrementVarAlways "lcase2" Decrement t2 (var ++ ".r") - retvar <- genName - emit $ SVarDeclUninit (repSTy (typeOf a)) retvar - emit $ SBlock (pure (SVarDecl True (repSTy (typeOf e)) var e1) - <> pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "0")) - (stmts2 <> pure (SAsg retvar e2)) - (pure (SIf (CEBinop (CEProj (CELit var) "tag") "==" (CELit "1")) - (stmts3 <> stmtsRel1 <> pure (SAsg retvar e3)) - (stmts4 <> stmtsRel2 <> pure (SAsg retvar e4)))))) - return (CELit retvar) - - EConstArr _ n t (Array sh vec) -> do - (strname, bufstrname) <- emitArrStruct (STArr n (STScal t)) - tldname <- genName' "carraybuf" - -- Give it a refcount of _half_ the size_t max, so that it can be - -- incremented and decremented at will and will "never" reach anything - -- where something happens - emitTLD $ "static " ++ bufstrname ++ " " ++ tldname ++ " = " ++ - "(" ++ bufstrname ++ "){.refc = (size_t)1<<63, " ++ - ".xs = {" ++ intercalate "," (map (compileScal False t) (toList vec)) ++ "}};" - return (CEStruct strname - [("buf", CEAddrOf (CELit tldname)) - ,("sh", CELit ("{" ++ intercalate "," (map show (shapeToList sh)) ++ "}"))]) - - EBuild _ n esh efun -> do - shname <- compileAssign "sh" env esh - - arrname <- allocArray "build" Malloc "arr" n (typeOf efun) Nothing (indexTupleComponents n shname) - - idxargname <- genName' "ix" - (funretval, funstmts) <- scope $ compile' (Const idxargname `SCons` env) efun - - linivar <- genName' "li" - ivars <- replicateM (fromSNat n) (genName' "i") - emit $ SBlock $ - pure (SVarDecl False "size_t" linivar (CELit "0")) - <> compose [pure . SLoop (repSTy tIx) ivar (CELit "0") - (CECast (repSTy tIx) (CEIndex (CELit (arrname ++ ".sh")) (CELit (show dimidx)))) - | (ivar, dimidx) <- zip ivars [0::Int ..]] - (pure (SVarDecl True (repSTy (typeOf esh)) idxargname - (shapeTupFromLitVars n ivars)) - <> funstmts - <> pure (SAsg (arrname ++ ".buf->xs[" ++ linivar ++ "++]") funretval)) - - return (CELit arrname) - - -- TODO: actually generate decent code here - EMap _ e1 e2 -> do - let STArr n _ = typeOf e2 - compile' env $ - elet e2 $ - EBuild ext n (EShape ext (evar IZ)) $ - elet (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) $ - weakenExpr (WCopy (WSink .> WSink)) e1 - - EFold1Inner _ commut efun ex0 earr -> do - let STArr (SS n) t = typeOf earr - - -- let vecwid = case commut of Commut -> 8 :: Int - -- Noncommut -> 1 - - x0name <- compileAssign "foldx0" env ex0 - arrname <- compileAssign "foldarr" env earr - - zeroRefcountCheck (typeOf earr) "fold1i" arrname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, which is - -- unexpected. But it's exactly what we want, so we do it anyway. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n arrname) - - resname <- allocArray "fold" Malloc "foldres" n t (Just (CELit shszname)) (compileArrShapeComponents n arrname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - ((), x0incrStmts) <- scope $ incrementVarAlways "foldx0" Increment t x0name - - ivar <- genName' "i" - jvar <- genName' "j" - -- kvar <- if vecwid > 1 then genName' "k" else return "" - - accvar <- genName' "tot" - pairvar <- genName' "pair" -- function input - (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun - - let arreltlit = arrname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ - ({- if vecwid > 1 then show vecwid ++ " * " ++ jvar ++ " + " ++ kvar else -} jvar) ++ "]" - ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldelt" Increment t arreltlit - - pairstrname <- emitStruct (STPair t t) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - pure (SVarDecl False (repSTy t) accvar (CELit x0name)) - <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) - <> funStmts - <> pure (SAsg accvar funres)) - <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldx0" Decrement t x0name - incrementVarAlways "foldarr" Decrement (typeOf earr) arrname - - return (CELit resname) - - ESum1Inner _ e -> do - let STArr (SS n) t = typeOf e - argname <- compileAssign "sumarg" env e - - zeroRefcountCheck (typeOf e) "sum1i" argname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, like EFold1Inner. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray "sum" Malloc "sumres" n t (Just (CELit shszname)) (compileArrShapeComponents n argname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - let vecwid = 8 :: Int - ivar <- genName' "i" - jvar <- genName' "j" - kvar <- genName' "k" - accvar <- genName' "tot" - let nchunks = CEBinop (CELit lenname) "/" (CELit (show vecwid)) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList - -- we have ScalIsNumeric, so it has 0 and (+) in C - [SVerbatim $ repSTy t ++ " " ++ accvar ++ "[" ++ show vecwid ++ "] = {" ++ intercalate "," (replicate vecwid "0") ++ "};" - ,SLoop (repSTy tIx) jvar (CELit "0") nchunks $ - pure $ SLoop (repSTy tIx) kvar (CELit "0") (CELit (show vecwid)) $ - pure $ SVerbatim $ accvar ++ "[" ++ kvar ++ "] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ show vecwid ++ " * " ++ jvar ++ " + " ++ kvar ++ "];" - ,SLoop (repSTy tIx) kvar (CELit "1") (CELit (show vecwid)) $ - pure $ SVerbatim $ accvar ++ "[0] += " ++ accvar ++ "[" ++ kvar ++ "];" - ,SLoop (repSTy tIx) kvar (CEBinop nchunks "*" (CELit (show vecwid))) (CELit lenname) $ - pure $ SVerbatim $ accvar ++ "[0] += " ++ argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ kvar ++ "];" - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit (accvar++"[0]"))] - - incrementVarAlways "sum" Decrement (typeOf e) argname - - return (CELit resname) - - EUnit _ e -> do - e' <- compile' env e - let typ = STArr SZ (typeOf e) - strname <- emitStruct typ - name <- genName - emit $ SVarDecl True strname name (CEStruct strname - [("buf", CECall "malloc_instr" [CELit (show (8 + sizeofSTy (typeOf e)))])]) - emit $ SAsg (name ++ ".buf->refc") (CELit "1") - emit $ SAsg (name ++ ".buf->xs[0]") e' - return (CELit name) - - EReplicate1Inner _ elen earg -> do - let STArr n t = typeOf earg - lenname <- compileAssign "replen" env elen - argname <- compileAssign "reparg" env earg - - zeroRefcountCheck (typeOf earg) "replicate1i" argname - - shszname <- genName' "shsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray "repl1i" Malloc "rep" (SS n) t - (Just (CEBinop (CELit shszname) "*" (CELit lenname))) - (compileArrShapeComponents n argname ++ [CELit lenname]) - - ivar <- genName' "i" - jvar <- genName' "j" - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - pure $ SAsg (resname ++ ".buf->xs[" ++ ivar ++ " * " ++ lenname ++ " + " ++ jvar ++ "]") - (CELit (argname ++ ".buf->xs[" ++ ivar ++ "]")) - - incrementVarAlways "repl1i" Decrement (typeOf earg) argname - - return (CELit resname) - - EMaximum1Inner _ e -> compileExtremum "max" "maximum1i" ">" env e - - EMinimum1Inner _ e -> compileExtremum "min" "minimum1i" "<" env e - - EReshape _ dim esh earg -> do - let STArr origDim eltty = typeOf earg - strname <- emitStruct (STArr dim eltty) - - shname <- compileAssign "reshsh" env esh - arrname <- compileAssign "resharg" env earg - - when emitChecks $ do - emit $ SIf (CEBinop (compileArrShapeSize origDim arrname) "!=" (CECast "size_t" (prodExpr (indexTupleComponents dim shname)))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: reshape on unequal sizes (%zu <- %zu)\\n\", " ++ - printCExpr 0 (prodExpr (indexTupleComponents dim shname)) ", " ++ - printCExpr 0 (compileArrShapeSize origDim arrname) "); return false;") - mempty - - return (CEStruct strname - [("buf", CEProj (CELit arrname) "buf") - ,("sh", CELit ("{" ++ intercalate ", " [printCExpr 0 e "" | e <- indexTupleComponents dim shname] ++ "}"))]) - - -- TODO: actually generate decent code here - EZip _ e1 e2 -> do - let STArr n _ = typeOf e1 - compile' env $ - elet e1 $ - elet (weakenExpr WSink e2) $ - EBuild ext n (EShape ext (evar (IS IZ))) $ - EPair ext (EIdx ext (evar (IS (IS IZ))) (EVar ext (tTup (sreplicate n tIx)) IZ)) - (EIdx ext (evar (IS IZ)) (EVar ext (tTup (sreplicate n tIx)) IZ)) - - EFold1InnerD1 _ commut efun ex0 earr -> do - let STArr (SS n) t = typeOf earr - STPair _ bty = typeOf efun - - x0name <- compileAssign "foldd1x0" env ex0 - arrname <- compileAssign "foldd1arr" env earr - - zeroRefcountCheck (typeOf earr) "fold1iD1" arrname - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (arrname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - shsz1name <- genName' "shszN" - emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n arrname) -- take init of arr's shape - shsz2name <- genName' "shszSN" - emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) - - resname <- allocArray "foldd1" Malloc "foldd1res" n t (Just (CELit shsz1name)) (compileArrShapeComponents n arrname) - storesname <- allocArray "foldd1" Malloc "foldd1stores" (SS n) bty (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) arrname) - - ((), x0incrStmts) <- scope $ incrementVarAlways "foldd1x0" Increment t x0name - - ivar <- genName' "i" - jvar <- genName' "j" - - accvar <- genName' "tot" - pairvar <- genName' "pair" -- function input - (funres, funStmts) <- scope $ compile' (Const pairvar `SCons` env) efun - let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ jvar - arreltlit = arrname ++ ".buf->xs[" ++ eltidx ++ "]" - funresvar <- genName' "res" - ((), arreltIncrStmts) <- scope $ incrementVarAlways "foldd1elt" Increment t arreltlit - - pairstrname <- emitStruct (STPair t t) - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ - pure (SVarDecl False (repSTy t) accvar (CELit x0name)) - <> x0incrStmts -- we're copying x0 here - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the array element - -- and the accumulator. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the array element. - arreltIncrStmts - <> pure (SVarDecl True pairstrname pairvar (CEStruct pairstrname [("a", CELit accvar), ("b", CELit arreltlit)])) - <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) - <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) - <> pure (SAsg (storesname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) - <> pure (SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldd1x0" Decrement t x0name - incrementVarAlways "foldd1arr" Decrement (typeOf earr) arrname - - strname <- emitStruct (STPair (STArr n t) (STArr (SS n) bty)) - return (CEStruct strname [("a", CELit resname), ("b", CELit storesname)]) - - EFold1InnerD2 _ commut efun estores ectg -> do - let STArr n t2 = typeOf ectg - STArr _ bty = typeOf estores - - storesname <- compileAssign "foldd2stores" env estores - ctgname <- compileAssign "foldd2ctg" env ectg - - zeroRefcountCheck (typeOf ectg) "fold1iD2" ctgname - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (storesname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - shsz1name <- genName' "shszN" - emit $ SVarDecl True (repSTy tIx) shsz1name (compileArrShapeSize n storesname) -- take init of the shape - shsz2name <- genName' "shszSN" - emit $ SVarDecl True (repSTy tIx) shsz2name (CEBinop (CELit shsz1name) "*" (CELit lenname)) - - x0ctgname <- allocArray "foldd2" Malloc "foldd2x0ctg" n t2 (Just (CELit shsz1name)) (compileArrShapeComponents n storesname) - outctgname <- allocArray "foldd2" Malloc "foldd2outctg" (SS n) t2 (Just (CELit shsz2name)) (compileArrShapeComponents (SS n) storesname) - - ivar <- genName' "i" - jvar <- genName' "j" - - accvar <- genName' "acc" - let eltidx = lenname ++ " * " ++ ivar ++ " + " ++ lenname ++ "-1 - " ++ jvar - storeseltlit = storesname ++ ".buf->xs[" ++ eltidx ++ "]" - ctgeltlit = ctgname ++ ".buf->xs[" ++ ivar ++ "]" - (funres, funStmts) <- scope $ compile' (Const accvar `SCons` Const storeseltlit `SCons` env) efun - funresvar <- genName' "res" - ((), storeseltIncrStmts) <- scope $ incrementVarAlways "foldd2selt" Increment bty storeseltlit - ((), ctgeltIncrStmts) <- scope $ incrementVarAlways "foldd2celt" Increment bty ctgeltlit - - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsz1name) $ - pure (SVarDecl False (repSTy t2) accvar (CELit ctgeltlit)) - <> ctgeltIncrStmts - -- we need to loop in reverse here, but we let jvar run in the - -- forward direction so that we can use SLoop. Note jvar is - -- reversed in eltidx above - <> (pure $ SLoop (repSTy tIx) jvar (CELit "0") (CELit lenname) $ - -- The combination function will consume the accumulator - -- and the stores element. The accumulator is replaced by - -- what comes out of the function anyway, so that's - -- fine, but we do need to increment the stores element. - storeseltIncrStmts - <> funStmts - <> pure (SVarDecl True (repSTy (typeOf efun)) funresvar funres) - <> pure (SAsg accvar (CEProj (CELit funresvar) "a")) - <> pure (SAsg (outctgname ++ ".buf->xs[" ++ eltidx ++ "]") (CEProj (CELit funresvar) "b"))) - <> pure (SAsg (x0ctgname ++ ".buf->xs[" ++ ivar ++ "]") (CELit accvar)) - - incrementVarAlways "foldd2stores" Decrement (STArr (SS n) bty) storesname - incrementVarAlways "foldd2ctg" Decrement (STArr n t2) ctgname - - strname <- emitStruct (STPair (STArr n t2) (STArr (SS n) t2)) - return (CEStruct strname [("a", CELit x0ctgname), ("b", CELit outctgname)]) - - EConst _ t x -> return $ CELit $ compileScal True t x - - EIdx0 _ e -> do - let STArr _ t = typeOf e - arrname <- compileAssign "" env e - zeroRefcountCheck (typeOf e) "idx0" arrname - name <- genName - emit $ SVarDecl True (repSTy t) name - (CEIndex (CEPtrProj (CEProj (CELit arrname) "buf") "xs") (CELit "0")) - incrementVarAlways "idx0" Decrement (STArr SZ t) arrname - return (CELit name) - - -- EIdx1 _ a b -> error "TODO" -- EIdx1 ext (compile' a) (compile' b) - - EIdx _ earr eidx -> do - let STArr n t = typeOf earr - arrname <- compileAssign "ixarr" env earr - zeroRefcountCheck (typeOf earr) "idx" arrname - idxname <- if fromSNat n > 0 -- prevent an unused-varable warning - then compileAssign "ixix" env eidx - else return "" -- won't be used in this case - - when emitChecks $ - forM_ (zip [0::Int ..] (indexTupleComponents n idxname)) $ \(i, ixcomp) -> - emit $ SIf (CEBinop (CEBinop ixcomp "<" (CELit "0")) "||" - (CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (arrname ++ ".sh[" ++ show i ++ "]"))))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: index out of range (arr=%p)\\n\", " ++ - arrname ++ ".buf); return false;") - mempty - - resname <- genName' "ixres" - emit $ SVarDecl True (repSTy t) resname (CEIndex (CELit (arrname ++ ".buf->xs")) (toLinearIdx n arrname idxname)) - incrementVarAlways "idxelt" Increment t resname - incrementVarAlways "idx" Decrement (STArr n t) arrname - return (CELit resname) - - EShape _ e -> do - let STArr n _ = typeOf e - t = tTup (sreplicate n tIx) - _ <- emitStruct t - name <- compileAssign "" env e - zeroRefcountCheck (typeOf e) "shape" name - resname <- genName - emit $ SVarDecl True (repSTy t) resname (compileShapeQuery n name) - incrementVarAlways "shape" Decrement (typeOf e) name - return (CELit resname) - - EOp _ op (EPair _ e1 e2) -> do - e1' <- compile' env e1 - e2' <- compile' env e2 - compileOpPair op e1' e2' - - EOp _ op e -> do - e' <- compile' env e - compileOpGeneral op e' - - ECustom _ _ _ _ earg _ _ e1 e2 -> do - name1 <- compileAssign "" env e1 - name2 <- compileAssign "" env e2 - case (incrementVar "custom1" Decrement (typeOf e1), incrementVar "custom2" Decrement (typeOf e2)) of - (Nothing, Nothing) -> compile' (Const name2 `SCons` Const name1 `SCons` SNil) earg - (mfun1, mfun2) -> do - name <- compileAssign "" (Const name2 `SCons` Const name1 `SCons` SNil) earg - maybe (return ()) ($ name1) mfun1 - maybe (return ()) ($ name2) mfun2 - return (CELit name) - - ERecompute _ e -> compile' env e - - EWith _ t e1 e2 -> do - actyname <- emitStruct (STAccum t) - name1 <- compileAssign "" env e1 - - zeroRefcountCheck (typeOf e1) "with" name1 - - emit $ SVerbatim $ "// copyForWriting start (" ++ name1 ++ ")" - mcopy <- copyForWriting t name1 - accname <- genName' "accum" - emit $ SVarDecl False actyname accname - (CEStruct actyname [("buf", CECall "malloc_instr" [CELit (show (sizeofSTy (fromSMTy t)))])]) - emit $ SAsg (accname++".buf->ac") (maybe (CELit name1) id mcopy) - emit $ SVerbatim $ "// initial accumulator constructed (" ++ name1 ++ ")." - - e2' <- compile' (Const accname `SCons` env) e2 - - resname <- genName' "acret" - emit $ SVarDecl True (repSTy (fromSMTy t)) resname (CELit (accname++".buf->ac")) - emit $ SVerbatim $ "free_instr(" ++ accname ++ ".buf);" - - rettyname <- emitStruct (STPair (typeOf e2) (fromSMTy t)) - return $ CEStruct rettyname [("a", e2'), ("b", CELit resname)] - - EAccum _ t prj eidx sparsity eval eacc | Just Refl <- isDense (acPrjTy prj t) sparsity -> do - let -- Add a value (s) into an existing accumulation value (d). If a sparse - -- component of d is encountered, s is copied there. - add :: SMTy a -> String -> String -> CompM () - add SMTNil _ _ = return () - add (SMTPair t1 t2) d s = do - add t1 (d++".a") (s++".a") - add t2 (d++".b") (s++".b") - add (SMTLEither t1 t2) d s = do - ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTLEither t1 t2)) s - ((), stmts1) <- scope $ add t1 (d++".l") (s++".l") - ((), stmts2) <- scope $ add t2 (d++".r") (s++".r") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s)) - <> srcIncrStmts) - ((if emitChecks - then pure (SIf (CEBinop (CEBinop (CELit (s++".tag")) "!=" (CELit "0")) - "&&" - (CEBinop (CELit (s++".tag")) "!=" (CELit (d++".tag")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum add leither with different tags " ++ - "(dest %d, src %d)\\n\", (int)" ++ d ++ ".tag, (int)" ++ s ++ ".tag); " ++ - "return false;") - mempty) - else mempty) - -- note: s may have tag 0 - <> pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) - stmts1 - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "2")) - stmts2 mempty)))) - add (SMTMaybe t1) d s = do - ((), srcIncrStmts) <- scope $ incrementVarAlways "accumadd" Increment (fromSMTy (SMTMaybe t1)) s - ((), stmts1) <- scope $ add t1 (d++".j") (s++".j") - emit $ SIf (CEBinop (CELit (d++".tag")) "==" (CELit "0")) - (pure (SAsg d (CELit s)) - <> srcIncrStmts) - (pure (SIf (CEBinop (CELit (s++".tag")) "==" (CELit "1")) stmts1 mempty)) - add (SMTArr n t1) d s = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ [0 .. fromSNat n - 1] $ \j -> do - emit $ SIf (CEBinop (CELit (s ++ ".sh[" ++ show j ++ "]")) - "!=" - (CELit (d ++ ".sh[" ++ show j ++ "]"))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum add incorrect (d=%p, " ++ - "dsh=" ++ shfmt ++ ", s=%p, ssh=" ++ shfmt ++ ")\\n\", " ++ - d ++ ".buf" ++ - concat [", " ++ d ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - ", " ++ s ++ ".buf" ++ - concat [", " ++ s ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - "); " ++ - "return false;") - mempty - - shsizename <- genName' "acshsz" - emit $ SVarDecl True (repSTy tIx) shsizename (compileArrShapeSize n s) - ivar <- genName' "i" - ((), stmts1) <- scope $ add t1 (d++".buf->xs["++ivar++"]") (s++".buf->xs["++ivar++"]") - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shsizename) - stmts1 - add (SMTScal _) d s = emit $ SVerbatim $ d ++ " += " ++ s ++ ";" - - let -- | Dereference an accumulation value and add a given value to that - -- position. Sparse components encountered along the way are - -- initialised before proceeding downwards. - -- accumRef (type) (projection) (accumulation component) (AcIdx variable) (value to accumulate there) - accumRef :: SMTy a -> SAcPrj p a b -> String -> String -> String -> CompM () - accumRef _ SAPHere v _ addend = add (acPrjTy prj t) v addend - - accumRef (SMTPair ta _) (SAPFst prj') v i addend = accumRef ta prj' (v++".a") i addend - accumRef (SMTPair _ tb) (SAPSnd prj') v i addend = accumRef tb prj' (v++".b") i addend - - accumRef (SMTLEither ta _) (SAPLeft prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +left)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef ta prj' (v++".l") i addend - accumRef (SMTLEither _ tb) (SAPRight prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "2")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (leither tag=%d, +right)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef tb prj' (v++".r") i addend - - accumRef (SMTMaybe tj) (SAPJust prj') v i addend = do - when emitChecks $ do - emit $ SIf (CEBinop (CELit (v++".tag")) "!=" (CELit "1")) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (maybe tag=%d, +just)\\n\", " ++ v ++ ".tag); " ++ - "return false;") - mempty - accumRef tj prj' (v++".j") i addend - - accumRef (SMTArr n t') (SAPArrIdx prj') v i addend = do - when emitChecks $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - forM_ (zip [0::Int ..] - (indexTupleComponents n (i++".a"))) $ \(j, ixcomp) -> do - let a .||. b = CEBinop a "||" b - emit $ SIf (CEBinop ixcomp "<" (CELit "0") - .||. - CEBinop ixcomp ">=" (CECast (repSTy tIx) (CELit (v ++ ".sh[" ++ show j ++ "]")))) - (pure $ SVerbatim $ - "fprintf(stderr, PRTAG \"CHECK: accum prj incorrect (arr=%p, " ++ - "arrsh=" ++ shfmt ++ ", acix=" ++ shfmt ++ ", acsh=(D))\\n\", " ++ - v ++ ".buf" ++ - concat [", " ++ v ++ ".sh[" ++ show j' ++ "]" | j' <- [0 .. fromSNat n - 1]] ++ - concat [", " ++ printCExpr 2 comp "" | comp <- indexTupleComponents n (i++".a")] ++ - "); " ++ - "return false;") - mempty - - accumRef t' prj' (v++".buf->xs[" ++ printCExpr 0 (toLinearIdx n v (i++".a")) "]") (i++".b") addend - - nameidx <- compileAssign "acidx" env eidx - nameval <- compileAssign "acval" env eval - nameacc <- compileAssign "acac" env eacc - - emit $ SVerbatim $ "// compile EAccum start (" ++ show prj ++ ")" - accumRef t prj (nameacc++".buf->ac") nameidx nameval - emit $ SVerbatim $ "// compile EAccum end" - - incrementVarAlways "accumendsrc" Decrement (typeOf eval) nameval - - return $ CEStruct (repSTy STNil) [] - - EAccum{} -> - error "Compile: EAccum with non-trivial sparsity should have been eliminated (use AST.UnMonoid)" - - EError _ t s -> do - let padleft len c s' = replicate (len - length s) c ++ s' - escape = concatMap $ \c -> if | c `elem` "\"\\" -> ['\\',c] - | ord c < 32 -> "\\x" ++ padleft 2 '0' (showHex (ord c) "") - | otherwise -> [c] - emit $ SVerbatim $ "fputs(\"ERROR: " ++ escape s ++ "\\n\", stderr); return false;" - case t of - STScal _ -> return (CELit "0") - _ -> do - name <- emitStruct t - return $ CEStruct name [] - - EZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EDeepZero{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EPlus{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - EOneHot{} -> error "Compile: monoid operations should have been eliminated (use AST.UnMonoid)" - - EIdx1{} -> error "Compile: not implemented: EIdx1" - -compileAssign :: String -> SList (Const String) env -> Ex env t -> CompM String -compileAssign prefix env e = do - e' <- compile' env e - case e' of - CELit name -> return name - _ -> do - name <- genName' prefix - emit $ SVarDecl True (repSTy (typeOf e)) name e' - return name - -data Increment = Increment | Decrement - deriving (Show) - --- | Increment reference counts in the components of the given variable. -incrementVar :: String -> Increment -> STy a -> Maybe (String -> CompM ()) -incrementVar marker inc ty = - let tree = makeArrayTree ty - in case tree of ATNoop -> Nothing - _ -> Just $ \var -> incrementVar' marker inc var tree - -incrementVarAlways :: String -> Increment -> STy a -> String -> CompM () -incrementVarAlways marker inc ty var = maybe (pure ()) ($ var) (incrementVar marker inc ty) - -data ArrayTree = ATArray (Some SNat) (Some STy) -- ^ we've arrived at an array we need to decrement the refcount of (contains rank and element type of the array) - | ATNoop -- ^ don't do anything here - | ATProj String ArrayTree -- ^ descend one field deeper - | ATCondTag ArrayTree ArrayTree -- ^ if tag is 0, first; if 1, second - | ATCond3Tag ArrayTree ArrayTree ArrayTree -- ^ if tag is: 0, 1, 2 - | ATBoth ArrayTree ArrayTree -- ^ do both these paths - -smartATProj :: String -> ArrayTree -> ArrayTree -smartATProj _ ATNoop = ATNoop -smartATProj field t = ATProj field t - -smartATCondTag :: ArrayTree -> ArrayTree -> ArrayTree -smartATCondTag ATNoop ATNoop = ATNoop -smartATCondTag t t' = ATCondTag t t' - -smartATCond3Tag :: ArrayTree -> ArrayTree -> ArrayTree -> ArrayTree -smartATCond3Tag ATNoop ATNoop ATNoop = ATNoop -smartATCond3Tag t1 t2 t3 = ATCond3Tag t1 t2 t3 - -smartATBoth :: ArrayTree -> ArrayTree -> ArrayTree -smartATBoth ATNoop t = t -smartATBoth t ATNoop = t -smartATBoth t t' = ATBoth t t' - -makeArrayTree :: STy a -> ArrayTree -makeArrayTree STNil = ATNoop -makeArrayTree (STPair a b) = smartATBoth (smartATProj "a" (makeArrayTree a)) - (smartATProj "b" (makeArrayTree b)) -makeArrayTree (STEither a b) = smartATCondTag (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STLEither a b) = smartATCond3Tag ATNoop - (smartATProj "l" (makeArrayTree a)) - (smartATProj "r" (makeArrayTree b)) -makeArrayTree (STMaybe t) = smartATCondTag ATNoop (smartATProj "j" (makeArrayTree t)) -makeArrayTree (STArr n t) = ATArray (Some n) (Some t) -makeArrayTree (STScal _) = ATNoop -makeArrayTree (STAccum _) = ATNoop - -incrementVar' :: String -> Increment -> String -> ArrayTree -> CompM () -incrementVar' marker inc path (ATArray (Some n) (Some eltty)) = - case inc of - Increment -> do - emit $ SVerbatim (path ++ ".buf->refc++;") - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p in+ -> %zu <" ++ marker ++ ">\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc);" - Decrement -> do - case incrementVar (marker++".elt") Decrement eltty of - Nothing -> - if debugRefc - then do - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ ">\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) { fprintf(stderr, \"; free(\"); free_instr(" ++ path ++ ".buf); fprintf(stderr, \") ok\\n\"); } else fprintf(stderr, \"\\n\");" - else do - emit $ SVerbatim $ "if (--" ++ path ++ ".buf->refc == 0) free_instr(" ++ path ++ ".buf);" - Just f -> do - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p de- -> %zu <" ++ marker ++ "> recfree\\n\", " ++ path ++ ".buf, " ++ path ++ ".buf->refc - 1);" - shszvar <- genName' "frshsz" - ivar <- genName' "i" - ((), eltDecrStmts) <- scope $ f (path ++ ".buf->xs[" ++ ivar ++ "]") - emit $ SIf (CELit ("--" ++ path ++ ".buf->refc == 0")) - (BList [SVarDecl True "size_t" shszvar (compileArrShapeSize n path) - ,SLoop "size_t" ivar (CELit "0") (CELit shszvar) $ - eltDecrStmts - ,SVerbatim $ "free_instr(" ++ path ++ ".buf);"]) - mempty -incrementVar' _ _ _ ATNoop = pure () -incrementVar' marker inc path (ATProj field t) = incrementVar' (marker++"."++field) inc (path ++ "." ++ field) t -incrementVar' marker inc path (ATCondTag t1 t2) = do - ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 - ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 - emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "0")) stmts1 stmts2 -incrementVar' marker inc path (ATCond3Tag t1 t2 t3) = do - ((), stmts1) <- scope $ incrementVar' (marker++".t1") inc path t1 - ((), stmts2) <- scope $ incrementVar' (marker++".t2") inc path t2 - ((), stmts3) <- scope $ incrementVar' (marker++".t3") inc path t3 - emit $ SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "1")) - stmts2 - (pure (SIf (CEBinop (CELit (path ++ ".tag")) "==" (CELit "2")) - stmts3 - stmts1)) -incrementVar' marker inc path (ATBoth t1 t2) = incrementVar' (marker++".1") inc path t1 >> incrementVar' (marker++".2") inc path t2 - -toLinearIdx :: SNat n -> String -> String -> CExpr -toLinearIdx SZ _ _ = CELit "0" -toLinearIdx (SS SZ) _ idxvar = CELit (idxvar ++ ".b") -toLinearIdx (SS n) arrvar idxvar = - CEBinop (CEBinop (toLinearIdx n arrvar (idxvar ++ ".a")) - "*" (CEIndex (CELit (arrvar ++ ".sh")) (CELit (show (fromSNat n))))) - "+" (CELit (idxvar ++ ".b")) - --- fromLinearIdx :: SNat n -> String -> String -> CompM CExpr --- fromLinearIdx SZ _ _ = return $ CEStruct (repSTy STNil) [] --- fromLinearIdx (SS n) arrvar idxvar = do --- name <- genName --- emit $ SVarDecl True (repSTy tIx) name (CEBinop (CELit idxvar) "/" (CELit (arrvar ++ ".sh[" ++ show (fromSNat n) ++ "]"))) --- _ - -data AllocMethod = Malloc | Calloc - deriving (Show) - --- | The shape must have the outer dimension at the head (and the inner dimension on the right). -allocArray :: HasCallStack => String -> AllocMethod -> String -> SNat n -> STy t -> Maybe CExpr -> [CExpr] -> CompM String -allocArray marker method nameBase rank eltty mshsz shape = do - when (length shape /= fromSNat rank) $ - error "allocArray: shape does not match rank" - let arrty = STArr rank eltty - strname <- emitStruct arrty - arrname <- genName' nameBase - shsz <- case mshsz of - Just e -> return e - Nothing -> return (foldl0' (\a b -> CEBinop a "*" b) (CELit "1") shape) - let nbytesExpr = CEBinop (CELit (show (fromSNat rank * 8 + 8))) - "+" - (CEBinop shsz "*" (CELit (show (sizeofSTy eltty)))) - emit $ SVarDecl True strname arrname $ CEStruct strname - [("buf", case method of Malloc -> CECall "malloc_instr" [nbytesExpr] - Calloc -> CECall "calloc_instr" [nbytesExpr]) - ,("sh", CELit ("{" ++ intercalate "," [printCExpr 0 dim "" | dim <- shape] ++ "}"))] - emit $ SAsg (arrname ++ ".buf->refc") (CELit "1") - when debugRefc $ - emit $ SVerbatim $ "fprintf(stderr, PRTAG \"arr %p allocated <" ++ marker ++ ">\\n\", " ++ arrname ++ ".buf);" - return arrname - -compileShapeQuery :: SNat n -> String -> CExpr -compileShapeQuery SZ _ = CEStruct (repSTy STNil) [] -compileShapeQuery (SS n) var = - CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) - [("a", compileShapeQuery n var) - ,("b", CEIndex (CELit (var ++ ".sh")) (CELit (show (fromSNat n))))] - --- | Takes a variable name for the array, not the buffer. -compileArrShapeSize :: SNat n -> String -> CExpr -compileArrShapeSize n var = prodExpr (compileArrShapeComponents n var) - --- | Takes a variable name for the array, not the buffer. -compileArrShapeComponents :: SNat n -> String -> [CExpr] -compileArrShapeComponents n var = - [CELit (var ++ ".sh[" ++ show i ++ "]") | i <- [0 .. fromSNat n - 1]] - -indexTupleComponents :: SNat n -> String -> [CExpr] -indexTupleComponents = \n var -> map CELit (toList (go n var)) - where - go :: SNat n -> String -> Bag String - go SZ _ = mempty - go (SS n) var = go n (var ++ ".a") <> pure (var ++ ".b") - --- | Takes variable names with the innermost dimension on the right. -shapeTupFromLitVars :: SNat n -> [String] -> CExpr -shapeTupFromLitVars = \n -> go n . reverse - where - -- takes variables with the innermost dimension at the _head_ - go :: SNat n -> [String] -> CExpr - go SZ [] = CEStruct (repSTy STNil) [] - go (SS n) (var : vars) = CEStruct (repSTy (tTup (sreplicate (SS n) tIx))) [("a", go n vars), ("b", CELit var)] - go _ _ = error "shapeTupFromLitVars: SNat and list do not correspond" - -prodExpr :: [CExpr] -> CExpr -prodExpr = foldl0' (\a b -> CEBinop a "*" b) (CELit "1") - -compileOpGeneral :: SOp a b -> CExpr -> CompM CExpr -compileOpGeneral op e1 = do - let unary cop = return @CompM $ CECall cop [e1] - let binary cop = do - name <- genName - emit $ SVarDecl True (repSTy (opt1 op)) name e1 - return $ CEBinop (CEProj (CELit name) "a") cop (CEProj (CELit name) "b") - case op of - OAdd _ -> binary "+" - OMul _ -> binary "*" - ONeg _ -> unary "-" - OLt _ -> binary "<" - OLe _ -> binary "<=" - OEq _ -> binary "==" - ONot -> unary "!" - OAnd -> binary "&&" - OOr -> binary "||" - OIf -> do - name <- emitStruct (STEither STNil STNil) - _ <- emitStruct STNil - return $ CEIf e1 (CEStruct name [("tag", CELit "0")]) - (CEStruct name [("tag", CELit "1")]) - ORound64 -> unary "(int64_t)round" -- ew - OToFl64 -> unary "(double)" - ORecip _ -> return $ CEBinop (CELit "1.0") "/" e1 - OExp STF32 -> unary "expf" - OExp STF64 -> unary "exp" - OLog STF32 -> unary "logf" - OLog STF64 -> unary "log" - OIDiv _ -> binary "/" - OMod _ -> binary "%" - -compileOpPair :: SOp a b -> CExpr -> CExpr -> CompM CExpr -compileOpPair op e1 e2 = do - let binary cop = return @CompM $ CEBinop e1 cop e2 - case op of - OAdd _ -> binary "+" - OMul _ -> binary "*" - OLt _ -> binary "<" - OLe _ -> binary "<=" - OEq _ -> binary "==" - OAnd -> binary "&&" - OOr -> binary "||" - OIDiv _ -> binary "/" - OMod _ -> binary "%" - _ -> error "compileOpPair: got unary operator" - --- | Bool: whether to ensure that the literal itself already has the appropriate type -compileScal :: Bool -> SScalTy t -> ScalRep t -> String -compileScal pedantic typ x = case typ of - STI32 -> (if pedantic then "(int32_t)" else "") ++ show x - STI64 -> (if pedantic then "(int64_t)" else "") ++ show x - STF32 -> show x ++ "f" - STF64 -> show x - STBool -> if x then "1" else "0" - -compileExtremum :: String -> String -> String -> SList (Const String) env -> Ex env (TArr (S n) t) -> CompM CExpr -compileExtremum nameBase opName operator env e = do - let STArr (SS n) t = typeOf e - argname <- compileAssign (nameBase ++ "arg") env e - - zeroRefcountCheck (typeOf e) opName argname - - shszname <- genName' "shsz" - -- This n is one less than the shape of the thing we're querying, which is - -- unexpected. But it's exactly what we want, so we do it anyway. - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n argname) - - resname <- allocArray nameBase Malloc (nameBase ++ "res") n t (Just (CELit shszname)) (compileArrShapeComponents n argname) - - lenname <- genName' "n" - emit $ SVarDecl True (repSTy tIx) lenname - (CELit (argname ++ ".sh[" ++ show (fromSNat n) ++ "]")) - - emit $ SVerbatim $ "if (" ++ lenname ++ " == 0) { fprintf(stderr, \"Empty array in " ++ opName ++ "\\n\"); return false; }" - - ivar <- genName' "i" - jvar <- genName' "j" - xvar <- genName - redvar <- genName' "red" -- use "red", not "acc", to avoid confusion with accumulators - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ BList - -- we have ScalIsNumeric, so it has 1 and (<) etc. in C - [SVarDecl False (repSTy t) redvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ "]")) - ,SLoop (repSTy tIx) jvar (CELit "1") (CELit lenname) $ BList - [SVarDecl True (repSTy t) xvar (CELit (argname ++ ".buf->xs[" ++ lenname ++ " * " ++ ivar ++ " + " ++ jvar ++ "]")) - ,SAsg redvar $ CEIf (CEBinop (CELit xvar) operator (CELit redvar)) (CELit xvar) (CELit redvar) - ] - ,SAsg (resname ++ ".buf->xs[" ++ ivar ++ "]") (CELit redvar)] - - incrementVarAlways nameBase Decrement (typeOf e) argname - - return (CELit resname) - --- | If this returns Nothing, there was nothing to copy because making a simple --- value copy in C already makes it suitable to write to. -copyForWriting :: SMTy t -> String -> CompM (Maybe CExpr) -copyForWriting topty var = case topty of - SMTNil -> return Nothing - - SMTPair a b -> do - e1 <- copyForWriting a (var ++ ".a") - e2 <- copyForWriting b (var ++ ".b") - case (e1, e2) of - (Nothing, Nothing) -> return Nothing - _ -> return $ Just $ CEStruct toptyname - [("a", fromMaybe (CELit (var++".a")) e1) - ,("b", fromMaybe (CELit (var++".b")) e2)] - - SMTLEither a b -> do - (e1, stmts1) <- scope $ copyForWriting a (var ++ ".l") - (e2, stmts2) <- scope $ copyForWriting b (var ++ ".r") - case (e1, e2) of - (Nothing, Nothing) -> return Nothing - _ -> do - name <- genName - emit $ SVarDeclUninit toptyname name - emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (stmts1 - <> pure (SAsg name (CEStruct toptyname - [("tag", CELit "0"), ("l", fromMaybe (CELit (var++".l")) e1)]))) - (stmts2 - <> pure (SAsg name (CEStruct toptyname - [("tag", CELit "1"), ("r", fromMaybe (CELit (var++".r")) e2)]))) - return (Just (CELit name)) - - SMTMaybe t -> do - (e1, stmts1) <- scope $ copyForWriting t (var ++ ".j") - case e1 of - Nothing -> return Nothing - Just e1' -> do - name <- genName - emit $ SVarDeclUninit toptyname name - emit $ SIf (CEBinop (CELit (var++".tag")) "==" (CELit "0")) - (pure (SAsg name (CEStruct toptyname [("tag", CELit "0")]))) - (stmts1 - <> pure (SAsg name (CEStruct toptyname [("tag", CELit "1"), ("j", e1')]))) - return (Just (CELit name)) - - -- If there are no nested arrays, we know that a refcount of 1 means that the - -- whole thing is owned. Nested arrays have their own refcount, so with - -- nesting we'd have to check the refcounts of all the nested arrays _too_; - -- let's not do that. Furthermore, no sub-arrays means that the whole thing - -- is flat, and we can just memcpy if necessary. - SMTArr n t | not (typeHasArrays (fromSMTy t)) -> do - name <- genName - shszname <- genName' "shsz" - emit $ SVarDeclUninit toptyname name - - when debugShapes $ do - let shfmt = "[" ++ intercalate "," (replicate (fromSNat n) "%\"PRIi64\"") ++ "]" - emit $ SVerbatim $ - "fprintf(stderr, PRTAG \"with array " ++ shfmt ++ "\\n\"" ++ - concat [", " ++ var ++ ".sh[" ++ show i ++ "]" | i <- [0 .. fromSNat n - 1]] ++ - ");" - - emit $ SIf (CEBinop (CELit (var ++ ".buf->refc")) "==" (CELit "1")) - (pure (SAsg name (CELit var))) - (let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) - totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes - in BList - [SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - ,SAsg name (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - ,SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" - ,SAsg (name ++ ".buf->refc") (CELit "1") - ,SVerbatim $ "memcpy(" ++ name ++ ".buf->xs, " ++ var ++ ".buf->xs, " ++ - printCExpr 0 databytes ");"]) - return (Just (CELit name)) - - SMTArr n t -> do - shszname <- genName' "shsz" - emit $ SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n var) - - let shbytes = fromSNat n * 8 - databytes = CEBinop (CELit shszname) "*" (CELit (show (sizeofSTy (fromSMTy t)))) - totalbytes = CEBinop (CELit (show (shbytes + 8))) "+" databytes - - name <- genName - emit $ SVarDecl False toptyname name - (CEStruct toptyname [("buf", CECall "malloc_instr" [totalbytes])]) - emit $ SVerbatim $ "memcpy(" ++ name ++ ".sh, " ++ var ++ ".sh, " ++ show shbytes ++ ");" - emit $ SAsg (name ++ ".buf->refc") (CELit "1") - - -- put the arrays in variables to cut short the not-quite-var chain - dstvar <- genName' "cpydst" - emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") dstvar (CELit (name ++ ".buf->xs")) - srcvar <- genName' "cpysrc" - emit $ SVarDecl True (repSTy (fromSMTy t) ++ " *") srcvar (CELit (var ++ ".buf->xs")) - - ivar <- genName' "i" - - (cpye, cpystmts) <- scope $ copyForWriting t (srcvar ++ "[" ++ ivar ++ "]") - let cpye' = case cpye of - Just e -> e - Nothing -> error "copyForWriting: arrays cannot be copied as-is, bug" - - emit $ SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) $ - cpystmts - <> pure (SAsg (dstvar ++ "[" ++ ivar ++ "]") cpye') - - return (Just (CELit name)) - - SMTScal _ -> return Nothing - - where - toptyname = repSTy (fromSMTy topty) - -zeroRefcountCheck :: STy t -> String -> String -> CompM () -zeroRefcountCheck toptyp opname topvar = - when emitChecks $ do - mstmts <- onlyIdGen $ runMaybeT (go toptyp topvar) - case mstmts of - Nothing -> return () - Just stmts -> forM_ stmts emit - where - -- | If this returns 'Nothing', no statements need to be generated for this type. - go :: STy t -> String -> MaybeT IdGen.IdGen (Bag Stmt) - go STNil _ = empty - go (STPair a b) path = do - (s1, s2) <- combine (go a (path++".a")) (go b (path++".b")) - return (s1 <> s2) - go (STEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "0")) s1 s2 - go (STLEither a b) path = do - (s1, s2) <- combine (go a (path++".l")) (go b (path++".r")) - return $ pure $ - SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) - s1 - (pure (SIf (CEBinop (CELit (path++".tag")) "==" (CELit "2")) - s2 - mempty)) - go (STMaybe a) path = do - ss <- go a (path++".j") - return $ pure $ SIf (CEBinop (CELit (path++".tag")) "==" (CELit "1")) ss mempty - go (STArr n a) path = do - ivar <- genName' "i" - ss <- go a (path++".buf->xs["++ivar++"]") - shszname <- genName' "shsz" - let s1 = SVerbatim $ - "if (__builtin_expect(" ++ path ++ ".buf->refc == 0, 0)) { " ++ - "fprintf(stderr, PRTAG \"CHECK: '" ++ opname ++ "' got array " ++ - "%p with refc=0\\n\", " ++ path ++ ".buf); return false; }" - let s2 = SVarDecl True (repSTy tIx) shszname (compileArrShapeSize n path) - let s3 = SLoop (repSTy tIx) ivar (CELit "0") (CELit shszname) ss - return (BList [s1, s2, s3]) - go STScal{} _ = empty - go STAccum{} _ = error "zeroRefcountCheck: passed an accumulator" - - combine :: (Monoid a, Monoid b, Monad m) => MaybeT m a -> MaybeT m b -> MaybeT m (a, b) - combine (MaybeT a) (MaybeT b) = MaybeT $ do - x <- a - y <- b - return $ case (x, y) of - (Nothing, Nothing) -> Nothing - (Just x', Nothing) -> Just (x', mempty) - (Nothing, Just y') -> Just (mempty, y') - (Just x', Just y') -> Just (x', y') - -{-# NOINLINE uniqueIdGenRef #-} -uniqueIdGenRef :: IORef Int -uniqueIdGenRef = unsafePerformIO $ newIORef 1 - -compose :: Foldable t => t (a -> a) -> a -> a -compose = foldr (.) id - -showPtr :: Ptr a -> String -showPtr (Ptr a) = "0x" ++ showHex (integerFromWord# (int2Word# (addr2Int# a))) "" - --- | Type-restricted. -(^) :: Num a => a -> Int -> a -(^) = (Prelude.^) - -foldl0' :: (a -> a -> a) -> a -> [a] -> a -foldl0' _ x [] = x -foldl0' f _ l = foldl1' f l diff --git a/src/Compile/Exec.hs b/src/Compile/Exec.hs deleted file mode 100644 index ad4180f..0000000 --- a/src/Compile/Exec.hs +++ /dev/null @@ -1,99 +0,0 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE TupleSections #-} -module Compile.Exec ( - KernelLib, - buildKernel, - callKernelFun, - - -- * misc - lineNumbers, -) where - -import Control.Monad (when) -import Data.IORef -import Foreign (Ptr) -import Foreign.Ptr (FunPtr) -import System.Directory (removeDirectoryRecursive) -import System.Environment (lookupEnv) -import System.Exit (ExitCode(..)) -import System.IO (hPutStrLn, stderr) -import System.IO.Error (mkIOError, userErrorType) -import System.IO.Unsafe (unsafePerformIO) -import System.Posix.DynamicLinker -import System.Posix.Temp (mkdtemp) -import System.Process (readProcessWithExitCode) - - -debug :: Bool -debug = False - --- The IORef wrapper is required for the finalizer to attach properly (see the 'Weak' docs) -data KernelLib = KernelLib !(IORef (FunPtr (Ptr () -> IO ()))) - -buildKernel :: String -> String -> IO (KernelLib, String) -buildKernel csource funname = do - template <- (++ "/tmp.chad.") <$> getTempDir - path <- mkdtemp template - - let outso = path ++ "/out.so" - let args = ["-O3", "-march=native" - ,"-shared", "-fPIC" - ,"-std=c99", "-x", "c" - ,"-o", outso, "-" - ,"-Wall", "-Wextra" - ,"-Wno-unused-variable", "-Wno-unused-but-set-variable" - ,"-Wno-unused-parameter", "-Wno-unused-function" - ,"-Wno-alloc-size-larger-than" -- ideally we'd keep this, but gcc reports false positives - ,"-Wno-maybe-uninitialized"] -- maximum1i goes out of range if its input is empty, yes, don't complain - (ec, gccStdout, gccStderr) <- readProcessWithExitCode "gcc" args csource - - -- Print the source before the GCC output. - case ec of - ExitSuccess -> return () - ExitFailure{} -> hPutStrLn stderr $ "[chad] Kernel compilation failed! Source: <<<\n" ++ lineNumbers csource ++ ">>>" - - case ec of - ExitSuccess -> return () - ExitFailure{} -> do - removeDirectoryRecursive path - ioError (mkIOError userErrorType "chad kernel compilation failed" Nothing Nothing) - - numLoaded <- atomicModifyIORef' numLoadedCounter (\n -> (n+1, n+1)) - when debug $ hPutStrLn stderr $ "[chad] loading kernel " ++ path ++ " (" ++ show numLoaded ++ " total)" - dl <- dlopen outso [RTLD_LAZY, RTLD_LOCAL] - - removeDirectoryRecursive path -- we keep a reference anyway because we have the file open now - - ref <- newIORef =<< dlsym dl funname - _ <- mkWeakIORef ref (do numLeft <- atomicModifyIORef' numLoadedCounter (\n -> (n-1, n-1)) - when debug $ hPutStrLn stderr $ "[chad] unloading kernel " ++ path ++ " (" ++ show numLeft ++ " left)" - dlclose dl) - return (KernelLib ref, gccStdout ++ (if null gccStdout then "" else "\n") ++ gccStderr) - -foreign import ccall "dynamic" - wrapKernelFun :: FunPtr (Ptr () -> IO ()) -> Ptr () -> IO () - --- Ensure that keeping a reference to the returned function also keeps the 'KernelLib' alive -{-# NOINLINE callKernelFun #-} -callKernelFun :: KernelLib -> Ptr () -> IO () -callKernelFun (KernelLib ref) arg = do - ptr <- readIORef ref - wrapKernelFun ptr arg - -getTempDir :: IO FilePath -getTempDir = - lookupEnv "TMPDIR" >>= \case - Just s | not (null s) -> return s - _ -> return "/tmp" - -{-# NOINLINE numLoadedCounter #-} -numLoadedCounter :: IORef Int -numLoadedCounter = unsafePerformIO $ newIORef 0 - -lineNumbers :: String -> String -lineNumbers str = - let lns = lines str - numlines = length lns - width = length (show numlines) - pad s = replicate (width - length s) ' ' ++ s - in unlines (zipWith (\i ln -> pad (show i) ++ " | " ++ ln) [1::Int ..] lns) diff --git a/src/Data.hs b/src/Data.hs deleted file mode 100644 index e6978c8..0000000 --- a/src/Data.hs +++ /dev/null @@ -1,192 +0,0 @@ -{-# LANGUAGE AllowAmbiguousTypes #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveTraversable #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE QuantifiedConstraints #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module Data (module Data, (:~:)(Refl), If) where - -import Data.Functor.Product -import Data.GADT.Compare -import Data.GADT.Show -import Data.Some -import Data.Type.Bool (If) -import Data.Type.Equality -import Unsafe.Coerce (unsafeCoerce) - -import Lemmas (Append) - - -data Dict c where - Dict :: c => Dict c - - -data SList f l where - SNil :: SList f '[] - SCons :: f a -> SList f l -> SList f (a : l) -deriving instance (forall a. Show (f a)) => Show (SList f l) -infixr `SCons` - -slistMap :: (forall t. f t -> g t) -> SList f list -> SList g list -slistMap _ SNil = SNil -slistMap f (SCons x list) = SCons (f x) (slistMap f list) - -slistMapA :: Applicative m => (forall t. f t -> m (g t)) -> SList f list -> m (SList g list) -slistMapA _ SNil = pure SNil -slistMapA f (SCons x list) = SCons <$> f x <*> slistMapA f list - -slistZip :: SList f list -> SList g list -> SList (Product f g) list -slistZip SNil SNil = SNil -slistZip (x `SCons` l1) (y `SCons` l2) = Pair x y `SCons` slistZip l1 l2 - -unSList :: (forall t. f t -> a) -> SList f list -> [a] -unSList _ SNil = [] -unSList f (x `SCons` l) = f x : unSList f l - -showSList :: (forall t. Int -> f t -> String) -> SList f list -> String -showSList _ SNil = "SNil" -showSList f (x `SCons` l) = f 11 x ++ " `SCons` " ++ showSList f l - -sappend :: SList f l1 -> SList f l2 -> SList f (Append l1 l2) -sappend SNil l = l -sappend (SCons x xs) l = SCons x (sappend xs l) - -type family Replicate n x where - Replicate Z x = '[] - Replicate (S n) x = x : Replicate n x - -sreplicate :: SNat n -> f t -> SList f (Replicate n t) -sreplicate SZ _ = SNil -sreplicate (SS n) x = x `SCons` sreplicate n x - -data Nat = Z | S Nat - deriving (Show, Eq, Ord) - -type N0 = Z -type N1 = S N0 -type N2 = S N1 -type N3 = S N2 - -data SNat n where - SZ :: SNat Z - SS :: SNat n -> SNat (S n) -deriving instance Show (SNat n) - -instance GCompare SNat where - gcompare SZ SZ = GEQ - gcompare SZ _ = GLT - gcompare _ SZ = GGT - gcompare (SS n) (SS n') = gorderingLift1 (gcompare n n') - -instance TestEquality SNat where testEquality = geq -instance GEq SNat where geq = defaultGeq -instance GShow SNat where gshowsPrec = defaultGshowsPrec - -fromSNat :: SNat n -> Int -fromSNat SZ = 0 -fromSNat (SS n) = succ (fromSNat n) - -unSNat :: SNat n -> Nat -unSNat SZ = Z -unSNat (SS n) = S (unSNat n) - -reSNat :: Nat -> Some SNat -reSNat Z = Some SZ -reSNat (S n) | Some n' <- reSNat n = Some (SS n') - -class KnownNat n where knownNat :: SNat n -instance KnownNat Z where knownNat = SZ -instance KnownNat n => KnownNat (S n) where knownNat = SS knownNat - -snatKnown :: SNat n -> Dict (KnownNat n) -snatKnown SZ = Dict -snatKnown (SS n) | Dict <- snatKnown n = Dict - -type family n + m where - Z + m = m - S n + m = S (n + m) - -type family n - m where - n - Z = n - S n - S m = n - m - -snatAdd :: SNat n -> SNat m -> SNat (n + m) -snatAdd SZ m = m -snatAdd (SS n) m = SS (snatAdd n m) - -lemPlusSuccRight :: n + S m :~: S (n + m) -lemPlusSuccRight = unsafeCoerceRefl - -lemPlusZero :: n + Z :~: n -lemPlusZero = unsafeCoerceRefl - -data Vec n t where - VNil :: Vec Z t - (:<) :: t -> Vec n t -> Vec (S n) t -deriving instance Show t => Show (Vec n t) -deriving instance Eq t => Eq (Vec n t) -deriving instance Functor (Vec n) -deriving instance Foldable (Vec n) -deriving instance Traversable (Vec n) - -vecLength :: Vec n t -> SNat n -vecLength VNil = SZ -vecLength (_ :< v) = SS (vecLength v) - -vecGenerate :: SNat n -> (forall i. SNat i -> t) -> Vec n t -vecGenerate = \n f -> go n f SZ - where - go :: SNat n -> (forall i. SNat i -> t) -> SNat i' -> Vec n t - go SZ _ _ = VNil - go (SS n) f i = f i :< go n f (SS i) - -vecReplicateA :: Applicative f => SNat n -> f a -> f (Vec n a) -vecReplicateA SZ _ = pure VNil -vecReplicateA (SS n) gen = (:<) <$> gen <*> vecReplicateA n gen - -vecZipWithA :: Applicative f => (a -> b -> f c) -> Vec n a -> Vec n b -> f (Vec n c) -vecZipWithA _ VNil VNil = pure VNil -vecZipWithA f (x :< xs) (y :< ys) = (:<) <$> f x y <*> vecZipWithA f xs ys - -vecInit :: Vec (S n) a -> Vec n a -vecInit (_ :< VNil) = VNil -vecInit (x :< xs@(_ :< _)) = x :< vecInit xs - -unsafeCoerceRefl :: a :~: b -unsafeCoerceRefl = unsafeCoerce Refl - -gorderingLift1 :: GOrdering a a' -> GOrdering (f a) (f a') -gorderingLift1 GLT = GLT -gorderingLift1 GGT = GGT -gorderingLift1 GEQ = GEQ - -gorderingLift2 :: GOrdering a a' -> GOrdering b b' -> GOrdering (f a b) (f a' b') -gorderingLift2 GLT _ = GLT -gorderingLift2 GGT _ = GGT -gorderingLift2 GEQ GLT = GLT -gorderingLift2 GEQ GGT = GGT -gorderingLift2 GEQ GEQ = GEQ - -data Bag t = BNone | BOne t | BTwo !(Bag t) !(Bag t) | BMany [Bag t] | BList [t] - deriving (Show, Functor, Foldable, Traversable) - --- | This instance is mostly there just for 'pure' -instance Applicative Bag where - pure = BOne - BNone <*> _ = BNone - BOne f <*> b = f <$> b - BTwo b1 b2 <*> b = BTwo (b1 <*> b) (b2 <*> b) - BMany bs <*> b = BMany (map (<*> b) bs) - BList bs <*> b = BMany (map (<$> b) bs) - -instance Semigroup (Bag t) where (<>) = BTwo -instance Monoid (Bag t) where mempty = BNone - -data SBool b where - SF :: SBool False - ST :: SBool True -deriving instance Show (SBool b) diff --git a/src/Data/VarMap.hs b/src/Data/VarMap.hs deleted file mode 100644 index 2712b08..0000000 --- a/src/Data/VarMap.hs +++ /dev/null @@ -1,119 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RoleAnnotations #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeOperators #-} -module Data.VarMap ( - VarMap, - empty, - insert, - delete, - TypedIdx(..), - lookup, - disjointUnion, - sink1, - unsink1, - subMap, - superMap, -) where - -import Prelude hiding (lookup) - -import qualified Data.Map.Strict as Map -import Data.Map.Strict (Map) -import Data.Maybe (mapMaybe) -import Data.Some -import qualified Data.Vector.Storable as VS -import Unsafe.Coerce - -import AST.Env -import AST.Types -import AST.Weaken - - -type role VarMap _ nominal -- ensure that 'env' is not phantom -data VarMap k (env :: [Ty]) = - VarMap Int -- ^ Global offset; must be added to any value in the map in order to get the proper index - Int -- ^ Time since last cleanup - (Map k (Some STy, Int)) -deriving instance Show k => Show (VarMap k env) - -empty :: VarMap k env -empty = VarMap 0 0 Map.empty - -insert :: Ord k => k -> STy t -> Idx env t -> VarMap k env -> VarMap k env -insert k ty idx (VarMap off interval mp) = - maybeCleanup $ VarMap off (interval + 1) (Map.insert k (Some ty, idx2int idx - off) mp) - -delete :: Ord k => k -> VarMap k env -> VarMap k env -delete k (VarMap off interval mp) = - maybeCleanup $ VarMap off (interval + 1) (Map.delete k mp) - -data TypedIdx env t = TypedIdx (STy t) (Idx env t) - deriving (Show) - -lookup :: Ord k => k -> VarMap k env -> Maybe (Some (TypedIdx env)) -lookup k (VarMap off _ mp) = do - (Some ty, i) <- Map.lookup k mp - idx <- unsafeInt2idx (i + off) - return (Some (TypedIdx ty idx)) - -disjointUnion :: Ord k => VarMap k env -> VarMap k env -> VarMap k env -disjointUnion (VarMap off1 cl1 m1) (VarMap off2 cl2 m2) | off1 == off2 = - VarMap off1 (min cl1 cl2) (Map.unionWith (error "VarMap.disjointUnion: overlapping keys") m1 m2) -disjointUnion vm1 vm2 = disjointUnion (cleanup vm1) (cleanup vm2) - -sink1 :: VarMap k env -> VarMap k (t : env) -sink1 (VarMap off interval mp) = VarMap (off + 1) interval mp - -unsink1 :: VarMap k (t : env) -> VarMap k env -unsink1 (VarMap off interval mp) = VarMap (off - 1) interval mp - -subMap :: Eq k => Subenv env env' -> VarMap k env -> VarMap k env' -subMap subenv = - let bools = let loop :: Subenv env env' -> [Bool] - loop SETop = [] - loop (SEYesR sub) = True : loop sub - loop (SENo sub) = False : loop sub - in VS.fromList $ loop subenv - newIndices = VS.init $ VS.scanl' (\n b -> if b then n + 1 else n) (0 :: Int) bools - modify off (k, (ty, i)) - | i + off < 0 = Nothing - | i + off >= VS.length bools = error "VarMap.subMap: found negative indices in map" - | bools VS.! (i + off) = Just (k, (ty, newIndices VS.! (i + off))) - | otherwise = Nothing - in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) - -superMap :: Eq k => Subenv env env' -> VarMap k env' -> VarMap k env -superMap subenv = - let loop :: Subenv env env' -> Int -> [Int] - loop SETop _ = [] - loop (SEYesR sub) i = i : loop sub (i+1) - loop (SENo sub) i = loop sub (i+1) - - newIndices = VS.fromList $ loop subenv 0 - modify off (k, (ty, i)) - | i + off < 0 = Nothing - | i + off >= VS.length newIndices = error "VarMap.superMap: found negative indices in map" - | otherwise = let j = newIndices VS.! (i + off) - in if j == -1 then Nothing else Just (k, (ty, j)) - - in \(VarMap off _ mp) -> VarMap 0 0 (Map.fromAscList . mapMaybe (modify off) . Map.toAscList $ mp) - -maybeCleanup :: VarMap k env -> VarMap k env -maybeCleanup vm@(VarMap _ interval mp) - | let sz = Map.size mp - , sz > 0, 2 * interval >= 3 * sz - = cleanup vm -maybeCleanup vm = vm - -cleanup :: VarMap k env -> VarMap k env -cleanup (VarMap off _ mp) = VarMap 0 0 (Map.mapMaybe (\(t, i) -> if i + off >= 0 then Just (t, i + off) else Nothing) mp) - -unsafeInt2idx :: Int -> Maybe (Idx env t) -unsafeInt2idx = \n -> if n < 0 then Nothing else Just (go n) - where - go :: Int -> Idx env t - go 0 = unsafeCoerce IZ - go n = unsafeCoerce (IS (go (n-1))) diff --git a/src/Example.hs b/src/Example.hs deleted file mode 100644 index e996002..0000000 --- a/src/Example.hs +++ /dev/null @@ -1,196 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} - -{-# OPTIONS -Wno-unused-imports #-} -module Example where - -import Array -import AST -import AST.Count -import AST.Pretty -import AST.UnMonoid -import CHAD -import CHAD.Top -import CHAD.Types -import ForwardAD -import Interpreter -import Language -import Simplify - -import Debug.Trace -import Example.Types - - --- ppExpr senv5 $ simplifyN 20 $ let d = descr5 SMerge SMerge in freezeRet d (drev d ex5) (EConst ext STF32 1.0) - - -pipeline :: KnownEnv env => CHADConfig -> Ex env t -> Ex (D2 t : env) (TPair t (Tup (D2E env))) -pipeline config term - | Dict <- styKnown (d2 (typeOf term)) = - simplifyFix $ pruneExpr knownEnv $ - simplifyFix $ unMonoid $ - simplifyFix $ chad' config knownEnv $ - simplifyFix $ term - --- :seti -XOverloadedLabels -XPartialTypeSignatures -Wno-partial-type-signatures -pipeline' :: KnownEnv env => CHADConfig -> Ex env t -> IO () -pipeline' config term - | Dict <- styKnown (d2 (typeOf term)) = - pprintExpr (pipeline config term) - - -bin :: SOp (TPair a b) c -> Ex env a -> Ex env b -> Ex env c -bin op a b = EOp ext op (EPair ext a b) - -senv1 :: SList STy [TScal TF32, TScal TF32] -senv1 = STScal STF32 `SCons` STScal STF32 `SCons` SNil - --- x y |- x * y + x --- --- let x3 = (x1, x2) --- x4 = ((*) x3, x1) --- in ( (+) x4 --- , let x5 = 1.0 --- x6 = Inr (x5, x5) --- in case x6 of --- Inl x7 -> return () --- Inr x8 -> --- let x9 = fst x8 --- x10 = Inr (snd x3 * x9, fst x3 * x9) --- in case x10 of --- Inl x11 -> return () --- Inr x12 -> --- let x13 = fst x12 --- in one "v1" x13 >>= \x14 -> --- let x15 = snd x12 --- in one "v2" x15 >>= \x16 -> --- let x17 = snd x8 --- in one "v1" x17) --- --- ( (x1 * x2) + x1 --- , let x5 = 1.0 --- in do one "v1" (x2 * x5) --- one "v2" (x1 * x5) --- one "v1" x5) -ex1 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex1 = fromNamed $ lambda #x $ lambda #y $ body $ - #x * #y + #x - --- x y |- let z = x + y in z * (z + x) -ex2 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex2 = fromNamed $ lambda #x $ lambda #y $ body $ - let_ #z (#x + #y) $ - #z * (#z + #x) - --- x y |- if x < y then 2 * x else 3 + x -ex3 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex3 = fromNamed $ lambda #x $ lambda #y $ body $ - if_ (#x .< #y) (2 * #x) (3 * #x) - --- x y |- if x < y then 2 * x + y * y else 3 + x -ex4 :: Ex [TScal TF32, TScal TF32] (TScal TF32) -ex4 = fromNamed $ lambda #x $ lambda #y $ body $ - if_ (#x .< #y) (2 * #x + #y * #y) (3 + #x) - --- x:R+R y:R |- case x of {inl a -> a * y ; inr b -> b * (y + 1)} -ex5 :: Ex [TScal TF32, TEither (TScal TF32) (TScal TF32)] (TScal TF32) -ex5 = fromNamed $ lambda #x $ lambda #y $ body $ - case_ #x (#a :-> #a * #y) - (#b :-> #b * (#y + 1)) - --- x:R n:I |- let a = unit x --- b = build1 n (\i. let c = idx0 a in c * c) --- in idx0 (b ! 3) -ex6 :: Ex [TScal TI64, TScal TF32] (TScal TF32) -ex6 = fromNamed $ lambda #x $ lambda #n $ body $ - let_ #a (unit #x) $ - let_ #b (build1 #n (#_ :-> let_ #c (idx0 #a) $ #c * #c)) $ - #b ! pair nil 3 - --- A "neural network" except it's just scalars, not matrices. --- ps:((((), (R,R)), (R,R)), (R,R)) x:R --- |- let p1 = snd ps --- p1' = fst ps --- x1 = fst p1 * x + snd p1 --- p2 = snd p1' --- p2' = fst p1' --- x2 = fst p2 * x + snd p2 --- p3 = snd p2' --- p3' = fst p2' --- x3 = fst p3 * x + snd p3 --- in x3 -ex7 :: Ex [R, TPair (TPair (TPair TNil (TPair R R)) (TPair R R)) (TPair R R)] R -ex7 = fromNamed $ lambda #pars123 $ lambda #input $ body $ - let tR = STScal STF64 - tpair = STPair tR tR - - layer :: (Lookup "parstup" env ~ p, Lookup "inp" env ~ R) - => STy p -> NExpr env R - layer (STPair t (STPair (STScal STF64) (STScal STF64))) | Dict <- styKnown t = - let_ #par (snd_ #parstup) $ - let_ #restpars (fst_ #parstup) $ - let_ #inp (fst_ #par * #inp + snd_ #par) $ - let_ #parstup #restpars $ - layer t - layer STNil = #inp - layer _ = error "Invalid layer inputs" - - in let_ #parstup #pars123 $ - let_ #inp #input $ - layer (STPair (STPair (STPair STNil tpair) tpair) tpair) - -neural :: Ex [TVec R, TVec R, TPair (TMat R) (TVec R), TPair (TMat R) (TVec R)] R -neural = fromNamed $ lambda #layer1 $ lambda #layer2 $ lambda #layer3 $ lambda #input $ body $ - let layer = lambda @(TMat R) #wei $ lambda @(TVec R) #bias $ lambda @(TVec R) #x $ body $ - -- prod = wei `matmul` x - let_ #prod (sum1i $ build (SS (SS SZ)) (shape #wei) $ #idx :-> - #wei ! #idx * #x ! pair nil (snd_ #idx)) $ - -- relu (prod + bias) - build (SS SZ) (shape #prod) $ #idx :-> - let_ #out (#prod ! #idx + #bias ! #idx) $ - if_ (#out .<= const_ 0) (const_ 0) #out - - in let_ #x1 (inline layer (SNil .$ fst_ #layer1 .$ snd_ #layer1 .$ #input)) $ - let_ #x2 (inline layer (SNil .$ fst_ #layer2 .$ snd_ #layer2 .$ #x1)) $ - let_ #x3 (sum1i $ build (SS SZ) (shape #x2) $ #idx :-> #x2 ! #idx * #layer3 ! #idx) $ - #x3 ! nil - -type NeuralGrad = ((Array N2 Double, Array N1 Double) - ,(Array N2 Double, Array N1 Double) - ,Array N1 Double - ,Array N1 Double) - -neuralGo :: (Double -- primal - ,NeuralGrad -- gradient using CHAD - ,NeuralGrad) -- gradient using dual-numbers forward AD -neuralGo = - let lay1 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) - lay2 = (arrayFromList (ShNil `ShCons` 2 `ShCons` 2) [1,1,1,1], arrayFromList (ShNil `ShCons` 2) [0,0]) - lay3 = arrayFromList (ShNil `ShCons` 2) [1,1] - input = arrayFromList (ShNil `ShCons` 2) [1,1] - argument = (Value input `SCons` Value lay3 `SCons` Value lay2 `SCons` Value lay1 `SCons` SNil) - revderiv = - simplifyN 20 $ - ELet ext (EConst ext STF64 1.0) $ - chad defaultConfig knownEnv neural - (primal, dlay1_1, dlay2_1, dlay3_1, dinput_1) = case interpretOpen False knownEnv argument revderiv of - (primal', (((((), (dlay1_1'a, dlay1_1'b)), (dlay2_1'a, dlay2_1'b)), dlay3_1'), dinput_1')) -> (primal', (dlay1_1'a, dlay1_1'b), (dlay2_1'a, dlay2_1'b), dlay3_1', dinput_1') - (Value dinput_2 `SCons` Value dlay3_2 `SCons` Value dlay2_2 `SCons` Value dlay1_2 `SCons` SNil) = drevByFwdInterp knownEnv neural argument 1.0 - in trace (ppExpr knownEnv revderiv) $ - (primal, (dlay1_1, dlay2_1, dlay3_1, dinput_1), (dlay1_2, dlay2_2, dlay3_2, dinput_2)) - --- The build body uses free variables in a non-linear way, so their primal --- values are required in the dual of the build. Thus, compositionally, they --- are stored in the tape from each individual lambda invocation. This results --- in n copies of y and z, where only one copy would have sufficed. -exUniformFree :: Ex '[R, I64] R -exUniformFree = fromNamed $ lambda #n $ lambda #x $ body $ - let_ #y (#x * 2) $ - let_ #z (#x * 3) $ - idx0 $ sum1i $ - build1 #n $ #i :-> #y * #z + toFloat_ #i diff --git a/src/Example/GMM.hs b/src/Example/GMM.hs deleted file mode 100644 index 206e534..0000000 --- a/src/Example/GMM.hs +++ /dev/null @@ -1,123 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE TypeApplications #-} -module Example.GMM where - -import Example.Types -import Language - - - --- N, D, K: integers > 0 --- alpha, M, Q, L: the active parameters --- X: inactive data --- m: integer --- k1: 1/2 N D log(2 pi) --- k2: 1/2 gamma^2 --- k3: K * (n' D (log(gamma) - 1/2 log(2)) - log MultiGamma(1/2 n', D)) --- where n' = D + m + 1 --- --- Inputs from the file are: N, D, K, alpha, M, Q, L, gamma, m. --- --- See: --- - "A benchmark of selected algorithmic differentiation tools on some problems --- in computer vision and machine learning". Optim. Methods Softw. 33(4-6): --- 889-906 (2018). --- --- --- - 2021 Tom Smeding: “Reverse Automatic Differentiation for Accelerate”. --- Master thesis at Utrecht University. (Appendix B.1) --- --- --- --- The 'wrong' argument, when set to True, changes the objective function to --- one with a bug that makes a certain `build` result unused. This --- makes the CHAD code fail because it tries to use a D2 (TArr) as if it's --- dense, even though it may be a zero (i.e. empty). The "unused" test in --- test/Main.hs tries to isolate this case, but the wrong version of --- gmmObjective is here to check (after that bug is fixed) whether it really --- fixes the original bug. -gmmObjective :: Bool -> Ex [R, R, R, I64, TMat R, TMat R, TMat R, TMat R, TVec R, I64, I64, I64] R -gmmObjective wrong = fromNamed $ - lambda #N $ lambda #D $ lambda #K $ - lambda #alpha $ lambda #M $ lambda #Q $ lambda #L $ - lambda #X $ lambda #m $ - lambda #k1 $ lambda #k2 $ lambda #k3 $ - body $ - let -- We have: - -- sum (exp (x - max(x))) - -- = sum (exp x / exp (max(x))) - -- = sum (exp x) / exp (max(x)) - -- Hence: - -- sum (exp x) = sum (exp (x - max(x))) * exp (max(x)) (*) - -- - -- So: - -- d/dxi log (sum (exp x)) - -- = 1/(sum (exp x)) * d/dxi sum (exp x) - -- = 1/(sum (exp x)) * sum (d/dxi exp x) - -- = 1/(sum (exp x)) * exp xi - -- = exp xi / sum (exp x) - -- (by (*)) - -- = exp xi / (sum (exp (x - max(x))) * exp (max(x))) - -- = exp (xi - max(x)) / sum (exp (x - max(x))) - logsumexp' = lambda @(TVec R) #vec $ body $ - let_ #m (maximum1i #vec) $ - log (idx0 (sum1i (map_ (#x :-> exp (#x - idx0 #m)) #vec))) + idx0 #m - -- custom (#_ :-> #v :-> - -- let_ #m (idx0 (maximum1i #v)) $ - -- log (idx0 (sum1i (map_ (#x :-> exp (#x - #m)) #v))) + #m) - -- (#_ :-> #v :-> - -- let_ #m (idx0 (maximum1i #v)) $ - -- let_ #ex (map_ (#x :-> exp (#x - #m)) #v) $ - -- let_ #s (idx0 (sum1i #ex)) $ - -- pair (log #s + #m) - -- (pair #ex #s)) - -- (#tape :-> #d :-> - -- map_ (#exi :-> #exi / snd_ #tape * #d) (fst_ #tape)) - -- nil #vec - logsumexp v = inline logsumexp' (SNil .$ v) - - mulmatvec = lambda @(TMat R) #mat $ lambda @(TVec R) #vec $ body $ - let_ #hei (snd_ (fst_ (shape #mat))) $ - let_ #wid (snd_ (shape #mat)) $ - build1 #hei $ #i :-> - idx0 (sum1i (build1 #wid $ #j :-> - #mat ! pair (pair nil #i) #j * #vec ! pair nil #j)) - m *@ v = inline mulmatvec (SNil .$ m .$ v) - - subvec = lambda @(TVec R) #a $ lambda @(TVec R) #b $ body $ - build1 (snd_ (shape #a)) $ #i :-> #a ! pair nil #i - #b ! pair nil #i - a .- b = inline subvec (SNil .$ a .$ b) - - matrow = lambda @(TMat R) #mat $ lambda @TIx #i $ body $ - build1 (snd_ (shape #mat)) (#j :-> #mat ! pair (pair nil #i) #j) - m .! i = inline matrow (SNil .$ m .$ i) - - normsq' = lambda @(TVec R) #vec $ body $ - idx0 (sum1i (build (SS SZ) (shape #vec) (#i :-> let_ #x (#vec ! #i) $ #x * #x))) - normsq v = inline normsq' (SNil .$ v) - - qmat' = lambda @(TVec R) #q $ lambda @(TVec R) #l $ body $ - let_ #n (snd_ (shape #q)) $ - build (SS (SS SZ)) (pair (pair nil #n) #n) $ #idx :-> - let_ #i (snd_ (fst_ #idx)) $ - let_ #j (snd_ #idx) $ - if_ (#i .== #j) - (exp (#q ! pair nil #i)) - (if_ (#i .> #j) - (if wrong then toFloat_ (#i * (#i - 1) `idiv` 2 + #j) - else #l ! pair nil (#i * (#i - 1) `idiv` 2 + #j)) - 0.0) - qmat q l = inline qmat' (SNil .$ q .$ l) - in let_ #k2arr (unit #k2) $ - - #k1 - + idx0 (sum1i (build1 #N $ #i :-> - logsumexp (build1 #K $ #k :-> - #alpha ! pair nil #k - + idx0 (sum1i (#Q .! #k)) - - 0.5 * normsq (qmat (#Q .! #k) (#L .! #k) *@ ((#X .! #i) .- (#M .! #k)))))) - - toFloat_ #N * logsumexp #alpha - + idx0 (sum1i (build1 #K $ #k :-> - idx0 #k2arr * (normsq (map_ (#x :-> exp #x) (#Q .! #k)) + normsq (#L .! #k)) - - toFloat_ #m * idx0 (sum1i (#Q .! #k)))) - - #k3 diff --git a/src/Example/Types.hs b/src/Example/Types.hs deleted file mode 100644 index d63159b..0000000 --- a/src/Example/Types.hs +++ /dev/null @@ -1,11 +0,0 @@ -{-# LANGUAGE DataKinds #-} -module Example.Types where - -import AST -import Data - - -type R = TScal TF64 -type I64 = TScal TI64 -type TVec = TArr (S Z) -type TMat = TArr (S (S Z)) diff --git a/src/ForwardAD.hs b/src/ForwardAD.hs deleted file mode 100644 index 6655423..0000000 --- a/src/ForwardAD.hs +++ /dev/null @@ -1,270 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module ForwardAD where - -import Data.Bifunctor (bimap) -import System.IO.Unsafe - --- import Debug.Trace --- import AST.Pretty - -import Array -import AST -import Compile -import Data -import ForwardAD.DualNumbers -import Interpreter -import Interpreter.Rep - - --- | Tangent along a type (coincides with cotangent for these types) -type family Tan t where - Tan TNil = TNil - Tan (TPair a b) = TPair (Tan a) (Tan b) - Tan (TEither a b) = TEither (Tan a) (Tan b) - Tan (TLEither a b) = TLEither (Tan a) (Tan b) - Tan (TMaybe t) = TMaybe (Tan t) - Tan (TArr n t) = TArr n (Tan t) - Tan (TScal t) = TanS t - -type family TanS t where - TanS TI32 = TNil - TanS TI64 = TNil - TanS TF32 = TScal TF32 - TanS TF64 = TScal TF64 - TanS TBool = TNil - -type family TanE env where - TanE '[] = '[] - TanE (t : env) = Tan t : TanE env - -tanty :: STy t -> STy (Tan t) -tanty STNil = STNil -tanty (STPair a b) = STPair (tanty a) (tanty b) -tanty (STEither a b) = STEither (tanty a) (tanty b) -tanty (STLEither a b) = STLEither (tanty a) (tanty b) -tanty (STMaybe t) = STMaybe (tanty t) -tanty (STArr n t) = STArr n (tanty t) -tanty (STScal t) = case t of - STI32 -> STNil - STI64 -> STNil - STF32 -> STScal STF32 - STF64 -> STScal STF64 - STBool -> STNil -tanty STAccum{} = error "Accumulators not allowed in input program" - -tanenv :: SList STy env -> SList STy (TanE env) -tanenv SNil = SNil -tanenv (t `SCons` env) = tanty t `SCons` tanenv env - -zeroTan :: STy t -> Rep t -> Rep (Tan t) -zeroTan STNil () = () -zeroTan (STPair a b) (x, y) = (zeroTan a x, zeroTan b y) -zeroTan (STEither a _) (Left x) = Left (zeroTan a x) -zeroTan (STEither _ b) (Right y) = Right (zeroTan b y) -zeroTan (STLEither _ _) Nothing = Nothing -zeroTan (STLEither a _) (Just (Left x)) = Just (Left (zeroTan a x)) -zeroTan (STLEither _ b) (Just (Right y)) = Just (Right (zeroTan b y)) -zeroTan (STMaybe _) Nothing = Nothing -zeroTan (STMaybe t) (Just x) = Just (zeroTan t x) -zeroTan (STArr _ t) x = fmap (zeroTan t) x -zeroTan (STScal STI32) _ = () -zeroTan (STScal STI64) _ = () -zeroTan (STScal STF32) _ = 0.0 -zeroTan (STScal STF64) _ = 0.0 -zeroTan (STScal STBool) _ = () -zeroTan STAccum{} _ = error "Accumulators not allowed in input program" - -tanScalars :: STy t -> Rep (Tan t) -> [Double] -tanScalars STNil () = [] -tanScalars (STPair a b) (x, y) = tanScalars a x ++ tanScalars b y -tanScalars (STEither a _) (Left x) = tanScalars a x -tanScalars (STEither _ b) (Right y) = tanScalars b y -tanScalars (STLEither _ _) Nothing = [] -tanScalars (STLEither a _) (Just (Left x)) = tanScalars a x -tanScalars (STLEither _ b) (Just (Right y)) = tanScalars b y -tanScalars (STMaybe _) Nothing = [] -tanScalars (STMaybe t) (Just x) = tanScalars t x -tanScalars (STArr _ t) x = foldMap id $ arrayMap (tanScalars t) x -tanScalars (STScal STI32) _ = [] -tanScalars (STScal STI64) _ = [] -tanScalars (STScal STF32) x = [realToFrac x] -tanScalars (STScal STF64) x = [x] -tanScalars (STScal STBool) _ = [] -tanScalars STAccum{} _ = error "Accumulators not allowed in input program" - -tanEScalars :: SList STy env -> SList Value (TanE env) -> [Double] -tanEScalars SNil SNil = [] -tanEScalars (t `SCons` ts) (Value x `SCons` xs) = tanScalars t x ++ tanEScalars ts xs - -unzipDN :: STy t -> Rep (DN t) -> (Rep t, Rep (Tan t)) -unzipDN STNil _ = ((), ()) -unzipDN (STPair a b) (d1, d2) = - let (x, dx) = unzipDN a d1 - (y, dy) = unzipDN b d2 - in ((x, y), (dx, dy)) -unzipDN (STEither a b) d = case d of - Left d1 -> bimap Left Left (unzipDN a d1) - Right d2 -> bimap Right Right (unzipDN b d2) -unzipDN (STLEither a b) d = case d of - Nothing -> (Nothing, Nothing) - Just (Left x) -> bimap (Just . Left) (Just . Left) (unzipDN a x) - Just (Right y) -> bimap (Just . Right) (Just . Right) (unzipDN b y) -unzipDN (STMaybe t) d = case d of - Nothing -> (Nothing, Nothing) - Just d' -> bimap Just Just (unzipDN t d') -unzipDN (STArr _ t) d = - let pairs = arrayMap (unzipDN t) d - in (arrayMap fst pairs, arrayMap snd pairs) -unzipDN (STScal ty) d = case ty of - STI32 -> (d, ()) - STI64 -> (d, ()) - STF32 -> d - STF64 -> d - STBool -> (d, ()) -unzipDN STAccum{} _ = error "Accumulators not allowed in input program" - -dotprodTan :: STy t -> Rep (Tan t) -> Rep (Tan t) -> Double -dotprodTan STNil _ _ = 0.0 -dotprodTan (STPair a b) (x, y) (x', y') = - dotprodTan a x x' + dotprodTan b y y' -dotprodTan (STEither a b) x y = case (x, y) of - (Left x', Left y') -> dotprodTan a x' y' - (Right x', Right y') -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible Either alternatives" -dotprodTan (STLEither a b) x y = case (x, y) of - (Nothing, _) -> 0.0 -- 0 * y = 0 - (_, Nothing) -> 0.0 -- x * 0 = 0 - (Just (Left x'), Just (Left y')) -> dotprodTan a x' y' - (Just (Right x'), Just (Right y')) -> dotprodTan b x' y' - _ -> error "dotprodTan: incompatible LEither alternatives" -dotprodTan (STMaybe t) x y = case (x, y) of - (Nothing, Nothing) -> 0.0 - (Just x', Just y') -> dotprodTan t x' y' - _ -> error "dotprodTan: incompatible Maybe alternatives" -dotprodTan (STArr _ t) x y = - let sh1 = arrayShape x - sh2 = arrayShape y - in if | shapeSize sh1 == 0 || shapeSize sh2 == 0 -> 0.0 - | sh1 == sh2 -> sum [dotprodTan t (arrayIndex x i) (arrayIndex y i) | i <- enumShape sh1] - | otherwise -> error "dotprodTan: incompatible array shapes" -dotprodTan (STScal ty) x y = case ty of - STI32 -> 0.0 - STI64 -> 0.0 - STF32 -> realToFrac @Float @Double (x * y) - STF64 -> x * y - STBool -> 0.0 -dotprodTan STAccum{} _ _ = error "Accumulators not allowed in input program" - --- -- Primal expression must be duplicable --- dnConstE :: STy t -> Ex env t -> Ex env (DN t) --- dnConstE STNil _ = ENil ext --- dnConstE (STPair t1 t2) e = --- -- This creates fst/snd stacks of unbounded size, but let's not care here --- EPair ext (dnConstE t1 (EFst ext e)) (dnConstE t2 (ESnd ext e)) --- dnConstE (STEither t1 t2) e = --- ECase ext e --- (EInl ext (dn t2) (dnConstE t1 (EVar ext t1 IZ))) --- (EInr ext (dn t1) (dnConstE t2 (EVar ext t2 IZ))) --- dnConstE (STMaybe t) e = --- EMaybe ext (ENothing ext (dn t)) (EJust ext (dnConstE t (EVar ext t IZ))) e --- dnConstE (STArr n t) e = --- EBuild ext n (EShape ext e) --- (dnConstE t (EIdx ext n (weakenExpr WSink e) (EVar ext (tTup (sreplicate n tIx)) IZ))) --- dnConstE (STScal t) e = case t of --- STI32 -> e --- STI64 -> e --- STF32 -> EPair ext e (EConst ext STF32 0.0) --- STF64 -> EPair ext e (EConst ext STF64 0.0) --- STBool -> e --- dnConstE STAccum{} _ = error "Accumulators not allowed in input program" - -dnConst :: STy t -> Rep t -> Rep (DN t) -dnConst STNil = const () -dnConst (STPair t1 t2) = bimap (dnConst t1) (dnConst t2) -dnConst (STEither t1 t2) = bimap (dnConst t1) (dnConst t2) -dnConst (STLEither t1 t2) = fmap (bimap (dnConst t1) (dnConst t2)) -dnConst (STMaybe t) = fmap (dnConst t) -dnConst (STArr _ t) = arrayMap (dnConst t) -dnConst (STScal t) = case t of - STI32 -> id - STI64 -> id - STF32 -> (,0.0) - STF64 -> (,0.0) - STBool -> id -dnConst STAccum{} = error "Accumulators not allowed in input program" - --- | Given a function that computes the forward derivative for a particular --- dual-numbers input, a 'RevByFwd' computes the gradient with respect to this --- @t@ input. -type RevByFwd t = (Rep (DN t) -> Double) -> Rep (Tan t) - -dnOnehots :: STy t -> Rep t -> RevByFwd t -dnOnehots STNil _ = \_ -> () -dnOnehots (STPair t1 t2) (x, y) = - \f -> (dnOnehots t1 x (f . (,dnConst t2 y)), dnOnehots t2 y (f . (dnConst t1 x,))) -dnOnehots (STEither t1 t2) e = - case e of - Left x -> \f -> Left (dnOnehots t1 x (f . Left)) - Right y -> \f -> Right (dnOnehots t2 y (f . Right)) -dnOnehots (STLEither t1 t2) e = - case e of - Nothing -> \_ -> Nothing - Just (Left x) -> \f -> Just (Left (dnOnehots t1 x (f . Just . Left))) - Just (Right y) -> \f -> Just (Right (dnOnehots t2 y (f . Just . Right))) -dnOnehots (STMaybe t) m = - case m of - Nothing -> \_ -> Nothing - Just x -> \f -> Just (dnOnehots t x (f . Just)) -dnOnehots (STArr _ t) a = - \f -> - arrayGenerate (arrayShape a) $ \idx -> - dnOnehots t (arrayIndex a idx) (f . (\oh -> arrayGenerate (arrayShape a) $ \i -> - if i == idx then oh else dnConst t (arrayIndex a i))) -dnOnehots (STScal t) x = case t of - STI32 -> \_ -> () - STI64 -> \_ -> () - STF32 -> \f -> realToFrac @Double @Float $ f (x, 1.0) - STF64 -> \f -> f (x, 1.0) - STBool -> \_ -> () -dnOnehots STAccum{} _ = error "Accumulators not allowed in input program" - -dnConstEnv :: SList STy env -> SList Value env -> SList Value (DNE env) -dnConstEnv SNil SNil = SNil -dnConstEnv (t `SCons` env) (Value x `SCons` val) = Value (dnConst t x) `SCons` dnConstEnv env val - -type RevByFwdEnv env = (SList Value (DNE env) -> Double) -> SList Value (TanE env) - -dnOnehotEnvs :: SList STy env -> SList Value env -> RevByFwdEnv env -dnOnehotEnvs SNil SNil = \_ -> SNil -dnOnehotEnvs (t `SCons` env) (Value x `SCons` val) = - \f -> - Value (dnOnehots t x (f . (\oh -> Value oh `SCons` dnConstEnv env val))) - `SCons` dnOnehotEnvs env val (f . (\oh -> Value (dnConst t x) `SCons` oh)) - -data FwdADArtifact env t = FwdADArtifact (SList STy env) (STy t) (SList Value (DNE env) -> Rep (DN t)) - -makeFwdADArtifactInterp :: SList STy env -> Ex env t -> FwdADArtifact env t -makeFwdADArtifactInterp env expr = - let dexpr = dfwdDN expr - in FwdADArtifact env (typeOf expr) (\inp -> interpretOpen False (dne env) inp dexpr) - -{-# NOINLINE makeFwdADArtifactCompile #-} -makeFwdADArtifactCompile :: SList STy env -> Ex env t -> IO (FwdADArtifact env t, String) -makeFwdADArtifactCompile env expr = do - (fun, output) <- compile (dne env) (dfwdDN expr) - return (FwdADArtifact env (typeOf expr) (unsafePerformIO . fun), output) - -drevByFwdInterp :: SList STy env -> Ex env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwdInterp env expr = drevByFwd (makeFwdADArtifactInterp env expr) - -drevByFwd :: FwdADArtifact env t -> SList Value env -> Rep (Tan t) -> SList Value (TanE env) -drevByFwd (FwdADArtifact env outty fun) input dres = - dnOnehotEnvs env input $ \dnInput -> - -- trace (showEnv (dne env) dnInput) $ - let (_, outtan) = unzipDN outty (fun dnInput) - in dotprodTan outty outtan dres diff --git a/src/ForwardAD/DualNumbers.hs b/src/ForwardAD/DualNumbers.hs deleted file mode 100644 index a1e9d0d..0000000 --- a/src/ForwardAD/DualNumbers.hs +++ /dev/null @@ -1,231 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE EmptyCase #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} - --- I want to bring various type variables in scope using type annotations in --- patterns, but I don't want to have to mention all the other type parameters --- of the types in question as well then. Partial type signatures (with '_') are --- useful here. -{-# LANGUAGE PartialTypeSignatures #-} -{-# OPTIONS -Wno-partial-type-signatures #-} -module ForwardAD.DualNumbers ( - dfwdDN, - DN, DNS, DNE, dn, dne, -) where - -import AST -import Data -import ForwardAD.DualNumbers.Types - - -dnPreservesTupIx :: SNat n -> DN (Tup (Replicate n TIx)) :~: Tup (Replicate n TIx) -dnPreservesTupIx SZ = Refl -dnPreservesTupIx (SS n) | Refl <- dnPreservesTupIx n = Refl - -convIdx :: Idx env t -> Idx (DNE env) (DN t) -convIdx IZ = IZ -convIdx (IS i) = IS (convIdx i) - -scalTyCase :: SScalTy t - -> ((ScalIsNumeric t ~ True, ScalIsFloating t ~ True, Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t)) => r) - -> (DN (TScal t) ~ TScal t => r) - -> r -scalTyCase STF32 k1 _ = k1 -scalTyCase STF64 k1 _ = k1 -scalTyCase STI32 _ k2 = k2 -scalTyCase STI64 _ k2 = k2 -scalTyCase STBool _ k2 = k2 - -floatingDual :: ScalIsFloating t ~ True - => SScalTy t - -> ((Fractional (ScalRep t), DN (TScal t) ~ TPair (TScal t) (TScal t), ScalIsNumeric t ~ True) => r) -> r -floatingDual STF32 k = k -floatingDual STF64 k = k - --- | Argument does not need to be duplicable. -dop :: forall a b env. SOp a b -> Ex env (DN a) -> Ex env (DN b) -dop = \case - OAdd t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (add t x y) (add t dx dy))) - (EOp ext (OAdd t)) - OMul t -> scalTyCase t - (binFloat (\(x, dx) (y, dy) -> EPair ext (mul t x y) (add t (mul t dx y) (mul t dy x)))) - (EOp ext (OMul t)) - ONeg t -> scalTyCase t - (unFloat (\(x, dx) -> EPair ext (neg t x) (neg t dx))) - (EOp ext (ONeg t)) - OLt t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLt t) (EPair ext x y))) - (EOp ext (OLt t)) - OLe t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OLe t) (EPair ext x y))) - (EOp ext (OLe t)) - OEq t -> scalTyCase t - (binFloat (\(x, _) (y, _) -> EOp ext (OEq t) (EPair ext x y))) - (EOp ext (OEq t)) - ONot -> EOp ext ONot - OAnd -> EOp ext OAnd - OOr -> EOp ext OOr - OIf -> EOp ext OIf - ORound64 -> \arg -> EOp ext ORound64 (EFst ext arg) - OToFl64 -> \arg -> EPair ext (EOp ext OToFl64 arg) (EConst ext STF64 0.0) - ORecip t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (recip' t x) - (mul t (neg t (recip' t (mul t x x))) dx)) - OExp t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (EOp ext (OExp t) x) (mul t (EOp ext (OExp t) x) dx)) - OLog t -> floatingDual t $ unFloat (\(x, dx) -> - EPair ext (EOp ext (OLog t) x) - (mul t (recip' t x) dx)) - OIDiv t -> scalTyCase t - (case t of {}) - (EOp ext (OIDiv t)) - OMod t -> scalTyCase t - (case t of {}) - (EOp ext (OMod t)) - where - add :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - add t a b = EOp ext (OAdd t) (EPair ext a b) - - mul :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) -> Ex env' (TScal t) - mul t a b = EOp ext (OMul t) (EPair ext a b) - - neg :: ScalIsNumeric t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - neg t = EOp ext (ONeg t) - - recip' :: ScalIsFloating t ~ True - => SScalTy t -> Ex env' (TScal t) -> Ex env' (TScal t) - recip' t = EOp ext (ORecip t) - - unFloat :: DN a ~ TPair a a - => (forall env'. (Ex env' a, Ex env' a) -> Ex env' (DN b)) - -> Ex env (DN a) -> Ex env (DN b) - unFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext var, ESnd ext var) - - binFloat :: (a ~ TPair s s, DN s ~ TPair s s) - => (forall env'. (Ex env' s, Ex env' s) -> (Ex env' s, Ex env' s) -> Ex env' (DN b)) - -> Ex env (DN a) -> Ex env (DN b) - binFloat f e = - ELet ext e $ - let var = EVar ext (typeOf e) IZ - in f (EFst ext (EFst ext var), ESnd ext (EFst ext var)) - (EFst ext (ESnd ext var), ESnd ext (ESnd ext var)) - -zeroScalarConst :: ScalIsNumeric t ~ True => SScalTy t -> Ex env (TScal t) -zeroScalarConst STI32 = EConst ext STI32 0 -zeroScalarConst STI64 = EConst ext STI64 0 -zeroScalarConst STF32 = EConst ext STF32 0.0 -zeroScalarConst STF64 = EConst ext STF64 0.0 - -dfwdDN :: Ex env t -> Ex (DNE env) (DN t) -dfwdDN = \case - EVar _ t i -> EVar ext (dn t) (convIdx i) - ELet _ a b -> ELet ext (dfwdDN a) (dfwdDN b) - EPair _ a b -> EPair ext (dfwdDN a) (dfwdDN b) - EFst _ e -> EFst ext (dfwdDN e) - ESnd _ e -> ESnd ext (dfwdDN e) - ENil _ -> ENil ext - EInl _ t e -> EInl ext (dn t) (dfwdDN e) - EInr _ t e -> EInr ext (dn t) (dfwdDN e) - ECase _ e a b -> ECase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) - ENothing _ t -> ENothing ext (dn t) - EJust _ e -> EJust ext (dfwdDN e) - EMaybe _ e a b -> EMaybe ext (dfwdDN e) (dfwdDN a) (dfwdDN b) - ELNil _ t1 t2 -> ELNil ext (dn t1) (dn t2) - ELInl _ t e -> ELInl ext (dn t) (dfwdDN e) - ELInr _ t e -> ELInr ext (dn t) (dfwdDN e) - ELCase _ e a b c -> ELCase ext (dfwdDN e) (dfwdDN a) (dfwdDN b) (dfwdDN c) - EConstArr _ n t x -> scalTyCase t - (emap (EPair ext (EVar ext (STScal t) IZ) (EConst ext t 0.0)) - (EConstArr ext n t x)) - (EConstArr ext n t x) - EBuild _ n a b - | Refl <- dnPreservesTupIx n -> EBuild ext n (dfwdDN a) (dfwdDN b) - EMap _ a b -> EMap ext (dfwdDN a) (dfwdDN b) - EFold1Inner _ cm a b c -> EFold1Inner ext cm (dfwdDN a) (dfwdDN b) (dfwdDN c) - ESum1Inner _ e -> - let STArr n (STScal t) = typeOf e - pairty = (STPair (STScal t) (STScal t)) - in scalTyCase t - (ELet ext (dfwdDN e) $ - ezip (ESum1Inner ext (emap (EFst ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ))) - (ESum1Inner ext (emap (ESnd ext (EVar ext pairty IZ)) - (EVar ext (STArr n pairty) IZ)))) - (ESum1Inner ext (dfwdDN e)) - EUnit _ e -> EUnit ext (dfwdDN e) - EReplicate1Inner _ a b -> EReplicate1Inner ext (dfwdDN a) (dfwdDN b) - EMaximum1Inner _ e -> deriv_extremum (EMaximum1Inner ext) e - EMinimum1Inner _ e -> deriv_extremum (EMinimum1Inner ext) e - EZip _ a b -> EZip ext (dfwdDN a) (dfwdDN b) - EReshape _ n esh e - | Refl <- dnPreservesTupIx n -> EReshape ext n (dfwdDN esh) (dfwdDN e) - EConst _ t x -> scalTyCase t - (EPair ext (EConst ext t x) (EConst ext t 0.0)) - (EConst ext t x) - EIdx0 _ e -> EIdx0 ext (dfwdDN e) - EIdx1 _ a b -> EIdx1 ext (dfwdDN a) (dfwdDN b) - EIdx _ a b - | STArr n _ <- typeOf a - , Refl <- dnPreservesTupIx n - -> EIdx ext (dfwdDN a) (dfwdDN b) - EShape _ e - | Refl <- dnPreservesTupIx (let STArr n _ = typeOf e in n) - -> EShape ext (dfwdDN e) - EOp _ op e -> dop op (dfwdDN e) - ECustom _ _ _ _ pr _ _ e1 e2 -> - ELet ext (dfwdDN e1) $ - ELet ext (weakenExpr WSink (dfwdDN e2)) $ - weakenExpr (WCopy (WCopy WClosed)) (dfwdDN pr) - ERecompute _ e -> dfwdDN e - EError _ t s -> EError ext (dn t) s - - EWith{} -> err_accum - EAccum{} -> err_accum - EDeepZero{} -> err_monoid - EZero{} -> err_monoid - EPlus{} -> err_monoid - EOneHot{} -> err_monoid - - EFold1InnerD1{} -> err_targetlang "EFold1InnerD1" - EFold1InnerD2{} -> err_targetlang "EFold1InnerD2" - where - err_accum = error "Accumulator operations unsupported in the source program" - err_monoid = error "Monoid operations unsupported in the source program" - err_targetlang s = error $ "Target language operation " ++ s ++ " not supported in source program" - - deriv_extremum :: ScalIsNumeric t ~ True - => (forall env'. Ex env' (TArr (S n) (TScal t)) -> Ex env' (TArr n (TScal t))) - -> Ex env (TArr (S n) (TScal t)) -> Ex (DNE env) (TArr n (DN (TScal t))) - deriv_extremum extremum e = - let STArr (SS n) (STScal t) = typeOf e - t2 = STPair (STScal t) (STScal t) - ta2 = STArr (SS n) t2 - tIxN = tTup (sreplicate (SS n) tIx) - in scalTyCase t - (ELet ext (dfwdDN e) $ - ELet ext (extremum (emap (EFst ext (EVar ext t2 IZ)) (EVar ext ta2 IZ))) $ - ezip (EVar ext (STArr n (STScal t)) IZ) - (ESum1Inner ext - {- build (shape SZ) (\i. if fst (SZ ! i) == Z ! tail i then snd (SZ ! i) else zero) -} - (EBuild ext (SS n) (EShape ext (EVar ext ta2 (IS IZ))) $ - ELet ext (EIdx ext (EVar ext ta2 (IS (IS IZ))) (EVar ext tIxN IZ)) $ - ECase ext (EOp ext OIf (EOp ext (OEq t) (EPair ext - (EFst ext (EVar ext t2 IZ)) - (EIdx ext (EVar ext (STArr n (STScal t)) (IS (IS IZ))) - (EFst ext (EVar ext tIxN (IS IZ))))))) - (ESnd ext (EVar ext t2 (IS IZ))) - (zeroScalarConst t)))) - (extremum (dfwdDN e)) diff --git a/src/ForwardAD/DualNumbers/Types.hs b/src/ForwardAD/DualNumbers/Types.hs deleted file mode 100644 index dcacf5f..0000000 --- a/src/ForwardAD/DualNumbers/Types.hs +++ /dev/null @@ -1,48 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -module ForwardAD.DualNumbers.Types where - -import AST.Types -import Data - - --- | Dual-numbers transformation -type family DN t where - DN TNil = TNil - DN (TPair a b) = TPair (DN a) (DN b) - DN (TEither a b) = TEither (DN a) (DN b) - DN (TLEither a b) = TLEither (DN a) (DN b) - DN (TMaybe t) = TMaybe (DN t) - DN (TArr n t) = TArr n (DN t) - DN (TScal t) = DNS t - -type family DNS t where - DNS TF32 = TPair (TScal TF32) (TScal TF32) - DNS TF64 = TPair (TScal TF64) (TScal TF64) - DNS TI32 = TScal TI32 - DNS TI64 = TScal TI64 - DNS TBool = TScal TBool - -type family DNE env where - DNE '[] = '[] - DNE (t : ts) = DN t : DNE ts - -dn :: STy t -> STy (DN t) -dn STNil = STNil -dn (STPair a b) = STPair (dn a) (dn b) -dn (STEither a b) = STEither (dn a) (dn b) -dn (STLEither a b) = STLEither (dn a) (dn b) -dn (STMaybe t) = STMaybe (dn t) -dn (STArr n t) = STArr n (dn t) -dn (STScal t) = case t of - STF32 -> STPair (STScal STF32) (STScal STF32) - STF64 -> STPair (STScal STF64) (STScal STF64) - STI32 -> STScal STI32 - STI64 -> STScal STI64 - STBool -> STScal STBool -dn STAccum{} = error "Accum in source program" - -dne :: SList STy env -> SList STy (DNE env) -dne SNil = SNil -dne (t `SCons` env) = dn t `SCons` dne env diff --git a/src/Interpreter.hs b/src/Interpreter.hs deleted file mode 100644 index e1c81cd..0000000 --- a/src/Interpreter.hs +++ /dev/null @@ -1,471 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Interpreter ( - interpret, - interpretOpen, - Value(..), -) where - -import Control.Monad (foldM, join, when, forM_) -import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.State.Strict (runStateT, get, put) -import Data.Bifunctor (bimap) -import Data.Bitraversable (bitraverse) -import Data.Char (isSpace) -import Data.Functor.Identity -import qualified Data.Functor.Product as Product -import Data.Int (Int64) -import Data.IORef -import Data.Tuple (swap) -import System.IO (hPutStrLn, stderr) -import System.IO.Unsafe (unsafePerformIO) - -import Debug.Trace - -import Array -import AST -import AST.Pretty -import AST.Sparse.Types -import Data -import Interpreter.Rep - - -newtype AcM s a = AcM { unAcM :: IO a } - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - -acmDebugLog :: String -> AcM s () -acmDebugLog s = AcM (hPutStrLn stderr s) - -data V t = V (STy t) (Rep t) - -interpret :: Ex '[] t -> Rep t -interpret = interpretOpen False SNil SNil - --- | Bool: whether to trace execution with debug prints (very verbose) -interpretOpen :: Bool -> SList STy env -> SList Value env -> Ex env t -> Rep t -interpretOpen prints env venv e = - runAcM $ - let ?depth = 0 - ?prints = prints - in interpret' (slistMap (\(Product.Pair t (Value v)) -> V t v) (slistZip env venv)) e - -interpret' :: forall env t s. (?prints :: Bool, ?depth :: Int) - => SList V env -> Ex env t -> AcM s (Rep t) -interpret' env e = do - let tenv = slistMap (\(V t _) -> t) env - let dep = ?depth - let lenlimit = max 20 (100 - dep) - let replace a b = map (\c -> if c == a then b else c) - let trunc s | length s > lenlimit = take (lenlimit - 3) (replace '\n' ' ' s) ++ "..." - | otherwise = replace '\n' ' ' s - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "ev: " ++ trunc (ppExpr tenv e) - res <- let ?depth = dep + 1 in interpret'Rec env e - when ?prints $ acmDebugLog $ replicate dep ' ' ++ "<- " ++ showValue 0 (typeOf e) res "" - return res - -interpret'Rec :: forall env t s. (?prints :: Bool, ?depth :: Int) => SList V env -> Ex env t -> AcM s (Rep t) -interpret'Rec env = \case - EVar _ _ i -> case slistIdx env i of V _ x -> return x - ELet _ a b -> do - x <- interpret' env a - let ?depth = ?depth - 1 in interpret' (V (typeOf a) x `SCons` env) b - expr | False && trace (" " ++ takeWhile (not . isSpace) (show expr)) False -> undefined - EPair _ a b -> (,) <$> interpret' env a <*> interpret' env b - EFst _ e -> fst <$> interpret' env e - ESnd _ e -> snd <$> interpret' env e - ENil _ -> return () - EInl _ _ e -> Left <$> interpret' env e - EInr _ _ e -> Right <$> interpret' env e - ECase _ e a b -> - let STEither t1 t2 = typeOf e - in interpret' env e >>= \case - Left x -> interpret' (V t1 x `SCons` env) a - Right y -> interpret' (V t2 y `SCons` env) b - ENothing _ _ -> return Nothing - EJust _ e -> Just <$> interpret' env e - EMaybe _ a b e -> - let STMaybe t1 = typeOf e - in maybe (interpret' env a) (\x -> interpret' (V t1 x `SCons` env) b) =<< interpret' env e - ELNil _ _ _ -> return Nothing - ELInl _ _ e -> Just . Left <$> interpret' env e - ELInr _ _ e -> Just . Right <$> interpret' env e - ELCase _ e a b c -> - let STLEither t1 t2 = typeOf e - in interpret' env e >>= \case - Nothing -> interpret' env a - Just (Left x) -> interpret' (V t1 x `SCons` env) b - Just (Right y) -> interpret' (V t2 y `SCons` env) c - EConstArr _ _ _ v -> return v - EBuild _ dim a b -> do - sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env a - arrayGenerateM sh (\idx -> interpret' (V (tTup (sreplicate dim tIx)) (tupRepIdx ixUncons dim idx) `SCons` env) b) - EMap _ a b -> do - let STArr _ t = typeOf b - arrayMapM (\x -> interpret' (V t x `SCons` env) a) =<< interpret' env b - EFold1Inner _ _ a b c -> do - let t = typeOf b - let f = \x -> interpret' (V (STPair t t) x `SCons` env) a - x0 <- interpret' env b - arr <- interpret' env c - let sh `ShCons` n = arrayShape arr - arrayGenerateM sh $ \idx -> foldM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - ESum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ arrayGenerate sh $ \idx -> sum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - EUnit _ e -> arrayGenerateLinM ShNil (\_ -> interpret' env e) - EReplicate1Inner _ a b -> do - n <- fromIntegral @Int64 @Int <$> interpret' env a - arr <- interpret' env b - let sh = arrayShape arr - return $ arrayGenerate (sh `ShCons` n) (\(idx `IxCons` _) -> arrayIndex arr idx) - EMaximum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> maximum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EMinimum1Inner _ e -> do - arr <- interpret' env e - let STArr _ (STScal t) = typeOf e - sh `ShCons` n = arrayShape arr - numericIsNum t $ return $ - arrayGenerate sh (\idx -> minimum [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n-1]]) - EReshape _ dim esh e -> do - sh <- unTupRepIdx ShNil ShCons dim <$> interpret' env esh - arr <- interpret' env e - return $ arrayReshape sh arr - EZip _ a b -> do - arr1 <- interpret' env a - arr2 <- interpret' env b - let sh = arrayShape arr1 - when (sh /= arrayShape arr2) $ - error "Interpreter: mismatched shapes in EZip" - return $ arrayGenerateLin sh (\i -> (arr1 `arrayIndexLinear` i, arr2 `arrayIndexLinear` i)) - EFold1InnerD1 _ _ a b c -> do - let t = typeOf b - let f = \x -> interpret' (V (STPair t t) x `SCons` env) a - x0 <- interpret' env b - arr <- interpret' env c - let sh `ShCons` n = arrayShape arr - -- TODO: this is very inefficient, even for an interpreter; with mutable - -- arrays this can be a lot better with no lists - res <- arrayGenerateM sh $ \idx -> do - (y, stores) <- mapAccumLM (curry f) x0 [arrayIndex arr (idx `IxCons` i) | i <- [0 .. n - 1]] - return (y, arrayFromList (ShNil `ShCons` n) stores) - return (arrayMap fst res - ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> - arrayIndexLinear (snd (arrayIndex res idx)) i) - EFold1InnerD2 _ _ ef ebog ed -> do - let STArr _ tB = typeOf ebog - STArr _ t2 = typeOf ed - let f = \tape ctg -> interpret' (V t2 ctg `SCons` V tB tape `SCons` env) ef - bog <- interpret' env ebog - arrctg <- interpret' env ed - let sh `ShCons` n = arrayShape bog - when (sh /= arrayShape arrctg) $ error "Interpreter: mismatched shapes in EFold1InnerD2" - res <- arrayGenerateM sh $ \idx -> do - let loop i !ctg !inpctgs | i < 0 = return (ctg, inpctgs) - loop i !ctg !inpctgs = do - let b = arrayIndex bog (idx `IxCons` i) - (ctg1, ctg2) <- f b ctg - loop (i - 1) ctg1 (ctg2 : inpctgs) - (x0ctg, inpctg) <- loop (n - 1) (arrayIndex arrctg idx) [] - return (x0ctg, arrayFromList (ShNil `ShCons` n) inpctg) - return (arrayMap fst res - ,arrayGenerate (sh `ShCons` n) $ \(idx `IxCons` i) -> - arrayIndexLinear (snd (arrayIndex res idx)) i) - EConst _ _ v -> return v - EIdx0 _ e -> (`arrayIndexLinear` 0) <$> interpret' env e - EIdx1 _ a b -> arrayIndex1 <$> interpret' env a <*> (fromIntegral @Int64 @Int <$> interpret' env b) - EIdx _ a b -> - let STArr n _ = typeOf a - in arrayIndex <$> interpret' env a <*> (unTupRepIdx IxNil IxCons n <$> interpret' env b) - EShape _ e | STArr n _ <- typeOf e -> tupRepIdx shUncons n . arrayShape <$> interpret' env e - EOp _ op e -> interpretOp op <$> interpret' env e - ECustom _ t1 t2 _ pr _ _ e1 e2 -> do - e1' <- interpret' env e1 - e2' <- interpret' env e2 - interpret' (V t2 e2' `SCons` V t1 e1' `SCons` SNil) pr - ERecompute _ e -> interpret' env e - EWith _ t e1 e2 -> do - initval <- interpret' env e1 - withAccum t (typeOf e2) initval $ \accum -> - interpret' (V (STAccum t) accum `SCons` env) e2 - EAccum _ t p e1 sp e2 e3 -> do - idx <- interpret' env e1 - val <- interpret' env e2 - accum <- interpret' env e3 - accumAddSparseD t p accum idx sp val - EZero _ t ezi -> do - zi <- interpret' env ezi - return $ zeroM t zi - EDeepZero _ t ezi -> do - zi <- interpret' env ezi - return $ deepZeroM t zi - EPlus _ t a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ addM t a' b' - EOneHot _ t p a b -> do - a' <- interpret' env a - b' <- interpret' env b - return $ onehotM p t a' b' - EError _ _ s -> error $ "Interpreter: Program threw error: " ++ s - -interpretOp :: SOp a t -> Rep a -> Rep t -interpretOp op arg = case op of - OAdd st -> numericIsNum st $ uncurry (+) arg - OMul st -> numericIsNum st $ uncurry (*) arg - ONeg st -> numericIsNum st $ negate arg - OLt st -> numericIsNum st $ uncurry (<) arg - OLe st -> numericIsNum st $ uncurry (<=) arg - OEq st -> styIsEq st $ uncurry (==) arg - ONot -> not arg - OAnd -> uncurry (&&) arg - OOr -> uncurry (||) arg - OIf -> if arg then Left () else Right () - ORound64 -> round arg - OToFl64 -> fromIntegral arg - ORecip st -> floatingIsFractional st $ recip arg - OExp st -> floatingIsFractional st $ exp arg - OLog st -> floatingIsFractional st $ log arg - OIDiv st -> integralIsIntegral st $ uncurry quot arg - OMod st -> integralIsIntegral st $ uncurry rem arg - where - styIsEq :: SScalTy t -> (Eq (Rep (TScal t)) => r) -> r - styIsEq STI32 = id - styIsEq STI64 = id - styIsEq STF32 = id - styIsEq STF64 = id - styIsEq STBool = id - -zeroM :: SMTy t -> Rep (ZeroInfo t) -> Rep t -zeroM typ zi = case typ of - SMTNil -> () - SMTPair t1 t2 -> (zeroM t1 (fst zi), zeroM t2 (snd zi)) - SMTLEither _ _ -> Nothing - SMTMaybe _ -> Nothing - SMTArr _ t -> arrayMap (zeroM t) zi - SMTScal sty -> case sty of - STI32 -> 0 - STI64 -> 0 - STF32 -> 0.0 - STF64 -> 0.0 - -deepZeroM :: SMTy t -> Rep (DeepZeroInfo t) -> Rep t -deepZeroM typ zi = case typ of - SMTNil -> () - SMTPair t1 t2 -> (deepZeroM t1 (fst zi), deepZeroM t2 (snd zi)) - SMTLEither t1 t2 -> fmap (bimap (deepZeroM t1) (deepZeroM t2)) zi - SMTMaybe t -> fmap (deepZeroM t) zi - SMTArr _ t -> arrayMap (deepZeroM t) zi - SMTScal sty -> case sty of - STI32 -> 0 - STI64 -> 0 - STF32 -> 0.0 - STF64 -> 0.0 - -addM :: SMTy t -> Rep t -> Rep t -> Rep t -addM typ a b = case typ of - SMTNil -> () - SMTPair t1 t2 -> (addM t1 (fst a) (fst b), addM t2 (snd a) (snd b)) - SMTLEither t1 t2 -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just (Left x), Just (Left y)) -> Just (Left (addM t1 x y)) - (Just (Right x), Just (Right y)) -> Just (Right (addM t2 x y)) - _ -> error "Plus of inconsistent LEithers" - SMTMaybe t -> case (a, b) of - (Nothing, _) -> b - (_, Nothing) -> a - (Just x, Just y) -> Just (addM t x y) - SMTArr _ t -> - let sh1 = arrayShape a - sh2 = arrayShape b - in if | shapeSize sh1 == 0 -> b - | shapeSize sh2 == 0 -> a - | sh1 == sh2 -> arrayGenerateLin sh1 (\i -> addM t (arrayIndexLinear a i) (arrayIndexLinear b i)) - | otherwise -> error "Plus of inconsistently shaped arrays" - SMTScal sty -> numericIsNum sty $ a + b - -onehotM :: SAcPrj p a b -> SMTy a -> Rep (AcIdxS p a) -> Rep b -> Rep a -onehotM SAPHere _ _ val = val -onehotM (SAPFst prj) (SMTPair a b) idx val = (onehotM prj a (fst idx) val, zeroM b (snd idx)) -onehotM (SAPSnd prj) (SMTPair a b) idx val = (zeroM a (fst idx), onehotM prj b (snd idx) val) -onehotM (SAPLeft prj) (SMTLEither a _) idx val = Just (Left (onehotM prj a idx val)) -onehotM (SAPRight prj) (SMTLEither _ b) idx val = Just (Right (onehotM prj b idx val)) -onehotM (SAPJust prj) (SMTMaybe a) idx val = Just (onehotM prj a idx val) -onehotM (SAPArrIdx prj) (SMTArr n a) idx val = - runIdentity $ onehotArray (\idx' -> Identity (onehotM prj a idx' val)) (\zi -> Identity (zeroM a zi)) n prj idx - -withAccum :: SMTy t -> STy a -> Rep t -> (RepAc t -> AcM s (Rep a)) -> AcM s (Rep a, Rep t) -withAccum t _ initval f = AcM $ do - accum <- newAcDense t initval - out <- unAcM $ f accum - val <- readAc t accum - return (out, val) - -newAcDense :: SMTy a -> Rep a -> IO (RepAc a) -newAcDense typ val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (newAcDense t1) (newAcDense t2) val - SMTLEither t1 t2 -> newIORef =<< traverse (bitraverse (newAcDense t1) (newAcDense t2)) val - SMTMaybe t1 -> newIORef =<< traverse (newAcDense t1) val - SMTArr _ t1 -> arrayMapM (newAcDense t1) val - SMTScal _ -> newIORef val - -onehotArray :: Monad m - => (Rep (AcIdxS p a) -> m v) -- ^ the "one" - -> (Rep (ZeroInfo a) -> m v) -- ^ the "zero" - -> SNat n -> SAcPrj p a b -> Rep (AcIdxS (APArrIdx p) (TArr n a)) -> m (Array n v) -onehotArray mkone mkzero n _ ((arrindex', ziarr), idx) = - let arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ziarr - !linindex = toLinearIndex arrsh arrindex - in arrayGenerateLinM arrsh (\i -> if i == linindex then mkone idx else mkzero (ziarr `arrayIndexLinear` i)) - -readAc :: SMTy t -> RepAc t -> IO (Rep t) -readAc typ val = case typ of - SMTNil -> return () - SMTPair t1 t2 -> bitraverse (readAc t1) (readAc t2) val - SMTLEither t1 t2 -> traverse (bitraverse (readAc t1) (readAc t2)) =<< readIORef val - SMTMaybe t -> traverse (readAc t) =<< readIORef val - SMTArr _ t -> traverse (readAc t) val - SMTScal _ -> readIORef val - -accumAddSparseD :: SMTy a -> SAcPrj p a b -> RepAc a -> Rep (AcIdxD p a) -> Sparse b c -> Rep c -> AcM s () -accumAddSparseD typ prj ref idx sp val = case (typ, prj) of - (_, SAPHere) -> accumAddDense typ ref sp val - - (SMTPair t1 _, SAPFst prj') -> accumAddSparseD t1 prj' (fst ref) idx sp val - (SMTPair _ t2, SAPSnd prj') -> accumAddSparseD t2 prj' (snd ref) idx sp val - - (SMTLEither t1 _, SAPLeft prj') -> - realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") - (\case Left ac1 -> accumAddSparseD t1 prj' ac1 idx sp val - Right{} -> error "Mismatched Either in accumAddSparseD (r +l)") - (SMTLEither _ t2, SAPRight prj') -> - realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") - (\case Right ac2 -> accumAddSparseD t2 prj' ac2 idx sp val - Left{} -> error "Mismatched Either in accumAddSparseD (l +r)") - - (SMTMaybe t1, SAPJust prj') -> - realiseMaybeSparse ref (error "Accumulating Just into Nothing (EWith requires EDeepZero)") - (\ac -> accumAddSparseD t1 prj' ac idx sp val) - - (SMTArr n t1, SAPArrIdx prj') -> - let (arrindex', idx') = idx - arrindex = unTupRepIdx IxNil IxCons n arrindex' - arrsh = arrayShape ref - linindex = toLinearIndex arrsh arrindex - in accumAddSparseD t1 prj' (arrayIndexLinear ref linindex) idx' sp val - -accumAddDense :: SMTy a -> RepAc a -> Sparse a b -> Rep b -> AcM s () -accumAddDense typ ref sp val = case (typ, sp) of - (_, _) | isAbsent sp -> return () - (_, SpAbsent) -> return () - (_, SpSparse s) -> - case val of - Nothing -> return () - Just val' -> accumAddDense typ ref s val' - (SMTPair t1 t2, SpPair s1 s2) -> do - accumAddDense t1 (fst ref) s1 (fst val) - accumAddDense t2 (snd ref) s2 (snd val) - (SMTLEither t1 t2, SpLEither s1 s2) -> - case val of - Nothing -> return () - Just (Left val1) -> - realiseMaybeSparse ref (error "Accumulating Left into LNil (EWith requires EDeepZero)") - (\case Left ac1 -> accumAddDense t1 ac1 s1 val1 - Right{} -> error "Mismatched Either in accumAddSparse (r +l)") - Just (Right val2) -> - realiseMaybeSparse ref (error "Accumulating Right into LNil (EWith requires EDeepZero)") - (\case Right ac2 -> accumAddDense t2 ac2 s2 val2 - Left{} -> error "Mismatched Either in accumAddSparse (l +r)") - (SMTMaybe t, SpMaybe s) -> - case val of - Nothing -> return () - Just val' -> - realiseMaybeSparse ref (error "Accumulating Just into Nothing (EAccum requires EDeepZero)") - (\ac -> accumAddDense t ac s val') - (SMTArr _ t1, SpArr s) -> - forM_ [0 .. arraySize ref - 1] $ \i -> - accumAddDense t1 (arrayIndexLinear ref i) s (arrayIndexLinear val i) - (SMTScal sty, SpScal) -> numericIsNum sty $ AcM $ atomicModifyIORef' ref (\x -> (x + val, ())) - --- TODO: makeval is always 'error' now. Simplify? -realiseMaybeSparse :: IORef (Maybe a) -> IO a -> (a -> AcM s ()) -> AcM s () -realiseMaybeSparse ref makeval modifyval = - -- Try modifying what's already in ref. The 'join' makes the snd - -- of the function's return value a _continuation_ that is run after - -- the critical section ends. - AcM $ join $ atomicModifyIORef' ref $ \ac -> case ac of - -- Oops, ref's contents was still sparse. Have to initialise - -- it first, then try again. - Nothing -> (ac, do val <- makeval - join $ atomicModifyIORef' ref $ \ac' -> case ac' of - Nothing -> (Just val, return ()) - Just val' -> (ac', unAcM $ modifyval val')) - -- Yep, ref already had a value in there, so we can just add - -- val' to it recursively. - Just val -> (ac, unAcM $ modifyval val) - - -numericIsNum :: ScalIsNumeric st ~ True => SScalTy st -> ((Num (ScalRep st), Ord (ScalRep st)) => r) -> r -numericIsNum STI32 = id -numericIsNum STI64 = id -numericIsNum STF32 = id -numericIsNum STF64 = id - -floatingIsFractional :: ScalIsFloating st ~ True => SScalTy st -> ((Floating (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsFloating st ~ True) => r) -> r -floatingIsFractional STF32 = id -floatingIsFractional STF64 = id - -integralIsIntegral :: ScalIsIntegral st ~ True => SScalTy st -> ((Integral (ScalRep st), Ord (ScalRep st), ScalIsNumeric st ~ True, ScalIsIntegral st ~ True) => r) -> r -integralIsIntegral STI32 = id -integralIsIntegral STI64 = id - -unTupRepIdx :: f Z -> (forall m. f m -> Int -> f (S m)) - -> SNat n -> Rep (Tup (Replicate n TIx)) -> f n -unTupRepIdx nil _ SZ _ = nil -unTupRepIdx nil cons (SS n) (idx, i) = unTupRepIdx nil cons n idx `cons` fromIntegral @Int64 @Int i - -tupRepIdx :: (forall m. f (S m) -> (f m, Int)) - -> SNat n -> f n -> Rep (Tup (Replicate n TIx)) -tupRepIdx _ SZ _ = () -tupRepIdx uncons (SS n) tup = - let (tup', i) = uncons tup - in ((,) $! tupRepIdx uncons n tup') $! fromIntegral @Int @Int64 i - -ixUncons :: Index (S n) -> (Index n, Int) -ixUncons (IxCons idx i) = (idx, i) - -shUncons :: Shape (S n) -> (Shape n, Int) -shUncons (ShCons idx i) = (idx, i) - -mapAccumLM :: (Traversable t, Monad m) => (s -> a -> m (s, b)) -> s -> t a -> m (s, t b) -mapAccumLM f s0 = fmap swap . flip runStateT s0 . traverse f' - where f' x = do - s <- get - (s', y) <- lift $ f s x - put s' - return y diff --git a/src/Interpreter/Accum.hs b/src/Interpreter/Accum.hs deleted file mode 100644 index af7be1e..0000000 --- a/src/Interpreter/Accum.hs +++ /dev/null @@ -1,366 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( - AcM, - runAcM, - Rep', - Accum, - withAccum, - accumAdd, - inParallel, -) where - -import Control.Concurrent -import Control.Monad (when, forM_) -import Data.Bifunctor (second) -import Data.Proxy -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) -import Foreign.Storable (sizeOf) -import GHC.Exts -import GHC.Float -import GHC.Int -import GHC.IO (IO(..)) -import GHC.Word -import System.IO.Unsafe (unsafePerformIO) - -import Array -import AST -import Data - - -newtype AcM s a = AcM (IO a) - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - --- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. -type family Rep' s t where - Rep' s TNil = () - Rep' s (TPair a b) = (Rep' s a, Rep' s b) - Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) - Rep' s (TMaybe t) = Maybe (Rep' s t) - Rep' s (TArr n t) = Array n (Rep' s t) - Rep' s (TScal sty) = ScalRep sty - Rep' s (TAccum t) = Accum s t - --- | Floats and integers are accumulated; booleans are left as-is. -data Accum s t = Accum (STy t) (ForeignPtr ()) - -tSize :: Proxy s -> STy t -> Rep' s t -> Int -tSize p ty x = tSize' p ty (Just x) - -tSize' :: Proxy s -> STy t -> Int -tSize' p typ = case typ of - STNil -> 0 - STPair a b -> tSize' p a + tSize' p b - STEither a b -> 1 + max (tSize' p a) (tSize' p b) - -- Representation of Maybe t is the same as Either () t; the add operation is different, however. - STMaybe t -> tSize' p (STEither STNil t) - STArr ndim t -> - case val of - Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing - STScal sty -> goScal sty - STAccum{} -> error "Nested accumulators unsupported" - where - goScal :: SScalTy t -> Int - goScal STI32 = 4 - goScal STI64 = 8 - goScal STF32 = 4 - goScal STF64 = 8 - goScal STBool = 1 - --- | This operation does not commute with 'accumAdd', so it must be used with --- care. Furthermore it must be used on exactly the same value as tSize was --- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () -accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - go inarr ty val off = case ty of - STNil -> return off - STPair a b -> do - off1 <- go inarr a (fst val) off - go inarr b (snd val) off1 - STEither a b -> do - let !(I# off#) = off - off1 <- case val of - Left x -> do - let !(I8# tag#) = 0 - writeInt8# addr# off# tag# - go inarr a x (off + 1) - Right y -> do - let !(I8# tag#) = 1 - writeInt8# addr# off# tag# - go inarr b y (off + 1) - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) - else return off1 - -- Representation is the same, but add operation is different - STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off - STArr _ t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - off1 <- goShape (arrayShape val) off - let eltsize = tSize' (Proxy @s) t Nothing - n = arraySize val - traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val - return (off1 + eltsize * n) - STScal sty -> goScal sty val off - STAccum{} -> error "Nested accumulators unsupported" - - goShape :: Shape n -> Int -> IO Int - goShape ShNil off = return off - goShape (ShCons sh n) off = do - off1@(I# off1#) <- goShape sh off - let !(I64# n'#) = fromIntegral n - writeInt64# addr# off1# n'# - return (off1 + 8) - - goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x - goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x - goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x - goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x - goScal STBool b off@(I# off#) = do - let !(I8# i) = fromIntegral (fromEnum b) - off + 1 <$ writeInt8# addr# off# i - - in () <$ go False topty top_value 0 - -accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) -accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') - go inarr ty off = case ty of - STNil -> return (off, ()) - STPair a b -> do - (off1, x) <- go inarr a off - (off2, y) <- go inarr b off1 - return (off1 + off2, (x, y)) - STEither a b -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - (off1, val) <- case tag of - 0 -> fmap Left <$> go inarr a (off + 1) - 1 -> fmap Right <$> go inarr b (off + 1) - _ -> error "Invalid tag in accum memory" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) - else return (off1, val) - -- Representation is the same, but add operation is different - STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off - STArr ndim t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' (Proxy @s) t Nothing - n = shapeSize sh - arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) - return (off1 + eltsize * n, arr) - STScal sty -> goScal sty off - STAccum{} -> error "Nested accumulators unsupported" - - goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') - goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# - goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# - goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# - goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# - goScal STBool off@(I# off#) = do - i8 <- readInt8 addr# off# - return (off + 1, toEnum (fromIntegral i8)) - - in snd <$> go False topty 0 - -readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) -readShape _ SZ off = return (off, ShNil) -readShape mbarr (SS ndim) off = do - (off1@(I# off1#), sh) <- readShape mbarr ndim off - n' <- readInt64 mbarr off1# - return (off1 + 8, ShCons sh (fromIntegral n')) - --- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of --- the list. -data InvShape n where - IShNil :: InvShape Z - IShCons :: Int -- ^ How many subarrays are there? - -> Int -- ^ What is the size of all subarrays together? - -> InvShape n -- ^ Sub array inverted shape - -> InvShape (S n) - -ishSize :: InvShape n -> Int -ishSize IShNil = 1 -ishSize (IShCons _ sz _) = sz - -invertShape :: forall n. Shape n -> InvShape n -invertShape | Refl <- lemPlusZero @n = flip go IShNil - where - go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) - go ShNil ish = ish - go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) - -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () -accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () - go inarr ty SZ () val off = () <$ performAdd inarr ty val off - go inarr ty (SS dep) idx val off = case (ty, idx, val) of - (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" - (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" - (STMaybe t, _, _) -> _ idx val - (STArr rank eltty, _, _) - | inarr -> error "Nested arrays" - | otherwise -> do - (off1, ish) <- second invertShape <$> readShape addr# rank off - goArr (SS dep) ish eltty idx val off1 - (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" - (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" - (STAccum{}, _, _) -> error "Nested accumulators unsupported" - - goArr :: SNat i' -> InvShape n -> STy t' - -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () - goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off - goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off - goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep' s TIx) @Int i - when (i' < 0 || i' >= n) $ - error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" - goArr depm1 ish t1 idx val (off + i' * ishSize ish) - - performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int - performAddArr arraySz eltty val off = do - let eltsize = tSize' (Proxy @s) eltty Nothing - forM_ [0 .. arraySz - 1] $ \lini -> - performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) - return (off + arraySz * eltsize) - - performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - performAdd inarr ty val off = case ty of - STNil -> return off - STPair t1 t2 -> do - off1 <- performAdd inarr t1 (fst val) off - performAdd inarr t2 (snd val) off1 - STEither t1 t2 -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - off1 <- case (val, tag) of - (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) - (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) - _ -> error "accumAdd: Tag mismatch for Either" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) - else return off1 - STArr n ty' - | inarr -> error "Nested array" - | otherwise -> do - (off1, sh) <- readShape addr# n off - performAddArr (shapeSize sh) ty' val off1 - STScal ty' -> performAddScal ty' val off - STAccum{} -> error "Nested accumulators unsupported" - - performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - performAddScal STI32 (I32# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 4 - = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) - | otherwise - = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) - performAddScal STI64 (I64# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 8 - = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) - | otherwise - = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) - performAddScal STF32 x off@(I# off#) = - off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) - performAddScal STF64 x off@(I# off#) = - off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) - performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans - - casLoop :: Eq w - => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) - -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) - -> Addr# -- ^ Address to attempt to modify - -> (w -> w) -- ^ Operation to apply to the value - -> IO () - casLoop readOp casOp addr modify = readOp addr 0# >>= loop - where - loop value = do - value' <- casOp addr value (modify value) - if value == value' - then return () - else loop value' - - in () <$ go False topty top_depth top_index top_value 0 - -withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) -withAccum ty start fun = do - -- The initial write must happen before any of the adds or reads, so it makes - -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) - ptr <- newForeignPtr finalizerFree buffer - let accum = Accum ty ptr - accumWrite accum start - return accum - b <- fun accum - out <- accumRead accum - return (b, out) - -inParallel :: [AcM s t] -> AcM s [t] -inParallel actions = AcM $ do - mvars <- mapM (\_ -> newEmptyMVar) actions - forM_ (zip actions mvars) $ \(AcM action, var) -> - forkIO $ action >>= putMVar var - mapM takeMVar mvars - --- | Offset is in bytes. -readInt8 :: Addr# -> Int# -> IO Int8 -readInt32 :: Addr# -> Int# -> IO Int32 -readInt64 :: Addr# -> Int# -> IO Int64 -readWord32 :: Addr# -> Int# -> IO Word32 -readWord64 :: Addr# -> Int# -> IO Word64 -readFloat :: Addr# -> Int# -> IO Float -readDouble :: Addr# -> Int# -> IO Double -readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) -readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) -readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) -readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) -readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) -readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) -readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) - -writeInt8# :: Addr# -> Int# -> Int8# -> IO () -writeInt32# :: Addr# -> Int# -> Int32# -> IO () -writeInt64# :: Addr# -> Int# -> Int64# -> IO () -writeFloat# :: Addr# -> Int# -> Float# -> IO () -writeDouble# :: Addr# -> Int# -> Double# -> IO () -writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) - -fetchAddWord# :: Addr# -> Int# -> Word# -> IO () -fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) - -atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 -atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 -atomicCasWord32Addr addr (W32# expected) (W32# desired) = - IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) -atomicCasWord64Addr addr (W64# expected) (W64# desired) = - IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/Interpreter/AccumOld.hs b/src/Interpreter/AccumOld.hs deleted file mode 100644 index af7be1e..0000000 --- a/src/Interpreter/AccumOld.hs +++ /dev/null @@ -1,366 +0,0 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE MagicHash #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UnboxedTuples #-} -module Interpreter.Accum ( - AcM, - runAcM, - Rep', - Accum, - withAccum, - accumAdd, - inParallel, -) where - -import Control.Concurrent -import Control.Monad (when, forM_) -import Data.Bifunctor (second) -import Data.Proxy -import Foreign (ForeignPtr, mallocBytes, newForeignPtr, finalizerFree, withForeignPtr) -import Foreign.Storable (sizeOf) -import GHC.Exts -import GHC.Float -import GHC.Int -import GHC.IO (IO(..)) -import GHC.Word -import System.IO.Unsafe (unsafePerformIO) - -import Array -import AST -import Data - - -newtype AcM s a = AcM (IO a) - deriving newtype (Functor, Applicative, Monad) - -runAcM :: (forall s. AcM s a) -> a -runAcM (AcM m) = unsafePerformIO m - --- | Equal to Interpreter.Rep.Rep, except that the TAccum case is defined. -type family Rep' s t where - Rep' s TNil = () - Rep' s (TPair a b) = (Rep' s a, Rep' s b) - Rep' s (TEither a b) = Either (Rep' s a) (Rep' s b) - Rep' s (TMaybe t) = Maybe (Rep' s t) - Rep' s (TArr n t) = Array n (Rep' s t) - Rep' s (TScal sty) = ScalRep sty - Rep' s (TAccum t) = Accum s t - --- | Floats and integers are accumulated; booleans are left as-is. -data Accum s t = Accum (STy t) (ForeignPtr ()) - -tSize :: Proxy s -> STy t -> Rep' s t -> Int -tSize p ty x = tSize' p ty (Just x) - -tSize' :: Proxy s -> STy t -> Int -tSize' p typ = case typ of - STNil -> 0 - STPair a b -> tSize' p a + tSize' p b - STEither a b -> 1 + max (tSize' p a) (tSize' p b) - -- Representation of Maybe t is the same as Either () t; the add operation is different, however. - STMaybe t -> tSize' p (STEither STNil t) - STArr ndim t -> - case val of - Nothing -> error "Nested arrays not supported in this implementation" - Just arr -> fromSNat ndim * 8 + arraySize arr * tSize' p t Nothing - STScal sty -> goScal sty - STAccum{} -> error "Nested accumulators unsupported" - where - goScal :: SScalTy t -> Int - goScal STI32 = 4 - goScal STI64 = 8 - goScal STF32 = 4 - goScal STF64 = 8 - goScal STBool = 1 - --- | This operation does not commute with 'accumAdd', so it must be used with --- care. Furthermore it must be used on exactly the same value as tSize was --- called on. Hence it lives in IO, not in AcM. -accumWrite :: forall s t. Accum s t -> Rep' s t -> IO () -accumWrite (Accum topty fptr) top_value = withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - go inarr ty val off = case ty of - STNil -> return off - STPair a b -> do - off1 <- go inarr a (fst val) off - go inarr b (snd val) off1 - STEither a b -> do - let !(I# off#) = off - off1 <- case val of - Left x -> do - let !(I8# tag#) = 0 - writeInt8# addr# off# tag# - go inarr a x (off + 1) - Right y -> do - let !(I8# tag#) = 1 - writeInt8# addr# off# tag# - go inarr b y (off + 1) - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing)) - else return off1 - -- Representation is the same, but add operation is different - STMaybe t -> go inarr (STEither STNil t) (maybe (Left ()) Right val) off - STArr _ t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - off1 <- goShape (arrayShape val) off - let eltsize = tSize' (Proxy @s) t Nothing - n = arraySize val - traverseArray_ (\lini x -> () <$ go True t x (off1 + eltsize * lini)) val - return (off1 + eltsize * n) - STScal sty -> goScal sty val off - STAccum{} -> error "Nested accumulators unsupported" - - goShape :: Shape n -> Int -> IO Int - goShape ShNil off = return off - goShape (ShCons sh n) off = do - off1@(I# off1#) <- goShape sh off - let !(I64# n'#) = fromIntegral n - writeInt64# addr# off1# n'# - return (off1 + 8) - - goScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - goScal STI32 (I32# x) off@(I# off#) = off + 4 <$ writeInt32# addr# off# x - goScal STI64 (I64# x) off@(I# off#) = off + 8 <$ writeInt64# addr# off# x - goScal STF32 (F# x) off@(I# off#) = off + 4 <$ writeFloat# addr# off# x - goScal STF64 (D# x) off@(I# off#) = off + 8 <$ writeDouble# addr# off# x - goScal STBool b off@(I# off#) = do - let !(I8# i) = fromIntegral (fromEnum b) - off + 1 <$ writeInt8# addr# off# i - - in () <$ go False topty top_value 0 - -accumRead :: forall s t. Accum s t -> AcM s (Rep' s t) -accumRead (Accum topty fptr) = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> Int -> IO (Int, Rep' s t') - go inarr ty off = case ty of - STNil -> return (off, ()) - STPair a b -> do - (off1, x) <- go inarr a off - (off2, y) <- go inarr b off1 - return (off1 + off2, (x, y)) - STEither a b -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - (off1, val) <- case tag of - 0 -> fmap Left <$> go inarr a (off + 1) - 1 -> fmap Right <$> go inarr b (off + 1) - _ -> error "Invalid tag in accum memory" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) a Nothing) (tSize' (Proxy @s) b Nothing), val) - else return (off1, val) - -- Representation is the same, but add operation is different - STMaybe t -> second (either (const Nothing) Just) <$> go inarr (STEither STNil t) off - STArr ndim t - | inarr -> error "Nested arrays not supported in this implementation" - | otherwise -> do - (off1, sh) <- readShape addr# ndim off - let eltsize = tSize' (Proxy @s) t Nothing - n = shapeSize sh - arr <- arrayGenerateLinM sh (\lini -> snd <$> go True t (off1 + eltsize * lini)) - return (off1 + eltsize * n, arr) - STScal sty -> goScal sty off - STAccum{} -> error "Nested accumulators unsupported" - - goScal :: SScalTy t' -> Int -> IO (Int, ScalRep t') - goScal STI32 off@(I# off#) = (off + 4,) <$> readInt32 addr# off# - goScal STI64 off@(I# off#) = (off + 8,) <$> readInt64 addr# off# - goScal STF32 off@(I# off#) = (off + 4,) <$> readFloat addr# off# - goScal STF64 off@(I# off#) = (off + 8,) <$> readDouble addr# off# - goScal STBool off@(I# off#) = do - i8 <- readInt8 addr# off# - return (off + 1, toEnum (fromIntegral i8)) - - in snd <$> go False topty 0 - -readShape :: Addr# -> SNat n -> Int -> IO (Int, Shape n) -readShape _ SZ off = return (off, ShNil) -readShape mbarr (SS ndim) off = do - (off1@(I# off1#), sh) <- readShape mbarr ndim off - n' <- readInt64 mbarr off1# - return (off1 + 8, ShCons sh (fromIntegral n')) - --- | @reverse@ of 'Shape'. The /outer/ dimension is on the left, at the head of --- the list. -data InvShape n where - IShNil :: InvShape Z - IShCons :: Int -- ^ How many subarrays are there? - -> Int -- ^ What is the size of all subarrays together? - -> InvShape n -- ^ Sub array inverted shape - -> InvShape (S n) - -ishSize :: InvShape n -> Int -ishSize IShNil = 1 -ishSize (IShCons _ sz _) = sz - -invertShape :: forall n. Shape n -> InvShape n -invertShape | Refl <- lemPlusZero @n = flip go IShNil - where - go :: forall n' m. Shape n' -> InvShape m -> InvShape (n' + m) - go ShNil ish = ish - go (sh `ShCons` n) ish | Refl <- lemPlusSuccRight @n' @m = go sh (IShCons n (n * ishSize ish) ish) - -accumAdd :: forall s t i. Accum s t -> SNat i -> Rep' s (AcIdx t i) -> Rep' s (AcVal t i) -> AcM s () -accumAdd (Accum topty fptr) top_depth top_index top_value = AcM $ withForeignPtr fptr $ \(Ptr addr#) -> - let - go :: Bool -> STy t' -> SNat i' -> Rep' s (AcIdx t' i') -> Rep' s (AcVal t' i') -> Int -> IO () - go inarr ty SZ () val off = () <$ performAdd inarr ty val off - go inarr ty (SS dep) idx val off = case (ty, idx, val) of - (STPair t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STPair _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STPair{}, _, _) -> error "Mismatching idx/val for Pair in accumAdd" - (STEither t1 _, Left idx1, Left val1) -> go inarr t1 dep idx1 val1 off - (STEither _ t2, Right idx2, Right val2) -> go inarr t2 dep idx2 val2 off - (STEither{}, _, _) -> error "Mismatching idx/val for Either in accumAdd" - (STMaybe t, _, _) -> _ idx val - (STArr rank eltty, _, _) - | inarr -> error "Nested arrays" - | otherwise -> do - (off1, ish) <- second invertShape <$> readShape addr# rank off - goArr (SS dep) ish eltty idx val off1 - (STScal{}, _, _) -> error "accumAdd: Scal impossible with nonzero depth" - (STNil, _, _) -> error "accumAdd: Nil impossible with nonzero depth" - (STAccum{}, _, _) -> error "Nested accumulators unsupported" - - goArr :: SNat i' -> InvShape n -> STy t' - -> Rep' s (AcIdx (TArr n t') i') -> Rep' s (AcVal (TArr n t') i') -> Int -> IO () - goArr SZ ish t1 () val off = () <$ performAddArr (ishSize ish) t1 val off - goArr (SS depm1) IShNil t1 idx val off = go True t1 depm1 idx val off - goArr (SS depm1) (IShCons n _ ish) t1 (i, idx) val off = do - let i' = fromIntegral @(Rep' s TIx) @Int i - when (i' < 0 || i' >= n) $ - error $ "accumAdd: index out of range: " ++ show i ++ " not in [0, " ++ show n ++ ")" - goArr depm1 ish t1 idx val (off + i' * ishSize ish) - - performAddArr :: Int -> STy t' -> Array n (Rep' s t') -> Int -> IO Int - performAddArr arraySz eltty val off = do - let eltsize = tSize' (Proxy @s) eltty Nothing - forM_ [0 .. arraySz - 1] $ \lini -> - performAdd True eltty (arrayIndexLinear val lini) (off + lini * eltsize) - return (off + arraySz * eltsize) - - performAdd :: Bool -> STy t' -> Rep' s t' -> Int -> IO Int - performAdd inarr ty val off = case ty of - STNil -> return off - STPair t1 t2 -> do - off1 <- performAdd inarr t1 (fst val) off - performAdd inarr t2 (snd val) off1 - STEither t1 t2 -> do - let !(I# off#) = off - tag <- readInt8 addr# off# - off1 <- case (val, tag) of - (Left val1, 0) -> performAdd inarr t1 val1 (off + 1) - (Right val2, 1) -> performAdd inarr t2 val2 (off + 1) - _ -> error "accumAdd: Tag mismatch for Either" - if inarr - then return (off + 1 + max (tSize' (Proxy @s) t1 Nothing) (tSize' (Proxy @s) t2 Nothing)) - else return off1 - STArr n ty' - | inarr -> error "Nested array" - | otherwise -> do - (off1, sh) <- readShape addr# n off - performAddArr (shapeSize sh) ty' val off1 - STScal ty' -> performAddScal ty' val off - STAccum{} -> error "Nested accumulators unsupported" - - performAddScal :: SScalTy t' -> ScalRep t' -> Int -> IO Int - performAddScal STI32 (I32# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 4 - = off + 4 <$ fetchAddWord# addr# off# (word32ToWord# (int32ToWord32# x#)) - | otherwise - = off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\(W32# w#) -> W32# (int32ToWord32# x# `plusWord32#` w#)) - performAddScal STI64 (I64# x#) off@(I# off#) - | sizeOf (undefined :: Int) == 8 - = off + 8 <$ fetchAddWord# addr# off# (word64ToWord# (int64ToWord64# x#)) - | otherwise - = off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\(W64# w#) -> W64# (int64ToWord64# x# `plusWord64#` w#)) - performAddScal STF32 x off@(I# off#) = - off + 4 <$ casLoop readWord32 atomicCasWord32Addr (addr# `plusAddr#` off#) (\w -> castFloatToWord32 (x + castWord32ToFloat w)) - performAddScal STF64 x off@(I# off#) = - off + 8 <$ casLoop readWord64 atomicCasWord64Addr (addr# `plusAddr#` off#) (\w -> castDoubleToWord64 (x + castWord64ToDouble w)) - performAddScal STBool _ off = return (off + 1) -- don't do anything with booleans - - casLoop :: Eq w - => (Addr# -> Int# -> IO w) -- ^ read value (from a given byte offset; will get 0#) - -> (Addr# -> w -> w -> IO w) -- ^ CAS value at address (expected -> desired -> IO observed) - -> Addr# -- ^ Address to attempt to modify - -> (w -> w) -- ^ Operation to apply to the value - -> IO () - casLoop readOp casOp addr modify = readOp addr 0# >>= loop - where - loop value = do - value' <- casOp addr value (modify value) - if value == value' - then return () - else loop value' - - in () <$ go False topty top_depth top_index top_value 0 - -withAccum :: forall t s b. STy t -> Rep' s t -> (Accum s t -> AcM s b) -> AcM s (b, Rep' s t) -withAccum ty start fun = do - -- The initial write must happen before any of the adds or reads, so it makes - -- sense to put it in IO together with the allocation, instead of in AcM. - accum <- AcM $ do buffer <- mallocBytes (tSize (Proxy @s) ty start) - ptr <- newForeignPtr finalizerFree buffer - let accum = Accum ty ptr - accumWrite accum start - return accum - b <- fun accum - out <- accumRead accum - return (b, out) - -inParallel :: [AcM s t] -> AcM s [t] -inParallel actions = AcM $ do - mvars <- mapM (\_ -> newEmptyMVar) actions - forM_ (zip actions mvars) $ \(AcM action, var) -> - forkIO $ action >>= putMVar var - mapM takeMVar mvars - --- | Offset is in bytes. -readInt8 :: Addr# -> Int# -> IO Int8 -readInt32 :: Addr# -> Int# -> IO Int32 -readInt64 :: Addr# -> Int# -> IO Int64 -readWord32 :: Addr# -> Int# -> IO Word32 -readWord64 :: Addr# -> Int# -> IO Word64 -readFloat :: Addr# -> Int# -> IO Float -readDouble :: Addr# -> Int# -> IO Double -readInt8 addr off# = IO $ \s -> case readInt8OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I8# val #) -readInt32 addr off# = IO $ \s -> case readInt32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I32# val #) -readInt64 addr off# = IO $ \s -> case readInt64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', I64# val #) -readWord32 addr off# = IO $ \s -> case readWord32OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W32# val #) -readWord64 addr off# = IO $ \s -> case readWord64OffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', W64# val #) -readFloat addr off# = IO $ \s -> case readFloatOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', F# val #) -readDouble addr off# = IO $ \s -> case readDoubleOffAddr# (addr `plusAddr#` off#) 0# s of (# s', val #) -> (# s', D# val #) - -writeInt8# :: Addr# -> Int# -> Int8# -> IO () -writeInt32# :: Addr# -> Int# -> Int32# -> IO () -writeInt64# :: Addr# -> Int# -> Int64# -> IO () -writeFloat# :: Addr# -> Int# -> Float# -> IO () -writeDouble# :: Addr# -> Int# -> Double# -> IO () -writeInt8# addr off# val = IO $ \s -> (# writeInt8OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt32# addr off# val = IO $ \s -> (# writeInt32OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeInt64# addr off# val = IO $ \s -> (# writeInt64OffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeFloat# addr off# val = IO $ \s -> (# writeFloatOffAddr# (addr `plusAddr#` off#) 0# val s, () #) -writeDouble# addr off# val = IO $ \s -> (# writeDoubleOffAddr# (addr `plusAddr#` off#) 0# val s, () #) - -fetchAddWord# :: Addr# -> Int# -> Word# -> IO () -fetchAddWord# addr off# val = IO $ \s -> case fetchAddWordAddr# (addr `plusAddr#` off#) val s of (# s', _ #) -> (# s', () #) - -atomicCasWord32Addr :: Addr# -> Word32 -> Word32 -> IO Word32 -atomicCasWord64Addr :: Addr# -> Word64 -> Word64 -> IO Word64 -atomicCasWord32Addr addr (W32# expected) (W32# desired) = - IO $ \s -> case atomicCasWord32Addr# addr expected desired s of (# s', old #) -> (# s', W32# old #) -atomicCasWord64Addr addr (W64# expected) (W64# desired) = - IO $ \s -> case atomicCasWord64Addr# addr expected desired s of (# s', old #) -> (# s', W64# old #) diff --git a/src/Interpreter/Rep.hs b/src/Interpreter/Rep.hs deleted file mode 100644 index 1682303..0000000 --- a/src/Interpreter/Rep.hs +++ /dev/null @@ -1,105 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE UndecidableInstances #-} -module Interpreter.Rep where - -import Control.DeepSeq -import Data.Coerce (coerce) -import Data.List (intersperse, intercalate) -import Data.Foldable (toList) -import Data.IORef -import GHC.Exts (withDict) - -import Array -import AST -import AST.Pretty -import Data - - -type family Rep t where - Rep TNil = () - Rep (TPair a b) = (Rep a, Rep b) - Rep (TEither a b) = Either (Rep a) (Rep b) - Rep (TLEither a b) = Maybe (Either (Rep a) (Rep b)) - Rep (TMaybe t) = Maybe (Rep t) - Rep (TArr n t) = Array n (Rep t) - Rep (TScal sty) = ScalRep sty - Rep (TAccum t) = RepAc t - --- Mutable, represents monoid types t. -type family RepAc t where - RepAc TNil = () - RepAc (TPair a b) = (RepAc a, RepAc b) - RepAc (TLEither a b) = IORef (Maybe (Either (RepAc a) (RepAc b))) - RepAc (TMaybe t) = IORef (Maybe (RepAc t)) - RepAc (TArr n t) = Array n (RepAc t) - RepAc (TScal sty) = IORef (ScalRep sty) - -newtype Value t = Value { unValue :: Rep t } - -liftV :: (Rep a -> Rep b) -> Value a -> Value b -liftV f (Value x) = Value (f x) - -liftV2 :: (Rep a -> Rep b -> Rep c) -> Value a -> Value b -> Value c -liftV2 f (Value x) (Value y) = Value (f x y) - -vPair :: Value a -> Value b -> Value (TPair a b) -vPair = liftV2 (,) - -vUnpair :: Value (TPair a b) -> (Value a, Value b) -vUnpair (Value (x, y)) = (Value x, Value y) - -showValue :: Int -> STy t -> Rep t -> ShowS -showValue _ STNil () = showString "()" -showValue _ (STPair a b) (x, y) = showString "(" . showValue 0 a x . showString "," . showValue 0 b y . showString ")" -showValue d (STEither a _) (Left x) = showParen (d > 10) $ showString "Inl " . showValue 11 a x -showValue d (STEither _ b) (Right y) = showParen (d > 10) $ showString "Inr " . showValue 11 b y -showValue _ (STLEither _ _) Nothing = showString "LNil" -showValue d (STLEither a _) (Just (Left x)) = showParen (d > 10) $ showString "LInl " . showValue 11 a x -showValue d (STLEither _ b) (Just (Right y)) = showParen (d > 10) $ showString "LInr " . showValue 11 b y -showValue _ (STMaybe _) Nothing = showString "Nothing" -showValue d (STMaybe t) (Just x) = showParen (d > 10) $ showString "Just " . showValue 11 t x -showValue d (STArr _ t) arr = showParen (d > 10) $ - showString "arrayFromList " . showsPrec 11 (arrayShape arr) - . showString " [" - . foldr (.) id (intersperse (showString ",") $ map (showValue 0 t) (toList arr)) - . showString "]" -showValue d (STScal sty) x = case sty of - STF32 -> showsPrec d x - STF64 -> showsPrec d x - STI32 -> showsPrec d x - STI64 -> showsPrec d x - STBool -> showsPrec d x -showValue _ (STAccum t) _ = showString $ "" - -showEnv :: SList STy env -> SList Value env -> String -showEnv = \env vals -> "[" ++ intercalate ", " (showEntries env vals) ++ "]" - where - showEntries :: SList STy env -> SList Value env -> [String] - showEntries SNil SNil = [] - showEntries (t `SCons` env) (Value x `SCons` xs) = showValue 0 t x "" : showEntries env xs - -rnfRep :: STy t -> Rep t -> () -rnfRep STNil () = () -rnfRep (STPair a b) (x, y) = rnfRep a x `seq` rnfRep b y -rnfRep (STEither a _) (Left x) = rnfRep a x -rnfRep (STEither _ b) (Right y) = rnfRep b y -rnfRep (STLEither _ _) Nothing = () -rnfRep (STLEither a _) (Just (Left x)) = rnfRep a x -rnfRep (STLEither _ b) (Just (Right y)) = rnfRep b y -rnfRep (STMaybe _) Nothing = () -rnfRep (STMaybe t) (Just x) = rnfRep t x -rnfRep (STArr (_ :: SNat n) (t :: STy t2)) arr = - withDict @(KnownTy t2) t $ rnf (coerce @(Array n (Rep t2)) @(Array n (Value t2)) arr) -rnfRep (STScal t) x = case t of - STI32 -> rnf x - STI64 -> rnf x - STF32 -> rnf x - STF64 -> rnf x - STBool -> rnf x -rnfRep STAccum{} _ = error "Cannot rnf accumulators" - -instance KnownTy t => NFData (Value t) where - rnf (Value x) = rnfRep (knownTy @t) x diff --git a/src/Language.hs b/src/Language.hs deleted file mode 100644 index 4886ddc..0000000 --- a/src/Language.hs +++ /dev/null @@ -1,267 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitForAll #-} -{-# LANGUAGE OverloadedLabels #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE TypeApplications #-} -module Language ( - fromNamed, - NExpr, - Ex, - module Language, - module AST.Types, - module Data, - Lookup, -) where - -import GHC.TypeLits (withSomeSSymbol, symbolVal, SSymbol, pattern SSymbol) - -import Array -import AST -import AST.Sparse.Types -import AST.Types -import CHAD.Types -import Data -import Language.AST - - -data a :-> b = a :-> b - deriving (Show) -infixr 0 :-> - - -body :: NExpr env t -> NFun env env t -body = NBody - -lambda :: forall a name env env' t. Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t -lambda = NLam - -inline :: NFun '[] params t -> SList (NExpr env) (UnName params) -> NExpr env t -inline = inlineNFun - --- To be used to construct the argument list for 'inline'. --- --- > let fun = lambda @(TScal TF64) #x $ lambda @(TScal TF64) #y $ body $ #x + #y --- > in inline fun (SNil .$ 16 .$ 26) -(.$) :: SList f list -> f a -> SList f (a : list) -(.$) = flip SCons - - -let_ :: forall a t env name. Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t -let_ = NELet - -pair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) -pair = NEPair - -fst_ :: NExpr env (TPair a b) -> NExpr env a -fst_ = NEFst - -snd_ :: NExpr env (TPair a b) -> NExpr env b -snd_ = NESnd - -nil :: NExpr env TNil -nil = NENil - -inl :: KnownTy b => NExpr env a -> NExpr env (TEither a b) -inl = NEInl knownTy - -inr :: KnownTy a => NExpr env b -> NExpr env (TEither a b) -inr = NEInr knownTy - -case_ :: NExpr env (TEither a b) -> (Var name1 a :-> NExpr ('(name1, a) : env) c) -> (Var name2 b :-> NExpr ('(name2, b) : env) c) -> NExpr env c -case_ e (v1 :-> e1) (v2 :-> e2) = NECase e v1 e1 v2 e2 - -nothing :: KnownTy a => NExpr env (TMaybe a) -nothing = NENothing knownTy - -just :: NExpr env a -> NExpr env (TMaybe a) -just = NEJust - -maybe_ :: NExpr env b -> (Var name a :-> NExpr ('(name, a) : env) b) -> NExpr env (TMaybe a) -> NExpr env b -maybe_ a (v :-> b) c = NEMaybe a v b c - -constArr_ :: forall t n env. (KnownNat n, KnownScalTy t) => Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) -constArr_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConstArr knownNat ty x - -build1 :: NExpr env TIx -> (Var name TIx :-> NExpr ('(name, TIx) : env) t) -> NExpr env (TArr (S Z) t) -build1 a (v :-> b) = NEBuild (SS SZ) (pair nil a) #idx (let_ v (snd_ #idx) (NEDrop (SS SZ) b)) - -build2 :: NExpr env TIx -> NExpr env TIx - -> (Var name1 TIx :-> Var name2 TIx :-> NExpr ('(name2, TIx) : '(name1, TIx) : env) t) - -> NExpr env (TArr (S (S Z)) t) -build2 a1 a2 (v1 :-> v2 :-> b) = - NEBuild (SS (SS SZ)) - (pair (pair nil a1) a2) - #idx - (let_ v1 (snd_ (fst_ #idx)) $ - let_ v2 (NEDrop SZ (snd_ #idx)) $ - NEDrop (SS (SS SZ)) b) - -build :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> (Var name (Tup (Replicate n TIx)) :-> NExpr ('(name, Tup (Replicate n TIx)) : env) t) -> NExpr env (TArr n t) -build n a (v :-> b) = NEBuild n a v b - -map_ :: forall n a b env name. (KnownNat n, KnownTy a) - => (Var name a :-> NExpr ('(name, a) : env) b) - -> NExpr env (TArr n a) -> NExpr env (TArr n b) -map_ (v :-> a) b = NEMap v a b - -fold1i :: (Var name1 t :-> Var name2 t :-> NExpr ('(name2, t) : '(name1, t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i (v1@(Var s1@SSymbol t) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = - withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> - assertSymbolNotUnderscore s3 $ - equalityReflexive s3 $ - assertSymbolDistinct s3 s1 $ - let v3 = Var s3 (STPair t t) - in fold1i' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ - let_ v2 (snd_ (NEVar v3)) $ - NEDrop (SS (SS SZ)) e1) - e2 e3 - -fold1i' :: (Var name (TPair t t) :-> NExpr ('(name, TPair t t) : env) t) -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) -fold1i' (v :-> e1) e2 e3 = NEFold1Inner v e1 e2 e3 - -sum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -sum1i e = NESum1Inner e - -unit :: NExpr env t -> NExpr env (TArr Z t) -unit = NEUnit - -replicate1i :: ScalIsNumeric t ~ True => NExpr env TIx -> NExpr env (TArr n (TScal t)) -> NExpr env (TArr (S n) (TScal t)) -replicate1i n a = NEReplicate1Inner n a - -maximum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -maximum1i e = NEMaximum1Inner e - -minimum1i :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) -minimum1i e = NEMinimum1Inner e - -reshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) -reshape = NEReshape - -fold1iD1 :: (Var name1 t1 :-> Var name2 t1 :-> NExpr ('(name2, t1) : '(name1, t1) : env) (TPair t1 b)) - -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1 (v1@(Var s1@SSymbol t1) :-> v2@(Var s2@SSymbol _) :-> e1) e2 e3 = - withSomeSSymbol (symbolVal s1 ++ "." ++ symbolVal s2) $ \(s3 :: SSymbol name3) -> - assertSymbolNotUnderscore s3 $ - equalityReflexive s3 $ - assertSymbolDistinct s3 s1 $ - let v3 = Var s3 (STPair t1 t1) - in fold1iD1' (v3 :-> let_ v1 (fst_ (NEVar v3)) $ - let_ v2 (snd_ (NEVar v3)) $ - NEDrop (SS (SS SZ)) e1) - e2 e3 - -fold1iD1' :: (Var name (TPair t1 t1) :-> NExpr ('(name, TPair t1 t1) : env) (TPair t1 b)) - -> NExpr env t1 -> NExpr env (TArr (S n) t1) -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) -fold1iD1' (v1 :-> e1) e2 e3 = NEFold1InnerD1 v1 e1 e2 e3 - -fold1iD2 :: (Var name1 b :-> Var name2 t2 :-> NExpr ('(name2, t2) : '(name1, b) : env) (TPair t2 t2)) - -> NExpr env (TArr (S n) b) -> NExpr env (TArr n t2) -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) -fold1iD2 (v1 :-> v2 :-> e1) e2 e3 = NEFold1InnerD2 v1 v2 e1 e2 e3 - -const_ :: KnownScalTy t => ScalRep t -> NExpr env (TScal t) -const_ x = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty x - -idx0 :: NExpr env (TArr Z t) -> NExpr env t -idx0 = NEIdx0 - --- (.!) :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) --- (.!) = NEIdx1 --- infixl 9 .! - -(!) :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t -(!) = NEIdx -infixl 9 ! - -shape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -shape = NEShape - -length_ :: NExpr env (TArr N1 t) -> NExpr env TIx -length_ e = snd_ (shape e) - -oper :: SOp a t -> NExpr env a -> NExpr env t -oper = NEOp - -oper2 :: SOp (TPair a b) t -> NExpr env a -> NExpr env b -> NExpr env t -oper2 op a b = NEOp op (pair a b) - -error_ :: KnownTy t => String -> NExpr env t -error_ s = NEError knownTy s - -custom :: (Var n1 a :-> Var n2 b :-> NExpr ['(n2, b), '(n1, a)] t) - -> (Var nf1 (D1 a) :-> Var nf2 (D1 b) :-> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape)) - -> (Var nr1 tape :-> Var nr2 (D2 t) :-> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b)) - -> NExpr env a -> NExpr env b - -> NExpr env t -custom (n1 :-> n2 :-> a) (nf1 :-> nf2 :-> b) (nr1 :-> nr2 :-> c) e1 e2 = - NECustom n1 n2 a nf1 nf2 b nr1 nr2 c e1 e2 - -recompute :: NExpr env a -> NExpr env a -recompute = NERecompute - -with :: forall t a env acname. KnownMTy t => NExpr env t -> (Var acname (TAccum t) :-> NExpr ('(acname, TAccum t) : env) a) -> NExpr env (TPair a t) -with a (n :-> b) = NEWith (knownMTy @t) a n b - -accum :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> NExpr env a -> NExpr env (TAccum t) -> NExpr env TNil -accum p a b c = NEAccum knownMTy p a (spDense (acPrjTy p knownMTy)) b c - -accumS :: KnownMTy t => SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil -accumS p a sp b c = NEAccum knownMTy p a sp b c - - -(.==) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .== b = oper (OEq knownScalTy) (pair a b) -infix 4 .== - -(.<) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .< b = oper (OLt knownScalTy) (pair a b) -infix 4 .< - -(.>) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>) = flip (.<) -infix 4 .> - -(.<=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -a .<= b = oper (OLe knownScalTy) (pair a b) -infix 4 .<= - -(.>=) :: (KnownScalTy st, ScalIsNumeric st ~ True) => NExpr env (TScal st) -> NExpr env (TScal st) -> NExpr env (TScal TBool) -(.>=) = flip (.<=) -infix 4 .>= - -not_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -not_ = oper ONot - -and_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -and_ = oper2 OAnd -infixr 3 `and_` - -or_ :: NExpr env (TScal TBool) -> NExpr env (TScal TBool) -> NExpr env (TScal TBool) -or_ = oper2 OOr -infixr 2 `or_` - -mod_ :: (ScalIsIntegral a ~ True, KnownScalTy a) => NExpr env (TScal a) -> NExpr env (TScal a) -> NExpr env (TScal a) -mod_ = oper2 (OMod knownScalTy) -infixl 7 `mod_` - --- | The first alternative is the True case; the second is the False case. -if_ :: NExpr env (TScal TBool) -> NExpr env t -> NExpr env t -> NExpr env t -if_ e a b = case_ (oper OIf e) (#_ :-> NEDrop SZ a) (#_ :-> NEDrop SZ b) - -round_ :: NExpr env (TScal TF64) -> NExpr env (TScal TI64) -round_ = oper ORound64 - -toFloat_ :: NExpr env (TScal TI64) -> NExpr env (TScal TF64) -toFloat_ = oper OToFl64 - -idiv :: (KnownScalTy t, ScalIsIntegral t ~ True) => NExpr env (TScal t) -> NExpr env (TScal t) -> NExpr env (TScal t) -idiv = oper2 (OIDiv knownScalTy) -infixl 7 `idiv` diff --git a/src/Language/AST.hs b/src/Language/AST.hs deleted file mode 100644 index 3d6ede5..0000000 --- a/src/Language/AST.hs +++ /dev/null @@ -1,300 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE StandaloneKindSignatures #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} -{-# LANGUAGE UndecidableInstances #-} -module Language.AST where - -import Data.Kind (Type) -import Data.Type.Equality -import GHC.OverloadedLabels -import GHC.TypeLits (Symbol, SSymbol, pattern SSymbol, symbolSing, KnownSymbol, TypeError, ErrorMessage(..), symbolVal) - -import Array -import AST -import AST.Sparse.Types -import CHAD.Types -import Data - - -type NExpr :: [(Symbol, Ty)] -> Ty -> Type -data NExpr env t where - -- lambda calculus - NEVar :: Lookup name env ~ t => Var name t -> NExpr env t - NELet :: Var name a -> NExpr env a -> NExpr ('(name, a) : env) t -> NExpr env t - - -- environment management - NEDrop :: SNat i -> NExpr (DropNth i env) t -> NExpr env t - - -- base types - NEPair :: NExpr env a -> NExpr env b -> NExpr env (TPair a b) - NEFst :: NExpr env (TPair a b) -> NExpr env a - NESnd :: NExpr env (TPair a b) -> NExpr env b - NENil :: NExpr env TNil - NEInl :: STy b -> NExpr env a -> NExpr env (TEither a b) - NEInr :: STy a -> NExpr env b -> NExpr env (TEither a b) - NECase :: NExpr env (TEither a b) -> Var name1 a -> NExpr ('(name1, a) : env) c -> Var name2 b -> NExpr ('(name2, b) : env) c -> NExpr env c - NENothing :: STy t -> NExpr env (TMaybe t) - NEJust :: NExpr env t -> NExpr env (TMaybe t) - NEMaybe :: NExpr env b -> Var name t -> NExpr ('(name, t) : env) b -> NExpr env (TMaybe t) -> NExpr env b - - -- array operations - NEConstArr :: Show (ScalRep t) => SNat n -> SScalTy t -> Array n (ScalRep t) -> NExpr env (TArr n (TScal t)) - NEBuild :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> Var name (Tup (Replicate n TIx)) -> NExpr ('(name, Tup (Replicate n TIx)) : env) t -> NExpr env (TArr n t) - NEMap :: Var name a -> NExpr ('(name, a) : env) t -> NExpr env (TArr n a) -> NExpr env (TArr n t) - NEFold1Inner :: Var name1 (TPair t t) -> NExpr ('(name1, TPair t t) : env) t -> NExpr env t -> NExpr env (TArr (S n) t) -> NExpr env (TArr n t) - NESum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEUnit :: NExpr env t -> NExpr env (TArr Z t) - NEReplicate1Inner :: NExpr env TIx -> NExpr env (TArr n t) -> NExpr env (TArr (S n) t) - NEMaximum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEMinimum1Inner :: ScalIsNumeric t ~ True => NExpr env (TArr (S n) (TScal t)) -> NExpr env (TArr n (TScal t)) - NEReshape :: SNat n -> NExpr env (Tup (Replicate n TIx)) -> NExpr env (TArr m t) -> NExpr env (TArr n t) - - NEFold1InnerD1 :: Var n1 (TPair t1 t1) -> NExpr ('(n1, TPair t1 t1) : env) (TPair t1 b) - -> NExpr env t1 - -> NExpr env (TArr (S n) t1) - -> NExpr env (TPair (TArr n t1) (TArr (S n) b)) - NEFold1InnerD2 :: Var n1 b -> Var n2 t2 -> NExpr ('(n2, t2) : '(n1, b) : env) (TPair t2 t2) - -> NExpr env (TArr (S n) b) - -> NExpr env (TArr n t2) - -> NExpr env (TPair (TArr n t2) (TArr (S n) t2)) - - -- expression operations - NEConst :: Show (ScalRep t) => SScalTy t -> ScalRep t -> NExpr env (TScal t) - NEIdx0 :: NExpr env (TArr Z t) -> NExpr env t - NEIdx1 :: NExpr env (TArr (S n) t) -> NExpr env TIx -> NExpr env (TArr n t) - NEIdx :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) -> NExpr env t - NEShape :: NExpr env (TArr n t) -> NExpr env (Tup (Replicate n TIx)) - NEOp :: SOp a t -> NExpr env a -> NExpr env t - - -- custom derivatives - NECustom :: Var n1 a -> Var n2 b -> NExpr ['(n2, b), '(n1, a)] t -- ^ regular operation - -> Var nf1 (D1 a) -> Var nf2 (D1 b) -> NExpr ['(nf2, D1 b), '(nf1, D1 a)] (TPair (D1 t) tape) -- ^ CHAD forward pass - -> Var nr1 tape -> Var nr2 (D2 t) -> NExpr ['(nr2, D2 t), '(nr1, tape)] (D2 b) -- ^ CHAD reverse derivative - -> NExpr env a -> NExpr env b - -> NExpr env t - - -- fake halfway checkpointing - NERecompute :: NExpr env t -> NExpr env t - - -- accumulation effect on monoids - NEWith :: SMTy t -> NExpr env t -> Var name (TAccum t) -> NExpr ('(name, TAccum t) : env) a -> NExpr env (TPair a t) - NEAccum :: SMTy t -> SAcPrj p t a -> NExpr env (AcIdxD p t) -> Sparse a b -> NExpr env b -> NExpr env (TAccum t) -> NExpr env TNil - - -- partiality - NEError :: STy a -> String -> NExpr env a - - -- embedded unnamed expressions - NEUnnamed :: Ex unenv t -> SList (NExpr env) unenv -> NExpr env t -deriving instance Show (NExpr env t) - -type Lookup name env = Lookup1 (name == "_") name env -type family Lookup1 eqblank name env where - Lookup1 True _ _ = TypeError (Text "Attempt to use variable with name '_'") - Lookup1 False name env = Lookup2 name env -type family Lookup2 name env where - Lookup2 name '[] = TypeError (Text "Variable '" :<>: Text name :<>: Text "' not in scope") - Lookup2 name ('(name2, t) : env) = Lookup3 (name == name2) t name env -type family Lookup3 eq t name env where - Lookup3 True t _ _ = t - Lookup3 False _ name env = Lookup2 name env - -type family DropNth i env where - DropNth Z (_ : env) = env - DropNth (S i) (p : env) = p : DropNth i env - -data Var name t = Var (SSymbol name) (STy t) - deriving (Show) - -instance (t ~ TScal st, ScalIsNumeric st ~ True, KnownScalTy st, Num (ScalRep st)) => Num (NExpr env t) where - a + b = NEOp (OAdd knownScalTy) (NEPair a b) - a * b = NEOp (OMul knownScalTy) (NEPair a b) - negate e = NEOp (ONeg knownScalTy) e - abs = error "abs undefined for NExpr" - signum = error "signum undefined for NExpr" - fromInteger = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty . fromInteger - -instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Fractional (ScalRep st)) - => Fractional (NExpr env t) where - recip e = NEOp (ORecip knownScalTy) e - fromRational = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty . fromRational - -instance (t ~ TScal st, ScalIsNumeric st ~ True, ScalIsFloating st ~ True, KnownScalTy st, Floating (ScalRep st)) - => Floating (NExpr env t) where - pi = - let ty = knownScalTy - in case scalRepIsShow ty of - Dict -> NEConst ty pi - exp = NEOp (OExp knownScalTy) - log = NEOp (OExp knownScalTy) - sin = undefined ; cos = undefined ; tan = undefined - asin = undefined ; acos = undefined ; atan = undefined - sinh = undefined ; cosh = undefined - asinh = undefined ; acosh = undefined ; atanh = undefined - -instance (KnownTy t, KnownSymbol name, name ~ n') => IsLabel name (Var n' t) where - fromLabel = Var symbolSing knownTy - -instance (KnownTy t, KnownSymbol name, Lookup name env ~ t) => IsLabel name (NExpr env t) where - fromLabel = NEVar (fromLabel @name) - --- | Innermost variable variable on the outside, on the right. -data NEnv env where - NTop :: NEnv '[] - NPush :: NEnv env -> Var name t -> NEnv ('(name, t) : env) - --- | First (outermost) parameter on the outside, on the left. --- * env: environment of this function (grows as you go deeper inside lambdas) --- * env': environment of the body of the function --- * params: parameters of the function (difference between env and env'), first (outermost) argument at the head of the list -data NFun env env' t where - NLam :: Var name a -> NFun ('(name, a) : env) env' t -> NFun env env' t - NBody :: NExpr env' t -> NFun env' env' t - -type family UnName env where - UnName '[] = '[] - UnName ('(name, t) : env) = t : UnName env - -envFromNEnv :: NEnv env -> SList STy (UnName env) -envFromNEnv NTop = SNil -envFromNEnv (NPush env (Var _ t)) = t `SCons` envFromNEnv env - -inlineNFun :: NFun '[] envB t -> SList (NExpr env) (UnName envB) -> NExpr env t -inlineNFun fun args = NEUnnamed (fromNamed fun) args - -fromNamed :: NFun '[] env t -> Ex (UnName env) t -fromNamed = fromNamedFun NTop - --- | Some of the parameters have already been put in the environment; some --- haven't. Transfer all parameters to the left into the environment. --- --- [] `fromNamedFun` λx y z. E --- = []:x `fromNamedFun` λy z. E --- = []:x:y `fromNamedFun` λz. E --- = []:x:y:z `fromNamedFun` λ. E --- = []:x:y:z `fromNamedExpr` E -fromNamedFun :: NEnv env -> NFun env env' t -> Ex (UnName env') t -fromNamedFun env (NLam var fun) = fromNamedFun (env `NPush` var) fun -fromNamedFun env (NBody e) = fromNamedExpr env e - -fromNamedExpr :: forall env t. NEnv env -> NExpr env t -> Ex (UnName env) t -fromNamedExpr val = \case - NEVar var@(Var _ ty) - | Just idx <- find var val -> EVar ext ty idx - | otherwise -> error "Variable out of scope in conversion from surface \ - \expression to De Bruijn expression" - NELet n a b -> ELet ext (go a) (lambda val n b) - - NEDrop i e -> weakenExpr (dropNthW i val) (fromNamedExpr (dropNth i val) e) - - NEPair a b -> EPair ext (go a) (go b) - NEFst e -> EFst ext (go e) - NESnd e -> ESnd ext (go e) - NENil -> ENil ext - NEInl t e -> EInl ext t (go e) - NEInr t e -> EInr ext t (go e) - NECase e n1 a n2 b -> ECase ext (go e) (lambda val n1 a) (lambda val n2 b) - NENothing t -> ENothing ext t - NEJust e -> EJust ext (go e) - NEMaybe a n b c -> EMaybe ext (go a) (lambda val n b) (go c) - - NEConstArr n t x -> EConstArr ext n t x - NEBuild k a n b -> EBuild ext k (go a) (lambda val n b) - NEMap n a b -> EMap ext (lambda val n a) (go b) - NEFold1Inner n1 a b c -> EFold1Inner ext Noncommut (lambda val n1 a) (go b) (go c) - NESum1Inner e -> ESum1Inner ext (go e) - NEUnit e -> EUnit ext (go e) - NEReplicate1Inner a b -> EReplicate1Inner ext (go a) (go b) - NEMaximum1Inner e -> EMaximum1Inner ext (go e) - NEMinimum1Inner e -> EMinimum1Inner ext (go e) - NEReshape n a b -> EReshape ext n (go a) (go b) - - NEFold1InnerD1 n1 a b c -> EFold1InnerD1 ext Noncommut (lambda val n1 a) (go b) (go c) - NEFold1InnerD2 n1 n2 a b c -> EFold1InnerD2 ext Noncommut (lambda2 val n1 n2 a) (go b) (go c) - - NEConst t x -> EConst ext t x - NEIdx0 e -> EIdx0 ext (go e) - NEIdx1 a b -> EIdx1 ext (go a) (go b) - NEIdx a b -> EIdx ext (go a) (go b) - NEShape e -> EShape ext (go e) - NEOp op e -> EOp ext op (go e) - - NECustom n1@(Var _ ta) n2@(Var _ tb) a nf1 nf2 b nr1@(Var _ ttape) nr2 c e1 e2 -> - ECustom ext ta tb ttape - (fromNamedExpr (NTop `NPush` n1 `NPush` n2) a) - (fromNamedExpr (NTop `NPush` nf1 `NPush` nf2) b) - (fromNamedExpr (NTop `NPush` nr1 `NPush` nr2) c) - (go e1) (go e2) - NERecompute e -> ERecompute ext (go e) - - NEWith t a n b -> EWith ext t (go a) (lambda val n b) - NEAccum t p a sp b c -> EAccum ext t p (go a) sp (go b) (go c) - - NEError t s -> EError ext t s - - NEUnnamed e args -> injectWrapLet (weakenExpr (wRaiseAbove args (envFromNEnv val)) e) args - where - go :: NExpr env t' -> Ex (UnName env) t' - go = fromNamedExpr val - - find :: Var name t' -> NEnv env' -> Maybe (Idx (UnName env') t') - find _ NTop = Nothing - find var@(Var s ty) (val' `NPush` Var s' ty') - | Just Refl <- testEquality s s' - , Just Refl <- testEquality ty ty' - = Just IZ - | otherwise - = IS <$> find var val' - - lambda :: NEnv env' -> Var name a -> NExpr ('(name, a) : env') b -> Ex (a : UnName env') b - lambda val' var e = fromNamedExpr (val' `NPush` var) e - - lambda2 :: NEnv env' -> Var name1 a -> Var name2 b -> NExpr ('(name2, b) : '(name1, a) : env') c -> Ex (b : a : UnName env') c - lambda2 val' var1 var2 e = fromNamedExpr (val' `NPush` var1 `NPush` var2) e - - injectWrapLet :: Ex (Append unenv (UnName env)) t -> SList (NExpr env) unenv -> Ex (UnName env) t - injectWrapLet e SNil = e - injectWrapLet e (arg `SCons` args) = - injectWrapLet (ELet ext (weakenExpr (wSinks args) $ fromNamedExpr val arg) e) - args - -dropNth :: SNat i -> NEnv env -> NEnv (DropNth i env) -dropNth SZ (val `NPush` _) = val -dropNth (SS i) (val `NPush` p) = dropNth i val `NPush` p -dropNth _ NTop = error "DropNth: index out of range" - -dropNthW :: SNat i -> NEnv env -> UnName (DropNth i env) :> UnName env -dropNthW SZ (_ `NPush` _) = WSink -dropNthW (SS i) (val `NPush` _) = WCopy (dropNthW i val) -dropNthW _ NTop = error "DropNth: index out of range" - -assertSymbolNotUnderscore :: forall s r. SSymbol s -> ((s == "_") ~ False => r) -> r -assertSymbolNotUnderscore s@SSymbol k = - case symbolVal s of - "_" -> error "assertSymbolNotUnderscore: was underscore" - _ | Refl <- unsafeCoerceRefl @(s == "_") @False -> k - -assertSymbolDistinct :: forall s1 s2 r. SSymbol s1 -> SSymbol s2 -> ((s1 == s2) ~ False => r) -> r -assertSymbolDistinct s1@SSymbol s2@SSymbol k - | symbolVal s1 == symbolVal s2 = error $ "assertSymbolDistinct: was equal (" ++ symbolVal s1 ++ ")" - | Refl <- unsafeCoerceRefl @(s1 == s2) @False = k - -equalityReflexive :: forall (s :: Symbol) proxy r. proxy s -> ((s == s) ~ True => r) -> r -equalityReflexive _ k | Refl <- unsafeCoerceRefl @(s == s) @True = k diff --git a/src/Lemmas.hs b/src/Lemmas.hs deleted file mode 100644 index 31a43ed..0000000 --- a/src/Lemmas.hs +++ /dev/null @@ -1,21 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE TypeOperators #-} - -{-# LANGUAGE AllowAmbiguousTypes #-} -module Lemmas (module Lemmas, (:~:)(Refl)) where - -import Data.Type.Equality -import Unsafe.Coerce (unsafeCoerce) - - -type family Append a b where - Append '[] l = l - Append (x : xs) l = x : Append xs l - -lemAppendNil :: Append a '[] :~: a -lemAppendNil = unsafeCoerce Refl - -lemAppendAssoc :: Append a (Append b c) :~: Append (Append a b) c -lemAppendAssoc = unsafeCoerce Refl diff --git a/src/Simplify.hs b/src/Simplify.hs deleted file mode 100644 index 19d0c17..0000000 --- a/src/Simplify.hs +++ /dev/null @@ -1,619 +0,0 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE ImplicitParams #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE QuasiQuotes #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeApplications #-} -{-# LANGUAGE TypeOperators #-} -module Simplify ( - simplifyN, simplifyFix, - SimplifyConfig(..), defaultSimplifyConfig, simplifyWith, simplifyFixWith, -) where - -import Control.Monad (ap) -import Data.Bifunctor (first) -import Data.Function (fix) -import Data.Monoid (Any(..)) - -import Debug.Trace - -import AST -import AST.Count -import AST.Pretty -import AST.Sparse.Types -import AST.UnMonoid (acPrjCompose) -import Data -import Simplify.TH - - -data SimplifyConfig = SimplifyConfig - { scLogging :: Bool - } - -defaultSimplifyConfig :: SimplifyConfig -defaultSimplifyConfig = SimplifyConfig False - -simplifyN :: KnownEnv env => Int -> Ex env t -> Ex env t -simplifyN 0 = id -simplifyN n = simplifyN (n - 1) . simplify - -simplify :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplify = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = defaultSimplifyConfig - in snd . runSM . simplify' - -simplifyWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in snd . runSM . simplify' - -simplifyFix :: forall env t. KnownEnv env => Ex env t -> Ex env t -simplifyFix = simplifyFixWith defaultSimplifyConfig - -simplifyFixWith :: forall env t. KnownEnv env => SimplifyConfig -> Ex env t -> Ex env t -simplifyFixWith config = - let ?accumInScope = checkAccumInScope @env knownEnv - ?config = config - in fix $ \loop e -> - let (act, e') = runSM (simplify' e) - in if act then loop e' else e' - --- | simplify monad -newtype SM tenv tt env t a = SM ((Ex env t -> Ex tenv tt) -> (Any, a)) - deriving (Functor) - -instance Applicative (SM tenv tt env t) where - pure x = SM (\_ -> (Any False, x)) - (<*>) = ap - -instance Monad (SM tenv tt env t) where - SM f >>= g = SM $ \ctx -> f ctx >>= \x -> let SM h = g x in h ctx - -runSM :: SM env t env t a -> (Bool, a) -runSM (SM f) = first getAny (f id) - -smReconstruct :: Ex env t -> SM tenv tt env t (Ex tenv tt) -smReconstruct core = SM (\ctx -> (Any False, ctx core)) - -class Monad m => ActedMonad m where - tellActed :: m () - hideActed :: m a -> m a - liftActed :: (Any, a) -> m a - -instance ActedMonad ((,) Any) where - tellActed = (Any True, ()) - hideActed (_, x) = (Any False, x) - liftActed = id - -instance ActedMonad (SM tenv tt env t) where - tellActed = SM (\_ -> tellActed) - hideActed (SM f) = SM (\ctx -> hideActed (f ctx)) - liftActed pair = SM (\_ -> pair) - --- more convenient in practice -acted :: ActedMonad m => m a -> m a -acted m = tellActed >> m - -within :: (Ex env' t' -> Ex env t) -> SM tenv tt env' t' a -> SM tenv tt env t a -within subctx (SM f) = SM $ \ctx -> f (ctx . subctx) - -simplify' :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) -simplify' expr - | scLogging ?config = do - res <- simplify'Rec expr - full <- smReconstruct res - let printed = ppExpr knownEnv full - replace a bs = concatMap (\x -> if x == a then bs else [x]) - str | '\n' `elem` printed = "--- simplify step:\n " ++ replace '\n' "\n " printed - | otherwise = "--- simplify step: " ++ printed - traceM str - return res - | otherwise = simplify'Rec expr - -simplify'Rec :: (?accumInScope :: Bool, ?config :: SimplifyConfig, KnownEnv tenv) => Ex env t -> SM tenv tt env t (Ex env t) -simplify'Rec = \case - -- inlining - ELet _ rhs body - | cheapExpr rhs - -> acted $ simplify' (substInline rhs body) - - | Occ lexOcc runOcc <- occCount IZ body - , ((not ?accumInScope || not (hasAdds rhs)) && lexOcc <= One && runOcc <= One) -- without effects, normal rules apply - || (lexOcc == One && runOcc == One) -- with effects, linear inlining is still allowed, but weakening is not - -> acted $ simplify' (substInline rhs body) - - -- let splitting / let peeling - ELet _ (EPair _ a b) body -> - acted $ simplify' $ - ELet ext a $ - ELet ext (weakenExpr WSink b) $ - subst (\_ t -> \case IZ -> EPair ext (EVar ext (typeOf a) (IS IZ)) (EVar ext (typeOf b) IZ) - IS i -> EVar ext t (IS (IS i))) - body - ELet _ (EJust _ a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EJust ext (EVar ext (typeOf a) IZ)) body - ELet _ (EInl _ t2 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInl ext t2 (EVar ext (typeOf a) IZ)) body - ELet _ (EInr _ t1 a) body -> - acted $ simplify' $ ELet ext a $ subst0 (EInr ext t1 (EVar ext (typeOf a) IZ)) body - - -- let rotation - ELet _ (ELet _ rhs a) b -> do - b' <- within (ELet ext (ELet ext rhs a)) $ simplify' b - acted $ simplify' $ - ELet ext rhs $ - ELet ext a $ - weakenExpr (WCopy WSink) b' - - -- beta rules for products - EFst _ (EPair _ e e') - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - ESnd _ (EPair _ e' e) - | not (hasAdds e') -> acted $ simplify' e - | otherwise -> acted $ simplify' $ ELet ext e' (weakenExpr WSink e) - - -- beta rules for coproducts - ECase _ (EInl _ _ e) rhs _ -> acted $ simplify' (ELet ext e rhs) - ECase _ (EInr _ _ e) _ rhs -> acted $ simplify' (ELet ext e rhs) - - -- beta rules for maybe - EMaybe _ e1 _ ENothing{} -> acted $ simplify' e1 - EMaybe _ _ e1 (EJust _ e2) -> acted $ simplify' $ ELet ext e2 e1 - - -- let floating - EFst _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EFst ext body)) - ESnd _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (ESnd ext body)) - ECase _ (ELet _ rhs body) e1 e2 -> acted $ simplify' (ELet ext rhs (ECase ext body (weakenExpr (WCopy WSink) e1) (weakenExpr (WCopy WSink) e2))) - EIdx0 _ (ELet _ rhs body) -> acted $ simplify' (ELet ext rhs (EIdx0 ext body)) - EIdx1 _ (ELet _ rhs body) e -> acted $ simplify' (ELet ext rhs (EIdx1 ext body (weakenExpr WSink e))) - EAccum _ t p e1 sp (ELet _ rhs body) acc -> - acted $ simplify' $ - ELet ext rhs $ - EAccum ext t p (weakenExpr WSink e1) sp body (weakenExpr WSink acc) - - -- let () = e in () ~> e - ELet _ e1 (ENil _) | STNil <- typeOf e1 -> - acted $ simplify' e1 - - -- map (\_ -> x) e ~> build (shape e) (\_ -> x) - EMap _ e1 e2 - | Occ Zero Zero <- occCount IZ e1 - , STArr n _ <- typeOf e2 -> - acted $ simplify' $ - EBuild ext n (EShape ext e2) $ - subst (\_ t' -> \case IZ -> error "Unused variable was used" - IS i -> EVar ext t' (IS i)) - e1 - - -- vertical fusion - EMap _ e1 (EMap _ e2 e3) -> - acted $ simplify' $ - EMap ext (ELet ext e2 (weakenExpr (WCopy WSink) e1)) e3 - - -- projection down-commuting - EFst _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (EFst ext e2) (EFst ext e3) - ESnd _ (ECase _ e1 e2 e3) -> - acted $ simplify' $ - ECase ext e1 (ESnd ext e2) (ESnd ext e3) - EFst _ (EMaybe _ e1 e2 e3) -> - acted $ simplify' $ - EMaybe ext (EFst ext e1) (EFst ext e2) e3 - ESnd _ (EMaybe _ e1 e2 e3) -> - acted $ simplify' $ - EMaybe ext (ESnd ext e1) (ESnd ext e2) e3 - - -- TODO: more array indexing - EIdx _ (EBuild _ _ e1 e2) e3 | not (hasAdds e1), not (hasAdds e2) -> acted $ simplify' $ elet e3 e2 - EIdx _ (EMap _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ elet (EIdx ext e2 e3) e1 - EIdx _ (EReplicate1Inner _ e1 e2) e3 | not (hasAdds e1) -> acted $ simplify' $ EIdx ext e2 (EFst ext e3) - EIdx _ (EUnit _ e1) e2 | not (hasAdds e2) -> acted $ simplify' $ e1 - - -- TODO: more array shape - EShape _ (EBuild _ _ e1 e2) | not (hasAdds e2) -> acted $ simplify' e1 - EShape _ (EMap _ e1 e2) | not (hasAdds e1) -> acted $ simplify' (EShape ext e2) - - -- TODO: more constant folding - EOp _ OIf (EConst _ STBool True) -> acted $ return (EInl ext STNil (ENil ext)) - EOp _ OIf (EConst _ STBool False) -> acted $ return (EInr ext STNil (ENil ext)) - - -- inline cheap array constructors - ELet _ (EReplicate1Inner _ e1 e2) e3 -> - acted $ simplify' $ - ELet ext (EPair ext e1 e2) $ - let v = EVar ext (STPair tIx (typeOf e2)) IZ - in subst0 (EReplicate1Inner ext (EFst ext v) (ESnd ext v)) e3 - -- -- TODO: This is a bad idea and anyway only helps in practice if (!) is - -- -- cheap, which it can't be because (!) is not cheap if you do AD after. - -- -- Should do proper SoA representation. - -- ELet _ (EBuild _ n e1 e2) e3 | cheapExpr e2 -> - -- acted $ simplify' $ - -- ELet ext e1 $ - -- subst0 (EBuild ext n (EVar ext (tTup (sreplicate n tIx)) IZ) (weakenExpr (WCopy WSink) e2)) e3 - - -- eta rule for unit - e | STNil <- typeOf e, not ?accumInScope || not (hasAdds e) -> - case e of - ENil _ -> return e - _ -> acted $ return (ENil ext) - - EBuild _ SZ _ e -> - acted $ simplify' $ EUnit ext (substInline (ENil ext) e) - - -- monoid rules - EAccum _ t p e1 sp e2 acc -> do - e1' <- within (\e1' -> EAccum ext t p e1' sp e2 acc ) $ simplify' e1 - e2' <- within (\e2' -> EAccum ext t p e1' sp e2' acc ) $ simplify' e2 - acc' <- within (\acc' -> EAccum ext t p e1' sp e2' acc') $ simplify' acc - simplifyOHT (OneHotTerm SAID t p e1' sp e2') - (acted $ return (ENil ext)) - (\sp' (InContext w wrap e) -> do - e' <- within (\e' -> wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc')) $ simplify' e - return (wrap $ EAccum ext t SAPHere (ENil ext) sp' e' (weakenExpr w acc'))) - (\(InContext w wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> do - -- The acted management here is a hideous mess. - e1''' <- hideActed $ within (\e1''' -> wrap $ EAccum ext t' p' e1''' sp' e2'' (weakenExpr w acc')) $ simplify' e1'' - e2''' <- hideActed $ within (\e2''' -> wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc')) $ simplify' e2'' - return (wrap $ EAccum ext t' p' e1''' sp' e2''' (weakenExpr w acc'))) - EPlus _ _ (EZero _ _ _) e -> acted $ simplify' e - EPlus _ _ e (EZero _ _ _) -> acted $ simplify' e - EOneHot _ t p e1 e2 -> do - e1' <- within (\e1' -> EOneHot ext t p e1' e2 ) $ simplify' e1 - e2' <- within (\e2' -> EOneHot ext t p e1' e2') $ simplify' e2 - simplifyOHT (OneHotTerm SAIS t p e1' (spDense (acPrjTy p t)) e2') - (acted $ return (EZero ext t (zeroInfoFromOneHot t p e1 e2))) - (\sp' (InContext _ wrap e) -> - case isDense t sp' of - Just Refl -> do - e' <- hideActed $ within wrap $ simplify' e - return (wrap e') - Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") - (\(InContext _ wrap (OneHotTerm _ t' p' e1'' sp' e2'')) -> - case isDense (acPrjTy p' t') sp' of - Just Refl -> do - e1''' <- hideActed $ within (\e1''' -> wrap $ EOneHot ext t' p' e1''' e2'') $ simplify' e1'' - e2''' <- hideActed $ within (\e2''' -> wrap $ EOneHot ext t' p' e1''' e2''') $ simplify' e2'' - return (wrap $ EOneHot ext t' p' e1''' e2''') - Nothing -> error "simplifyOneHotTerm sparsified a dense Sparse") - - -- type-specific equations for plus - EPlus _ SMTNil e1 e2 | not (hasAdds e1), not (hasAdds e2) -> - acted $ return (ENil ext) - - EPlus _ (SMTPair t1 t2) (EPair _ a1 b1) (EPair _ a2 b2) -> - acted $ simplify' $ EPair ext (EPlus ext t1 a1 a2) (EPlus ext t2 b1 b2) - - EPlus _ (SMTLEither t1 _) (ELInl _ dt2 a1) (ELInl _ _ a2) -> - acted $ simplify' $ ELInl ext dt2 (EPlus ext t1 a1 a2) - EPlus _ (SMTLEither _ t2) (ELInr _ dt1 b1) (ELInr _ _ b2) -> - acted $ simplify' $ ELInr ext dt1 (EPlus ext t2 b1 b2) - EPlus _ SMTLEither{} ELNil{} e -> acted $ simplify' e - EPlus _ SMTLEither{} e ELNil{} -> acted $ simplify' e - - EPlus _ (SMTMaybe t) (EJust _ e1) (EJust _ e2) -> - acted $ simplify' $ EJust ext (EPlus ext t e1 e2) - EPlus _ SMTMaybe{} ENothing{} e -> acted $ simplify' e - EPlus _ SMTMaybe{} e ENothing{} -> acted $ simplify' e - - -- fallback recursion - EVar _ t i -> pure $ EVar ext t i - ELet _ a b -> [simprec| ELet ext *a *b |] - EPair _ a b -> [simprec| EPair ext *a *b |] - EFst _ e -> [simprec| EFst ext *e |] - ESnd _ e -> [simprec| ESnd ext *e |] - ENil _ -> pure $ ENil ext - EInl _ t e -> [simprec| EInl ext t *e |] - EInr _ t e -> [simprec| EInr ext t *e |] - ECase _ e a b -> [simprec| ECase ext *e *a *b |] - ENothing _ t -> pure $ ENothing ext t - EJust _ e -> [simprec| EJust ext *e |] - EMaybe _ a b e -> [simprec| EMaybe ext *a *b *e |] - ELNil _ t1 t2 -> pure $ ELNil ext t1 t2 - ELInl _ t e -> [simprec| ELInl ext t *e |] - ELInr _ t e -> [simprec| ELInr ext t *e |] - ELCase _ e a b c -> [simprec| ELCase ext *e *a *b *c |] - EConstArr _ n t v -> pure $ EConstArr ext n t v - EBuild _ n a b -> [simprec| EBuild ext n *a *b |] - EMap _ a b -> [simprec| EMap ext *a *b |] - EFold1Inner _ cm a b c -> [simprec| EFold1Inner ext cm *a *b *c |] - ESum1Inner _ e -> [simprec| ESum1Inner ext *e |] - EUnit _ e -> [simprec| EUnit ext *e |] - EReplicate1Inner _ a b -> [simprec| EReplicate1Inner ext *a *b |] - EMaximum1Inner _ e -> [simprec| EMaximum1Inner ext *e |] - EMinimum1Inner _ e -> [simprec| EMinimum1Inner ext *e |] - EReshape _ n a b -> [simprec| EReshape ext n *a *b |] - EZip _ a b -> [simprec| EZip ext *a *b |] - EFold1InnerD1 _ cm a b c -> [simprec| EFold1InnerD1 ext cm *a *b *c |] - EFold1InnerD2 _ cm a b c -> [simprec| EFold1InnerD2 ext cm *a *b *c |] - EConst _ t v -> pure $ EConst ext t v - EIdx0 _ e -> [simprec| EIdx0 ext *e |] - EIdx1 _ a b -> [simprec| EIdx1 ext *a *b |] - EIdx _ a b -> [simprec| EIdx ext *a *b |] - EShape _ e -> [simprec| EShape ext *e |] - EOp _ op e -> [simprec| EOp ext op *e |] - ECustom _ s t p a b c e1 e2 -> do - a' <- within (\a' -> ECustom ext s t p a' b c e1 e2) (let ?accumInScope = False in simplify' a) - b' <- within (\b' -> ECustom ext s t p a' b' c e1 e2) (let ?accumInScope = False in simplify' b) - c' <- within (\c' -> ECustom ext s t p a' b' c' e1 e2) (let ?accumInScope = False in simplify' c) - e1' <- within (\e1' -> ECustom ext s t p a' b' c' e1' e2) (simplify' e1) - e2' <- within (\e2' -> ECustom ext s t p a' b' c' e1' e2') (simplify' e2) - pure (ECustom ext s t p a' b' c' e1' e2') - ERecompute _ e -> [simprec| ERecompute ext *e |] - EWith _ t e1 e2 -> do - e1' <- within (\e1' -> EWith ext t e1' e2) (simplify' e1) - e2' <- within (\e2' -> EWith ext t e1' e2') (let ?accumInScope = True in simplify' e2) - pure (EWith ext t e1' e2') - -- EOneHot _ t p e1 e2 -> [simprec| EOneHot ext t p *e1 *e2 |] - -- EAccum _ t p e1 sp e2 acc -> [simprec| EAccum ext t p *e1 sp *e2 *acc |] - EZero _ t e -> [simprec| EZero ext t *e |] - EDeepZero _ t e -> [simprec| EDeepZero ext t *e |] - EPlus _ t a b -> [simprec| EPlus ext t *a *b |] - EError _ t s -> pure $ EError ext t s - --- | This can be made more precise by tracking (and not counting) adds on --- locally eliminated accumulators. -hasAdds :: Expr x env t -> Bool -hasAdds = \case - EVar _ _ _ -> False - ELet _ rhs body -> hasAdds rhs || hasAdds body - EPair _ a b -> hasAdds a || hasAdds b - EFst _ e -> hasAdds e - ESnd _ e -> hasAdds e - ENil _ -> False - EInl _ _ e -> hasAdds e - EInr _ _ e -> hasAdds e - ECase _ e a b -> hasAdds e || hasAdds a || hasAdds b - ENothing _ _ -> False - EJust _ e -> hasAdds e - EMaybe _ a b e -> hasAdds a || hasAdds b || hasAdds e - ELNil _ _ _ -> False - ELInl _ _ e -> hasAdds e - ELInr _ _ e -> hasAdds e - ELCase _ e a b c -> hasAdds e || hasAdds a || hasAdds b || hasAdds c - EConstArr _ _ _ _ -> False - EBuild _ _ a b -> hasAdds a || hasAdds b - EMap _ a b -> hasAdds a || hasAdds b - EFold1Inner _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ESum1Inner _ e -> hasAdds e - EUnit _ e -> hasAdds e - EReplicate1Inner _ a b -> hasAdds a || hasAdds b - EMaximum1Inner _ e -> hasAdds e - EMinimum1Inner _ e -> hasAdds e - EReshape _ _ a b -> hasAdds a || hasAdds b - EZip _ a b -> hasAdds a || hasAdds b - EFold1InnerD1 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - EFold1InnerD2 _ _ a b c -> hasAdds a || hasAdds b || hasAdds c - ECustom _ _ _ _ a b c d e -> hasAdds a || hasAdds b || hasAdds c || hasAdds d || hasAdds e - EConst _ _ _ -> False - EIdx0 _ e -> hasAdds e - EIdx1 _ a b -> hasAdds a || hasAdds b - EIdx _ a b -> hasAdds a || hasAdds b - EShape _ e -> hasAdds e - EOp _ _ e -> hasAdds e - EWith _ _ a b -> hasAdds a || hasAdds b - ERecompute _ e -> hasAdds e - EAccum _ _ _ _ _ _ _ -> True - EZero _ _ e -> hasAdds e - EDeepZero _ _ e -> hasAdds e - EPlus _ _ a b -> hasAdds a || hasAdds b - EOneHot _ _ _ a b -> hasAdds a || hasAdds b - EError _ _ _ -> False - -checkAccumInScope :: SList STy env -> Bool -checkAccumInScope = \case SNil -> False - SCons t env -> check t || checkAccumInScope env - where - check :: STy t -> Bool - check STNil = False - check (STPair s t) = check s || check t - check (STEither s t) = check s || check t - check (STLEither s t) = check s || check t - check (STMaybe t) = check t - check (STArr _ t) = check t - check (STScal _) = False - check STAccum{} = True - -data OneHotTerm dense env a where - OneHotTerm :: SAIDense dense -> SMTy a -> SAcPrj p a b -> Ex env (AcIdx dense p a) -> Sparse b c -> Ex env c -> OneHotTerm dense env a -deriving instance Show (OneHotTerm dense env a) - -data InContext f env (a :: Ty) where - InContext :: env :> env' -> (forall t. Ex env' t -> Ex env t) -> f env' a -> InContext f env a - -simplifyOHT_recogniseMonoid :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) -simplifyOHT_recogniseMonoid (OneHotTerm dense t prj idx sp val) = do - val' <- liftActed $ recogniseMonoid (applySparse sp (acPrjTy prj t)) val - return $ OneHotTerm dense t prj idx sp val' - -simplifyOHT_unsparse :: ActedMonad m => OneHotTerm dense env a -> m (InContext (OneHotTerm dense) env a) -simplifyOHT_unsparse (OneHotTerm SAID t prj1 idx1 sp1 val1) = - unsparseOneHotD sp1 val1 $ \w wrap prj2 idx2 sp2 val2 -> - acPrjCompose SAID prj1 (weakenExpr w idx1) prj2 idx2 $ \prj' idx' -> - return $ InContext w wrap (OneHotTerm SAID t prj' idx' sp2 val2) -simplifyOHT_unsparse oht@(OneHotTerm SAIS _ _ _ _ _) = return $ InContext WId id oht - -simplifyOHT_concat :: ActedMonad m => OneHotTerm dense env a -> m (OneHotTerm dense env a) -simplifyOHT_concat (OneHotTerm @dense @_ @_ @_ @env dense t1 prj1 idx1 sp (EOneHot @_ @c @p2 _ t2 prj2 idx2 val)) - | Just Refl <- isDense (acPrjTy prj1 t1) sp = - let idx2' :: Ex env (AcIdx dense p2 c) - idx2' = case dense of - SAID -> reduceAcIdx t2 prj2 idx2 - SAIS -> idx2 - in acPrjCompose dense prj1 idx1 prj2 idx2' $ \prj' idx' -> - acted $ return $ OneHotTerm dense t1 prj' idx' (spDense (acPrjTy prj' t1)) val -simplifyOHT_concat oht = return oht - --- -- Property not expressed in types: if the Sparse in the input OneHotTerm is --- -- dense, then the Sparse in the output will also be dense. This property is --- -- used when simplifying EOneHot, which cannot represent sparsity. -simplifyOHT :: ActedMonad m => OneHotTerm dense env a - -> m r -- ^ Zero case (onehot is actually zero) - -> (forall b. Sparse a b -> InContext Ex env b -> m r) -- ^ Trivial case (no zeros in onehot) - -> (InContext (OneHotTerm dense) env a -> m r) -- ^ Simplified - -> m r -simplifyOHT oht kzero ktriv k = do - -- traceM $ "sOHT: input " ++ show oht - oht1 <- simplifyOHT_recogniseMonoid oht - -- traceM $ "sOHT: recog " ++ show oht1 - InContext w1 wrap1 oht2 <- simplifyOHT_unsparse oht1 - -- traceM $ "sOHT: unspa " ++ show oht2 - oht3 <- simplifyOHT_concat oht2 - -- traceM $ "sOHT: conca " ++ show oht3 - -- traceM "" - case oht3 of - OneHotTerm _ _ _ _ _ EZero{} -> kzero - OneHotTerm _ _ SAPHere _ sp val -> ktriv sp (InContext w1 wrap1 val) - _ -> k (InContext w1 wrap1 oht3) - --- Sets the acted flag whenever a non-trivial projection is returned or the --- output Sparse is different from the input Sparse. -unsparseOneHotD :: ActedMonad m => Sparse a a' -> Ex env a' - -> (forall p b c env'. env :> env' -> (forall s. Ex env' s -> Ex env s) - -> SAcPrj p a b -> Ex env' (AcIdxD p a) -> Sparse b c -> Ex env' c -> m r) -> m r -unsparseOneHotD topsp topval k = case (topsp, topval) of - -- eliminate always-Just sparse onehot - (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> - acted $ unsparseOneHotD s (EOneHot ext t prj idx val) k - - -- expand the top levels of a onehot for a sparse type into a onehot for the - -- corresponding non-sparse type - (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj (efst idx) val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPFst spprj) idx' s1' e' - (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> - unsparseOneHotD s2 (EOneHot ext t2 prj (esnd idx) val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPSnd spprj) idx' s1' e' - (SpLEither s1 _, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPLeft spprj) idx' s1' e' - (SpLEither _ s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> - unsparseOneHotD s2 (EOneHot ext t2 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPRight spprj) idx' s1' e' - (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> - unsparseOneHotD s1 (EOneHot ext t1 prj idx val) $ \w wrap spprj idx' s1' e' -> - acted $ k w wrap (SAPJust spprj) idx' s1' e' - (SpArr s1, EOneHot _ (SMTArr _ t1) (SAPArrIdx prj) idx val) - | Dict <- styKnown (typeOf idx) -> - unsparseOneHotD s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \w wrap spprj idx' s1' e' -> - acted $ k (w .> WSink) (elet idx . wrap) (SAPArrIdx spprj) (EPair ext (efst (efst (evar (w @> IZ)))) idx') s1' e' - - -- anything else we don't know how to improve - _ -> k WId id SAPHere (ENil ext) topsp topval - -{- -unsparseOneHotS :: ActedMonad m - => Sparse a a' -> Ex env a' - -> (forall b. Sparse a b -> Ex env b -> m r) -> m r -unsparseOneHotS topsp topval k = case (topsp, topval) of - -- order is relevant to make sure we set the acted flag correctly - (SpAbsent, v@ENil{}) -> k SpAbsent v - (SpAbsent, v@EZero{}) -> k SpAbsent v - (SpAbsent, _) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - (_, EZero{}) -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - (sp, _) | isAbsent sp -> acted $ k SpAbsent (EZero ext SMTNil (ENil ext)) - - -- the unsparsifying - (SpSparse s, EOneHot _ (SMTMaybe t) (SAPJust prj) idx val) -> - acted $ unsparseOneHotS s (EOneHot ext t prj idx val) k - - -- recursion - -- TODO: coproducts could safely become projections as they do not need - -- zeroinfo. But that would only work if the coproduct is at the top, because - -- as soon as we hit a product, we need zeroinfo to make it a projection and - -- we don't have that. - (SpSparse s, e) -> k (SpSparse s) e - (SpPair s1 _, EOneHot _ (SMTPair t1 _) (SAPFst prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj (efst idx) val) $ \s1' e' -> - acted $ k (SpPair s1' SpAbsent) (EPair ext e' (ENil ext)) - (SpPair _ s2, EOneHot _ (SMTPair _ t2) (SAPSnd prj) idx val) -> - unsparseOneHotS s2 (EOneHot ext t2 prj (esnd idx) val) $ \s2' e' -> - acted $ k (SpPair SpAbsent s2') (EPair ext (ENil ext) e') - (SpLEither s1 s2, EOneHot _ (SMTLEither t1 _) (SAPLeft prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> do - case s2 of SpAbsent -> pure () ; _ -> tellActed - k (SpLEither s1' SpAbsent) (ELInl ext STNil e') - (SpLEither s1 s2, EOneHot _ (SMTLEither _ t2) (SAPRight prj) idx val) -> - unsparseOneHotS s2 (EOneHot ext t2 prj idx val) $ \s2' e' -> do - case s1 of SpAbsent -> pure () ; _ -> tellActed - acted $ k (SpLEither SpAbsent s2') (ELInr ext STNil e') - (SpMaybe s1, EOneHot _ (SMTMaybe t1) (SAPJust prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj idx val) $ \s1' e' -> - k (SpMaybe s1') (EJust ext e') - (SpArr s1, EOneHot _ (SMTArr n t1) (SAPArrIdx prj) idx val) -> - unsparseOneHotS s1 (EOneHot ext t1 prj (esnd (evar IZ)) (weakenExpr WSink val)) $ \s1' e' -> - k (SpArr s1') (elet idx $ EOneHot ext (SMTArr n (applySparse s1' _)) (SAPArrIdx SAPHere) (EPair ext (efst (evar IZ)) (ENil ext)) e') - _ -> _ --} - --- | Recognises 'EZero' and 'EOneHot'. -recogniseMonoid :: SMTy t -> Ex env t -> (Any, Ex env t) -recogniseMonoid _ e@EOneHot{} = return e -recogniseMonoid SMTNil (ENil _) = acted $ return $ EZero ext SMTNil (ENil ext) -recogniseMonoid typ@(SMTPair t1 t2) (EPair _ a b) = - ((,) <$> recogniseMonoid t1 a <*> recogniseMonoid t2 b) >>= \case - (EZero _ _ ezi1, EZero _ _ ezi2) -> acted $ return $ EZero ext typ (EPair ext ezi1 ezi2) - (a', EZero _ _ ezi2) -> acted $ EOneHot ext typ (SAPFst SAPHere) (EPair ext (ENil ext) ezi2) <$> recogniseMonoid t1 a' - (EZero _ _ ezi1, b') -> acted $ EOneHot ext typ (SAPSnd SAPHere) (EPair ext ezi1 (ENil ext)) <$> recogniseMonoid t2 b' - (a', b') -> return $ EPair ext a' b' -recogniseMonoid typ@(SMTLEither t1 t2) expr = - case expr of - ELNil{} -> acted $ return $ EZero ext typ (ENil ext) - ELInl _ _ e -> acted $ EOneHot ext typ (SAPLeft SAPHere) (ENil ext) <$> recogniseMonoid t1 e - ELInr _ _ e -> acted $ EOneHot ext typ (SAPRight SAPHere) (ENil ext) <$> recogniseMonoid t2 e - _ -> return expr -recogniseMonoid typ@(SMTMaybe t1) expr = - case expr of - ENothing{} -> acted $ return $ EZero ext typ (ENil ext) - EJust _ e -> acted $ EOneHot ext typ (SAPJust SAPHere) (ENil ext) <$> recogniseMonoid t1 e - _ -> return expr -recogniseMonoid typ@(SMTArr SZ t) (EUnit _ e) = - acted $ do - e' <- recogniseMonoid t e - return $ - ELet ext e' $ - EOneHot ext typ (SAPArrIdx SAPHere) - (EPair ext (EPair ext (ENil ext) (EUnit ext (makeZeroInfo t (EVar ext (fromSMTy t) IZ)))) - (ENil ext)) - (EVar ext (fromSMTy t) IZ) -recogniseMonoid typ@(SMTScal sty) e@(EConst _ _ x) = case (sty, x) of - (STI32, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STI64, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STF32, 0) -> acted $ return $ EZero ext typ (ENil ext) - (STF64, 0) -> acted $ return $ EZero ext typ (ENil ext) - _ -> return e -recogniseMonoid _ e = return e - -reduceAcIdx :: SMTy a -> SAcPrj p a b -> Ex env (AcIdxS p a) -> Ex env (AcIdxD p a) -reduceAcIdx topty topprj e = case (topty, topprj) of - (_, SAPHere) -> ENil ext - (SMTPair t1 _, SAPFst p) -> reduceAcIdx t1 p (efst e) - (SMTPair _ t2, SAPSnd p) -> reduceAcIdx t2 p (esnd e) - (SMTLEither t1 _ , SAPLeft p) -> reduceAcIdx t1 p e - (SMTLEither _ t2, SAPRight p) -> reduceAcIdx t2 p e - (SMTMaybe t1, SAPJust p) -> reduceAcIdx t1 p e - (SMTArr _ t, SAPArrIdx p) -> - eunPair e $ \_ e1 e2 -> - EPair ext (efst e1) (reduceAcIdx t p e2) - -zeroInfoFromOneHot :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) -zeroInfoFromOneHot = \ty prj eidx e -> ELet ext eidx $ go ty prj (EVar ext (typeOf eidx) IZ) (weakenExpr WSink e) - where - -- invariant: AcIdx expression is duplicable - go :: SMTy t -> SAcPrj p t a -> Ex env (AcIdxS p t) -> Ex env a -> Ex env (ZeroInfo t) - go t SAPHere _ e = makeZeroInfo t e - go (SMTPair t1 _) (SAPFst prj) eidx e = EPair ext (go t1 prj (EFst ext eidx) e) (ESnd ext eidx) - go (SMTPair _ t2) (SAPSnd prj) eidx e = EPair ext (EFst ext eidx) (go t2 prj (ESnd ext eidx) e) - go SMTLEither{} _ _ _ = ENil ext - go SMTMaybe{} _ _ _ = ENil ext - go SMTArr{} SAPArrIdx{} eidx _ = ESnd ext (EFst ext eidx) diff --git a/src/Simplify/TH.hs b/src/Simplify/TH.hs deleted file mode 100644 index 03a74de..0000000 --- a/src/Simplify/TH.hs +++ /dev/null @@ -1,80 +0,0 @@ -{-# LANGUAGE TemplateHaskellQuotes #-} -module Simplify.TH (simprec) where - -import Data.Bifunctor (first) -import Data.Char -import Data.List (foldl', foldl1') -import Language.Haskell.TH -import Language.Haskell.TH.Quote -import Text.ParserCombinators.ReadP - - --- [simprec| EPair ext *a *b |] --- ~> --- do a' <- within (\a' -> EPair ext a' b) (simplify' a) --- b' <- within (\b' -> EPair ext a' b') (simplify' b) --- pure (EPair ext a' b') - -simprec :: QuasiQuoter -simprec = QuasiQuoter - { quoteDec = \_ -> fail "simprec used outside of expression context" - , quoteType = \_ -> fail "simprec used outside of expression context" - , quoteExp = handler - , quotePat = \_ -> fail "simprec used outside of expression context" - } - -handler :: String -> Q Exp -handler str = - case readP_to_S pTemplate str of - [(template, "")] -> generate template - _:_:_ -> fail "simprec: template grammar ambiguous" - _ -> fail "simprec: could not parse template" - -generate :: Template -> Q Exp -generate (Template topitems) = - let takePrefix (Plain x : xs) = first (x:) (takePrefix xs) - takePrefix xs = ([], xs) - - itemVar "" = error "simprec: empty item name?" - itemVar name@(c:_) | isLower c = VarE (mkName name) - | isUpper c = ConE (mkName name) - | otherwise = error "simprec: non-letter item name?" - - loop :: Exp -> [Item] -> Q [Stmt] - loop yet [] = return [NoBindS (VarE 'pure `AppE` yet)] - loop yet (Plain x : xs) = loop (yet `AppE` itemVar x) xs - loop yet (Recurse x : xs) = do - primeName <- newName (x ++ "'") - let appPrePrime e (Plain y) = e `AppE` itemVar y - appPrePrime e (Recurse y) = e `AppE` itemVar y - let stmt = BindS (VarP primeName) $ - VarE (mkName "within") - `AppE` LamE [VarP primeName] (foldl' appPrePrime (yet `AppE` VarE primeName) xs) - `AppE` (VarE (mkName "simplify'") `AppE` VarE (mkName x)) - stmts <- loop (yet `AppE` VarE primeName) xs - return (stmt : stmts) - - (prefix, items') = takePrefix topitems - in DoE Nothing <$> loop (foldl1' AppE (map itemVar prefix)) items' - -data Template = Template [Item] - deriving (Show) - -data Item = Plain String | Recurse String - deriving (Show) - -pTemplate :: ReadP Template -pTemplate = do - items <- many (skipSpaces >> pItem) - skipSpaces - eof - return (Template items) - -pItem :: ReadP Item -pItem = (char '*' >> Recurse <$> pName) +++ (Plain <$> pName) - -pName :: ReadP String -pName = do - c1 <- satisfy (\c -> isAlpha c || c == '_') - cs <- munch (\c -> isAlphaNum c || c `elem` "_'") - return (c1:cs) diff --git a/src/Util/IdGen.hs b/src/Util/IdGen.hs deleted file mode 100644 index 3f6611d..0000000 --- a/src/Util/IdGen.hs +++ /dev/null @@ -1,19 +0,0 @@ -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE GeneralizedNewtypeDeriving #-} -module Util.IdGen where - -import Control.Monad.Fix -import Control.Monad.Trans.State.Strict - - -newtype IdGen a = IdGen (State Int a) - deriving newtype (Functor, Applicative, Monad, MonadFix) - -genId :: IdGen Int -genId = IdGen (state (\i -> (i, i + 1))) - -runIdGen :: Int -> IdGen a -> a -runIdGen start (IdGen m) = evalState m start - -runIdGen' :: Int -> IdGen a -> (a, Int) -runIdGen' start (IdGen m) = runState m start -- cgit v1.2.3-70-g09d2